source: main/waeup.cas/trunk/waeup/cas/server.py @ 10608

Last change on this file since 10608 was 10607, checked in by uli, 11 years ago

Enable shared header/footer partial templates.

File size: 12.8 KB
Line 
1"""A WSGI app for serving CAS.
2"""
3import datetime
4import os
5import random
6import re
7import time
8try:
9    from urllib import urlencode        # Python 2.x
10except ImportError:                     # pragma: no cover
11    from urllib.parse import urlencode  # Python 3.x
12try:
13    from urlparse import urlparse, parse_qsl, urlunparse       # Python 2.x
14except ImportError:                                     # pragma: no cover
15    from urllib.parse import urlparse, parse_qsl, urlunparse  # Python 3.x
16from webob import exc, Response
17from webob.dec import wsgify
18from waeup.cas.authenticators import get_authenticator
19from waeup.cas.db import (
20    DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
21
22template_dir = os.path.join(os.path.dirname(__file__), 'templates')
23
24RANDOM = random.SystemRandom(os.urandom(1024))
25
26#: The chars allowed by protocol specification for tickets and cookie
27#: values.
28ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
29            'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
30            '01234567789-')
31
32#: A regular expression that matches a div tag around a MSG_TEXT
33RE_MSG_TAG = re.compile('\<div id="msg".*MSG_TEXT[^<]*</div>', re.DOTALL)
34
35
36def get_random_string(length):
37    """Get a random string of length `length`.
38
39    The returned string should be hard to guess but is not
40    neccessarily unique.
41    """
42    return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
43
44
45def get_unique_string():
46    """Get a unique string based on current time.
47
48    The returned string contains only chars from `ALPHABET`.
49
50    We try to be unique by using a timestamp in high resolution, so
51    that even tickets created shortly after another should differ. On
52    very fast machines, however, this might be not enough (currently
53    we use 16 decimal places).
54
55    This is fast because we don't have to fetch foreign data sources
56    nor have to do database lookups.
57
58    The returned string will be unique but it won't be hard to guess
59    for people able to read a clock.
60    """
61    return ('%.16f' % time.time()).replace('.', '-')
62
63
64def create_service_ticket(user, service=None, sso=True):
65    """Get a service ticket.
66
67    Ticket length will be 32 chars, randomly picked from `ALPHABET`.
68    """
69    t_id = 'ST-' + get_random_string(29)
70    return ServiceTicket(t_id, user, service, sso)
71
72
73def check_service_ticket(db, ticket, service, renew=False):
74    """Check whether (`ticket`, `service`) represents a valid service
75    ticket in `db`.
76
77    Returns a database set or ``None``.
78    """
79    if None in (ticket, service):
80        return None
81    ticket, service = str(ticket), str(service)
82    q = db.query(ServiceTicket).filter(
83        ServiceTicket.ticket == ticket).filter(
84        ServiceTicket.service == service).first()
85    if renew and q.sso:
86        return None
87    return q
88
89
90def create_login_ticket():
91    """Create a unique login ticket.
92
93    Login tickets are required to be unique (but not neccessarily hard
94    to guess), according to protocol specification.
95    """
96    t_id = 'LT-%s' % get_unique_string()
97    return LoginTicket(t_id)
98
99
100def check_login_ticket(db, lt_string):
101    """Check whether `lt_string` represents a valid login ticket in `db`.
102    """
103    if lt_string is None:
104        return False
105    q = db.query(LoginTicket).filter(LoginTicket.ticket == str(lt_string))
106    result = [x for x in q]
107    if result:
108        db.delete(result[0])
109    return len(result) > 0
110
111
112def create_tgc_value():
113    """Get a ticket granting cookie value.
114    """
115    value = 'TGC-' + get_random_string(128)
116    return TicketGrantingCookie(value)
117
118
119def set_session_cookie(db, response):
120    """Create a session cookie (ticket granting cookie) on `response`.
121
122    The `db` database is used to make the created cookie value
123    persistent.
124    """
125    tgc = create_tgc_value()
126    db.add(tgc)
127    response.set_cookie(
128        'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
129    return response
130
131
132def delete_session_cookie(db, response, old_value=None):
133    """Delete session cookie.
134
135    Sets cookie with expiration date in past and deletes respective
136    entry from database.
137    """
138    if old_value is not None:
139        # delete old tgc from db
140        q = db.query(TicketGrantingCookie).filter(
141            TicketGrantingCookie.value == old_value)
142        result = list(q)
143        if len(result) == 1:
144            db.delete(result[0])
145    response.set_cookie(
146        'cas-tgc', '', path='/', secure=True, httponly=True,
147        expires=datetime.datetime(1970, 1, 1, 0, 0, 0))
148    return response
149
150
151def check_session_cookie(db, cookie_value):
152    """Check whether `cookie_value` represents a valid ticket granting
153    ticket in `db`.
154
155    `cookie_value` is a string representing a ticket granting ticket
156    maybe stored in `db`.
157
158    If a respective cookie can be found, a
159    :class:`waeup.cas.db.TicketGrantingCookie` is returend. Else
160    ``None`` is returned.
161    """
162    if cookie_value is None:
163        return None
164    try:
165        # turn value into unicode (py2.x) / str (py3.x)
166        cookie_value = cookie_value.decode('utf-8')
167    except AttributeError:                         # pragma: no cover
168        pass
169    q = db.query(TicketGrantingCookie).filter(
170        TicketGrantingCookie.value == cookie_value)
171    result = [x for x in q]
172    if len(result):
173        return result[0]
174    return None
175
176
177def get_template(name):
178    """Read template named `name`.
179
180    Templates are looked up in the local `templates` dir.
181
182    In the result any 'PART_HEADER' and 'PART_FOOTER' parts are
183    replaced by the respective templates.
184
185    Returns the HTML template.
186    """
187    path = os.path.join(template_dir, name)
188    if os.path.isfile(path):
189        html = open(path, 'r').read()
190        html = html.replace(
191            'PART_HEADER',
192            open(os.path.join(template_dir, 'part_header.tpl'), 'r').read())
193        html = html.replace(
194            'PART_FOOTER',
195            open(os.path.join(template_dir, 'part_footer.tpl'), 'r').read())
196        return html
197    return None
198
199
200def update_url(url, params_dict):
201    """Update query params of an url.
202
203    The `url` is modified to have the query parameters set to
204    keys/values in `params_dict`, preserving any different existing
205    keys/values and overwriting any existing keys/values that are also
206    in `params_dict`.
207
208    Thus, ``'http://sample?a=1', dict(b='1')`` will result in
209    ``'http://sample?a=1&b=1`` and similar.
210    """
211    parts = [x for x in urlparse(url)]
212    old_params = dict(parse_qsl(parts[4]))
213    old_params.update(params_dict)
214    query_string = urlencode(old_params)
215    parts[4] = query_string
216    return urlunparse(parts)
217
218
219def login_redirect_service(db, user, service, sso=True,
220                           create_ticket=True, warn=False):
221    """Return a response redirecting to a service via HTTP 303 See Other.
222    """
223    if create_ticket:
224        st = create_service_ticket(user, service, sso)
225        db.add(st)
226        service = update_url(service, dict(ticket=st.ticket))
227    html = get_template('login_service_redirect.html')
228    if warn:
229        html = get_template('login_service_confirm.html')
230    html = html.replace('SERVICE_URL', service)
231    resp = exc.HTTPSeeOther(location=service)
232    if warn:
233        resp = Response()
234    # try to forbid caching of any type
235    resp.cache_control = 'no-store'
236    resp.pragma = 'no-cache'
237    # some arbitrary date in the past
238    resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
239    resp.text = html
240    if not sso:
241        resp = set_session_cookie(db, resp)
242    return resp
243
244
245def login_success_no_service(db, msg='', sso=False):
246    """Show logged-in screen after successful auth.
247
248    `sso` must be a boolean indicating whether login happened via
249    credentials (``False``) or via cookie (``True``).
250
251    Returns a response.
252    """
253    # show logged-in screen
254    html = get_template('login_successful.html')
255    html = set_message(msg, html)
256    resp = Response(html)
257    if not sso:
258        resp = set_session_cookie(db, resp)
259    return resp
260
261
262def set_message(msg, html):
263    """Insert a message box in html template.
264
265    If the message is empty, not only the string `MSG_TEXT` is removed
266    from `html`, but also any encapsulating ``<div>`` tag with id
267    ``msg``.
268
269    This makes it possible to give message boxes certain additional
270    styles that will not show up if there is no message to display.
271    """
272    if not msg:
273        if not '<div id="msg"' in html:
274            return html.replace('MSG_TEXT', '')
275        return RE_MSG_TAG.sub('', html)
276    return html.replace('MSG_TEXT', msg)
277
278
279class CASServer(object):
280    """A WSGI CAS server.
281
282    This CAS server stores credential data (tickets, etc.) in a local
283    sqlite3 database file.
284
285    `db_path` -
286       The filesystem path to the database to use. If none is given
287       CAS server will create a new one in some new, temporary
288       directory. Please note that credentials will be lost after a
289       CAS server restart.
290
291       If the path is given and the file exists already, it will be
292       used.
293
294       If the database file does not exist, it will be created.
295    """
296    def __init__(self, db='sqlite:///:memory:', auth=None):
297        self.db_connection_string = db
298        self.db = DB(self.db_connection_string)
299        self.auth = auth
300
301    @wsgify
302    def __call__(self, req):
303        if req.path == '/style.css':
304            return Response(get_template('style.css'), content_type='text/css')
305        with DBSessionContext():
306            if req.path in ['/login', '/validate', '/logout']:
307                return getattr(self, req.path[1:])(req)
308        return exc.HTTPNotFound()
309
310    def _get_template(self, name):
311        return get_template(name)
312
313    def login(self, req):
314        service = req.POST.get('service', req.GET.get('service', None))
315        renew = req.POST.get('renew', req.GET.get('renew', None))
316        warn = req.POST.get('warn', req.GET.get('warn', False))
317        gateway = req.POST.get('gateway', req.GET.get('gateway', None))
318        if renew is not None and gateway is not None:
319            gateway = None
320        service_field = ''
321        msg = ''
322        username = req.POST.get('username', None)
323        password = req.POST.get('password', None)
324        valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
325        tgc = check_session_cookie(self.db, req.cookies.get('cas-tgc', None))
326        if gateway and (not tgc) and service:
327            return login_redirect_service(
328                self.db, username, service, sso=True, create_ticket=False)
329        if tgc and (renew is None):
330            if service:
331                return login_redirect_service(
332                    self.db, username, service, sso=True, warn=warn)
333            else:
334                return login_success_no_service(
335                    self.db, 'You logged in already.', True)
336        if username and password and valid_lt:
337            # act as credentials acceptor
338            cred_ok, reason = self.auth.check_credentials(
339                username, password)
340            if cred_ok:
341                if service is None:
342                    # show logged-in screen
343                    return login_success_no_service(self.db, msg, False)
344                else:
345                    # safely redirect to service given
346                    return login_redirect_service(
347                        self.db, username, service, sso=False, warn=warn)
348            else:
349                # login failed
350                msg = '<i>Login failed</i><br />Reason: %s' % reason
351        if service is not None:
352            service_field = (
353                '<input type="hidden" name="service" value="%s" />' % (
354                    service)
355                )
356        lt = create_login_ticket()
357        self.db.add(lt)
358        html = self._get_template('login.html')
359        html = html.replace('LT_VALUE', lt.ticket)
360        html = html.replace('SERVICE_FIELD_VALUE', service_field)
361        html = set_message(msg, html)
362        return Response(html)
363
364    def validate(self, req):
365        service = req.POST.get('service', req.GET.get('service', None))
366        ticket = req.POST.get('ticket', req.GET.get('ticket', None))
367        renew = req.POST.get('renew', req.GET.get('renew', None))
368        renew = renew is not None
369        st = check_service_ticket(self.db, ticket, service, renew)
370        if st is not None:
371            return Response('yes' + chr(0x0a) + st.user + chr(0x0a))
372        return Response('no' + chr(0x0a) + chr(0x0a))
373
374    def logout(self, req):
375        url = req.GET.get('url', req.POST.get('url', None))
376        old_val = req.cookies.get('cas-tgc', None)
377        html = self._get_template('logout.html')
378        if url is not None:
379            html = self._get_template('logout_url.html')
380            html = html.replace('URL_HREF', url)
381        resp = Response(html)
382        delete_session_cookie(self.db, resp, old_val)
383        return resp
384
385
386cas_server = CASServer
387
388
389def make_cas_server(global_conf, **local_conf):
390    local_conf = get_authenticator(local_conf)
391    return CASServer(**local_conf)
Note: See TracBrowser for help on using the repository browser.