"""A WSGI app for serving CAS.
"""
import os
import random
import time
from webob import exc, Response
from webob.dec import wsgify
from waeup.cas.authenticators import get_authenticator
from waeup.cas.db import (
DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
template_dir = os.path.join(os.path.dirname(__file__), 'templates')
RANDOM = random.SystemRandom(os.urandom(1024))
#: The chars allowed by protocol specification for tickets and cookie
#: values.
ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
'01234567789-')
def get_random_string(length):
"""Get a random string of length `length`.
The returned string should be hard to guess but is not
neccessarily unique.
"""
return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
def get_unique_string():
"""Get a unique string based on current time.
The returned string contains only chars from `ALPHABET`.
We try to be unique by using a timestamp in high resolution, so
that even tickets created shortly after another should differ. On
very fast machines, however, this might be not enough (currently
we use 16 decimal places).
This is fast because we don't have to fetch foreign data sources
nor have to do database lookups.
The returned string will be unique but it won't be hard to guess
for people able to read a clock.
"""
return ('%.16f' % time.time()).replace('.', '-')
def create_service_ticket(user, service=None):
"""Get a service ticket.
Ticket length will be 32 chars, randomly picked from `ALPHABET`.
"""
t_id = 'ST-' + get_random_string(29)
return ServiceTicket(t_id, user, service)
def create_login_ticket():
"""Create a unique login ticket.
Login tickets are required to be unique (but not neccessarily hard
to guess), according to protocol specification.
"""
t_id = 'LT-%s' % get_unique_string()
return LoginTicket(t_id)
def check_login_ticket(db, lt_string):
"""Check whether `lt_string` represents a valid login ticket in `db`.
"""
if lt_string is None:
return False
q = db.query(LoginTicket).filter(LoginTicket.ticket == lt_string)
result = [x for x in q]
if result:
db.delete(result[0])
return len(result) > 0
def create_tgc_value():
"""Get a ticket granting cookie value.
"""
value = 'TGC-' + get_random_string(128)
return TicketGrantingCookie(value)
def set_session_cookie(response, db):
"""Create a session cookie (ticket granting cookie) on `response`.
The `db` database is used to make the created cookie value
persistent.
"""
tgc = create_tgc_value()
db.add(tgc)
response.set_cookie(
'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
return response
def check_session_cookie(db, cookie_value):
"""Check whether `cookie_value` represents a valid ticket granting
ticket in `db`.
"""
if cookie_value is None:
return False
try:
# turn value into unicode (py2.x) / str (py3.x)
cookie_value = cookie_value.decode('utf-8')
except AttributeError: # pragma: no cover
pass
q = db.query(TicketGrantingCookie).filter(
TicketGrantingCookie.value == cookie_value)
result = [x for x in q]
if len(result):
return result[0]
return None
class CASServer(object):
"""A WSGI CAS server.
This CAS server stores credential data (tickets, etc.) in a local
sqlite3 database file.
`db_path` -
The filesystem path to the database to use. If none is given
CAS server will create a new one in some new, temporary
directory. Please note that credentials will be lost after a
CAS server restart.
If the path is given and the file exists already, it will be
used.
If the database file does not exist, it will be created.
"""
def __init__(self, db='sqlite:///:memory:', auth=None):
self.db_connection_string = db
self.db = DB(self.db_connection_string)
self.auth = auth
@wsgify
def __call__(self, req):
with DBSessionContext():
if req.path in ['/login', '/validate', '/logout']:
return getattr(self, req.path[1:])(req)
return exc.HTTPNotFound()
def _get_template(self, name):
path = os.path.join(template_dir, name)
if os.path.isfile(path):
return open(path, 'r').read()
return None
def login(self, req):
service = req.POST.get('service', req.GET.get('service', None))
service_field = ''
msg = ''
username = req.POST.get('username', None)
password = req.POST.get('password', None)
valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
tgc = req.cookies.get('cas-tgc', None)
tgc = check_session_cookie(self.db, tgc)
if username and password and valid_lt or tgc:
# act as credentials acceptor
if tgc:
cred_ok, reason = True, ''
if not service:
msg = 'You logged in already.'
else:
cred_ok, reason = self.auth.check_credentials(
username, password)
if cred_ok:
if service is None:
# show logged-in screen
html = self._get_template('login_successful.html')
html = html.replace('MSG_TEXT', msg)
resp = Response(html)
if not tgc:
resp = set_session_cookie(resp, self.db)
return resp
else:
# safely redirect to service given
st = create_service_ticket(service)
self.db.add(st)
service = '%s?ticket=%s' % (service, st.ticket)
html = self._get_template('login_service_redirect.html')
html = html.replace('SERVICE_URL', service)
resp = exc.HTTPSeeOther(location=service)
resp.cache_control = 'no-store'
resp.pragma = 'no-cache'
# some arbitrary date in the past
resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
resp.text = html
return resp
else:
# login failed
msg = 'Login failed Reason: %s' % reason
if service is not None:
service_field = (
'' % (
service)
)
lt = create_login_ticket()
self.db.add(lt)
html = self._get_template('login.html')
html = html.replace('LT_VALUE', lt.ticket)
html = html.replace('SERVICE_FIELD_VALUE', service_field)
html = html.replace('MSG_TEXT', msg)
return Response(html)
def validate(self, req):
return exc.HTTPNotImplemented()
def logout(self, req):
return exc.HTTPNotImplemented()
cas_server = CASServer
def make_cas_server(global_conf, **local_conf):
local_conf = get_authenticator(local_conf)
return CASServer(**local_conf)