add some tests

This commit is contained in:
Valentin Samir 2015-06-21 18:56:16 +02:00
parent c0d8550120
commit 50781dba18
13 changed files with 195 additions and 78 deletions

View file

@ -7,6 +7,8 @@ env:
matrix: matrix:
- TOX_ENV=py27-django17 - TOX_ENV=py27-django17
- TOX_ENV=py27-django18 - TOX_ENV=py27-django18
- TOX_ENV=py34-django17
- TOX_ENV=py34-django18
- TOX_ENV=flake8 - TOX_ENV=flake8
cache: cache:
directories: directories:

View file

@ -14,8 +14,8 @@ from .default_settings import settings
from django import forms from django import forms
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
import utils import cas_server.utils as utils
import models import cas_server.models as models
class UserCredential(forms.Form): class UserCredential(forms.Form):

View file

@ -27,7 +27,7 @@ from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from requests_futures.sessions import FuturesSession from requests_futures.sessions import FuturesSession
import utils import cas_server.utils as utils
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore SessionStore = import_module(settings.SESSION_ENGINE).SessionStore

View file

@ -16,12 +16,17 @@ from django.utils.importlib import import_module
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
import urlparse
import urllib
import random import random
import string import string
try:
from urlparse import urlparse, urlunparse, parse_qsl
from urllib import urlencode
except ImportError:
from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
def import_attr(path): def import_attr(path):
"""transform a python module.attr path to the attr""" """transform a python module.attr path to the attr"""
if not isinstance(path, str): if not isinstance(path, str):
@ -33,26 +38,29 @@ def import_attr(path):
def redirect_params(url_name, params=None): def redirect_params(url_name, params=None):
"""Redirect to `url_name` with `params` as querystring""" """Redirect to `url_name` with `params` as querystring"""
url = reverse(url_name) url = reverse(url_name)
params = urllib.urlencode(params if params else {}) params = urlencode(params if params else {})
return HttpResponseRedirect(url + "?%s" % params) return HttpResponseRedirect(url + "?%s" % params)
def update_url(url, params): def update_url(url, params):
"""update params in the `url` query string""" """update params in the `url` query string"""
if isinstance(url, unicode): if not isinstance(url, bytes):
url = url.encode('utf-8') url = url.encode('utf-8')
for key, value in params.items(): for key, value in list(params.items()):
if isinstance(key, unicode): if not isinstance(key, bytes):
del params[key] del params[key]
key = key.encode('utf-8') key = key.encode('utf-8')
if isinstance(value, unicode): if not isinstance(value, bytes):
value = value.encode('utf-8') value = value.encode('utf-8')
params[key] = value params[key] = value
url_parts = list(urlparse.urlparse(url)) url_parts = list(urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4])) query = dict(parse_qsl(url_parts[4]))
query.update(params) query.update(params)
url_parts[4] = urllib.urlencode(query) url_parts[4] = urlencode(query)
return urlparse.urlunparse(url_parts).decode('utf-8') for i in range(len(url_parts)):
if not isinstance(url_parts[i], bytes):
url_parts[i] = url_parts[i].encode('utf-8')
return urlunparse(url_parts).decode('utf-8')
def unpack_nested_exception(error): def unpack_nested_exception(error):

View file

@ -26,9 +26,9 @@ import requests
from lxml import etree from lxml import etree
from datetime import timedelta from datetime import timedelta
import utils import cas_server.utils as utils
import forms import cas_server.forms as forms
import models import cas_server.models as models
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket
from .models import ServicePattern from .models import ServicePattern
@ -633,7 +633,7 @@ class Proxy(View):
self.target_service, self.target_service,
pattern, pattern,
renew=False) renew=False)
pticket.proxies.create(url=ticket.service) models.Proxy.objects.create(proxy_ticket=pticket, url=ticket.service)
return render( return render(
self.request, self.request,
"cas_server/proxy.xml", "cas_server/proxy.xml",

View file

@ -1,3 +1,4 @@
import functools
from cas_server import models from cas_server import models
class DummyUserManager(object): class DummyUserManager(object):
@ -10,6 +11,75 @@ class DummyUserManager(object):
else: else:
raise models.User.DoesNotExist() raise models.User.DoesNotExist()
def dummy(*args, **kwds):
pass
def dummy_service_pattern(**kwargs):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
service_validate = models.ServicePattern.validate
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs))
ret = func(*args, **kwds)
models.ServicePattern.validate = service_validate
return ret
return wrapper
return decorator
def dummy_user(username, session_key):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
user_manager = models.User.objects
user_save = models.User.save
user_delete = models.User.delete
models.User.objects = DummyUserManager(username, session_key)
models.User.save = dummy
models.User.delete = dummy
ret = func(*args, **kwds)
models.User.objects = user_manager
models.User.save = user_save
models.User.delete = user_delete
return ret
return wrapper
return decorator
def dummy_ticket(ticket_class, service, ticket):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
ticket_manager = ticket_class.objects
ticket_save = ticket_class.save
ticket_delete = ticket_class.delete
ticket_class.objects = DummyTicketManager(ticket_class, service, ticket)
ticket_class.save = dummy
ticket_class.delete = dummy
ret = func(*args, **kwds)
ticket_class.objects = ticket_manager
ticket_class.save = ticket_save
ticket_class.delete = ticket_delete
return ret
return wrapper
return decorator
def dummy_proxy(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
proxy_manager = models.Proxy.objects
models.Proxy.objects = DummyProxyManager()
ret = func(*args, **kwds)
models.Proxy.objects = proxy_manager
return ret
return wrapper
class DummyProxyManager(object):
def create(self, **kwargs):
for field in models.Proxy._meta.fields:
field.allow_unsaved_instance_assignment = True
return models.Proxy(**kwargs)
class DummyTicketManager(object): class DummyTicketManager(object):
def __init__(self, ticket_class, service, ticket): def __init__(self, ticket_class, service, ticket):
self.ticket_class = ticket_class self.ticket_class = ticket_class
@ -17,7 +87,7 @@ class DummyTicketManager(object):
self.ticket = ticket self.ticket = ticket
def create(self, **kwargs): def create(self, **kwargs):
for field in models.ServiceTicket._meta.fields: for field in self.ticket_class._meta.fields:
field.allow_unsaved_instance_assignment = True field.allow_unsaved_instance_assignment = True
return self.ticket_class(**kwargs) return self.ticket_class(**kwargs)
@ -25,6 +95,8 @@ class DummyTicketManager(object):
return DummyQuerySet() return DummyQuerySet()
def get(self, **kwargs): def get(self, **kwargs):
for field in self.ticket_class._meta.fields:
field.allow_unsaved_instance_assignment = True
if 'value' in kwargs: if 'value' in kwargs:
if kwargs['value'] != self.ticket: if kwargs['value'] != self.ticket:
raise self.ticket_class.DoesNotExist() raise self.ticket_class.DoesNotExist()
@ -41,7 +113,7 @@ class DummyTicketManager(object):
for field in models.ServiceTicket._meta.fields: for field in models.ServiceTicket._meta.fields:
field.allow_unsaved_instance_assignment = True field.allow_unsaved_instance_assignment = True
for key in kwargs.keys(): for key in list(kwargs):
if '__' in key: if '__' in key:
del kwargs[key] del kwargs[key]
kwargs['attributs'] = {'mail': 'test@example.com'} kwargs['attributs'] = {'mail': 'test@example.com'}

52
tests/test_proxy.py Normal file
View file

@ -0,0 +1,52 @@
from __future__ import absolute_import
from tests.init import *
from django.test import RequestFactory
import os
import pytest
from lxml import etree
from cas_server.views import ValidateService, Proxy
from cas_server import models
from tests.dummy import *
@pytest.mark.django_db
@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random")
@dummy_service_pattern(proxy=True)
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random")
@dummy_proxy
def test_proxy_ok():
factory = RequestFactory()
request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com')
request.session = DummySession()
proxy = Proxy()
response = proxy.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(proxy_tickets) == 1
factory = RequestFactory()
request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com')
validate = ValidateService()
validate.allow_proxy_ticket = True
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(users) == 1
assert users[0].text == "test"

View file

@ -12,15 +12,13 @@ from cas_server import models
from .dummy import * from .dummy import *
@pytest.mark.django_db @pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_service_view_ok(): def test_validate_service_view_ok():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com') request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
request.session = DummySession() request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServiceTicket.save = lambda x:None
validate = ValidateService() validate = ValidateService()
validate.allow_proxy_ticket = False validate.allow_proxy_ticket = False
response = validate.get(request) response = validate.get(request)
@ -47,15 +45,13 @@ def test_validate_service_view_ok():
@pytest.mark.django_db @pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random")
def test_validate_service_view_badservice(): def test_validate_service_view_badservice():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com') request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
request.session = DummySession() request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example2.com', "ST-random")
models.ServiceTicket.save = lambda x:None
validate = ValidateService() validate = ValidateService()
validate.allow_proxy_ticket = False validate.allow_proxy_ticket = False
response = validate.get(request) response = validate.get(request)
@ -70,15 +66,13 @@ def test_validate_service_view_badservice():
assert error[0].attrib['code'] == 'INVALID_SERVICE' assert error[0].attrib['code'] == 'INVALID_SERVICE'
@pytest.mark.django_db @pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2")
def test_validate_service_view_badticket(): def test_validate_service_view_badticket():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com') request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
request.session = DummySession() request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random2")
models.ServiceTicket.save = lambda x:None
validate = ValidateService() validate = ValidateService()
validate.allow_proxy_ticket = False validate.allow_proxy_ticket = False
response = validate.get(request) response = validate.get(request)

View file

@ -14,36 +14,33 @@ from .dummy import *
settings.CAS_AUTH_SHARED_SECRET = "test" settings.CAS_AUTH_SHARED_SECRET = "test"
@pytest.mark.django_db @pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
@dummy_user(username="test", session_key="test_session")
@dummy_service_pattern()
def test_auth_view_goodpass(): def test_auth_view_goodpass():
factory = RequestFactory() factory = RequestFactory()
request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'}) request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
request.session = DummySession() request.session = DummySession()
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
auth = Auth() auth = Auth()
response = auth.post(request) response = auth.post(request)
assert response.status_code == 200 assert response.status_code == 200
assert response.content == "yes\n" assert response.content == b"yes\n"
@dummy_service_pattern()
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
@dummy_user(username="test", session_key="test_session")
def test_auth_view_badpass(): def test_auth_view_badpass():
factory = RequestFactory() factory = RequestFactory()
request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'}) request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
request.session = DummySession() request.session = DummySession()
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
auth = Auth() auth = Auth()
response = auth.post(request) response = auth.post(request)
assert response.status_code == 200 assert response.status_code == 200
assert response.content == "no\n" assert response.content == b"no\n"

View file

@ -92,6 +92,7 @@ def test_view_login_get_unauth():
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.django_db @pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_view_login_get_auth(): def test_view_login_get_auth():
factory = RequestFactory() factory = RequestFactory()
request = factory.post('/login') request = factory.post('/login')
@ -107,14 +108,15 @@ def test_view_login_get_auth():
assert ret == LoginView.USER_AUTHENTICATED assert ret == LoginView.USER_AUTHENTICATED
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
login = LoginView() login = LoginView()
response = login.get(request) response = login.get(request)
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.django_db @pytest.mark.django_db
@dummy_service_pattern()
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_view_login_get_auth_service(): def test_view_login_get_auth_service():
factory = RequestFactory() factory = RequestFactory()
request = factory.post('/login?service=https://www.example.com') request = factory.post('/login?service=https://www.example.com')
@ -130,12 +132,6 @@ def test_view_login_get_auth_service():
assert ret == LoginView.USER_AUTHENTICATED assert ret == LoginView.USER_AUTHENTICATED
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.User.save = lambda x:None
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
models.ServiceTicket.save = lambda x:None
login = LoginView() login = LoginView()
response = login.get(request) response = login.get(request)
@ -143,6 +139,9 @@ def test_view_login_get_auth_service():
assert response['Location'].startswith('https://www.example.com?ticket=ST-') assert response['Location'].startswith('https://www.example.com?ticket=ST-')
@pytest.mark.django_db @pytest.mark.django_db
@dummy_service_pattern()
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_view_login_get_auth_service_warn(): def test_view_login_get_auth_service_warn():
factory = RequestFactory() factory = RequestFactory()
request = factory.post('/login?service=https://www.example.com') request = factory.post('/login?service=https://www.example.com')
@ -158,12 +157,6 @@ def test_view_login_get_auth_service_warn():
assert ret == LoginView.USER_AUTHENTICATED assert ret == LoginView.USER_AUTHENTICATED
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.User.save = lambda x:None
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
models.ServiceTicket.save = lambda x:None
login = LoginView() login = LoginView()
response = login.get(request) response = login.get(request)

View file

@ -13,6 +13,7 @@ from .dummy import *
@pytest.mark.django_db @pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view(): def test_logout_view():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/logout') request = factory.get('/logout')
@ -23,21 +24,17 @@ def test_logout_view():
request.session["username"] = "test" request.session["username"] = "test"
request.session["warn"] = False request.session["warn"] = False
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
dlist = [None]
models.User.delete = lambda x:dlist.pop()
logout = LogoutView() logout = LogoutView()
response = logout.get(request) response = logout.get(request)
assert response.status_code == 200 assert response.status_code == 200
assert dlist == []
assert not request.session.get("authenticated") assert not request.session.get("authenticated")
assert not request.session.get("username") assert not request.session.get("username")
assert not request.session.get("warn") assert not request.session.get("warn")
@pytest.mark.django_db @pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view_url(): def test_logout_view_url():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/logout?url=https://www.example.com') request = factory.get('/logout?url=https://www.example.com')
@ -48,16 +45,11 @@ def test_logout_view_url():
request.session["username"] = "test" request.session["username"] = "test"
request.session["warn"] = False request.session["warn"] = False
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
dlist = [None]
models.User.delete = lambda x:dlist.pop()
logout = LogoutView() logout = LogoutView()
response = logout.get(request) response = logout.get(request)
assert response.status_code == 302 assert response.status_code == 302
assert response['Location'] == 'https://www.example.com' assert response['Location'] == 'https://www.example.com'
assert dlist == []
assert not request.session.get("authenticated") assert not request.session.get("authenticated")
assert not request.session.get("username") assert not request.session.get("username")
assert not request.session.get("warn") assert not request.session.get("warn")
@ -65,6 +57,7 @@ def test_logout_view_url():
@pytest.mark.django_db @pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view_service(): def test_logout_view_service():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/logout?service=https://www.example.com') request = factory.get('/logout?service=https://www.example.com')
@ -75,16 +68,11 @@ def test_logout_view_service():
request.session["username"] = "test" request.session["username"] = "test"
request.session["warn"] = False request.session["warn"] = False
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
dlist = [None]
models.User.delete = lambda x:dlist.pop()
logout = LogoutView() logout = LogoutView()
response = logout.get(request) response = logout.get(request)
assert response.status_code == 302 assert response.status_code == 302
assert response['Location'] == 'https://www.example.com' assert response['Location'] == 'https://www.example.com'
assert dlist == []
assert not request.session.get("authenticated") assert not request.session.get("authenticated")
assert not request.session.get("username") assert not request.session.get("username")
assert not request.session.get("warn") assert not request.session.get("warn")

View file

@ -12,50 +12,47 @@ from cas_server import models
from .dummy import * from .dummy import *
@pytest.mark.django_db @pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_view_ok(): def test_validate_view_ok():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random&service=https://www.example.com') request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
request.session = DummySession() request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
validate = Validate() validate = Validate()
response = validate.get(request) response = validate.get(request)
assert response.status_code == 200 assert response.status_code == 200
assert response.content == "yes\n" assert response.content == b"yes\n"
@pytest.mark.django_db @pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_view_badservice(): def test_validate_view_badservice():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com') request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
request.session = DummySession() request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
validate = Validate() validate = Validate()
response = validate.get(request) response = validate.get(request)
assert response.status_code == 200 assert response.status_code == 200
assert response.content == "no\n" assert response.content == b"no\n"
@pytest.mark.django_db @pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1")
def test_validate_view_badticket(): def test_validate_view_badticket():
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com') request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
request.session = DummySession() request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random1")
validate = Validate() validate = Validate()
response = validate.get(request) response = validate.get(request)
assert response.status_code == 200 assert response.status_code == 200
assert response.content == "no\n" assert response.content == b"no\n"

14
tox.ini
View file

@ -2,6 +2,8 @@
envlist= envlist=
py27-django17, py27-django17,
py27-django18, py27-django18,
py34-django17,
py34-django18,
flake8, flake8,
[flake8] [flake8]
@ -27,6 +29,18 @@ deps =
Django>=1.8,<1.9 Django>=1.8,<1.9
{[base]deps} {[base]deps}
[testenv:py34-django17]
basepython=python3.4
deps =
Django>=1.7,<1.8
{[base]deps}
[testenv:py34-django18]
basepython=python3.4
deps =
Django>=1.8,<1.9
{[base]deps}
[testenv:flake8] [testenv:flake8]
basepython=python basepython=python
deps=flake8 deps=flake8