# tests for db.py module
import os
import shutil
import sqlite3
import tempfile
import unittest
from sqlalchemy.engine import Engine
from waeup.cas.db import (
    DB, LoginTicket, ServiceTicket, TicketGrantingCookie)


class DBTests(unittest.TestCase):
    def setUp(self):
        self.workdir = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.workdir)

    def test_create_db(self):
        # we can create a database
        conn_string = 'sqlite:///'
        db = DB(conn_string)
        assert hasattr(db, 'engine')
        assert isinstance(db.engine, Engine)

    def test_create_db_exists(self):
        # an already existing db will be left untouched
        db_path = os.path.join(self.workdir, 'sample-cas.db')
        conn_string = 'sqlite:///%s' % db_path
        assert os.path.isdir(self.workdir)
        DB(conn_string)  # create database
        conn = sqlite3.connect(db_path)
        with conn:
            conn.execute('''CREATE TABLE mytest (name text, forname text)''')
            conn.execute('''INSERT INTO mytest VALUES ("Foo", "Bar")''')
        conn.close()

        conn = sqlite3.connect(db_path)
        with conn:
            result = [x for x in conn.execute('''SELECT * FROM mytest''')]
        conn.close()
        assert result == [('Foo', 'Bar')]

    def test_create_db_populates_db(self):
        # created DBs will be populated with the required tables
        db_path = os.path.join(self.workdir, 'sample-cas.db')
        conn_string = 'sqlite:///%s' % db_path
        DB(conn_string)  # creates database
        conn = sqlite3.connect(db_path)
        with conn:
            result = [x[:2] for x in conn.execute(
                '''SELECT * FROM sqlite_master ORDER BY type''')]
        conn.close()
        assert ('table', 'service_tickets') in result
        assert ('table', 'login_tickets') in result


class TicketTests(unittest.TestCase):

    def setUp(self):
        self.db = DB('sqlite:///')
        self.db.session()

    def tearDown(self):
        self.db.session.remove()

    def test_login_ticket_add(self):
        # we can add login tickets
        assert self.db.engine.has_table('login_tickets')

        self.db.add(LoginTicket('foo'))
        assert [x.ticket for x in self.db.query(LoginTicket)] == ['foo']

    def test_login_ticket_delete(self):
        # we can delete single login tickets
        assert self.db.engine.has_table('login_tickets')

        self.db.add(LoginTicket('foo'))
        contents = [x for x in self.db.query(LoginTicket)]
        assert len(contents) == 1
        lt = contents[0]
        self.db.delete(lt)
        assert [x.ticket for x in self.db.query(LoginTicket)] == []

    def test_login_ticket_repr(self):
        # we can get a proper LoginTicket representation
        # (i.e. one that can be fed to `eval`)
        ticket = LoginTicket('foo', 12.1)
        assert ticket.__repr__() == "LoginTicket('foo', 12.1)"

    def test_login_ticket_timestamp(self):
        # we get a timestamp stored, if none is passed to init
        lticket = LoginTicket('foo')
        assert isinstance(lticket.ts, float)

    def test_add_service_ticket(self):
        # we can add service tickets
        self.db.add(ServiceTicket('foo', 'bar', 'baz', False, 12.1))
        result = [(x.ticket, x.user, x.service, x.sso, x.ts)
                  for x in self.db.query(ServiceTicket)]
        assert result == [('foo', 'bar', 'baz', False, 12.1)]

    def test_service_ticket_repr(self):
        # we can get a proper ServiceTicket representation
        # (i.e. one that can be fed to `eval`)
        sticket = ServiceTicket('foo', 'bar', 'baz', True, 12.1)
        st_repr = sticket.__repr__()
        assert st_repr == (
            "ServiceTicket(ticket='foo', user='bar', service='baz', "
            "sso=True, timestamp=12.1)")
        new_st = eval(st_repr)
        assert sticket.user == new_st.user
        assert sticket.ticket == new_st.ticket
        assert sticket.service == new_st.service
        assert sticket.sso == new_st.sso
        assert sticket.ts == new_st.ts

    def test_service_ticket_timestamp(self):
        # we get a timestamp stored, if none is passed to init
        sticket = ServiceTicket('foo', 'bar', 'baz')
        assert isinstance(sticket.ts, float)

    def test_service_ticket_sso(self):
        # sso is set to True by default
        sticket = ServiceTicket('foo', 'bar', 'baz')
        assert sticket.sso is True

    def test_ticket_granting_cookie_add(self):
        # we can add ticket granting cookies
        assert self.db.engine.has_table('ticket_granting_cookies')

        self.db.add(TicketGrantingCookie('foo'))
        assert [x.value for x in self.db.query(
            TicketGrantingCookie)] == ['foo']

    def test_ticket_granting_cookie_repr(self):
        # we can get a proper ticket-granting cookie representation
        # (i.e. one that can be fed to `eval`)
        tgc = TicketGrantingCookie('foo', 12.1)
        assert tgc.__repr__() == "TicketGrantingCookie('foo', 12.1)"

    def test_ticket_granting_cookie_timestamp(self):
        # we get a timestamp stored, if none is passed to init
        tgc = TicketGrantingCookie('foo')
        assert isinstance(tgc.ts, float)
