source: main/waeup.cas/trunk/waeup/cas/server.py @ 10492

Last change on this file since 10492 was 10491, checked in by uli, 11 years ago

Add a helper to update URLs query strings correctly.

File size: 11.2 KB
Line 
1"""A WSGI app for serving CAS.
2"""
3import datetime
4import os
5import random
6import time
7import urllib
8from urlparse import urlparse, parse_qsl, urlunparse
9from webob import exc, Response
10from webob.dec import wsgify
11from waeup.cas.authenticators import get_authenticator
12from waeup.cas.db import (
13    DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
14
15template_dir = os.path.join(os.path.dirname(__file__), 'templates')
16
17RANDOM = random.SystemRandom(os.urandom(1024))
18
19#: The chars allowed by protocol specification for tickets and cookie
20#: values.
21ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
22            'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
23            '01234567789-')
24
25
26def get_random_string(length):
27    """Get a random string of length `length`.
28
29    The returned string should be hard to guess but is not
30    neccessarily unique.
31    """
32    return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
33
34
35def get_unique_string():
36    """Get a unique string based on current time.
37
38    The returned string contains only chars from `ALPHABET`.
39
40    We try to be unique by using a timestamp in high resolution, so
41    that even tickets created shortly after another should differ. On
42    very fast machines, however, this might be not enough (currently
43    we use 16 decimal places).
44
45    This is fast because we don't have to fetch foreign data sources
46    nor have to do database lookups.
47
48    The returned string will be unique but it won't be hard to guess
49    for people able to read a clock.
50    """
51    return ('%.16f' % time.time()).replace('.', '-')
52
53
54def create_service_ticket(user, service=None, sso=True):
55    """Get a service ticket.
56
57    Ticket length will be 32 chars, randomly picked from `ALPHABET`.
58    """
59    t_id = 'ST-' + get_random_string(29)
60    return ServiceTicket(t_id, user, service, sso)
61
62
63def check_service_ticket(db, ticket, service, renew=False):
64    """Check whether (`ticket`, `service`) represents a valid service
65    ticket in `db`.
66
67    Returns a database set or ``None``.
68    """
69    if None in (ticket, service):
70        return None
71    ticket, service = str(ticket), str(service)
72    q = db.query(ServiceTicket).filter(
73        ServiceTicket.ticket == ticket).filter(
74        ServiceTicket.service == service).first()
75    if renew and q.sso:
76        return None
77    return q
78
79
80def create_login_ticket():
81    """Create a unique login ticket.
82
83    Login tickets are required to be unique (but not neccessarily hard
84    to guess), according to protocol specification.
85    """
86    t_id = 'LT-%s' % get_unique_string()
87    return LoginTicket(t_id)
88
89
90def check_login_ticket(db, lt_string):
91    """Check whether `lt_string` represents a valid login ticket in `db`.
92    """
93    if lt_string is None:
94        return False
95    q = db.query(LoginTicket).filter(LoginTicket.ticket == str(lt_string))
96    result = [x for x in q]
97    if result:
98        db.delete(result[0])
99    return len(result) > 0
100
101
102def create_tgc_value():
103    """Get a ticket granting cookie value.
104    """
105    value = 'TGC-' + get_random_string(128)
106    return TicketGrantingCookie(value)
107
108
109def set_session_cookie(db, response):
110    """Create a session cookie (ticket granting cookie) on `response`.
111
112    The `db` database is used to make the created cookie value
113    persistent.
114    """
115    tgc = create_tgc_value()
116    db.add(tgc)
117    response.set_cookie(
118        'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
119    return response
120
121
122def delete_session_cookie(db, response, old_value=None):
123    """Delete session cookie.
124
125    Sets cookie with expiration date in past and deletes respective
126    entry from database.
127    """
128    if old_value is not None:
129        # delete old tgc from db
130        q = db.query(TicketGrantingCookie).filter(
131            TicketGrantingCookie.value == old_value)
132        result = list(q)
133        if len(result) == 1:
134            db.delete(result[0])
135    response.set_cookie(
136        'cas-tgc', '', path='/', secure=True, httponly=True,
137        expires=datetime.datetime(1970, 1, 1, 0, 0, 0))
138    return response
139
140
141def check_session_cookie(db, cookie_value):
142    """Check whether `cookie_value` represents a valid ticket granting
143    ticket in `db`.
144
145    `cookie_value` is a string representing a ticket granting ticket
146    maybe stored in `db`.
147
148    If a respective cookie can be found, a
149    :class:`waeup.cas.db.TicketGrantingCookie` is returend. Else
150    ``None`` is returned.
151    """
152    if cookie_value is None:
153        return None
154    try:
155        # turn value into unicode (py2.x) / str (py3.x)
156        cookie_value = cookie_value.decode('utf-8')
157    except AttributeError:                         # pragma: no cover
158        pass
159    q = db.query(TicketGrantingCookie).filter(
160        TicketGrantingCookie.value == cookie_value)
161    result = [x for x in q]
162    if len(result):
163        return result[0]
164    return None
165
166
167def get_template(name):
168    path = os.path.join(template_dir, name)
169    if os.path.isfile(path):
170        return open(path, 'r').read()
171    return None
172
173
174def update_url(url, params_dict):
175    """Update query params of an url.
176
177    The `url` is modified to have the query parameters set to
178    keys/values in `params_dict`, preserving any different existing
179    keys/values and overwriting any existing keys/values that are also
180    in `params_dict`.
181
182    Thus, ``'http://sample?a=1', dict(b='1')`` will result in
183    ``'http://sample?a=1&b=1`` and similar.
184    """
185    parts = [x for x in urlparse(url)]
186    old_params = dict(parse_qsl(parts[4]))
187    old_params.update(params_dict)
188    query_string = urllib.urlencode(old_params)
189    parts[4] = query_string
190    return urlunparse(parts)
191
192
193def login_redirect_service(db, service, sso=True, create_ticket=True,
194                           warn=False):
195    """Return a response redirecting to a service via HTTP 303 See Other.
196    """
197    if create_ticket:
198        st = create_service_ticket(service, sso)
199        db.add(st)
200        service = '%s?ticket=%s' % (service, st.ticket)
201    html = get_template('login_service_redirect.html')
202    if warn:
203        html = get_template('login_service_confirm.html')
204    html = html.replace('SERVICE_URL', service)
205    resp = exc.HTTPSeeOther(location=service)
206    if warn:
207        resp = Response()
208    # try to forbid caching of any type
209    resp.cache_control = 'no-store'
210    resp.pragma = 'no-cache'
211    # some arbitrary date in the past
212    resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
213    resp.text = html
214    if not sso:
215        resp = set_session_cookie(db, resp)
216    return resp
217
218
219def login_success_no_service(db, msg='', sso=False):
220    """Show logged-in screen after successful auth.
221
222    `sso` must be a boolean indicating whether login happened via
223    credentials (``False``) or via cookie (``True``).
224
225    Returns a response.
226    """
227    # show logged-in screen
228    html = get_template('login_successful.html')
229    html = html.replace('MSG_TEXT', msg)
230    resp = Response(html)
231    if not sso:
232        resp = set_session_cookie(db, resp)
233    return resp
234
235
236class CASServer(object):
237    """A WSGI CAS server.
238
239    This CAS server stores credential data (tickets, etc.) in a local
240    sqlite3 database file.
241
242    `db_path` -
243       The filesystem path to the database to use. If none is given
244       CAS server will create a new one in some new, temporary
245       directory. Please note that credentials will be lost after a
246       CAS server restart.
247
248       If the path is given and the file exists already, it will be
249       used.
250
251       If the database file does not exist, it will be created.
252    """
253    def __init__(self, db='sqlite:///:memory:', auth=None):
254        self.db_connection_string = db
255        self.db = DB(self.db_connection_string)
256        self.auth = auth
257
258    @wsgify
259    def __call__(self, req):
260        with DBSessionContext():
261            if req.path in ['/login', '/validate', '/logout']:
262                return getattr(self, req.path[1:])(req)
263        return exc.HTTPNotFound()
264
265    def _get_template(self, name):
266        path = os.path.join(template_dir, name)
267        if os.path.isfile(path):
268            return open(path, 'r').read()
269        return None
270
271    def login(self, req):
272        service = req.POST.get('service', req.GET.get('service', None))
273        renew = req.POST.get('renew', req.GET.get('renew', None))
274        warn = req.POST.get('warn', req.GET.get('warn', False))
275        gateway = req.POST.get('gateway', req.GET.get('gateway', None))
276        if renew is not None and gateway is not None:
277            gateway = None
278        service_field = ''
279        msg = ''
280        username = req.POST.get('username', None)
281        password = req.POST.get('password', None)
282        valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
283        tgc = check_session_cookie(self.db, req.cookies.get('cas-tgc', None))
284        if gateway and (not tgc) and service:
285            return login_redirect_service(
286                self.db, service, sso=True, create_ticket=False)
287        if tgc and (renew is None):
288            if service:
289                return login_redirect_service(
290                    self.db, service, sso=True, warn=warn)
291            else:
292                return login_success_no_service(
293                    self.db, 'You logged in already.', True)
294        if username and password and valid_lt:
295            # act as credentials acceptor
296            cred_ok, reason = self.auth.check_credentials(
297                username, password)
298            if cred_ok:
299                if service is None:
300                    # show logged-in screen
301                    return login_success_no_service(self.db, msg, False)
302                else:
303                    # safely redirect to service given
304                    return login_redirect_service(
305                        self.db, service, sso=False, warn=warn)
306            else:
307                # login failed
308                msg = '<i>Login failed</i><br />Reason: %s' % reason
309        if service is not None:
310            service_field = (
311                '<input type="hidden" name="service" value="%s" />' % (
312                    service)
313                )
314        lt = create_login_ticket()
315        self.db.add(lt)
316        html = self._get_template('login.html')
317        html = html.replace('LT_VALUE', lt.ticket)
318        html = html.replace('SERVICE_FIELD_VALUE', service_field)
319        html = html.replace('MSG_TEXT', msg)
320        return Response(html)
321
322    def validate(self, req):
323        service = req.POST.get('service', req.GET.get('service', None))
324        ticket = req.POST.get('ticket', req.GET.get('ticket', None))
325        renew = req.POST.get('renew', req.GET.get('renew', None))
326        renew = renew is not None
327        st = check_service_ticket(self.db, ticket, service, renew)
328        if st is not None:
329            return Response('yes' + chr(0x0a) + st.user + chr(0x0a))
330        return Response('no' + chr(0x0a) + chr(0x0a))
331
332    def logout(self, req):
333        url = req.GET.get('url', req.POST.get('url', None))
334        old_val = req.cookies.get('cas-tgc', None)
335        html = self._get_template('logout.html')
336        if url is not None:
337            html = self._get_template('logout_url.html')
338            html = html.replace('URL_HREF', url)
339        resp = Response(html)
340        delete_session_cookie(self.db, resp, old_val)
341        return resp
342
343
344cas_server = CASServer
345
346
347def make_cas_server(global_conf, **local_conf):
348    local_conf = get_authenticator(local_conf)
349    return CASServer(**local_conf)
Note: See TracBrowser for help on using the repository browser.