source: main/waeup.cas/trunk/waeup/cas/server.py @ 14600

Last change on this file since 14600 was 10611, checked in by uli, 11 years ago

Handle msg box correctly also when logging out.

File size: 13.0 KB
Line 
1"""A WSGI app for serving CAS.
2"""
3import datetime
4import os
5import random
6import re
7import time
8try:
9    from urllib import urlencode        # Python 2.x
10except ImportError:                     # pragma: no cover
11    from urllib.parse import urlencode  # Python 3.x
12try:
13    from urlparse import urlparse, parse_qsl, urlunparse       # Python 2.x
14except ImportError:                                     # pragma: no cover
15    from urllib.parse import urlparse, parse_qsl, urlunparse  # Python 3.x
16from webob import exc, Response
17from webob.dec import wsgify
18from waeup.cas.authenticators import get_authenticator
19from waeup.cas.db import (
20    DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
21
22template_dir = os.path.join(os.path.dirname(__file__), 'templates')
23
24#: A piece of HTML that can be used in HTML headers.
25SHARED_HEADER = open(os.path.join(template_dir, 'part_header.tpl'), 'r').read()
26
27#: A piece of HTML that can be used in HTML footers.
28SHARED_FOOTER = open(os.path.join(template_dir, 'part_footer.tpl'), 'r').read()
29
30#: Seed random.
31RANDOM = random.SystemRandom(os.urandom(1024))
32
33#: The chars allowed by protocol specification for tickets and cookie
34#: values.
35ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
36            'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
37            '01234567789-')
38
39#: A regular expression that matches a div tag around a MSG_TEXT
40RE_MSG_TAG = re.compile('\<div id="msg".*MSG_TEXT[^<]*</div>', re.DOTALL)
41
42
43def get_random_string(length):
44    """Get a random string of length `length`.
45
46    The returned string should be hard to guess but is not
47    neccessarily unique.
48    """
49    return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
50
51
52def get_unique_string():
53    """Get a unique string based on current time.
54
55    The returned string contains only chars from `ALPHABET`.
56
57    We try to be unique by using a timestamp in high resolution, so
58    that even tickets created shortly after another should differ. On
59    very fast machines, however, this might be not enough (currently
60    we use 16 decimal places).
61
62    This is fast because we don't have to fetch foreign data sources
63    nor have to do database lookups.
64
65    The returned string will be unique but it won't be hard to guess
66    for people able to read a clock.
67    """
68    return ('%.16f' % time.time()).replace('.', '-')
69
70
71def create_service_ticket(user, service=None, sso=True):
72    """Get a service ticket.
73
74    Ticket length will be 32 chars, randomly picked from `ALPHABET`.
75    """
76    t_id = 'ST-' + get_random_string(29)
77    return ServiceTicket(t_id, user, service, sso)
78
79
80def check_service_ticket(db, ticket, service, renew=False):
81    """Check whether (`ticket`, `service`) represents a valid service
82    ticket in `db`.
83
84    Returns a database set or ``None``.
85    """
86    if None in (ticket, service):
87        return None
88    ticket, service = str(ticket), str(service)
89    q = db.query(ServiceTicket).filter(
90        ServiceTicket.ticket == ticket).filter(
91        ServiceTicket.service == service).first()
92    if renew and q.sso:
93        return None
94    return q
95
96
97def create_login_ticket():
98    """Create a unique login ticket.
99
100    Login tickets are required to be unique (but not neccessarily hard
101    to guess), according to protocol specification.
102    """
103    t_id = 'LT-%s' % get_unique_string()
104    return LoginTicket(t_id)
105
106
107def check_login_ticket(db, lt_string):
108    """Check whether `lt_string` represents a valid login ticket in `db`.
109    """
110    if lt_string is None:
111        return False
112    q = db.query(LoginTicket).filter(LoginTicket.ticket == str(lt_string))
113    result = [x for x in q]
114    if result:
115        db.delete(result[0])
116    return len(result) > 0
117
118
119def create_tgc_value():
120    """Get a ticket granting cookie value.
121    """
122    value = 'TGC-' + get_random_string(128)
123    return TicketGrantingCookie(value)
124
125
126def set_session_cookie(db, response):
127    """Create a session cookie (ticket granting cookie) on `response`.
128
129    The `db` database is used to make the created cookie value
130    persistent.
131    """
132    tgc = create_tgc_value()
133    db.add(tgc)
134    response.set_cookie(
135        'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
136    return response
137
138
139def delete_session_cookie(db, response, old_value=None):
140    """Delete session cookie.
141
142    Sets cookie with expiration date in past and deletes respective
143    entry from database.
144    """
145    if old_value is not None:
146        # delete old tgc from db
147        q = db.query(TicketGrantingCookie).filter(
148            TicketGrantingCookie.value == old_value)
149        result = list(q)
150        if len(result) == 1:
151            db.delete(result[0])
152    response.set_cookie(
153        'cas-tgc', '', path='/', secure=True, httponly=True,
154        expires=datetime.datetime(1970, 1, 1, 0, 0, 0))
155    return response
156
157
158def check_session_cookie(db, cookie_value):
159    """Check whether `cookie_value` represents a valid ticket granting
160    ticket in `db`.
161
162    `cookie_value` is a string representing a ticket granting ticket
163    maybe stored in `db`.
164
165    If a respective cookie can be found, a
166    :class:`waeup.cas.db.TicketGrantingCookie` is returend. Else
167    ``None`` is returned.
168    """
169    if cookie_value is None:
170        return None
171    try:
172        # turn value into unicode (py2.x) / str (py3.x)
173        cookie_value = cookie_value.decode('utf-8')
174    except AttributeError:                         # pragma: no cover
175        pass
176    q = db.query(TicketGrantingCookie).filter(
177        TicketGrantingCookie.value == cookie_value)
178    result = [x for x in q]
179    if len(result):
180        return result[0]
181    return None
182
183
184def get_template(name):
185    """Read template named `name`.
186
187    Templates are looked up in the local `templates` dir.
188
189    In the result any 'PART_HEADER' and 'PART_FOOTER' parts are
190    replaced by the respective templates.
191
192    Returns the HTML template.
193    """
194    path = os.path.join(template_dir, name)
195    if os.path.isfile(path):
196        html = open(path, 'r').read()
197        html = html.replace('PART_HEADER', SHARED_HEADER)
198        html = html.replace('PART_FOOTER', SHARED_FOOTER)
199        return html
200    return None
201
202
203def update_url(url, params_dict):
204    """Update query params of an url.
205
206    The `url` is modified to have the query parameters set to
207    keys/values in `params_dict`, preserving any different existing
208    keys/values and overwriting any existing keys/values that are also
209    in `params_dict`.
210
211    Thus, ``'http://sample?a=1', dict(b='1')`` will result in
212    ``'http://sample?a=1&b=1`` and similar.
213    """
214    parts = [x for x in urlparse(url)]
215    old_params = dict(parse_qsl(parts[4]))
216    old_params.update(params_dict)
217    query_string = urlencode(old_params)
218    parts[4] = query_string
219    return urlunparse(parts)
220
221
222def login_redirect_service(db, user, service, sso=True,
223                           create_ticket=True, warn=False):
224    """Return a response redirecting to a service via HTTP 303 See Other.
225    """
226    if create_ticket:
227        st = create_service_ticket(user, service, sso)
228        db.add(st)
229        service = update_url(service, dict(ticket=st.ticket))
230    html = get_template('login_service_redirect.html')
231    if warn:
232        html = get_template('login_service_confirm.html')
233    html = html.replace('SERVICE_URL', service)
234    resp = exc.HTTPSeeOther(location=service)
235    if warn:
236        resp = Response()
237    # try to forbid caching of any type
238    resp.cache_control = 'no-store'
239    resp.pragma = 'no-cache'
240    # some arbitrary date in the past
241    resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
242    resp.text = html
243    if not sso:
244        resp = set_session_cookie(db, resp)
245    return resp
246
247
248def login_success_no_service(db, msg='', sso=False):
249    """Show logged-in screen after successful auth.
250
251    `sso` must be a boolean indicating whether login happened via
252    credentials (``False``) or via cookie (``True``).
253
254    Returns a response.
255    """
256    # show logged-in screen
257    html = get_template('login_successful.html')
258    html = set_message(msg, html)
259    resp = Response(html)
260    if not sso:
261        resp = set_session_cookie(db, resp)
262    return resp
263
264
265def set_message(msg, html):
266    """Insert a message box in html template.
267
268    If the message is empty, not only the string `MSG_TEXT` is removed
269    from `html`, but also any encapsulating ``<div>`` tag with id
270    ``msg``.
271
272    This makes it possible to give message boxes certain additional
273    styles that will not show up if there is no message to display.
274    """
275    if not msg:
276        if not '<div id="msg"' in html:
277            return html.replace('MSG_TEXT', '')
278        return RE_MSG_TAG.sub('', html)
279    return html.replace('MSG_TEXT', msg)
280
281
282class CASServer(object):
283    """A WSGI CAS server.
284
285    This CAS server stores credential data (tickets, etc.) in a local
286    sqlite3 database file.
287
288    `db_path` -
289       The filesystem path to the database to use. If none is given
290       CAS server will create a new one in some new, temporary
291       directory. Please note that credentials will be lost after a
292       CAS server restart.
293
294       If the path is given and the file exists already, it will be
295       used.
296
297       If the database file does not exist, it will be created.
298    """
299    def __init__(self, db='sqlite:///:memory:', auth=None):
300        self.db_connection_string = db
301        self.db = DB(self.db_connection_string)
302        self.auth = auth
303
304    @wsgify
305    def __call__(self, req):
306        if req.path == '/style.css':
307            return Response(get_template('style.css'), content_type='text/css')
308        with DBSessionContext():
309            if req.path in ['/login', '/validate', '/logout']:
310                return getattr(self, req.path[1:])(req)
311        return exc.HTTPNotFound()
312
313    def _get_template(self, name):
314        return get_template(name)
315
316    def login(self, req):
317        service = req.POST.get('service', req.GET.get('service', None))
318        renew = req.POST.get('renew', req.GET.get('renew', None))
319        warn = req.POST.get('warn', req.GET.get('warn', False))
320        gateway = req.POST.get('gateway', req.GET.get('gateway', None))
321        if renew is not None and gateway is not None:
322            gateway = None
323        service_field = ''
324        msg = ''
325        username = req.POST.get('username', None)
326        password = req.POST.get('password', None)
327        valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
328        tgc = check_session_cookie(self.db, req.cookies.get('cas-tgc', None))
329        if gateway and (not tgc) and service:
330            return login_redirect_service(
331                self.db, username, service, sso=True, create_ticket=False)
332        if tgc and (renew is None):
333            if service:
334                return login_redirect_service(
335                    self.db, username, service, sso=True, warn=warn)
336            else:
337                return login_success_no_service(
338                    self.db, 'You logged in already.', True)
339        if username and password and valid_lt:
340            # act as credentials acceptor
341            cred_ok, reason = self.auth.check_credentials(
342                username, password)
343            if cred_ok:
344                if service is None:
345                    # show logged-in screen
346                    return login_success_no_service(self.db, msg, False)
347                else:
348                    # safely redirect to service given
349                    return login_redirect_service(
350                        self.db, username, service, sso=False, warn=warn)
351            else:
352                # login failed
353                msg = '<i>Login failed</i><br />Reason: %s' % reason
354        if service is not None:
355            service_field = (
356                '<input type="hidden" name="service" value="%s" />' % (
357                    service)
358                )
359        lt = create_login_ticket()
360        self.db.add(lt)
361        html = self._get_template('login.html')
362        html = html.replace('LT_VALUE', lt.ticket)
363        html = html.replace('SERVICE_FIELD_VALUE', service_field)
364        html = set_message(msg, html)
365        return Response(html)
366
367    def validate(self, req):
368        service = req.POST.get('service', req.GET.get('service', None))
369        ticket = req.POST.get('ticket', req.GET.get('ticket', None))
370        renew = req.POST.get('renew', req.GET.get('renew', None))
371        renew = renew is not None
372        st = check_service_ticket(self.db, ticket, service, renew)
373        if st is not None:
374            return Response('yes' + chr(0x0a) + st.user + chr(0x0a))
375        return Response('no' + chr(0x0a) + chr(0x0a))
376
377    def logout(self, req):
378        url = req.GET.get('url', req.POST.get('url', None))
379        old_val = req.cookies.get('cas-tgc', None)
380        html = self._get_template('logout.html')
381        html = set_message('', html)
382        if url is not None:
383            html = self._get_template('logout_url.html')
384            html = html.replace('URL_HREF', url)
385            html = set_message('', html)
386        resp = Response(html)
387        delete_session_cookie(self.db, resp, old_val)
388        return resp
389
390
391cas_server = CASServer
392
393
394def make_cas_server(global_conf, **local_conf):
395    local_conf = get_authenticator(local_conf)
396    return CASServer(**local_conf)
Note: See TracBrowser for help on using the repository browser.