Changeset 10394 for main/waeup.cas


Ignore:
Timestamp:
4 Jul 2013, 09:04:41 (11 years ago)
Author:
uli
Message:

Store work from last days.

Location:
main/waeup.cas/trunk
Files:
4 added
7 edited

Legend:

Unmodified
Added
Removed
  • main/waeup.cas/trunk/setup.py

    r10357 r10394  
    2222    'setuptools',
    2323    'webob',
     24    'SQLAlchemy',
    2425    ]
    25 
    26 if sys.version_info < (3, 1):
    27     install_requires.append('pysqlite')  # included in Python >= 3.1
    2826
    2927tests_require = [
     
    3129    'pytest-cov',
    3230    'PasteDeploy',
     31    'WebTest',
    3332    ]
    3433
     
    6766    entry_points="""[paste.app_factory]
    6867    server = waeup.cas:make_cas_server
     68    [waeup.cas.authenticators]
     69    dummy = waeup.cas.authenticators:DummyAuthenticator
    6970    """,
    7071)
  • main/waeup.cas/trunk/waeup/cas/db.py

    r10344 r10394  
    11# components to store tickets in a database
    22# we use sqlite3
    3 import os
    4 import sqlite3
     3import time
     4from sqlalchemy import (
     5    create_engine, Column, Integer, Float, String, MetaData)
     6from sqlalchemy.ext.declarative import declarative_base
     7from sqlalchemy.orm import sessionmaker, scoped_session
    58
    69
    7 def create_db(path):
    8     if os.path.exists(path):
    9         return
    10     conn = sqlite3.connect(path)
    11     conn.close()
    12     return
     10Base = declarative_base()
     11
     12
     13class ServiceTicket(Base):
     14    __tablename__ = 'service_tickets'
     15
     16    id = Column(Integer, primary_key=True)
     17    ticket = Column(String(255), nullable=False)
     18    ts = Column(Float, nullable=False)
     19    service = Column(String(2048), nullable=True)
     20    user = Column(String(512), nullable=False)
     21
     22    def __init__(self, ticket, user, service=None, timestamp=None):
     23        if timestamp is None:
     24            timestamp = time.time()
     25        self.ticket = ticket
     26        self.ts = timestamp
     27        self.service = service
     28        self.user = user
     29
     30    def __repr__(self):
     31        return "ServiceTicket('%s', '%s', '%s', %s)" % (
     32            self.ticket, self.user, self.service, self.ts)
     33
     34
     35class LoginTicket(Base):
     36    __tablename__ = 'login_tickets'
     37
     38    ticket = Column(String(255), primary_key=True)
     39    ts = Column(Float, nullable=False)
     40
     41    def __init__(self, ticket, timestamp=None):
     42        if timestamp is None:
     43            timestamp = time.time()
     44        self.ticket = ticket
     45        self.ts = timestamp
     46
     47    def __repr__(self):
     48        return "LoginTicket('%s', %s)" % (self.ticket, self.ts)
     49
     50
     51class TicketGrantingCookie(Base):
     52    __tablename__ = 'ticket_granting_cookies'
     53
     54    value = Column(String(255), primary_key=True)
     55    ts = Column(Float, nullable=False)
     56
     57    def __init__(self, value, timestamp=None):
     58        if timestamp is None:
     59            timestamp = time.time()
     60        self.value = value
     61        self.ts = timestamp
     62
     63    def __repr__(self):
     64        return "TicketGrantingCookie('%s', %s)" % (self.value, self.ts)
     65
     66
     67Session = scoped_session(sessionmaker())
     68
     69
     70class DB(object):
     71    """Abstract database to make data persistent.
     72
     73    Creates a database identified by `connection_string` if it does
     74    not exist and initializes sessions.
     75    """
     76    @property
     77    def session(self):
     78        return Session
     79
     80    def __init__(self, connection_string):
     81        self.engine = create_engine(connection_string)
     82        self.metadata = MetaData()
     83        Base.metadata.create_all(self.engine)
     84        Session.configure(bind=self.engine)
     85
     86    def add(self, item):
     87        """Insert `item` into database.
     88
     89        `item` must be an instance of the database content types
     90        defined in this module, normally some `Ticket` instance.
     91        """
     92        Session.add(item)
     93        Session.commit()
     94
     95    def delete(self, item):
     96        """Delete `item` from database.
     97
     98        `item must be an instance of the database content types
     99        defined in this module, normally some `Ticket` instance.
     100        """
     101        Session.delete(item)
     102        Session.commit()
     103
     104    def query(self, *args, **kw):
     105        return Session.query(*args, **kw)
     106
     107
     108class DBSessionContext(object):
     109    """A context manager providing database sessions.
     110
     111    Creates a new (threadlocal) database session when entering and
     112    tears down the session after leaving the context.
     113
     114    Tearing down includes committing transactions and closing the
     115    connection.
     116
     117    Meant to be used as a wrapper for web request handlers, such, that
     118    a session is created when a request comes in and released when
     119    the response is ready to be delivered.
     120
     121    This context manager does *not* catch any exceptions.
     122    """
     123    def __enter__(self):
     124        return Session()
     125
     126    def __exit__(self, *args, **kw):
     127        if args or kw:
     128            Session.rollback()
     129        Session.remove()
     130        return False
  • main/waeup.cas/trunk/waeup/cas/server.py

    r10351 r10394  
    22"""
    33import os
    4 import tempfile
     4import random
     5import time
    56from webob import exc, Response
    67from webob.dec import wsgify
     8from waeup.cas.authenticators import get_authenticator
     9from waeup.cas.db import (
     10    DB, DBSessionContext, LoginTicket, ServiceTicket, TicketGrantingCookie)
    711
    812template_dir = os.path.join(os.path.dirname(__file__), 'templates')
     13
     14RANDOM = random.SystemRandom(os.urandom(1024))
     15
     16#: The chars allowed by protocol specification for tickets and cookie
     17#: values.
     18ALPHABET = ('abcdefghijklmnopqrstuvwxyz'
     19            'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
     20            '01234567789-')
     21
     22
     23def get_random_string(length):
     24    """Get a random string of length `length`.
     25
     26    The returned string should be hard to guess but is not
     27    neccessarily unique.
     28    """
     29    return ''.join([RANDOM.choice(ALPHABET) for x in range(length)])
     30
     31
     32def get_unique_string():
     33    """Get a unique string based on current time.
     34
     35    The returned string contains only chars from `ALPHABET`.
     36
     37    We try to be unique by using a timestamp in high resolution, so
     38    that even tickets created shortly after another should differ. On
     39    very fast machines, however, this might be not enough (currently
     40    we use 16 decimal places).
     41
     42    This is fast because we don't have to fetch foreign data sources
     43    nor have to do database lookups.
     44
     45    The returned string will be unique but it won't be hard to guess
     46    for people able to read a clock.
     47    """
     48    return ('%.16f' % time.time()).replace('.', '-')
     49
     50
     51def create_service_ticket(user, service=None):
     52    """Get a service ticket.
     53
     54    Ticket length will be 32 chars, randomly picked from `ALPHABET`.
     55    """
     56    t_id = 'ST-' + get_random_string(29)
     57    return ServiceTicket(t_id, user, service)
     58
     59
     60def create_login_ticket():
     61    """Create a unique login ticket.
     62
     63    Login tickets are required to be unique (but not neccessarily hard
     64    to guess), according to protocol specification.
     65    """
     66    t_id = 'LT-%s' % get_unique_string()
     67    return LoginTicket(t_id)
     68
     69
     70def check_login_ticket(db, lt_string):
     71    """Check whether `lt_string` represents a valid login ticket in `db`.
     72    """
     73    if lt_string is None:
     74        return False
     75    q = db.query(LoginTicket).filter(LoginTicket.ticket == lt_string)
     76    result = [x for x in q]
     77    if result:
     78        db.delete(result[0])
     79    return len(result) > 0
     80
     81
     82def create_tgc_value():
     83    """Get a ticket granting cookie value.
     84    """
     85    value = 'TGC-' + get_random_string(128)
     86    return TicketGrantingCookie(value)
     87
     88
     89def set_session_cookie(response, db):
     90    """Create a session cookie (ticket granting cookie) on `response`.
     91
     92    The `db` database is used to make the created cookie value
     93    persistent.
     94    """
     95    tgc = create_tgc_value()
     96    db.add(tgc)
     97    response.set_cookie(
     98        'cas-tgc', tgc.value, path='/', secure=True, httponly=True)
     99    return response
    9100
    10101
     
    26117       If the database file does not exist, it will be created.
    27118    """
    28     def __init__(self, db_path=None):
    29         if db_path is None:
    30             db_path = os.path.join(tempfile.mkdtemp(), 'cas.db')
    31         self.db_path = db_path
     119    def __init__(self, db='sqlite:///:memory:', auth=None):
     120        self.db_connection_string = db
     121        self.db = DB(self.db_connection_string)
     122        self.auth = auth
    32123
    33124    @wsgify
    34125    def __call__(self, req):
    35         if req.path in ['/login', '/validate', '/logout']:
    36             return getattr(self, req.path[1:])(req)
     126        with DBSessionContext():
     127            if req.path in ['/login', '/validate', '/logout']:
     128                return getattr(self, req.path[1:])(req)
    37129        return exc.HTTPNotFound()
    38130
     131    def _get_template(self, name):
     132        path = os.path.join(template_dir, name)
     133        if os.path.isfile(path):
     134            return open(path, 'r').read()
     135        return None
     136
    39137    def login(self, req):
    40         return Response(
    41             open(os.path.join(template_dir, 'login.html'), 'r').read())
     138        service = req.POST.get('service', req.GET.get('service', None))
     139        service_field = ''
     140        username = req.POST.get('username', None)
     141        password = req.POST.get('password', None)
     142        valid_lt = check_login_ticket(self.db, req.POST.get('lt'))
     143        if username and password and valid_lt:
     144            # act as credentials acceptor
     145            cred_ok = self.auth.check_credentials(username, password)
     146            if cred_ok:
     147                if service is None:
     148                    # show logged-in screen
     149                    html = self._get_template('login_successful.html')
     150                    resp = Response(html)
     151                    resp = set_session_cookie(resp, self.db)
     152                    return resp
     153                else:
     154                    # safely redirect to service given
     155                    st = create_service_ticket(service)
     156                    self.db.add(st)
     157                    service = '%s?ticket=%s' % (service, st.ticket)
     158                    html = self._get_template('login_service_redirect.html')
     159                    html = html.replace('SERVICE_URL', service)
     160                    resp = exc.HTTPSeeOther(location=service)
     161                    resp.cache_control = 'no-store'
     162                    resp.pragma = 'no-cache'
     163                    # some arbitrary date in the past
     164                    resp.expires = 'Thu, 01 Dec 1994 16:00:00 GMT'
     165                    resp.text = html
     166                    return resp
     167        if service is not None:
     168            service_field = (
     169                '<input type="hidden" name="service" value="%s" />' % (
     170                    service)
     171                )
     172        lt = create_login_ticket()
     173        self.db.add(lt)
     174        html = self._get_template('login.html')
     175        html = html.replace('LT_VALUE', lt.ticket)
     176        html = html.replace('SERVICE_FIELD_VALUE', service_field)
     177        return Response(html)
    42178
    43179    def validate(self, req):
     
    51187
    52188def make_cas_server(global_conf, **local_conf):
     189    local_conf = get_authenticator(local_conf)
    53190    return CASServer(**local_conf)
  • main/waeup.cas/trunk/waeup/cas/templates/login.html

    r10335 r10394  
    88      <input type="text" name="username" />
    99      <input type="password" name="password" />
    10       <input type="submit" value="Login" />
     10      <input type="hidden" name="lt" value="LT_VALUE" />
     11      SERVICE_FIELD_VALUE
     12      <input type="submit" name="AUTHENTICATE" value="Login" />
    1113    </form>
    1214  </body>
  • main/waeup.cas/trunk/waeup/cas/tests/sample2.ini

    r10352 r10394  
    33[app:main]
    44use = egg:waeup.cas#server
    5 # what file to use to store tickets, credentials, etc.
    6 db_path = /a/path/to/cas.db
     5# SQLAlchemy connection string
     6# By default we get an in-memory SQLite database
     7# Samples:
     8#   sqlite in-memory db:  sqlite://:memory:
     9#   sqlite file-based db: sqlite:////path/to/db
     10db = sqlite:///:memory:
     11auth = dummy
  • main/waeup.cas/trunk/waeup/cas/tests/test_db.py

    r10349 r10394  
    55import tempfile
    66import unittest
    7 from waeup.cas.db import create_db
     7from sqlalchemy.engine import Engine
     8from waeup.cas.db import (
     9    DB, LoginTicket, ServiceTicket, TicketGrantingCookie)
    810
    911
     
    1719    def test_create_db(self):
    1820        # we can create a database
    19         db_path = os.path.join(self.workdir, 'sample-cas.db')
    20         create_db(db_path)
    21         assert os.path.isfile(db_path)
    22         conn = sqlite3.connect(db_path)
    23         assert conn is not None
    24         conn.close()
     21        conn_string = 'sqlite:///'
     22        db = DB(conn_string)
     23        assert hasattr(db, 'engine')
     24        assert isinstance(db.engine, Engine)
    2525
    2626    def test_create_db_exists(self):
    2727        # an already existing db will be left untouched
    2828        db_path = os.path.join(self.workdir, 'sample-cas.db')
    29         create_db(db_path)
     29        conn_string = 'sqlite:///%s' % db_path
     30        assert os.path.isdir(self.workdir)
     31        DB(conn_string)  # create database
    3032        conn = sqlite3.connect(db_path)
    31         cursor = conn.cursor()
    32         cursor.execute('''CREATE TABLE mytest (name text, forname text)''')
    33         cursor.execute('''INSERT INTO mytest VALUES ("Foo", "Bar")''')
    34         conn.commit()
     33        with conn:
     34            conn.execute('''CREATE TABLE mytest (name text, forname text)''')
     35            conn.execute('''INSERT INTO mytest VALUES ("Foo", "Bar")''')
    3536        conn.close()
    3637
    37         create_db(db_path)
    3838        conn = sqlite3.connect(db_path)
    39         cursor = conn.cursor()
    40         cursor.execute('''SELECT * FROM mytest''')
    41         assert cursor.fetchall() == [('Foo', 'Bar')]
     39        with conn:
     40            result = [x for x in conn.execute('''SELECT * FROM mytest''')]
     41        conn.close()
     42        assert result == [('Foo', 'Bar')]
     43
     44    def test_create_db_populates_db(self):
     45        # created DBs will be populated with the required tables
     46        db_path = os.path.join(self.workdir, 'sample-cas.db')
     47        conn_string = 'sqlite:///%s' % db_path
     48        DB(conn_string)  # creates database
     49        conn = sqlite3.connect(db_path)
     50        with conn:
     51            result = [x[:2] for x in conn.execute(
     52                '''SELECT * FROM sqlite_master ORDER BY type''')]
     53        conn.close()
     54        assert ('table', 'service_tickets') in result
     55        assert ('table', 'login_tickets') in result
     56
     57
     58class TicketTests(unittest.TestCase):
     59
     60    def setUp(self):
     61        self.db = DB('sqlite:///')
     62        self.db.session()
     63
     64    def tearDown(self):
     65        self.db.session.remove()
     66
     67    def test_login_ticket_add(self):
     68        # we can add login tickets
     69        assert self.db.engine.has_table('login_tickets')
     70
     71        self.db.add(LoginTicket('foo'))
     72        assert [x.ticket for x in self.db.query(LoginTicket)] == ['foo']
     73
     74    def test_login_ticket_delete(self):
     75        # we can delete single login tickets
     76        assert self.db.engine.has_table('login_tickets')
     77
     78        self.db.add(LoginTicket('foo'))
     79        contents = [x for x in self.db.query(LoginTicket)]
     80        assert len(contents) == 1
     81        lt = contents[0]
     82        self.db.delete(lt)
     83        assert [x.ticket for x in self.db.query(LoginTicket)] == []
     84
     85    def test_login_ticket_repr(self):
     86        # we can get a proper LoginTicket representation
     87        # (i.e. one that can be fed to `eval`)
     88        ticket = LoginTicket('foo', 12.1)
     89        assert ticket.__repr__() == "LoginTicket('foo', 12.1)"
     90
     91    def test_login_ticket_timestamp(self):
     92        # we get a timestamp stored, if none is passed to init
     93        lticket = LoginTicket('foo')
     94        assert isinstance(lticket.ts, float)
     95
     96    def test_add_service_ticket(self):
     97        # we can add service tickets
     98        self.db.add(ServiceTicket('foo', 'bar', 'baz', 12.1))
     99        result = [(x.ticket, x.user, x.service, x.ts)
     100                  for x in self.db.query(ServiceTicket)]
     101        assert result == [('foo', 'bar', 'baz', 12.1)]
     102
     103    def test_service_ticket_repr(self):
     104        # we can get a proper ServiceTicket representation
     105        # (i.e. one that can be fed to `eval`)
     106        sticket = ServiceTicket('foo', 'bar', 'baz', 12.1)
     107        st_repr = sticket.__repr__()
     108        assert st_repr == "ServiceTicket('foo', 'bar', 'baz', 12.1)"
     109
     110    def test_service_ticket_timestamp(self):
     111        # we get a timestamp stored, if none is passed to init
     112        sticket = ServiceTicket('foo', 'bar', 'baz')
     113        assert isinstance(sticket.ts, float)
     114
     115    def test_ticket_granting_cookie_add(self):
     116        # we can add ticket granting cookies
     117        assert self.db.engine.has_table('ticket_granting_cookies')
     118
     119        self.db.add(TicketGrantingCookie('foo'))
     120        assert [x.value for x in self.db.query(
     121            TicketGrantingCookie)] == ['foo']
     122
     123    def test_ticket_granting_cookie_repr(self):
     124        # we can get a proper ticket-granting cookie representation
     125        # (i.e. one that can be fed to `eval`)
     126        tgc = TicketGrantingCookie('foo', 12.1)
     127        assert tgc.__repr__() == "TicketGrantingCookie('foo', 12.1)"
     128
     129    def test_ticket_granting_cookie_timestamp(self):
     130        # we get a timestamp stored, if none is passed to init
     131        tgc = TicketGrantingCookie('foo')
     132        assert isinstance(tgc.ts, float)
  • main/waeup.cas/trunk/waeup/cas/tests/test_server.py

    r10352 r10394  
    11import os
     2import re
    23import shutil
    34import tempfile
    45import unittest
    56from paste.deploy import loadapp
    6 from webob import Request
    7 from waeup.cas.server import CASServer
     7from webob import Request, Response
     8from webtest import TestApp as WebTestApp  # avoid py.test skip message
     9from waeup.cas.authenticators import DummyAuthenticator
     10from waeup.cas.db import DB, LoginTicket, ServiceTicket, TicketGrantingCookie
     11from waeup.cas.server import (
     12    CASServer, create_service_ticket, create_login_ticket,
     13    create_tgc_value, check_login_ticket, set_session_cookie,
     14    )
     15
     16
     17RE_ALPHABET = re.compile('^[a-zA-Z0-9\-]*$')
     18RE_COOKIE = re.compile('^cas-tgc=[A-Za-z0-9\-]+; Path=/; secure; HttpOnly$')
    819
    920
     
    1930        self.workdir = os.path.join(self._new_tmpdir, 'home')
    2031        self.db_path = os.path.join(self.workdir, 'mycas.db')
     32        os.mkdir(self.workdir)
    2133        self.paste_conf1 = os.path.join(
    2234            os.path.dirname(__file__), 'sample1.ini')
     
    3446        app = loadapp('config:%s' % self.paste_conf1)
    3547        assert isinstance(app, CASServer)
    36         assert hasattr(app, 'db_path')
    37         assert app.db_path is not None
     48        assert hasattr(app, 'db')
     49        assert isinstance(app.db, DB)
     50        assert hasattr(app, 'auth')
    3851
    3952    def test_paste_deploy_options(self):
     
    4154        app = loadapp('config:%s' % self.paste_conf2)
    4255        assert isinstance(app, CASServer)
    43         assert app.db_path == '/a/path/to/cas.db'
     56        assert app.db_connection_string == 'sqlite:///:memory:'
     57        assert isinstance(app.auth, DummyAuthenticator)
    4458
    4559    def test_init(self):
    46         # we get a db dir set automatically
    47         app = CASServer()
    48         assert hasattr(app, 'db_path')
    49         assert app.db_path is not None
    50         assert os.path.exists(os.path.dirname(app.db_path))
     60        # we get a `DB` instance created automatically
     61        app = CASServer()
     62        assert hasattr(app, 'db')
     63        assert app.db is not None
    5164
    5265    def test_init_explicit_db_path(self):
    5366        # we can set a db_path explicitly
    54         app = CASServer(db_path=self.db_path)
    55         assert hasattr(app, 'db_path')
    56         assert app.db_path == self.db_path
     67        app = CASServer(db='sqlite:///%s' % self.db_path)
     68        assert hasattr(app, 'db')
     69        assert isinstance(app.db, DB)
     70        assert os.path.isfile(self.db_path)
     71
     72    def test_get_template(self):
     73        app = CASServer()
     74        assert app._get_template('login.html') is not None
     75        assert app._get_template('not-existent.html') is None
    5776
    5877    def test_call_root(self):
     
    93112        assert resp.content_type == 'text/html'
    94113        assert b'<form ' in resp.body
     114
     115
     116class BrowserTests(unittest.TestCase):
     117
     118    def setUp(self):
     119        self.raw_app = CASServer(auth=DummyAuthenticator())
     120        self.app = WebTestApp(self.raw_app)
     121
     122    def test_login(self):
     123        resp = self.app.get('/login')
     124        assert resp.status == '200 OK'
     125        form = resp.forms[0]
     126        # 2.1.3: form must be submitted by POST
     127        assert form.method == 'post'
     128        fieldnames = form.fields.keys()
     129        # 2.1.3: form must contain: username, password, lt
     130        assert 'username' in fieldnames
     131        assert 'password' in fieldnames
     132        assert 'lt' in fieldnames
     133        assert RE_ALPHABET.match(form['lt'].value)
     134
     135    def test_login_no_service(self):
     136        # w/o a service passed in, the form should not contain service
     137        # (not a strict protocol requirement, but handy)
     138        resp = self.app.get('/login')
     139        assert 'service' not in resp.forms[0].fields.keys()
     140
     141    def test_login_service_replayed(self):
     142        # 2.1.3: the login form must contain the service param sent
     143        resp = self.app.get('/login?service=http%3A%2F%2Fwww.service.com')
     144        form = resp.forms[0]
     145        assert resp.status == '200 OK'
     146        assert 'service' in form.fields.keys()
     147        assert form['service'].value == 'http://www.service.com'
     148
     149    def test_login_cred_acceptor_valid_no_service(self):
     150        # 2.2.4: successful login w/o service yields a message
     151        lt = create_login_ticket()
     152        self.raw_app.db.add(lt)
     153        lt_string = lt.ticket
     154        resp = self.app.post('/login', dict(
     155            username='bird', password='bebop', lt=lt_string))
     156        assert resp.status == '200 OK'
     157        assert b'successful' in resp.body
     158        # single-sign-on session initiated
     159        assert 'Set-Cookie' in resp.headers
     160        cookie = resp.headers['Set-Cookie']
     161        assert cookie.startswith('cas-tgc=')
     162
     163    def test_login_cred_acceptor_valid_w_service(self):
     164        # 2.2.4: successful login with service makes a redirect
     165        # Appendix B: safe redirect
     166        lt = create_login_ticket()
     167        self.raw_app.db.add(lt)
     168        lt_string = lt.ticket
     169        resp = self.app.post('/login', dict(
     170            username='bird', password='bebop', lt=lt_string,
     171            service='http://example.com/Login'))
     172        assert resp.status == '303 See Other'
     173        assert 'Location' in resp.headers
     174        assert resp.headers['Location'].startswith(
     175            'http://example.com/Login?ticket=ST-')
     176        assert 'Pragma' in resp.headers
     177        assert resp.headers['Pragma'] == 'no-cache'
     178        assert 'Cache-Control' in resp.headers
     179        assert resp.headers['Cache-Control'] == 'no-store'
     180        assert 'Expires' in resp.headers
     181        assert resp.headers['Expires'] == 'Thu, 01 Dec 1994 16:00:00 GMT'
     182        assert b'window.location.href' in resp.body
     183        assert b'noscript' in resp.body
     184        assert b'ticket=ST-' in resp.body
     185
     186
     187class CASServerHelperTests(unittest.TestCase):
     188
     189    def setUp(self):
     190        self.workdir = tempfile.mkdtemp()
     191        self.db_file = os.path.join(self.workdir, 'mycas.db')
     192        self.conn_string = 'sqlite:///%s' % self.db_file
     193        self.db = DB(self.conn_string)
     194
     195    def tearDown(self):
     196        shutil.rmtree(self.workdir)
     197
     198    def test_create_service_ticket(self):
     199        # we can create service tickets
     200        st = create_service_ticket(
     201            user='bob', service='http://www.example.com')
     202        assert isinstance(st, ServiceTicket)
     203        # 3.1.1: service not part of ticket
     204        assert 'example.com' not in st.ticket
     205        # 3.1.1: ticket must start with 'ST-'
     206        assert st.ticket.startswith('ST-')
     207        # 3.1.1: min. ticket length clients must be able to process is 32
     208        assert len(st.ticket) < 33
     209        # 3.7: allowed character set == [a-zA-Z0-9\-]
     210        assert RE_ALPHABET.match(st.ticket), (
     211            'Ticket contains forbidden chars: %s' % st)
     212
     213    def test_create_login_ticket(self):
     214        # we can create login tickets
     215        lt = create_login_ticket()
     216        # 3.5.1: ticket should start with 'LT-'
     217        assert lt.ticket.startswith('LT-')
     218        # 3.7: allowed character set == [a-zA-Z0-9\-]
     219        assert RE_ALPHABET.match(lt.ticket), (
     220            'Ticket contains forbidden chars: %s' % lt)
     221
     222    def test_create_login_ticket_unique(self):
     223        # 3.5.1: login tickets are unique (although not hard to guess)
     224        ticket_num = 1000  # increase to test more thoroughly
     225        lt_list = [create_login_ticket() for x in range(ticket_num)]
     226        assert len(set(lt_list)) == ticket_num
     227
     228    def test_create_tgc_value(self):
     229        # we can create ticket granting cookies
     230        tgc = create_tgc_value()
     231        assert isinstance(tgc, TicketGrantingCookie)
     232        # 3.6.1: cookie value should start with 'TGC-'
     233        assert tgc.value.startswith('TGC-')
     234        # 3.7: allowed character set == [a-zA-Z0-9\-]
     235        assert RE_ALPHABET.match(tgc.value), (
     236            'Cookie value contains forbidden chars: %s' % tgc)
     237
     238    def test_check_login_ticket(self):
     239        db = DB('sqlite:///')
     240        lt = LoginTicket('LT-123456')
     241        db.add(lt)
     242        assert check_login_ticket(db, None) is False
     243        assert check_login_ticket(db, 'LT-123456') is True
     244        # the ticket will be removed after check
     245        assert check_login_ticket(db, 'LT-123456') is False
     246        assert check_login_ticket(db, 'LT-654321') is False
     247
     248    def test_set_session_cookie(self):
     249        # make sure we can add session cookies to responses
     250        db = DB('sqlite:///')
     251        resp = set_session_cookie(Response(), db)
     252        assert 'Set-Cookie' in resp.headers
     253        cookie = resp.headers['Set-Cookie']
     254        assert RE_COOKIE.match(cookie), (
     255            'Cookie in unexpected format: %s' % cookie)
     256        # the cookie is stored in database
     257        value = cookie.split('=')[1].split(';')[0]
     258        q = db.query(TicketGrantingCookie).filter(
     259            TicketGrantingCookie.value == value)
     260        assert len(list(q)) == 1
Note: See TracChangeset for help on using the changeset viewer.