1 | # tests for db.py module |
---|
2 | import os |
---|
3 | import shutil |
---|
4 | import sqlite3 |
---|
5 | import tempfile |
---|
6 | import unittest |
---|
7 | from sqlalchemy.engine import Engine |
---|
8 | from waeup.cas.db import ( |
---|
9 | DB, LoginTicket, ServiceTicket, TicketGrantingCookie) |
---|
10 | |
---|
11 | |
---|
12 | class DBTests(unittest.TestCase): |
---|
13 | def setUp(self): |
---|
14 | self.workdir = tempfile.mkdtemp() |
---|
15 | |
---|
16 | def tearDown(self): |
---|
17 | shutil.rmtree(self.workdir) |
---|
18 | |
---|
19 | def test_create_db(self): |
---|
20 | # we can create a database |
---|
21 | conn_string = 'sqlite:///' |
---|
22 | db = DB(conn_string) |
---|
23 | assert hasattr(db, 'engine') |
---|
24 | assert isinstance(db.engine, Engine) |
---|
25 | |
---|
26 | def test_create_db_exists(self): |
---|
27 | # an already existing db will be left untouched |
---|
28 | db_path = os.path.join(self.workdir, 'sample-cas.db') |
---|
29 | conn_string = 'sqlite:///%s' % db_path |
---|
30 | assert os.path.isdir(self.workdir) |
---|
31 | DB(conn_string) # create database |
---|
32 | conn = sqlite3.connect(db_path) |
---|
33 | with conn: |
---|
34 | conn.execute('''CREATE TABLE mytest (name text, forname text)''') |
---|
35 | conn.execute('''INSERT INTO mytest VALUES ("Foo", "Bar")''') |
---|
36 | conn.close() |
---|
37 | |
---|
38 | conn = sqlite3.connect(db_path) |
---|
39 | with conn: |
---|
40 | result = [x for x in conn.execute('''SELECT * FROM mytest''')] |
---|
41 | conn.close() |
---|
42 | assert result == [('Foo', 'Bar')] |
---|
43 | |
---|
44 | def test_create_db_populates_db(self): |
---|
45 | # created DBs will be populated with the required tables |
---|
46 | db_path = os.path.join(self.workdir, 'sample-cas.db') |
---|
47 | conn_string = 'sqlite:///%s' % db_path |
---|
48 | DB(conn_string) # creates database |
---|
49 | conn = sqlite3.connect(db_path) |
---|
50 | with conn: |
---|
51 | result = [x[:2] for x in conn.execute( |
---|
52 | '''SELECT * FROM sqlite_master ORDER BY type''')] |
---|
53 | conn.close() |
---|
54 | assert ('table', 'service_tickets') in result |
---|
55 | assert ('table', 'login_tickets') in result |
---|
56 | |
---|
57 | |
---|
58 | class TicketTests(unittest.TestCase): |
---|
59 | |
---|
60 | def setUp(self): |
---|
61 | self.db = DB('sqlite:///') |
---|
62 | self.db.session() |
---|
63 | |
---|
64 | def tearDown(self): |
---|
65 | self.db.session.remove() |
---|
66 | |
---|
67 | def test_login_ticket_add(self): |
---|
68 | # we can add login tickets |
---|
69 | assert self.db.engine.has_table('login_tickets') |
---|
70 | |
---|
71 | self.db.add(LoginTicket('foo')) |
---|
72 | assert [x.ticket for x in self.db.query(LoginTicket)] == ['foo'] |
---|
73 | |
---|
74 | def test_login_ticket_delete(self): |
---|
75 | # we can delete single login tickets |
---|
76 | assert self.db.engine.has_table('login_tickets') |
---|
77 | |
---|
78 | self.db.add(LoginTicket('foo')) |
---|
79 | contents = [x for x in self.db.query(LoginTicket)] |
---|
80 | assert len(contents) == 1 |
---|
81 | lt = contents[0] |
---|
82 | self.db.delete(lt) |
---|
83 | assert [x.ticket for x in self.db.query(LoginTicket)] == [] |
---|
84 | |
---|
85 | def test_login_ticket_repr(self): |
---|
86 | # we can get a proper LoginTicket representation |
---|
87 | # (i.e. one that can be fed to `eval`) |
---|
88 | ticket = LoginTicket('foo', 12.1) |
---|
89 | assert ticket.__repr__() == "LoginTicket('foo', 12.1)" |
---|
90 | |
---|
91 | def test_login_ticket_timestamp(self): |
---|
92 | # we get a timestamp stored, if none is passed to init |
---|
93 | lticket = LoginTicket('foo') |
---|
94 | assert isinstance(lticket.ts, float) |
---|
95 | |
---|
96 | def test_add_service_ticket(self): |
---|
97 | # we can add service tickets |
---|
98 | self.db.add(ServiceTicket('foo', 'bar', 'baz', False, 12.1)) |
---|
99 | result = [(x.ticket, x.user, x.service, x.sso, x.ts) |
---|
100 | for x in self.db.query(ServiceTicket)] |
---|
101 | assert result == [('foo', 'bar', 'baz', False, 12.1)] |
---|
102 | |
---|
103 | def test_service_ticket_repr(self): |
---|
104 | # we can get a proper ServiceTicket representation |
---|
105 | # (i.e. one that can be fed to `eval`) |
---|
106 | sticket = ServiceTicket('foo', 'bar', 'baz', True, 12.1) |
---|
107 | st_repr = sticket.__repr__() |
---|
108 | assert st_repr == "ServiceTicket('foo', 'bar', 'baz', True, 12.1)" |
---|
109 | |
---|
110 | def test_service_ticket_timestamp(self): |
---|
111 | # we get a timestamp stored, if none is passed to init |
---|
112 | sticket = ServiceTicket('foo', 'bar', 'baz') |
---|
113 | assert isinstance(sticket.ts, float) |
---|
114 | |
---|
115 | |
---|
116 | def test_service_ticket_sso(self): |
---|
117 | # sso is set to True by default |
---|
118 | sticket = ServiceTicket('foo', 'bar', 'baz') |
---|
119 | assert sticket.sso is True |
---|
120 | |
---|
121 | def test_ticket_granting_cookie_add(self): |
---|
122 | # we can add ticket granting cookies |
---|
123 | assert self.db.engine.has_table('ticket_granting_cookies') |
---|
124 | |
---|
125 | self.db.add(TicketGrantingCookie('foo')) |
---|
126 | assert [x.value for x in self.db.query( |
---|
127 | TicketGrantingCookie)] == ['foo'] |
---|
128 | |
---|
129 | def test_ticket_granting_cookie_repr(self): |
---|
130 | # we can get a proper ticket-granting cookie representation |
---|
131 | # (i.e. one that can be fed to `eval`) |
---|
132 | tgc = TicketGrantingCookie('foo', 12.1) |
---|
133 | assert tgc.__repr__() == "TicketGrantingCookie('foo', 12.1)" |
---|
134 | |
---|
135 | def test_ticket_granting_cookie_timestamp(self): |
---|
136 | # we get a timestamp stored, if none is passed to init |
---|
137 | tgc = TicketGrantingCookie('foo') |
---|
138 | assert isinstance(tgc.ts, float) |
---|