source: main/waeup.cas/trunk/waeup/cas/server.py @ 10600

Last change on this file since 10600 was 10600, checked in by uli, 11 years ago

Let server provide custom style on request.

File size: 11.7 KB
Line 
1"""A WSGI app for serving CAS.
2"""
3import datetime
4import os
5import random
6import time
7try:
8    from urllib import urlencode        # Python 2.x
9except ImportError:                     # pragma: no cover
10    from urllib.parse import urlencode  # Python 3.x
11try:
12    from urlparse import urlparse, parse_qsl, urlunparse       # Python 2.x
13except ImportError:                                     # pragma: no cover
14    from urllib.parse import urlparse, parse_qsl, urlunparse  # Python 3.x
15from webob import exc, Response
16from webob.dec import wsgify
17from waeup.cas.authenticators import get_authenticator
18from waeup.cas.db import (
19    DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
20
21template_dir = os.path.join(os.path.dirname(__file__), 'templates')
22
23RANDOM = random.SystemRandom(os.urandom(1024))
24
25#: The chars allowed by protocol specification for tickets and cookie
26#: values.
27ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
28            'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
29            '01234567789-')
30
31
32def get_random_string(length):
33    """Get a random string of length `length`.
34
35    The returned string should be hard to guess but is not
36    neccessarily unique.
37    """
38    return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
39
40
41def get_unique_string():
42    """Get a unique string based on current time.
43
44    The returned string contains only chars from `ALPHABET`.
45
46    We try to be unique by using a timestamp in high resolution, so
47    that even tickets created shortly after another should differ. On
48    very fast machines, however, this might be not enough (currently
49    we use 16 decimal places).
50
51    This is fast because we don't have to fetch foreign data sources
52    nor have to do database lookups.
53
54    The returned string will be unique but it won't be hard to guess
55    for people able to read a clock.
56    """
57    return ('%.16f' % time.time()).replace('.', '-')
58
59
60def create_service_ticket(user, service=None, sso=True):
61    """Get a service ticket.
62
63    Ticket length will be 32 chars, randomly picked from `ALPHABET`.
64    """
65    t_id = 'ST-' + get_random_string(29)
66    return ServiceTicket(t_id, user, service, sso)
67
68
69def check_service_ticket(db, ticket, service, renew=False):
70    """Check whether (`ticket`, `service`) represents a valid service
71    ticket in `db`.
72
73    Returns a database set or ``None``.
74    """
75    if None in (ticket, service):
76        return None
77    ticket, service = str(ticket), str(service)
78    q = db.query(ServiceTicket).filter(
79        ServiceTicket.ticket == ticket).filter(
80        ServiceTicket.service == service).first()
81    if renew and q.sso:
82        return None
83    return q
84
85
86def create_login_ticket():
87    """Create a unique login ticket.
88
89    Login tickets are required to be unique (but not neccessarily hard
90    to guess), according to protocol specification.
91    """
92    t_id = 'LT-%s' % get_unique_string()
93    return LoginTicket(t_id)
94
95
96def check_login_ticket(db, lt_string):
97    """Check whether `lt_string` represents a valid login ticket in `db`.
98    """
99    if lt_string is None:
100        return False
101    q = db.query(LoginTicket).filter(LoginTicket.ticket == str(lt_string))
102    result = [x for x in q]
103    if result:
104        db.delete(result[0])
105    return len(result) > 0
106
107
108def create_tgc_value():
109    """Get a ticket granting cookie value.
110    """
111    value = 'TGC-' + get_random_string(128)
112    return TicketGrantingCookie(value)
113
114
115def set_session_cookie(db, response):
116    """Create a session cookie (ticket granting cookie) on `response`.
117
118    The `db` database is used to make the created cookie value
119    persistent.
120    """
121    tgc = create_tgc_value()
122    db.add(tgc)
123    response.set_cookie(
124        'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
125    return response
126
127
128def delete_session_cookie(db, response, old_value=None):
129    """Delete session cookie.
130
131    Sets cookie with expiration date in past and deletes respective
132    entry from database.
133    """
134    if old_value is not None:
135        # delete old tgc from db
136        q = db.query(TicketGrantingCookie).filter(
137            TicketGrantingCookie.value == old_value)
138        result = list(q)
139        if len(result) == 1:
140            db.delete(result[0])
141    response.set_cookie(
142        'cas-tgc', '', path='/', secure=True, httponly=True,
143        expires=datetime.datetime(1970, 1, 1, 0, 0, 0))
144    return response
145
146
147def check_session_cookie(db, cookie_value):
148    """Check whether `cookie_value` represents a valid ticket granting
149    ticket in `db`.
150
151    `cookie_value` is a string representing a ticket granting ticket
152    maybe stored in `db`.
153
154    If a respective cookie can be found, a
155    :class:`waeup.cas.db.TicketGrantingCookie` is returend. Else
156    ``None`` is returned.
157    """
158    if cookie_value is None:
159        return None
160    try:
161        # turn value into unicode (py2.x) / str (py3.x)
162        cookie_value = cookie_value.decode('utf-8')
163    except AttributeError:                         # pragma: no cover
164        pass
165    q = db.query(TicketGrantingCookie).filter(
166        TicketGrantingCookie.value == cookie_value)
167    result = [x for x in q]
168    if len(result):
169        return result[0]
170    return None
171
172
173def get_template(name):
174    path = os.path.join(template_dir, name)
175    if os.path.isfile(path):
176        return open(path, 'r').read()
177    return None
178
179
180def update_url(url, params_dict):
181    """Update query params of an url.
182
183    The `url` is modified to have the query parameters set to
184    keys/values in `params_dict`, preserving any different existing
185    keys/values and overwriting any existing keys/values that are also
186    in `params_dict`.
187
188    Thus, ``'http://sample?a=1', dict(b='1')`` will result in
189    ``'http://sample?a=1&b=1`` and similar.
190    """
191    parts = [x for x in urlparse(url)]
192    old_params = dict(parse_qsl(parts[4]))
193    old_params.update(params_dict)
194    query_string = urlencode(old_params)
195    parts[4] = query_string
196    return urlunparse(parts)
197
198
199def login_redirect_service(db, user, service, sso=True,
200                           create_ticket=True, warn=False):
201    """Return a response redirecting to a service via HTTP 303 See Other.
202    """
203    if create_ticket:
204        st = create_service_ticket(user, service, sso)
205        db.add(st)
206        service = update_url(service, dict(ticket=st.ticket))
207    html = get_template('login_service_redirect.html')
208    if warn:
209        html = get_template('login_service_confirm.html')
210    html = html.replace('SERVICE_URL', service)
211    resp = exc.HTTPSeeOther(location=service)
212    if warn:
213        resp = Response()
214    # try to forbid caching of any type
215    resp.cache_control = 'no-store'
216    resp.pragma = 'no-cache'
217    # some arbitrary date in the past
218    resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
219    resp.text = html
220    if not sso:
221        resp = set_session_cookie(db, resp)
222    return resp
223
224
225def login_success_no_service(db, msg='', sso=False):
226    """Show logged-in screen after successful auth.
227
228    `sso` must be a boolean indicating whether login happened via
229    credentials (``False``) or via cookie (``True``).
230
231    Returns a response.
232    """
233    # show logged-in screen
234    html = get_template('login_successful.html')
235    html = html.replace('MSG_TEXT', msg)
236    resp = Response(html)
237    if not sso:
238        resp = set_session_cookie(db, resp)
239    return resp
240
241
242class CASServer(object):
243    """A WSGI CAS server.
244
245    This CAS server stores credential data (tickets, etc.) in a local
246    sqlite3 database file.
247
248    `db_path` -
249       The filesystem path to the database to use. If none is given
250       CAS server will create a new one in some new, temporary
251       directory. Please note that credentials will be lost after a
252       CAS server restart.
253
254       If the path is given and the file exists already, it will be
255       used.
256
257       If the database file does not exist, it will be created.
258    """
259    def __init__(self, db='sqlite:///:memory:', auth=None):
260        self.db_connection_string = db
261        self.db = DB(self.db_connection_string)
262        self.auth = auth
263
264    @wsgify
265    def __call__(self, req):
266        if req.path == '/style.css':
267            return get_template('style.css')
268        with DBSessionContext():
269            if req.path in ['/login', '/validate', '/logout']:
270                return getattr(self, req.path[1:])(req)
271        return exc.HTTPNotFound()
272
273    def _get_template(self, name):
274        path = os.path.join(template_dir, name)
275        if os.path.isfile(path):
276            return open(path, 'r').read()
277        return None
278
279    def login(self, req):
280        service = req.POST.get('service', req.GET.get('service', None))
281        renew = req.POST.get('renew', req.GET.get('renew', None))
282        warn = req.POST.get('warn', req.GET.get('warn', False))
283        gateway = req.POST.get('gateway', req.GET.get('gateway', None))
284        if renew is not None and gateway is not None:
285            gateway = None
286        service_field = ''
287        msg = ''
288        username = req.POST.get('username', None)
289        password = req.POST.get('password', None)
290        valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
291        tgc = check_session_cookie(self.db, req.cookies.get('cas-tgc', None))
292        if gateway and (not tgc) and service:
293            return login_redirect_service(
294                self.db, username, service, sso=True, create_ticket=False)
295        if tgc and (renew is None):
296            if service:
297                return login_redirect_service(
298                    self.db, username, service, sso=True, warn=warn)
299            else:
300                return login_success_no_service(
301                    self.db, 'You logged in already.', True)
302        if username and password and valid_lt:
303            # act as credentials acceptor
304            cred_ok, reason = self.auth.check_credentials(
305                username, password)
306            if cred_ok:
307                if service is None:
308                    # show logged-in screen
309                    return login_success_no_service(self.db, msg, False)
310                else:
311                    # safely redirect to service given
312                    return login_redirect_service(
313                        self.db, username, service, sso=False, warn=warn)
314            else:
315                # login failed
316                msg = '<i>Login failed</i><br />Reason: %s' % reason
317        if service is not None:
318            service_field = (
319                '<input type="hidden" name="service" value="%s" />' % (
320                    service)
321                )
322        lt = create_login_ticket()
323        self.db.add(lt)
324        html = self._get_template('login.html')
325        html = html.replace('LT_VALUE', lt.ticket)
326        html = html.replace('SERVICE_FIELD_VALUE', service_field)
327        html = html.replace('MSG_TEXT', msg)
328        return Response(html)
329
330    def validate(self, req):
331        service = req.POST.get('service', req.GET.get('service', None))
332        ticket = req.POST.get('ticket', req.GET.get('ticket', None))
333        renew = req.POST.get('renew', req.GET.get('renew', None))
334        renew = renew is not None
335        st = check_service_ticket(self.db, ticket, service, renew)
336        if st is not None:
337            return Response('yes' + chr(0x0a) + st.user + chr(0x0a))
338        return Response('no' + chr(0x0a) + chr(0x0a))
339
340    def logout(self, req):
341        url = req.GET.get('url', req.POST.get('url', None))
342        old_val = req.cookies.get('cas-tgc', None)
343        html = self._get_template('logout.html')
344        if url is not None:
345            html = self._get_template('logout_url.html')
346            html = html.replace('URL_HREF', url)
347        resp = Response(html)
348        delete_session_cookie(self.db, resp, old_val)
349        return resp
350
351
352cas_server = CASServer
353
354
355def make_cas_server(global_conf, **local_conf):
356    local_conf = get_authenticator(local_conf)
357    return CASServer(**local_conf)
Note: See TracBrowser for help on using the repository browser.