1 | import os |
---|
2 | import re |
---|
3 | import shutil |
---|
4 | import tempfile |
---|
5 | import unittest |
---|
6 | from paste.deploy import loadapp |
---|
7 | from webob import Request, Response |
---|
8 | from webtest import TestApp as WebTestApp # avoid py.test skip message |
---|
9 | from waeup.cas.authenticators import DummyAuthenticator |
---|
10 | from waeup.cas.db import DB, LoginTicket, ServiceTicket, TicketGrantingCookie |
---|
11 | from waeup.cas.server import ( |
---|
12 | CASServer, create_service_ticket, create_login_ticket, |
---|
13 | create_tgc_value, check_login_ticket, set_session_cookie, |
---|
14 | ) |
---|
15 | |
---|
16 | |
---|
17 | RE_ALPHABET = re.compile('^[a-zA-Z0-9\-]*$') |
---|
18 | RE_COOKIE = re.compile('^cas-tgc=[A-Za-z0-9\-]+; Path=/; secure; HttpOnly$') |
---|
19 | |
---|
20 | |
---|
21 | class CASServerTests(unittest.TestCase): |
---|
22 | |
---|
23 | def setUp(self): |
---|
24 | # Create a new location where tempfiles are created. This way |
---|
25 | # also temporary dirs of local CASServers can be removed on |
---|
26 | # tear-down. |
---|
27 | self._new_tmpdir = tempfile.mkdtemp() |
---|
28 | self._old_tmpdir = tempfile.tempdir |
---|
29 | tempfile.tempdir = self._new_tmpdir |
---|
30 | self.workdir = os.path.join(self._new_tmpdir, 'home') |
---|
31 | self.db_path = os.path.join(self.workdir, 'mycas.db') |
---|
32 | os.mkdir(self.workdir) |
---|
33 | self.paste_conf1 = os.path.join( |
---|
34 | os.path.dirname(__file__), 'sample1.ini') |
---|
35 | self.paste_conf2 = os.path.join( |
---|
36 | os.path.dirname(__file__), 'sample2.ini') |
---|
37 | |
---|
38 | def tearDown(self): |
---|
39 | # remove local tempfile and reset old tempdir setting |
---|
40 | if os.path.isdir(self._new_tmpdir): |
---|
41 | shutil.rmtree(self._new_tmpdir) |
---|
42 | tempfile.tempdir = self._old_tmpdir |
---|
43 | |
---|
44 | def test_paste_deploy_loader(self): |
---|
45 | # we can load the CAS server via paste.deploy plugin |
---|
46 | app = loadapp('config:%s' % self.paste_conf1) |
---|
47 | assert isinstance(app, CASServer) |
---|
48 | assert hasattr(app, 'db') |
---|
49 | assert isinstance(app.db, DB) |
---|
50 | assert hasattr(app, 'auth') |
---|
51 | |
---|
52 | def test_paste_deploy_options(self): |
---|
53 | # we can set CAS server-related options via paste.deploy config |
---|
54 | app = loadapp('config:%s' % self.paste_conf2) |
---|
55 | assert isinstance(app, CASServer) |
---|
56 | assert app.db_connection_string == 'sqlite:///:memory:' |
---|
57 | assert isinstance(app.auth, DummyAuthenticator) |
---|
58 | |
---|
59 | def test_init(self): |
---|
60 | # we get a `DB` instance created automatically |
---|
61 | app = CASServer() |
---|
62 | assert hasattr(app, 'db') |
---|
63 | assert app.db is not None |
---|
64 | |
---|
65 | def test_init_explicit_db_path(self): |
---|
66 | # we can set a db_path explicitly |
---|
67 | app = CASServer(db='sqlite:///%s' % self.db_path) |
---|
68 | assert hasattr(app, 'db') |
---|
69 | assert isinstance(app.db, DB) |
---|
70 | assert os.path.isfile(self.db_path) |
---|
71 | |
---|
72 | def test_get_template(self): |
---|
73 | app = CASServer() |
---|
74 | assert app._get_template('login.html') is not None |
---|
75 | assert app._get_template('not-existent.html') is None |
---|
76 | |
---|
77 | def test_call_root(self): |
---|
78 | # the CAS protocol requires no root |
---|
79 | app = CASServer() |
---|
80 | req = Request.blank('http://localhost/') |
---|
81 | resp = app(req) |
---|
82 | assert resp.status == '404 Not Found' |
---|
83 | |
---|
84 | def test_first_time_login(self): |
---|
85 | # we can get a login page |
---|
86 | app = CASServer() |
---|
87 | req = Request.blank('http://localhost/login') |
---|
88 | resp = app(req) |
---|
89 | assert resp.status == '200 OK' |
---|
90 | |
---|
91 | def test_validate(self): |
---|
92 | # we can access a validation page |
---|
93 | app = CASServer() |
---|
94 | req = Request.blank('http://localhost/validate') |
---|
95 | resp = app(req) |
---|
96 | assert resp.status == '501 Not Implemented' |
---|
97 | |
---|
98 | def test_logout(self): |
---|
99 | # we can access a logout page |
---|
100 | app = CASServer() |
---|
101 | req = Request.blank('http://localhost/logout') |
---|
102 | resp = app(req) |
---|
103 | assert resp.status == '501 Not Implemented' |
---|
104 | |
---|
105 | def test_login_simple(self): |
---|
106 | # a simple login with no service will result in login screen |
---|
107 | # (2.1.1#service of protocol specs) |
---|
108 | app = CASServer() |
---|
109 | req = Request.blank('http://localhost/login') |
---|
110 | resp = app(req) |
---|
111 | assert resp.status == '200 OK' |
---|
112 | assert resp.content_type == 'text/html' |
---|
113 | assert b'<form ' in resp.body |
---|
114 | |
---|
115 | |
---|
116 | class BrowserTests(unittest.TestCase): |
---|
117 | |
---|
118 | def setUp(self): |
---|
119 | self.raw_app = CASServer(auth=DummyAuthenticator()) |
---|
120 | self.app = WebTestApp(self.raw_app) |
---|
121 | |
---|
122 | def test_login(self): |
---|
123 | resp = self.app.get('/login') |
---|
124 | assert resp.status == '200 OK' |
---|
125 | form = resp.forms[0] |
---|
126 | # 2.1.3: form must be submitted by POST |
---|
127 | assert form.method == 'post' |
---|
128 | fieldnames = form.fields.keys() |
---|
129 | # 2.1.3: form must contain: username, password, lt |
---|
130 | assert 'username' in fieldnames |
---|
131 | assert 'password' in fieldnames |
---|
132 | assert 'lt' in fieldnames |
---|
133 | assert RE_ALPHABET.match(form['lt'].value) |
---|
134 | |
---|
135 | def test_login_no_service(self): |
---|
136 | # w/o a service passed in, the form should not contain service |
---|
137 | # (not a strict protocol requirement, but handy) |
---|
138 | resp = self.app.get('/login') |
---|
139 | assert 'service' not in resp.forms[0].fields.keys() |
---|
140 | |
---|
141 | def test_login_service_replayed(self): |
---|
142 | # 2.1.3: the login form must contain the service param sent |
---|
143 | resp = self.app.get('/login?service=http%3A%2F%2Fwww.service.com') |
---|
144 | form = resp.forms[0] |
---|
145 | assert resp.status == '200 OK' |
---|
146 | assert 'service' in form.fields.keys() |
---|
147 | assert form['service'].value == 'http://www.service.com' |
---|
148 | |
---|
149 | def test_login_cred_acceptor_valid_no_service(self): |
---|
150 | # 2.2.4: successful login w/o service yields a message |
---|
151 | lt = create_login_ticket() |
---|
152 | self.raw_app.db.add(lt) |
---|
153 | lt_string = lt.ticket |
---|
154 | resp = self.app.post('/login', dict( |
---|
155 | username='bird', password='bebop', lt=lt_string)) |
---|
156 | assert resp.status == '200 OK' |
---|
157 | assert b'successful' in resp.body |
---|
158 | # single-sign-on session initiated |
---|
159 | assert 'Set-Cookie' in resp.headers |
---|
160 | cookie = resp.headers['Set-Cookie'] |
---|
161 | assert cookie.startswith('cas-tgc=') |
---|
162 | |
---|
163 | def test_login_cred_acceptor_valid_w_service(self): |
---|
164 | # 2.2.4: successful login with service makes a redirect |
---|
165 | # Appendix B: safe redirect |
---|
166 | lt = create_login_ticket() |
---|
167 | self.raw_app.db.add(lt) |
---|
168 | lt_string = lt.ticket |
---|
169 | resp = self.app.post('/login', dict( |
---|
170 | username='bird', password='bebop', lt=lt_string, |
---|
171 | service='http://example.com/Login')) |
---|
172 | assert resp.status == '303 See Other' |
---|
173 | assert 'Location' in resp.headers |
---|
174 | assert resp.headers['Location'].startswith( |
---|
175 | 'http://example.com/Login?ticket=ST-') |
---|
176 | assert 'Pragma' in resp.headers |
---|
177 | assert resp.headers['Pragma'] == 'no-cache' |
---|
178 | assert 'Cache-Control' in resp.headers |
---|
179 | assert resp.headers['Cache-Control'] == 'no-store' |
---|
180 | assert 'Expires' in resp.headers |
---|
181 | assert resp.headers['Expires'] == 'Thu, 01 Dec 1994 16:00:00 GMT' |
---|
182 | assert b'window.location.href' in resp.body |
---|
183 | assert b'noscript' in resp.body |
---|
184 | assert b'ticket=ST-' in resp.body |
---|
185 | |
---|
186 | |
---|
187 | class CASServerHelperTests(unittest.TestCase): |
---|
188 | |
---|
189 | def setUp(self): |
---|
190 | self.workdir = tempfile.mkdtemp() |
---|
191 | self.db_file = os.path.join(self.workdir, 'mycas.db') |
---|
192 | self.conn_string = 'sqlite:///%s' % self.db_file |
---|
193 | self.db = DB(self.conn_string) |
---|
194 | |
---|
195 | def tearDown(self): |
---|
196 | shutil.rmtree(self.workdir) |
---|
197 | |
---|
198 | def test_create_service_ticket(self): |
---|
199 | # we can create service tickets |
---|
200 | st = create_service_ticket( |
---|
201 | user='bob', service='http://www.example.com') |
---|
202 | assert isinstance(st, ServiceTicket) |
---|
203 | # 3.1.1: service not part of ticket |
---|
204 | assert 'example.com' not in st.ticket |
---|
205 | # 3.1.1: ticket must start with 'ST-' |
---|
206 | assert st.ticket.startswith('ST-') |
---|
207 | # 3.1.1: min. ticket length clients must be able to process is 32 |
---|
208 | assert len(st.ticket) < 33 |
---|
209 | # 3.7: allowed character set == [a-zA-Z0-9\-] |
---|
210 | assert RE_ALPHABET.match(st.ticket), ( |
---|
211 | 'Ticket contains forbidden chars: %s' % st) |
---|
212 | |
---|
213 | def test_create_login_ticket(self): |
---|
214 | # we can create login tickets |
---|
215 | lt = create_login_ticket() |
---|
216 | # 3.5.1: ticket should start with 'LT-' |
---|
217 | assert lt.ticket.startswith('LT-') |
---|
218 | # 3.7: allowed character set == [a-zA-Z0-9\-] |
---|
219 | assert RE_ALPHABET.match(lt.ticket), ( |
---|
220 | 'Ticket contains forbidden chars: %s' % lt) |
---|
221 | |
---|
222 | def test_create_login_ticket_unique(self): |
---|
223 | # 3.5.1: login tickets are unique (although not hard to guess) |
---|
224 | ticket_num = 1000 # increase to test more thoroughly |
---|
225 | lt_list = [create_login_ticket() for x in range(ticket_num)] |
---|
226 | assert len(set(lt_list)) == ticket_num |
---|
227 | |
---|
228 | def test_create_tgc_value(self): |
---|
229 | # we can create ticket granting cookies |
---|
230 | tgc = create_tgc_value() |
---|
231 | assert isinstance(tgc, TicketGrantingCookie) |
---|
232 | # 3.6.1: cookie value should start with 'TGC-' |
---|
233 | assert tgc.value.startswith('TGC-') |
---|
234 | # 3.7: allowed character set == [a-zA-Z0-9\-] |
---|
235 | assert RE_ALPHABET.match(tgc.value), ( |
---|
236 | 'Cookie value contains forbidden chars: %s' % tgc) |
---|
237 | |
---|
238 | def test_check_login_ticket(self): |
---|
239 | db = DB('sqlite:///') |
---|
240 | lt = LoginTicket('LT-123456') |
---|
241 | db.add(lt) |
---|
242 | assert check_login_ticket(db, None) is False |
---|
243 | assert check_login_ticket(db, 'LT-123456') is True |
---|
244 | # the ticket will be removed after check |
---|
245 | assert check_login_ticket(db, 'LT-123456') is False |
---|
246 | assert check_login_ticket(db, 'LT-654321') is False |
---|
247 | |
---|
248 | def test_set_session_cookie(self): |
---|
249 | # make sure we can add session cookies to responses |
---|
250 | db = DB('sqlite:///') |
---|
251 | resp = set_session_cookie(Response(), db) |
---|
252 | assert 'Set-Cookie' in resp.headers |
---|
253 | cookie = resp.headers['Set-Cookie'] |
---|
254 | assert RE_COOKIE.match(cookie), ( |
---|
255 | 'Cookie in unexpected format: %s' % cookie) |
---|
256 | # the cookie is stored in database |
---|
257 | value = cookie.split('=')[1].split(';')[0] |
---|
258 | q = db.query(TicketGrantingCookie).filter( |
---|
259 | TicketGrantingCookie.value == value) |
---|
260 | assert len(list(q)) == 1 |
---|