source: main/waeup.cas/trunk/waeup/cas/server.py @ 10522

Last change on this file since 10522 was 10499, checked in by uli, 11 years ago

Handle service ticket generation correctly.

File size: 11.6 KB
RevLine 
[10321]1"""A WSGI app for serving CAS.
2"""
[10414]3import datetime
[10335]4import os
[10394]5import random
6import time
[10494]7try:
8    from urllib import urlencode        # Python 2.x
9except ImportError:                     # pragma: no cover
10    from urllib.parse import urlencode  # Python 3.x
11try:
12    from urlparse import urlparse, parse_qsl, urlunparse       # Python 2.x
13except ImportError:                                     # pragma: no cover
14    from urllib.parse import urlparse, parse_qsl, urlunparse  # Python 3.x
[10327]15from webob import exc, Response
[10321]16from webob.dec import wsgify
[10394]17from waeup.cas.authenticators import get_authenticator
18from waeup.cas.db import (
19    DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
[10321]20
[10335]21template_dir = os.path.join(os.path.dirname(__file__), 'templates')
[10321]22
[10394]23RANDOM = random.SystemRandom(os.urandom(1024))
[10335]24
[10394]25#: The chars allowed by protocol specification for tickets and cookie
26#: values.
27ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
28            'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
29            '01234567789-')
30
31
32def get_random_string(length):
33    """Get a random string of length `length`.
34
35    The returned string should be hard to guess but is not
36    neccessarily unique.
37    """
38    return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
39
40
41def get_unique_string():
42    """Get a unique string based on current time.
43
44    The returned string contains only chars from `ALPHABET`.
45
46    We try to be unique by using a timestamp in high resolution, so
47    that even tickets created shortly after another should differ. On
48    very fast machines, however, this might be not enough (currently
49    we use 16 decimal places).
50
51    This is fast because we don't have to fetch foreign data sources
52    nor have to do database lookups.
53
54    The returned string will be unique but it won't be hard to guess
55    for people able to read a clock.
56    """
57    return ('%.16f' % time.time()).replace('.', '-')
58
59
[10408]60def create_service_ticket(user, service=None, sso=True):
[10394]61    """Get a service ticket.
62
63    Ticket length will be 32 chars, randomly picked from `ALPHABET`.
64    """
65    t_id = 'ST-' + get_random_string(29)
[10408]66    return ServiceTicket(t_id, user, service, sso)
[10394]67
68
[10416]69def check_service_ticket(db, ticket, service, renew=False):
70    """Check whether (`ticket`, `service`) represents a valid service
71    ticket in `db`.
72
73    Returns a database set or ``None``.
74    """
75    if None in (ticket, service):
76        return None
77    ticket, service = str(ticket), str(service)
78    q = db.query(ServiceTicket).filter(
79        ServiceTicket.ticket == ticket).filter(
80        ServiceTicket.service == service).first()
81    if renew and q.sso:
82        return None
83    return q
84
85
[10394]86def create_login_ticket():
87    """Create a unique login ticket.
88
89    Login tickets are required to be unique (but not neccessarily hard
90    to guess), according to protocol specification.
91    """
92    t_id = 'LT-%s' % get_unique_string()
93    return LoginTicket(t_id)
94
95
96def check_login_ticket(db, lt_string):
97    """Check whether `lt_string` represents a valid login ticket in `db`.
98    """
99    if lt_string is None:
100        return False
[10416]101    q = db.query(LoginTicket).filter(LoginTicket.ticket == str(lt_string))
[10394]102    result = [x for x in q]
103    if result:
104        db.delete(result[0])
105    return len(result) > 0
106
107
108def create_tgc_value():
109    """Get a ticket granting cookie value.
110    """
111    value = 'TGC-' + get_random_string(128)
112    return TicketGrantingCookie(value)
113
114
[10405]115def set_session_cookie(db, response):
[10394]116    """Create a session cookie (ticket granting cookie) on `response`.
117
118    The `db` database is used to make the created cookie value
119    persistent.
120    """
121    tgc = create_tgc_value()
122    db.add(tgc)
123    response.set_cookie(
124        'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
125    return response
126
127
[10414]128def delete_session_cookie(db, response, old_value=None):
129    """Delete session cookie.
130
131    Sets cookie with expiration date in past and deletes respective
132    entry from database.
133    """
134    if old_value is not None:
135        # delete old tgc from db
136        q = db.query(TicketGrantingCookie).filter(
137            TicketGrantingCookie.value == old_value)
138        result = list(q)
139        if len(result) == 1:
140            db.delete(result[0])
141    response.set_cookie(
142        'cas-tgc', '', path='/', secure=True, httponly=True,
143        expires=datetime.datetime(1970, 1, 1, 0, 0, 0))
144    return response
145
146
[10398]147def check_session_cookie(db, cookie_value):
148    """Check whether `cookie_value` represents a valid ticket granting
149    ticket in `db`.
[10403]150
151    `cookie_value` is a string representing a ticket granting ticket
152    maybe stored in `db`.
153
154    If a respective cookie can be found, a
155    :class:`waeup.cas.db.TicketGrantingCookie` is returend. Else
156    ``None`` is returned.
[10398]157    """
158    if cookie_value is None:
[10403]159        return None
[10398]160    try:
161        # turn value into unicode (py2.x) / str (py3.x)
162        cookie_value = cookie_value.decode('utf-8')
163    except AttributeError:                         # pragma: no cover
164        pass
165    q = db.query(TicketGrantingCookie).filter(
166        TicketGrantingCookie.value == cookie_value)
167    result = [x for x in q]
168    if len(result):
169        return result[0]
170    return None
171
172
[10403]173def get_template(name):
174    path = os.path.join(template_dir, name)
175    if os.path.isfile(path):
176        return open(path, 'r').read()
177    return None
178
179
[10491]180def update_url(url, params_dict):
181    """Update query params of an url.
182
183    The `url` is modified to have the query parameters set to
184    keys/values in `params_dict`, preserving any different existing
185    keys/values and overwriting any existing keys/values that are also
186    in `params_dict`.
187
188    Thus, ``'http://sample?a=1', dict(b='1')`` will result in
189    ``'http://sample?a=1&b=1`` and similar.
190    """
191    parts = [x for x in urlparse(url)]
192    old_params = dict(parse_qsl(parts[4]))
193    old_params.update(params_dict)
[10494]194    query_string = urlencode(old_params)
[10491]195    parts[4] = query_string
196    return urlunparse(parts)
197
198
[10499]199def login_redirect_service(db, user, service, sso=True,
200                           create_ticket=True, warn=False):
[10403]201    """Return a response redirecting to a service via HTTP 303 See Other.
202    """
[10412]203    if create_ticket:
[10499]204        st = create_service_ticket(user, service, sso)
[10412]205        db.add(st)
[10493]206        service = update_url(service, dict(ticket=st.ticket))
[10403]207    html = get_template('login_service_redirect.html')
[10413]208    if warn:
209        html = get_template('login_service_confirm.html')
[10403]210    html = html.replace('SERVICE_URL', service)
211    resp = exc.HTTPSeeOther(location=service)
[10413]212    if warn:
213        resp = Response()
[10403]214    # try to forbid caching of any type
215    resp.cache_control = 'no-store'
216    resp.pragma = 'no-cache'
217    # some arbitrary date in the past
218    resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
219    resp.text = html
[10409]220    if not sso:
221        resp = set_session_cookie(db, resp)
[10403]222    return resp
223
224
225def login_success_no_service(db, msg='', sso=False):
226    """Show logged-in screen after successful auth.
227
228    `sso` must be a boolean indicating whether login happened via
229    credentials (``False``) or via cookie (``True``).
230
231    Returns a response.
232    """
233    # show logged-in screen
234    html = get_template('login_successful.html')
235    html = html.replace('MSG_TEXT', msg)
236    resp = Response(html)
237    if not sso:
[10405]238        resp = set_session_cookie(db, resp)
[10403]239    return resp
240
241
[10321]242class CASServer(object):
243    """A WSGI CAS server.
[10351]244
245    This CAS server stores credential data (tickets, etc.) in a local
246    sqlite3 database file.
247
248    `db_path` -
249       The filesystem path to the database to use. If none is given
250       CAS server will create a new one in some new, temporary
251       directory. Please note that credentials will be lost after a
252       CAS server restart.
253
254       If the path is given and the file exists already, it will be
255       used.
256
257       If the database file does not exist, it will be created.
[10321]258    """
[10394]259    def __init__(self, db='sqlite:///:memory:', auth=None):
260        self.db_connection_string = db
261        self.db = DB(self.db_connection_string)
262        self.auth = auth
[10351]263
[10321]264    @wsgify
265    def __call__(self, req):
[10394]266        with DBSessionContext():
267            if req.path in ['/login', '/validate', '/logout']:
268                return getattr(self, req.path[1:])(req)
[10321]269        return exc.HTTPNotFound()
270
[10394]271    def _get_template(self, name):
272        path = os.path.join(template_dir, name)
273        if os.path.isfile(path):
274            return open(path, 'r').read()
275        return None
276
[10327]277    def login(self, req):
[10394]278        service = req.POST.get('service', req.GET.get('service', None))
[10411]279        renew = req.POST.get('renew', req.GET.get('renew', None))
[10413]280        warn = req.POST.get('warn', req.GET.get('warn', False))
[10412]281        gateway = req.POST.get('gateway', req.GET.get('gateway', None))
282        if renew is not None and gateway is not None:
283            gateway = None
[10394]284        service_field = ''
[10397]285        msg = ''
[10394]286        username = req.POST.get('username', None)
287        password = req.POST.get('password', None)
288        valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
[10403]289        tgc = check_session_cookie(self.db, req.cookies.get('cas-tgc', None))
[10412]290        if gateway and (not tgc) and service:
291            return login_redirect_service(
[10499]292                self.db, username, service, sso=True, create_ticket=False)
[10412]293        if tgc and (renew is None):
[10405]294            if service:
[10413]295                return login_redirect_service(
[10499]296                    self.db, username, service, sso=True, warn=warn)
[10405]297            else:
298                return login_success_no_service(
299                    self.db, 'You logged in already.', True)
300        if username and password and valid_lt:
[10394]301            # act as credentials acceptor
[10405]302            cred_ok, reason = self.auth.check_credentials(
303                username, password)
[10394]304            if cred_ok:
305                if service is None:
306                    # show logged-in screen
[10405]307                    return login_success_no_service(self.db, msg, False)
[10394]308                else:
309                    # safely redirect to service given
[10413]310                    return login_redirect_service(
[10499]311                        self.db, username, service, sso=False, warn=warn)
[10397]312            else:
313                # login failed
314                msg = '<i>Login failed</i><br />Reason: %s' % reason
[10394]315        if service is not None:
316            service_field = (
317                '<input type="hidden" name="service" value="%s" />' % (
318                    service)
319                )
320        lt = create_login_ticket()
321        self.db.add(lt)
322        html = self._get_template('login.html')
323        html = html.replace('LT_VALUE', lt.ticket)
324        html = html.replace('SERVICE_FIELD_VALUE', service_field)
[10397]325        html = html.replace('MSG_TEXT', msg)
[10394]326        return Response(html)
[10321]327
[10327]328    def validate(self, req):
[10416]329        service = req.POST.get('service', req.GET.get('service', None))
330        ticket = req.POST.get('ticket', req.GET.get('ticket', None))
331        renew = req.POST.get('renew', req.GET.get('renew', None))
332        renew = renew is not None
333        st = check_service_ticket(self.db, ticket, service, renew)
334        if st is not None:
335            return Response('yes' + chr(0x0a) + st.user + chr(0x0a))
336        return Response('no' + chr(0x0a) + chr(0x0a))
[10327]337
338    def logout(self, req):
[10415]339        url = req.GET.get('url', req.POST.get('url', None))
340        old_val = req.cookies.get('cas-tgc', None)
341        html = self._get_template('logout.html')
342        if url is not None:
343            html = self._get_template('logout_url.html')
344            html = html.replace('URL_HREF', url)
345        resp = Response(html)
346        delete_session_cookie(self.db, resp, old_val)
347        return resp
[10327]348
[10415]349
[10321]350cas_server = CASServer
351
352
353def make_cas_server(global_conf, **local_conf):
[10394]354    local_conf = get_authenticator(local_conf)
[10321]355    return CASServer(**local_conf)
Note: See TracBrowser for help on using the repository browser.