source: main/waeup.cas/trunk/waeup/cas/server.py @ 10606

Last change on this file since 10606 was 10605, checked in by uli, 11 years ago

Add a smarter message box handling.

File size: 12.4 KB
Line 
1"""A WSGI app for serving CAS.
2"""
3import datetime
4import os
5import random
6import re
7import time
8try:
9    from urllib import urlencode        # Python 2.x
10except ImportError:                     # pragma: no cover
11    from urllib.parse import urlencode  # Python 3.x
12try:
13    from urlparse import urlparse, parse_qsl, urlunparse       # Python 2.x
14except ImportError:                                     # pragma: no cover
15    from urllib.parse import urlparse, parse_qsl, urlunparse  # Python 3.x
16from webob import exc, Response
17from webob.dec import wsgify
18from waeup.cas.authenticators import get_authenticator
19from waeup.cas.db import (
20    DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
21
22template_dir = os.path.join(os.path.dirname(__file__), 'templates')
23
24RANDOM = random.SystemRandom(os.urandom(1024))
25
26#: The chars allowed by protocol specification for tickets and cookie
27#: values.
28ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
29            'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
30            '01234567789-')
31
32#: A regular expression that matches a div tag around a MSG_TEXT
33RE_MSG_TAG = re.compile('\<div id="msg".*MSG_TEXT[^<]*</div>', re.DOTALL)
34
35
36def get_random_string(length):
37    """Get a random string of length `length`.
38
39    The returned string should be hard to guess but is not
40    neccessarily unique.
41    """
42    return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
43
44
45def get_unique_string():
46    """Get a unique string based on current time.
47
48    The returned string contains only chars from `ALPHABET`.
49
50    We try to be unique by using a timestamp in high resolution, so
51    that even tickets created shortly after another should differ. On
52    very fast machines, however, this might be not enough (currently
53    we use 16 decimal places).
54
55    This is fast because we don't have to fetch foreign data sources
56    nor have to do database lookups.
57
58    The returned string will be unique but it won't be hard to guess
59    for people able to read a clock.
60    """
61    return ('%.16f' % time.time()).replace('.', '-')
62
63
64def create_service_ticket(user, service=None, sso=True):
65    """Get a service ticket.
66
67    Ticket length will be 32 chars, randomly picked from `ALPHABET`.
68    """
69    t_id = 'ST-' + get_random_string(29)
70    return ServiceTicket(t_id, user, service, sso)
71
72
73def check_service_ticket(db, ticket, service, renew=False):
74    """Check whether (`ticket`, `service`) represents a valid service
75    ticket in `db`.
76
77    Returns a database set or ``None``.
78    """
79    if None in (ticket, service):
80        return None
81    ticket, service = str(ticket), str(service)
82    q = db.query(ServiceTicket).filter(
83        ServiceTicket.ticket == ticket).filter(
84        ServiceTicket.service == service).first()
85    if renew and q.sso:
86        return None
87    return q
88
89
90def create_login_ticket():
91    """Create a unique login ticket.
92
93    Login tickets are required to be unique (but not neccessarily hard
94    to guess), according to protocol specification.
95    """
96    t_id = 'LT-%s' % get_unique_string()
97    return LoginTicket(t_id)
98
99
100def check_login_ticket(db, lt_string):
101    """Check whether `lt_string` represents a valid login ticket in `db`.
102    """
103    if lt_string is None:
104        return False
105    q = db.query(LoginTicket).filter(LoginTicket.ticket == str(lt_string))
106    result = [x for x in q]
107    if result:
108        db.delete(result[0])
109    return len(result) > 0
110
111
112def create_tgc_value():
113    """Get a ticket granting cookie value.
114    """
115    value = 'TGC-' + get_random_string(128)
116    return TicketGrantingCookie(value)
117
118
119def set_session_cookie(db, response):
120    """Create a session cookie (ticket granting cookie) on `response`.
121
122    The `db` database is used to make the created cookie value
123    persistent.
124    """
125    tgc = create_tgc_value()
126    db.add(tgc)
127    response.set_cookie(
128        'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
129    return response
130
131
132def delete_session_cookie(db, response, old_value=None):
133    """Delete session cookie.
134
135    Sets cookie with expiration date in past and deletes respective
136    entry from database.
137    """
138    if old_value is not None:
139        # delete old tgc from db
140        q = db.query(TicketGrantingCookie).filter(
141            TicketGrantingCookie.value == old_value)
142        result = list(q)
143        if len(result) == 1:
144            db.delete(result[0])
145    response.set_cookie(
146        'cas-tgc', '', path='/', secure=True, httponly=True,
147        expires=datetime.datetime(1970, 1, 1, 0, 0, 0))
148    return response
149
150
151def check_session_cookie(db, cookie_value):
152    """Check whether `cookie_value` represents a valid ticket granting
153    ticket in `db`.
154
155    `cookie_value` is a string representing a ticket granting ticket
156    maybe stored in `db`.
157
158    If a respective cookie can be found, a
159    :class:`waeup.cas.db.TicketGrantingCookie` is returend. Else
160    ``None`` is returned.
161    """
162    if cookie_value is None:
163        return None
164    try:
165        # turn value into unicode (py2.x) / str (py3.x)
166        cookie_value = cookie_value.decode('utf-8')
167    except AttributeError:                         # pragma: no cover
168        pass
169    q = db.query(TicketGrantingCookie).filter(
170        TicketGrantingCookie.value == cookie_value)
171    result = [x for x in q]
172    if len(result):
173        return result[0]
174    return None
175
176
177def get_template(name):
178    path = os.path.join(template_dir, name)
179    if os.path.isfile(path):
180        return open(path, 'r').read()
181    return None
182
183
184def update_url(url, params_dict):
185    """Update query params of an url.
186
187    The `url` is modified to have the query parameters set to
188    keys/values in `params_dict`, preserving any different existing
189    keys/values and overwriting any existing keys/values that are also
190    in `params_dict`.
191
192    Thus, ``'http://sample?a=1', dict(b='1')`` will result in
193    ``'http://sample?a=1&b=1`` and similar.
194    """
195    parts = [x for x in urlparse(url)]
196    old_params = dict(parse_qsl(parts[4]))
197    old_params.update(params_dict)
198    query_string = urlencode(old_params)
199    parts[4] = query_string
200    return urlunparse(parts)
201
202
203def login_redirect_service(db, user, service, sso=True,
204                           create_ticket=True, warn=False):
205    """Return a response redirecting to a service via HTTP 303 See Other.
206    """
207    if create_ticket:
208        st = create_service_ticket(user, service, sso)
209        db.add(st)
210        service = update_url(service, dict(ticket=st.ticket))
211    html = get_template('login_service_redirect.html')
212    if warn:
213        html = get_template('login_service_confirm.html')
214    html = html.replace('SERVICE_URL', service)
215    resp = exc.HTTPSeeOther(location=service)
216    if warn:
217        resp = Response()
218    # try to forbid caching of any type
219    resp.cache_control = 'no-store'
220    resp.pragma = 'no-cache'
221    # some arbitrary date in the past
222    resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
223    resp.text = html
224    if not sso:
225        resp = set_session_cookie(db, resp)
226    return resp
227
228
229def login_success_no_service(db, msg='', sso=False):
230    """Show logged-in screen after successful auth.
231
232    `sso` must be a boolean indicating whether login happened via
233    credentials (``False``) or via cookie (``True``).
234
235    Returns a response.
236    """
237    # show logged-in screen
238    html = get_template('login_successful.html')
239    html = set_message(msg, html)
240    resp = Response(html)
241    if not sso:
242        resp = set_session_cookie(db, resp)
243    return resp
244
245
246def set_message(msg, html):
247    """Insert a message box in html template.
248
249    If the message is empty, not only the string `MSG_TEXT` is removed
250    from `html`, but also any encapsulating ``<div>`` tag with id
251    ``msg``.
252
253    This makes it possible to give message boxes certain additional
254    styles that will not show up if there is no message to display.
255    """
256    if not msg:
257        if not '<div id="msg"' in html:
258            return html.replace('MSG_TEXT', '')
259        return RE_MSG_TAG.sub('', html)
260    return html.replace('MSG_TEXT', msg)
261
262
263class CASServer(object):
264    """A WSGI CAS server.
265
266    This CAS server stores credential data (tickets, etc.) in a local
267    sqlite3 database file.
268
269    `db_path` -
270       The filesystem path to the database to use. If none is given
271       CAS server will create a new one in some new, temporary
272       directory. Please note that credentials will be lost after a
273       CAS server restart.
274
275       If the path is given and the file exists already, it will be
276       used.
277
278       If the database file does not exist, it will be created.
279    """
280    def __init__(self, db='sqlite:///:memory:', auth=None):
281        self.db_connection_string = db
282        self.db = DB(self.db_connection_string)
283        self.auth = auth
284
285    @wsgify
286    def __call__(self, req):
287        if req.path == '/style.css':
288            return Response(get_template('style.css'), content_type='text/css')
289        with DBSessionContext():
290            if req.path in ['/login', '/validate', '/logout']:
291                return getattr(self, req.path[1:])(req)
292        return exc.HTTPNotFound()
293
294    def _get_template(self, name):
295        path = os.path.join(template_dir, name)
296        if os.path.isfile(path):
297            return open(path, 'r').read()
298        return None
299
300    def login(self, req):
301        service = req.POST.get('service', req.GET.get('service', None))
302        renew = req.POST.get('renew', req.GET.get('renew', None))
303        warn = req.POST.get('warn', req.GET.get('warn', False))
304        gateway = req.POST.get('gateway', req.GET.get('gateway', None))
305        if renew is not None and gateway is not None:
306            gateway = None
307        service_field = ''
308        msg = ''
309        username = req.POST.get('username', None)
310        password = req.POST.get('password', None)
311        valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
312        tgc = check_session_cookie(self.db, req.cookies.get('cas-tgc', None))
313        if gateway and (not tgc) and service:
314            return login_redirect_service(
315                self.db, username, service, sso=True, create_ticket=False)
316        if tgc and (renew is None):
317            if service:
318                return login_redirect_service(
319                    self.db, username, service, sso=True, warn=warn)
320            else:
321                return login_success_no_service(
322                    self.db, 'You logged in already.', True)
323        if username and password and valid_lt:
324            # act as credentials acceptor
325            cred_ok, reason = self.auth.check_credentials(
326                username, password)
327            if cred_ok:
328                if service is None:
329                    # show logged-in screen
330                    return login_success_no_service(self.db, msg, False)
331                else:
332                    # safely redirect to service given
333                    return login_redirect_service(
334                        self.db, username, service, sso=False, warn=warn)
335            else:
336                # login failed
337                msg = '<i>Login failed</i><br />Reason: %s' % reason
338        if service is not None:
339            service_field = (
340                '<input type="hidden" name="service" value="%s" />' % (
341                    service)
342                )
343        lt = create_login_ticket()
344        self.db.add(lt)
345        html = self._get_template('login.html')
346        html = html.replace('LT_VALUE', lt.ticket)
347        html = html.replace('SERVICE_FIELD_VALUE', service_field)
348        html = set_message(msg, html)
349        return Response(html)
350
351    def validate(self, req):
352        service = req.POST.get('service', req.GET.get('service', None))
353        ticket = req.POST.get('ticket', req.GET.get('ticket', None))
354        renew = req.POST.get('renew', req.GET.get('renew', None))
355        renew = renew is not None
356        st = check_service_ticket(self.db, ticket, service, renew)
357        if st is not None:
358            return Response('yes' + chr(0x0a) + st.user + chr(0x0a))
359        return Response('no' + chr(0x0a) + chr(0x0a))
360
361    def logout(self, req):
362        url = req.GET.get('url', req.POST.get('url', None))
363        old_val = req.cookies.get('cas-tgc', None)
364        html = self._get_template('logout.html')
365        if url is not None:
366            html = self._get_template('logout_url.html')
367            html = html.replace('URL_HREF', url)
368        resp = Response(html)
369        delete_session_cookie(self.db, resp, old_val)
370        return resp
371
372
373cas_server = CASServer
374
375
376def make_cas_server(global_conf, **local_conf):
377    local_conf = get_authenticator(local_conf)
378    return CASServer(**local_conf)
Note: See TracBrowser for help on using the repository browser.