source: main/waeup.cas/trunk/waeup/cas/server.py @ 10396

Last change on this file since 10396 was 10394, checked in by uli, 12 years ago

Store work from last days.

File size: 6.1 KB
Line 
1"""A WSGI app for serving CAS.
2"""
3import os
4import random
5import time
6from webob import exc, Response
7from webob.dec import wsgify
8from waeup.cas.authenticators import get_authenticator
9from waeup.cas.db import (
10    DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
11
12template_dir = os.path.join(os.path.dirname(__file__), 'templates')
13
14RANDOM = random.SystemRandom(os.urandom(1024))
15
16#: The chars allowed by protocol specification for tickets and cookie
17#: values.
18ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
19            'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
20            '01234567789-')
21
22
23def get_random_string(length):
24    """Get a random string of length `length`.
25
26    The returned string should be hard to guess but is not
27    neccessarily unique.
28    """
29    return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
30
31
32def get_unique_string():
33    """Get a unique string based on current time.
34
35    The returned string contains only chars from `ALPHABET`.
36
37    We try to be unique by using a timestamp in high resolution, so
38    that even tickets created shortly after another should differ. On
39    very fast machines, however, this might be not enough (currently
40    we use 16 decimal places).
41
42    This is fast because we don't have to fetch foreign data sources
43    nor have to do database lookups.
44
45    The returned string will be unique but it won't be hard to guess
46    for people able to read a clock.
47    """
48    return ('%.16f' % time.time()).replace('.', '-')
49
50
51def create_service_ticket(user, service=None):
52    """Get a service ticket.
53
54    Ticket length will be 32 chars, randomly picked from `ALPHABET`.
55    """
56    t_id = 'ST-' + get_random_string(29)
57    return ServiceTicket(t_id, user, service)
58
59
60def create_login_ticket():
61    """Create a unique login ticket.
62
63    Login tickets are required to be unique (but not neccessarily hard
64    to guess), according to protocol specification.
65    """
66    t_id = 'LT-%s' % get_unique_string()
67    return LoginTicket(t_id)
68
69
70def check_login_ticket(db, lt_string):
71    """Check whether `lt_string` represents a valid login ticket in `db`.
72    """
73    if lt_string is None:
74        return False
75    q = db.query(LoginTicket).filter(LoginTicket.ticket == lt_string)
76    result = [x for x in q]
77    if result:
78        db.delete(result[0])
79    return len(result) > 0
80
81
82def create_tgc_value():
83    """Get a ticket granting cookie value.
84    """
85    value = 'TGC-' + get_random_string(128)
86    return TicketGrantingCookie(value)
87
88
89def set_session_cookie(response, db):
90    """Create a session cookie (ticket granting cookie) on `response`.
91
92    The `db` database is used to make the created cookie value
93    persistent.
94    """
95    tgc = create_tgc_value()
96    db.add(tgc)
97    response.set_cookie(
98        'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
99    return response
100
101
102class CASServer(object):
103    """A WSGI CAS server.
104
105    This CAS server stores credential data (tickets, etc.) in a local
106    sqlite3 database file.
107
108    `db_path` -
109       The filesystem path to the database to use. If none is given
110       CAS server will create a new one in some new, temporary
111       directory. Please note that credentials will be lost after a
112       CAS server restart.
113
114       If the path is given and the file exists already, it will be
115       used.
116
117       If the database file does not exist, it will be created.
118    """
119    def __init__(self, db='sqlite:///:memory:', auth=None):
120        self.db_connection_string = db
121        self.db = DB(self.db_connection_string)
122        self.auth = auth
123
124    @wsgify
125    def __call__(self, req):
126        with DBSessionContext():
127            if req.path in ['/login', '/validate', '/logout']:
128                return getattr(self, req.path[1:])(req)
129        return exc.HTTPNotFound()
130
131    def _get_template(self, name):
132        path = os.path.join(template_dir, name)
133        if os.path.isfile(path):
134            return open(path, 'r').read()
135        return None
136
137    def login(self, req):
138        service = req.POST.get('service', req.GET.get('service', None))
139        service_field = ''
140        username = req.POST.get('username', None)
141        password = req.POST.get('password', None)
142        valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
143        if username and password and valid_lt:
144            # act as credentials acceptor
145            cred_ok = self.auth.check_credentials(username, password)
146            if cred_ok:
147                if service is None:
148                    # show logged-in screen
149                    html = self._get_template('login_successful.html')
150                    resp = Response(html)
151                    resp = set_session_cookie(resp, self.db)
152                    return resp
153                else:
154                    # safely redirect to service given
155                    st = create_service_ticket(service)
156                    self.db.add(st)
157                    service = '%s?ticket=%s' % (service, st.ticket)
158                    html = self._get_template('login_service_redirect.html')
159                    html = html.replace('SERVICE_URL', service)
160                    resp = exc.HTTPSeeOther(location=service)
161                    resp.cache_control = 'no-store'
162                    resp.pragma = 'no-cache'
163                    # some arbitrary date in the past
164                    resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
165                    resp.text = html
166                    return resp
167        if service is not None:
168            service_field = (
169                '<input type="hidden" name="service" value="%s" />' % (
170                    service)
171                )
172        lt = create_login_ticket()
173        self.db.add(lt)
174        html = self._get_template('login.html')
175        html = html.replace('LT_VALUE', lt.ticket)
176        html = html.replace('SERVICE_FIELD_VALUE', service_field)
177        return Response(html)
178
179    def validate(self, req):
180        return exc.HTTPNotImplemented()
181
182    def logout(self, req):
183        return exc.HTTPNotImplemented()
184
185cas_server = CASServer
186
187
188def make_cas_server(global_conf, **local_conf):
189    local_conf = get_authenticator(local_conf)
190    return CASServer(**local_conf)
Note: See TracBrowser for help on using the repository browser.