import os
import re
import shutil
import tempfile
import unittest
from paste.deploy import loadapp
try:
    from urlparse import parse_qsl         # Python 2.x
except ImportError:                        # pragma: no cover
    from urllib.parse import parse_qsl     # Python 3.x
from webob import Request, Response
from webtest import TestApp as WebTestApp  # avoid py.test skip message
from waeup.cas.authenticators import DummyAuthenticator
from waeup.cas.db import DB, LoginTicket, ServiceTicket, TicketGrantingCookie
from waeup.cas.server import (
    CASServer, create_service_ticket, create_login_ticket,
    create_tgc_value, check_login_ticket, set_session_cookie,
    check_session_cookie, get_template, delete_session_cookie,
    check_service_ticket, update_url, set_message
    )

RE_ALPHABET = re.compile('^[a-zA-Z0-9\-]*$')
RE_COOKIE = re.compile('^cas-tgc=[A-Za-z0-9\-]+; Path=/; secure; HttpOnly$')
RE_COOKIE_DEL = re.compile(
    '^cas-tgc=; Max-Age=\-[0-9]+; Path=/; '
    'expires=Thu, 01-Jan-1970 00:00:00 GMT; secure; HttpOnly$')


class CASServerTests(unittest.TestCase):

    def setUp(self):
        # Create a new location where tempfiles are created.  This way
        # also temporary dirs of local CASServers can be removed on
        # tear-down.
        self._new_tmpdir = tempfile.mkdtemp()
        self._old_tmpdir = tempfile.tempdir
        tempfile.tempdir = self._new_tmpdir
        self.workdir = os.path.join(self._new_tmpdir, 'home')
        self.db_path = os.path.join(self.workdir, 'mycas.db')
        os.mkdir(self.workdir)
        self.paste_conf1 = os.path.join(
            os.path.dirname(__file__), 'sample1.ini')
        self.paste_conf2 = os.path.join(
            os.path.dirname(__file__), 'sample2.ini')

    def tearDown(self):
        # remove local tempfile and reset old tempdir setting
        if os.path.isdir(self._new_tmpdir):
            shutil.rmtree(self._new_tmpdir)
        tempfile.tempdir = self._old_tmpdir

    def test_paste_deploy_loader(self):
        # we can load the CAS server via paste.deploy plugin
        app = loadapp('config:%s' % self.paste_conf1)
        assert isinstance(app, CASServer)
        assert hasattr(app, 'db')
        assert isinstance(app.db, DB)
        assert hasattr(app, 'auth')

    def test_paste_deploy_options(self):
        # we can set CAS server-related options via paste.deploy config
        app = loadapp('config:%s' % self.paste_conf2)
        assert isinstance(app, CASServer)
        assert app.db_connection_string == 'sqlite:///:memory:'
        assert isinstance(app.auth, DummyAuthenticator)

    def test_init(self):
        # we get a `DB` instance created automatically
        app = CASServer()
        assert hasattr(app, 'db')
        assert app.db is not None

    def test_init_explicit_db_path(self):
        # we can set a db_path explicitly
        app = CASServer(db='sqlite:///%s' % self.db_path)
        assert hasattr(app, 'db')
        assert isinstance(app.db, DB)
        assert os.path.isfile(self.db_path)

    def test_get_template(self):
        app = CASServer()
        assert app._get_template('login.html') is not None
        assert app._get_template('not-existent.html') is None
        # parts of header and footer are replaced
        assert 'PART_HEADER' not in app._get_template('login.html')
        assert 'PART_FOOTER' not in app._get_template('login.html')

    def test_call_root(self):
        # the CAS protocol requires no root
        app = CASServer()
        req = Request.blank('http://localhost/')
        resp = app(req)
        assert resp.status == '404 Not Found'

    def test_first_time_login(self):
        # we can get a login page
        app = CASServer()
        req = Request.blank('http://localhost/login')
        resp = app(req)
        assert resp.status == '200 OK'

    def test_validate(self):
        # we can access a validation page
        app = CASServer()
        req = Request.blank('http://localhost/validate?service=foo&ticket=bar')
        resp = app(req)
        assert resp.status == '200 OK'

    def test_logout(self):
        # we can access a logout page
        app = CASServer()
        req = Request.blank('http://localhost/logout')
        resp = app(req)
        assert resp.status == '200 OK'

    def test_login_simple(self):
        # a simple login with no service will result in login screen
        # (2.1.1#service of protocol specs)
        app = CASServer()
        req = Request.blank('http://localhost/login')
        resp = app(req)
        assert resp.status == '200 OK'
        assert resp.content_type == 'text/html'
        assert b'<form ' in resp.body

    def test_login_cred_acceptor_sso_no_service(self):
        # 2.2.4: successful login via single sign on
        app = CASServer()
        tgc = create_tgc_value()
        app.db.add(tgc)
        value = str(tgc.value)
        req = Request.blank('https://localhost/login')
        req.headers['Cookie'] = 'cas-tgc=%s' % value
        resp = app(req)
        assert resp.status == '200 OK'
        assert b'already' in resp.body
        assert 'Set-Cookie' not in resp.headers
        return

    def test_login_renew_with_cookie(self):
        # 2.1.1: cookie will be ignored when renew is set
        app = CASServer()
        tgc = create_tgc_value()
        app.db.add(tgc)
        value = str(tgc.value)
        req = Request.blank('https://localhost/login?renew=true')
        req.headers['Cookie'] = 'cas-tgc=%s' % value
        resp = app(req)
        assert resp.status == '200 OK'
        assert b'username' in resp.body
        assert 'Set-Cookie' not in resp.headers

    def test_login_renew_without_cookie(self):
        # 2.1.1: with renew and no cookie, normal auth will happen
        app = CASServer()
        req = Request.blank('https://localhost/login?renew=true')
        resp = app(req)
        assert resp.status == '200 OK'
        assert b'username' in resp.body

    def test_login_renew_as_empty_string(self):
        # `renew` is handled correctly, even with empty value
        app = CASServer()
        tgc = create_tgc_value()
        app.db.add(tgc)
        value = str(tgc.value)
        req = Request.blank('https://localhost/login?renew')
        req.headers['Cookie'] = 'cas-tgc=%s' % value
        resp = app(req)
        assert resp.status == '200 OK'
        assert b'username' in resp.body
        assert 'Set-Cookie' not in resp.headers

    def test_login_gateway_no_cookie_with_service(self):
        # 2.1.1: with gateway but w/o cookie we will be redirected to service
        # no service ticket will be issued
        app = CASServer()
        params = 'gateway=true&service=http%3A%2F%2Fwww.service.com'
        req = Request.blank('https://localhost/login?%s' % params)
        resp = app(req)
        assert resp.status == '303 See Other'
        assert 'Location' in resp.headers
        assert resp.headers['Location'] == 'http://www.service.com'

    def test_login_gateway_with_cookie_and_service(self):
        # 2.1.1: with cookie and gateway we will be redirected to service
        app = CASServer()
        tgc = create_tgc_value()
        app.db.add(tgc)
        value = str(tgc.value)
        params = 'gateway=true&service=http%3A%2F%2Fwww.service.com'
        req = Request.blank('https://localhost/login?%s' % params)
        req.headers['Cookie'] = 'cas-tgc=%s' % value
        resp = app(req)
        assert resp.status == '303 See Other'
        assert 'Location' in resp.headers
        assert resp.headers['Location'].startswith(
            'http://www.service.com?ticket=ST-')

    def test_login_gateway_and_renew(self):
        # 2.1.1 if both, gateway and renew are specified, only renew is valid
        app = CASServer()
        tgc = create_tgc_value()
        app.db.add(tgc)
        value = str(tgc.value)
        req = Request.blank('https://localhost/login?renew=true&gateway=true')
        req.headers['Cookie'] = 'cas-tgc=%s' % value
        resp = app(req)
        # with only gateway true, this would lead to a redirect
        assert resp.status == '200 OK'
        assert b'username' in resp.body
        assert 'Set-Cookie' not in resp.headers

    def test_login_warn(self):
        # 2.2.1 as a credential acceptor, with `warn` set we require confirm
        app = CASServer()
        tgc = create_tgc_value()
        app.db.add(tgc)
        value = str(tgc.value)
        params = 'warn=true&service=http%3A%2F%2Fwww.service.com'
        req = Request.blank('https://localhost/login?%s' % params)
        req.headers['Cookie'] = 'cas-tgc=%s' % value
        resp = app(req)
        # without warn, we would get a redirect
        assert resp.status == '200 OK'
        assert b'Logged in' in resp.body

    def test_logout_no_cookie(self):
        # 2.3 logout displays a logout page.
        app = CASServer()
        req = Request.blank('https://localhost/logout')
        resp = app(req)
        assert resp.status == '200 OK'
        assert b'logged out' in resp.body

    def test_logout_with_cookie(self):
        # 2.3 logout destroys any existing SSO session
        app = CASServer()
        tgc = create_tgc_value()
        app.db.add(tgc)
        value = str(tgc.value)
        req = Request.blank('https://localhost/logout')
        req.headers['Cookie'] = 'cas-tgc=%s' % value
        resp = app(req)
        assert resp.status == '200 OK'
        assert b'logged out' in resp.body
        assert 'Set-Cookie' in resp.headers
        cookie = resp.headers['Set-Cookie']
        assert cookie.startswith('cas-tgc=;')
        assert 'expires' in cookie
        assert 'Max-Age' in cookie
        assert len(list(app.db.query(TicketGrantingCookie))) == 0

    def test_logout_url(self):
        # 2.3.1 with an `url` given we provide a link on logout
        app = CASServer()
        params = 'url=http%3A%2F%2Fwww.logout.com'
        req = Request.blank('https://localhost/logout?%s' % params)
        resp = app(req)
        assert resp.status == '200 OK'
        assert b'logged out' in resp.body
        assert b'like you to' in resp.body
        assert b'http://www.logout.com' in resp.body

    def test_validate_invalid(self):
        # 2.4.2 validation failures is indicated by a given format
        app = CASServer()
        params = 'ticket=foo&service=bar'
        req = Request.blank('https://localhost/validate?%s' % params)
        resp = app(req)
        assert resp.body == b'no\n\n'

    def test_validate_valid(self):
        # 2.4 validation success is indicated by a given format
        app = CASServer()
        sticket = create_service_ticket(
            'someuser', 'http://service.com/', sso=False)
        app.db.add(sticket)
        params = 'ticket=%s&service=%s' % (
            sticket.ticket, sticket.service)
        req = Request.blank('https://localhost/validate?%s' % params)
        resp = app(req)
        assert resp.body == b'yes\nsomeuser\n'

    def test_validate_renew_invalid(self):
        # 2.4.1 with `renew` we accept only non-sso issued tickets
        app = CASServer()
        sticket = create_service_ticket(
            'someuser', 'http://service.com/', sso=True)
        app.db.add(sticket)
        params = 'ticket=%s&service=%s&renew=true' % (
            sticket.ticket, sticket.service)
        req = Request.blank('https://localhost/validate?%s' % params)
        resp = app(req)
        assert resp.body == b'no\n\n'

    def test_validate_renew_valid(self):
        # 2.4.1 with `renew` we accept only non-sso issued tickets
        app = CASServer()
        sticket = create_service_ticket(
            'someuser', 'http://service.com/', sso=False)
        app.db.add(sticket)
        params = 'ticket=%s&service=%s&renew=true' % (
            sticket.ticket, sticket.service)
        req = Request.blank('https://localhost/validate?%s' % params)
        resp = app(req)
        assert resp.body == b'yes\nsomeuser\n'

    def test_style_css(self):
        # we can get a custom style sheet
        app = CASServer()
        req = Request.blank('https://localhost/style.css')
        resp = app(req)
        assert resp.status == '200 OK'
        assert resp.content_type == 'text/css'

    def test_set_message(self):
        # we can set a message
        result = set_message('Hi there', '<body>MSG_TEXT</body>')
        assert result == '<body>Hi there</body>'

    def test_set_message_empty_or_none(self):
        # without a message, no message text will be displayed
        result = set_message('', '<body>MSG_TEXT</body>')
        assert result == '<body></body>'
        result = set_message('', '<body><div id="msg">MSG_TEXT</div></body>')
        assert result == '<body></body>'
        result = set_message(None, '<body><div id="msg">MSG_TEXT</div></body>')
        assert result == '<body></body>'
        result = set_message(
            None,
            '<body><div a="1">\n<div id="msg">MSG_TEXT</div>a\n</div></body>')
        assert result == '<body><div a="1">\na\n</div></body>'
        result = set_message(
            None,
            '<body><div a="1"><div id="msg">MSG_TEXT</div>a</div></body>')
        assert result == '<body><div a="1">a</div></body>'


class BrowserTests(unittest.TestCase):

    def setUp(self):
        self.raw_app = CASServer(auth=DummyAuthenticator())
        self.app = WebTestApp(self.raw_app)

    def test_login(self):
        resp = self.app.get('/login')
        assert resp.status == '200 OK'
        form = resp.forms[0]
        # 2.1.3: form must be submitted by POST
        assert form.method == 'post'
        fieldnames = form.fields.keys()
        # 2.1.3: form must contain: username, password, lt
        assert 'username' in fieldnames
        assert 'password' in fieldnames
        assert 'lt' in fieldnames
        assert RE_ALPHABET.match(form['lt'].value)

    def test_login_no_service(self):
        # w/o a service passed in, the form should not contain service
        # (not a strict protocol requirement, but handy)
        resp = self.app.get('/login')
        assert 'service' not in resp.forms[0].fields.keys()

    def test_login_service_replayed(self):
        # 2.1.3: the login form must contain the service param sent
        resp = self.app.get('/login?service=http%3A%2F%2Fwww.service.com')
        form = resp.forms[0]
        assert resp.status == '200 OK'
        assert 'service' in form.fields.keys()
        assert form['service'].value == 'http://www.service.com'

    def test_login_cred_acceptor_valid_no_service(self):
        # 2.2.4: successful login w/o service yields a message
        lt = create_login_ticket()
        self.raw_app.db.add(lt)
        lt_string = lt.ticket
        resp = self.app.post('/login', dict(
            username='bird', password='bebop', lt=lt_string))
        assert resp.status == '200 OK'
        assert b'successful' in resp.body
        # single-sign-on session initiated
        assert 'Set-Cookie' in resp.headers
        cookie = resp.headers['Set-Cookie']
        assert cookie.startswith('cas-tgc=')

    def test_login_cred_acceptor_valid_w_service(self):
        # 2.2.4: successful login with service makes a redirect
        # Appendix B: safe redirect
        lt = create_login_ticket()
        self.raw_app.db.add(lt)
        lt_string = lt.ticket
        resp = self.app.post('/login', dict(
            username='bird', password='bebop', lt=lt_string,
            service='http://example.com/Login'))
        assert resp.status == '303 See Other'
        assert 'Location' in resp.headers
        assert resp.headers['Location'].startswith(
            'http://example.com/Login?ticket=ST-')
        assert 'Pragma' in resp.headers
        assert resp.headers['Pragma'] == 'no-cache'
        assert 'Cache-Control' in resp.headers
        assert resp.headers['Cache-Control'] == 'no-store'
        assert 'Expires' in resp.headers
        assert resp.headers['Expires'] == 'Thu, 01 Dec 1994 16:00:00 GMT'
        assert b'window.location.href' in resp.body
        assert b'noscript' in resp.body
        assert b'ticket=ST-' in resp.body
        q = self.raw_app.db.query(ServiceTicket)
        st = q.all()[0]
        assert st.user == 'bird'
        assert st.service == 'http://example.com/Login'
        assert st.ticket.startswith('ST-')

    def test_login_cred_acceptor_failed(self):
        # 2.2.4: failed login yields a message
        lt = create_login_ticket()
        self.raw_app.db.add(lt)
        lt_string = lt.ticket
        resp = self.app.post('/login', dict(
            username='bird', password='cat', lt=lt_string))
        assert resp.status == '200 OK'
        assert b'failed' in resp.body

    def test_login_sso_no_service(self):
        # we can initiate single-sign-on without service
        resp1 = self.app.get('https://localhost/login')  # HTTPS required!
        assert resp1.status == '200 OK'
        assert 'cas-tgc' not in self.app.cookies
        form = resp1.forms[0]
        form.set('username', 'bird')
        form.set('password', 'bebop')
        resp2 = form.submit('AUTHENTICATE')
        assert resp2.status == '200 OK'
        # we got a secure cookie
        assert 'cas-tgc' in self.app.cookies
        # when we get the login page again, the cookie will replace creds.
        resp3 = self.app.get('https://localhost/login')
        assert b'You logged in already' in resp3.body

    def test_login_sso_with_service(self):
        resp1 = self.app.get(
            'https://localhost/login?service=http%3A%2F%2Fservice.com%2F')
        assert resp1.status == '200 OK'
        assert 'cas-tgc' not in self.app.cookies
        form = resp1.forms[0]
        form.set('username', 'bird')
        form.set('password', 'bebop')
        resp2 = form.submit('AUTHENTICATE')
        assert resp2.status == '303 See Other'
        assert resp2.headers['Location'].startswith(
            'http://service.com/?ticket=ST-')
        # we got a secure cookie
        assert 'cas-tgc' in self.app.cookies
        resp3 = self.app.get(
            'https://localhost/login?service=http%3A%2F%2Fservice.com%2F')
        assert resp3.status == '303 See Other'
        assert resp3.headers['Location'].startswith(
            'http://service.com/?ticket=ST-')

    def test_login_sso_with_service_additional_params1(self):
        # we can get a service ticket also with a service providing
        # get params
        # this service url reads http://service.com/index.php?authCAS=CAS
        service_url = 'http%3A%2F%2Fservice.com%2Findex.php%3FauthCAS%3DCAS'
        resp1 = self.app.get(
            'https://localhost/login?service=%s' % service_url)
        assert resp1.status == '200 OK'
        assert 'cas-tgc' not in self.app.cookies
        form = resp1.forms[0]
        form.set('username', 'bird')
        form.set('password', 'bebop')
        resp2 = form.submit('AUTHENTICATE')
        assert resp2.status == '303 See Other'
        location = resp2.headers['Location']
        query_string = location.split('?', 1)[1]
        query_params = dict(parse_qsl(query_string))
        assert 'authCAS' in query_params.keys()
        assert 'ticket' in query_params.keys()
        assert len(query_params['ticket']) == 32


class CASServerHelperTests(unittest.TestCase):

    def setUp(self):
        self.workdir = tempfile.mkdtemp()
        self.db_file = os.path.join(self.workdir, 'mycas.db')
        self.conn_string = 'sqlite:///%s' % self.db_file
        self.db = DB(self.conn_string)

    def tearDown(self):
        shutil.rmtree(self.workdir)

    def test_create_service_ticket(self):
        # we can create service tickets
        st = create_service_ticket(
            user='bob', service='http://www.example.com')
        assert isinstance(st, ServiceTicket)
        # 3.1.1: service not part of ticket
        assert 'example.com' not in st.ticket
        # 3.1.1: ticket must start with 'ST-'
        assert st.ticket.startswith('ST-')
        # 3.1.1: min. ticket length clients must be able to process is 32
        assert len(st.ticket) < 33
        # 3.7: allowed character set == [a-zA-Z0-9\-]
        assert RE_ALPHABET.match(st.ticket), (
            'Ticket contains forbidden chars: %s' % st)
        assert st.service == 'http://www.example.com'
        assert st.user == 'bob'

    def test_create_login_ticket(self):
        # we can create login tickets
        lt = create_login_ticket()
        # 3.5.1: ticket should start with 'LT-'
        assert lt.ticket.startswith('LT-')
        # 3.7: allowed character set == [a-zA-Z0-9\-]
        assert RE_ALPHABET.match(lt.ticket), (
            'Ticket contains forbidden chars: %s' % lt)

    def test_create_login_ticket_unique(self):
        # 3.5.1: login tickets are unique (although not hard to guess)
        ticket_num = 1000  # increase to test more thoroughly
        lt_list = [create_login_ticket() for x in range(ticket_num)]
        assert len(set(lt_list)) == ticket_num

    def test_create_tgc_value(self):
        # we can create ticket granting cookies
        tgc = create_tgc_value()
        assert isinstance(tgc, TicketGrantingCookie)
        # 3.6.1: cookie value should start with 'TGC-'
        assert tgc.value.startswith('TGC-')
        # 3.7: allowed character set == [a-zA-Z0-9\-]
        assert RE_ALPHABET.match(tgc.value), (
            'Cookie value contains forbidden chars: %s' % tgc)

    def test_check_login_ticket(self):
        db = DB('sqlite:///')
        lt = LoginTicket('LT-123456')
        db.add(lt)
        assert check_login_ticket(db, None) is False
        assert check_login_ticket(db, 'LT-123456') is True
        # the ticket will be removed after check
        assert check_login_ticket(db, 'LT-123456') is False
        assert check_login_ticket(db, 'LT-654321') is False

    def test_set_session_cookie1(self):
        # make sure we can add session cookies to responses
        db = DB('sqlite:///')
        resp = set_session_cookie(db, Response())
        assert 'Set-Cookie' in resp.headers
        cookie = resp.headers['Set-Cookie']
        assert RE_COOKIE.match(cookie), (
            'Cookie in unexpected format: %s' % cookie)
        # the cookie is stored in database
        value = cookie.split('=')[1].split(';')[0]
        q = db.query(TicketGrantingCookie).filter(
            TicketGrantingCookie.value == value)
        assert len(list(q)) == 1

    def test_check_session_cookie2(self):
        db = DB('sqlite:///')
        tgc = create_tgc_value()
        db.add(tgc)
        value = tgc.value
        assert check_session_cookie(db, value) == tgc
        assert check_session_cookie(db, 'foo') is None
        assert check_session_cookie(db, b'foo') is None
        assert check_session_cookie(db, None) is None
        value2 = value.encode('utf-8')
        assert check_session_cookie(db, value2) == tgc

    def test_get_template(self):
        # we can load templates
        assert get_template('not-existing-template') is None
        assert get_template('login.html') is not None
        # parts of header and footer are replaced
        assert 'PART_HEADER' not in get_template('login.html')
        assert 'PART_FOOTER' not in get_template('login.html')

    def test_delete_session_cookie(self):
        # we can unset cookies
        db = DB('sqlite:///')
        tgc = create_tgc_value()
        db.add(tgc)
        value = tgc.value
        resp = delete_session_cookie(db, Response(), old_value=value)
        assert 'Set-Cookie' in resp.headers
        cookie = resp.headers['Set-Cookie']
        assert RE_COOKIE_DEL.match(cookie), (
            'Cookie in unexpected format: %s' % cookie)
        # the cookie values was deleted from database
        q = db.query(TicketGrantingCookie).filter(
            TicketGrantingCookie.value == value)
        assert len(list(q)) == 0

    def test_check_service_ticket(self):
        db = DB('sqlite:///')
        st = ServiceTicket(
            'ST-123456', 'someuser', 'http://myservice.com', True)
        db.add(st)
        assert check_service_ticket(db, None, 'foo') is None
        assert check_service_ticket(db, 'foo', None) is None
        assert check_service_ticket(db, 'ST-123456', 'foo') is None
        assert check_service_ticket(db, 'foo', 'http://myservice.com') is None
        result = check_service_ticket(db, 'ST-123456', 'http://myservice.com')
        assert isinstance(result, ServiceTicket)
        assert result.user == 'someuser'
        assert check_service_ticket(
            db,  'ST-123456', 'http://myservice.com', True) is None
        assert check_service_ticket(
            db,  'ST-123456', 'http://myservice.com', False) is not None

    def test_update_url(self):
        # we can create valid new urls with query string params updated
        result1 = update_url('http://sample.com/index?a=1&b=2', dict(b='3'))
        assert result1 in (
            'http://sample.com/index?a=1&b=3',
            'http://sample.com/index?b=3&a=1')
        result2 = update_url('http://sample.com/index?b=2', dict(b='3'))
        assert result2 == 'http://sample.com/index?b=3'
        result3 = update_url('http://sample.com/index', dict(b='3'))
        assert result3 == 'http://sample.com/index?b=3'
        result4 = update_url('http://sample.com/index?a=2', dict(b='3'))
        assert result4 in (
            'http://sample.com/index?a=2&b=3',
            'http://sample.com/index?b=3&a=2')
