add some tests
This commit is contained in:
parent
c0d8550120
commit
50781dba18
13 changed files with 195 additions and 78 deletions
|
@ -7,6 +7,8 @@ env:
|
||||||
matrix:
|
matrix:
|
||||||
- TOX_ENV=py27-django17
|
- TOX_ENV=py27-django17
|
||||||
- TOX_ENV=py27-django18
|
- TOX_ENV=py27-django18
|
||||||
|
- TOX_ENV=py34-django17
|
||||||
|
- TOX_ENV=py34-django18
|
||||||
- TOX_ENV=flake8
|
- TOX_ENV=flake8
|
||||||
cache:
|
cache:
|
||||||
directories:
|
directories:
|
||||||
|
|
|
@ -14,8 +14,8 @@ from .default_settings import settings
|
||||||
from django import forms
|
from django import forms
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
import utils
|
import cas_server.utils as utils
|
||||||
import models
|
import cas_server.models as models
|
||||||
|
|
||||||
|
|
||||||
class UserCredential(forms.Form):
|
class UserCredential(forms.Form):
|
||||||
|
|
|
@ -27,7 +27,7 @@ from datetime import timedelta
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from requests_futures.sessions import FuturesSession
|
from requests_futures.sessions import FuturesSession
|
||||||
|
|
||||||
import utils
|
import cas_server.utils as utils
|
||||||
|
|
||||||
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
||||||
|
|
||||||
|
|
|
@ -16,12 +16,17 @@ from django.utils.importlib import import_module
|
||||||
from django.core.urlresolvers import reverse
|
from django.core.urlresolvers import reverse
|
||||||
from django.http import HttpResponseRedirect
|
from django.http import HttpResponseRedirect
|
||||||
|
|
||||||
import urlparse
|
|
||||||
import urllib
|
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from urlparse import urlparse, urlunparse, parse_qsl
|
||||||
|
from urllib import urlencode
|
||||||
|
except ImportError:
|
||||||
|
from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
|
||||||
|
|
||||||
|
|
||||||
def import_attr(path):
|
def import_attr(path):
|
||||||
"""transform a python module.attr path to the attr"""
|
"""transform a python module.attr path to the attr"""
|
||||||
if not isinstance(path, str):
|
if not isinstance(path, str):
|
||||||
|
@ -33,26 +38,29 @@ def import_attr(path):
|
||||||
def redirect_params(url_name, params=None):
|
def redirect_params(url_name, params=None):
|
||||||
"""Redirect to `url_name` with `params` as querystring"""
|
"""Redirect to `url_name` with `params` as querystring"""
|
||||||
url = reverse(url_name)
|
url = reverse(url_name)
|
||||||
params = urllib.urlencode(params if params else {})
|
params = urlencode(params if params else {})
|
||||||
return HttpResponseRedirect(url + "?%s" % params)
|
return HttpResponseRedirect(url + "?%s" % params)
|
||||||
|
|
||||||
|
|
||||||
def update_url(url, params):
|
def update_url(url, params):
|
||||||
"""update params in the `url` query string"""
|
"""update params in the `url` query string"""
|
||||||
if isinstance(url, unicode):
|
if not isinstance(url, bytes):
|
||||||
url = url.encode('utf-8')
|
url = url.encode('utf-8')
|
||||||
for key, value in params.items():
|
for key, value in list(params.items()):
|
||||||
if isinstance(key, unicode):
|
if not isinstance(key, bytes):
|
||||||
del params[key]
|
del params[key]
|
||||||
key = key.encode('utf-8')
|
key = key.encode('utf-8')
|
||||||
if isinstance(value, unicode):
|
if not isinstance(value, bytes):
|
||||||
value = value.encode('utf-8')
|
value = value.encode('utf-8')
|
||||||
params[key] = value
|
params[key] = value
|
||||||
url_parts = list(urlparse.urlparse(url))
|
url_parts = list(urlparse(url))
|
||||||
query = dict(urlparse.parse_qsl(url_parts[4]))
|
query = dict(parse_qsl(url_parts[4]))
|
||||||
query.update(params)
|
query.update(params)
|
||||||
url_parts[4] = urllib.urlencode(query)
|
url_parts[4] = urlencode(query)
|
||||||
return urlparse.urlunparse(url_parts).decode('utf-8')
|
for i in range(len(url_parts)):
|
||||||
|
if not isinstance(url_parts[i], bytes):
|
||||||
|
url_parts[i] = url_parts[i].encode('utf-8')
|
||||||
|
return urlunparse(url_parts).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def unpack_nested_exception(error):
|
def unpack_nested_exception(error):
|
||||||
|
|
|
@ -26,9 +26,9 @@ import requests
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
import utils
|
import cas_server.utils as utils
|
||||||
import forms
|
import cas_server.forms as forms
|
||||||
import models
|
import cas_server.models as models
|
||||||
|
|
||||||
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket
|
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket
|
||||||
from .models import ServicePattern
|
from .models import ServicePattern
|
||||||
|
@ -633,7 +633,7 @@ class Proxy(View):
|
||||||
self.target_service,
|
self.target_service,
|
||||||
pattern,
|
pattern,
|
||||||
renew=False)
|
renew=False)
|
||||||
pticket.proxies.create(url=ticket.service)
|
models.Proxy.objects.create(proxy_ticket=pticket, url=ticket.service)
|
||||||
return render(
|
return render(
|
||||||
self.request,
|
self.request,
|
||||||
"cas_server/proxy.xml",
|
"cas_server/proxy.xml",
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import functools
|
||||||
from cas_server import models
|
from cas_server import models
|
||||||
|
|
||||||
class DummyUserManager(object):
|
class DummyUserManager(object):
|
||||||
|
@ -10,6 +11,75 @@ class DummyUserManager(object):
|
||||||
else:
|
else:
|
||||||
raise models.User.DoesNotExist()
|
raise models.User.DoesNotExist()
|
||||||
|
|
||||||
|
|
||||||
|
def dummy(*args, **kwds):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def dummy_service_pattern(**kwargs):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwds):
|
||||||
|
service_validate = models.ServicePattern.validate
|
||||||
|
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs))
|
||||||
|
ret = func(*args, **kwds)
|
||||||
|
models.ServicePattern.validate = service_validate
|
||||||
|
return ret
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def dummy_user(username, session_key):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwds):
|
||||||
|
user_manager = models.User.objects
|
||||||
|
user_save = models.User.save
|
||||||
|
user_delete = models.User.delete
|
||||||
|
models.User.objects = DummyUserManager(username, session_key)
|
||||||
|
models.User.save = dummy
|
||||||
|
models.User.delete = dummy
|
||||||
|
ret = func(*args, **kwds)
|
||||||
|
models.User.objects = user_manager
|
||||||
|
models.User.save = user_save
|
||||||
|
models.User.delete = user_delete
|
||||||
|
return ret
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def dummy_ticket(ticket_class, service, ticket):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwds):
|
||||||
|
ticket_manager = ticket_class.objects
|
||||||
|
ticket_save = ticket_class.save
|
||||||
|
ticket_delete = ticket_class.delete
|
||||||
|
ticket_class.objects = DummyTicketManager(ticket_class, service, ticket)
|
||||||
|
ticket_class.save = dummy
|
||||||
|
ticket_class.delete = dummy
|
||||||
|
ret = func(*args, **kwds)
|
||||||
|
ticket_class.objects = ticket_manager
|
||||||
|
ticket_class.save = ticket_save
|
||||||
|
ticket_class.delete = ticket_delete
|
||||||
|
return ret
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_proxy(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwds):
|
||||||
|
proxy_manager = models.Proxy.objects
|
||||||
|
models.Proxy.objects = DummyProxyManager()
|
||||||
|
ret = func(*args, **kwds)
|
||||||
|
models.Proxy.objects = proxy_manager
|
||||||
|
return ret
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
class DummyProxyManager(object):
|
||||||
|
def create(self, **kwargs):
|
||||||
|
for field in models.Proxy._meta.fields:
|
||||||
|
field.allow_unsaved_instance_assignment = True
|
||||||
|
return models.Proxy(**kwargs)
|
||||||
|
|
||||||
class DummyTicketManager(object):
|
class DummyTicketManager(object):
|
||||||
def __init__(self, ticket_class, service, ticket):
|
def __init__(self, ticket_class, service, ticket):
|
||||||
self.ticket_class = ticket_class
|
self.ticket_class = ticket_class
|
||||||
|
@ -17,7 +87,7 @@ class DummyTicketManager(object):
|
||||||
self.ticket = ticket
|
self.ticket = ticket
|
||||||
|
|
||||||
def create(self, **kwargs):
|
def create(self, **kwargs):
|
||||||
for field in models.ServiceTicket._meta.fields:
|
for field in self.ticket_class._meta.fields:
|
||||||
field.allow_unsaved_instance_assignment = True
|
field.allow_unsaved_instance_assignment = True
|
||||||
return self.ticket_class(**kwargs)
|
return self.ticket_class(**kwargs)
|
||||||
|
|
||||||
|
@ -25,6 +95,8 @@ class DummyTicketManager(object):
|
||||||
return DummyQuerySet()
|
return DummyQuerySet()
|
||||||
|
|
||||||
def get(self, **kwargs):
|
def get(self, **kwargs):
|
||||||
|
for field in self.ticket_class._meta.fields:
|
||||||
|
field.allow_unsaved_instance_assignment = True
|
||||||
if 'value' in kwargs:
|
if 'value' in kwargs:
|
||||||
if kwargs['value'] != self.ticket:
|
if kwargs['value'] != self.ticket:
|
||||||
raise self.ticket_class.DoesNotExist()
|
raise self.ticket_class.DoesNotExist()
|
||||||
|
@ -41,7 +113,7 @@ class DummyTicketManager(object):
|
||||||
|
|
||||||
for field in models.ServiceTicket._meta.fields:
|
for field in models.ServiceTicket._meta.fields:
|
||||||
field.allow_unsaved_instance_assignment = True
|
field.allow_unsaved_instance_assignment = True
|
||||||
for key in kwargs.keys():
|
for key in list(kwargs):
|
||||||
if '__' in key:
|
if '__' in key:
|
||||||
del kwargs[key]
|
del kwargs[key]
|
||||||
kwargs['attributs'] = {'mail': 'test@example.com'}
|
kwargs['attributs'] = {'mail': 'test@example.com'}
|
||||||
|
|
52
tests/test_proxy.py
Normal file
52
tests/test_proxy.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from tests.init import *
|
||||||
|
|
||||||
|
from django.test import RequestFactory
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from lxml import etree
|
||||||
|
from cas_server.views import ValidateService, Proxy
|
||||||
|
from cas_server import models
|
||||||
|
|
||||||
|
from tests.dummy import *
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random")
|
||||||
|
@dummy_service_pattern(proxy=True)
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
|
@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random")
|
||||||
|
@dummy_proxy
|
||||||
|
def test_proxy_ok():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
proxy = Proxy()
|
||||||
|
response = proxy.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
root = etree.fromstring(response.content)
|
||||||
|
proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
|
|
||||||
|
assert len(proxy_tickets) == 1
|
||||||
|
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com')
|
||||||
|
|
||||||
|
validate = ValidateService()
|
||||||
|
validate.allow_proxy_ticket = True
|
||||||
|
response = validate.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
root = etree.fromstring(response.content)
|
||||||
|
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
|
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users[0].text == "test"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,15 +12,13 @@ from cas_server import models
|
||||||
from .dummy import *
|
from .dummy import *
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
def test_validate_service_view_ok():
|
def test_validate_service_view_ok():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
|
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
|
||||||
|
|
||||||
request.session = DummySession()
|
request.session = DummySession()
|
||||||
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
|
||||||
models.ServiceTicket.save = lambda x:None
|
|
||||||
|
|
||||||
validate = ValidateService()
|
validate = ValidateService()
|
||||||
validate.allow_proxy_ticket = False
|
validate.allow_proxy_ticket = False
|
||||||
response = validate.get(request)
|
response = validate.get(request)
|
||||||
|
@ -47,15 +45,13 @@ def test_validate_service_view_ok():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random")
|
||||||
def test_validate_service_view_badservice():
|
def test_validate_service_view_badservice():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
|
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
|
||||||
|
|
||||||
request.session = DummySession()
|
request.session = DummySession()
|
||||||
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example2.com', "ST-random")
|
|
||||||
models.ServiceTicket.save = lambda x:None
|
|
||||||
|
|
||||||
validate = ValidateService()
|
validate = ValidateService()
|
||||||
validate.allow_proxy_ticket = False
|
validate.allow_proxy_ticket = False
|
||||||
response = validate.get(request)
|
response = validate.get(request)
|
||||||
|
@ -70,15 +66,13 @@ def test_validate_service_view_badservice():
|
||||||
assert error[0].attrib['code'] == 'INVALID_SERVICE'
|
assert error[0].attrib['code'] == 'INVALID_SERVICE'
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2")
|
||||||
def test_validate_service_view_badticket():
|
def test_validate_service_view_badticket():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
|
request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
|
||||||
|
|
||||||
request.session = DummySession()
|
request.session = DummySession()
|
||||||
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random2")
|
|
||||||
models.ServiceTicket.save = lambda x:None
|
|
||||||
|
|
||||||
validate = ValidateService()
|
validate = ValidateService()
|
||||||
validate.allow_proxy_ticket = False
|
validate.allow_proxy_ticket = False
|
||||||
response = validate.get(request)
|
response = validate.get(request)
|
||||||
|
|
|
@ -14,36 +14,33 @@ from .dummy import *
|
||||||
settings.CAS_AUTH_SHARED_SECRET = "test"
|
settings.CAS_AUTH_SHARED_SECRET = "test"
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
|
@dummy_service_pattern()
|
||||||
def test_auth_view_goodpass():
|
def test_auth_view_goodpass():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
|
request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
|
||||||
|
|
||||||
request.session = DummySession()
|
request.session = DummySession()
|
||||||
|
|
||||||
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
|
||||||
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
|
|
||||||
|
|
||||||
auth = Auth()
|
auth = Auth()
|
||||||
response = auth.post(request)
|
response = auth.post(request)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.content == "yes\n"
|
assert response.content == b"yes\n"
|
||||||
|
|
||||||
|
|
||||||
|
@dummy_service_pattern()
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
def test_auth_view_badpass():
|
def test_auth_view_badpass():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
|
request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
|
||||||
|
|
||||||
request.session = DummySession()
|
request.session = DummySession()
|
||||||
|
|
||||||
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
|
||||||
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
|
|
||||||
|
|
||||||
auth = Auth()
|
auth = Auth()
|
||||||
response = auth.post(request)
|
response = auth.post(request)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.content == "no\n"
|
assert response.content == b"no\n"
|
||||||
|
|
||||||
|
|
|
@ -92,6 +92,7 @@ def test_view_login_get_unauth():
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
def test_view_login_get_auth():
|
def test_view_login_get_auth():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.post('/login')
|
request = factory.post('/login')
|
||||||
|
@ -107,14 +108,15 @@ def test_view_login_get_auth():
|
||||||
|
|
||||||
assert ret == LoginView.USER_AUTHENTICATED
|
assert ret == LoginView.USER_AUTHENTICATED
|
||||||
|
|
||||||
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
|
||||||
|
|
||||||
login = LoginView()
|
login = LoginView()
|
||||||
response = login.get(request)
|
response = login.get(request)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_service_pattern()
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
def test_view_login_get_auth_service():
|
def test_view_login_get_auth_service():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.post('/login?service=https://www.example.com')
|
request = factory.post('/login?service=https://www.example.com')
|
||||||
|
@ -130,12 +132,6 @@ def test_view_login_get_auth_service():
|
||||||
|
|
||||||
assert ret == LoginView.USER_AUTHENTICATED
|
assert ret == LoginView.USER_AUTHENTICATED
|
||||||
|
|
||||||
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
|
||||||
models.User.save = lambda x:None
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
|
||||||
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
|
|
||||||
models.ServiceTicket.save = lambda x:None
|
|
||||||
|
|
||||||
login = LoginView()
|
login = LoginView()
|
||||||
response = login.get(request)
|
response = login.get(request)
|
||||||
|
|
||||||
|
@ -143,6 +139,9 @@ def test_view_login_get_auth_service():
|
||||||
assert response['Location'].startswith('https://www.example.com?ticket=ST-')
|
assert response['Location'].startswith('https://www.example.com?ticket=ST-')
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_service_pattern()
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
def test_view_login_get_auth_service_warn():
|
def test_view_login_get_auth_service_warn():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.post('/login?service=https://www.example.com')
|
request = factory.post('/login?service=https://www.example.com')
|
||||||
|
@ -158,12 +157,6 @@ def test_view_login_get_auth_service_warn():
|
||||||
|
|
||||||
assert ret == LoginView.USER_AUTHENTICATED
|
assert ret == LoginView.USER_AUTHENTICATED
|
||||||
|
|
||||||
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
|
||||||
models.User.save = lambda x:None
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
|
||||||
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
|
|
||||||
models.ServiceTicket.save = lambda x:None
|
|
||||||
|
|
||||||
login = LoginView()
|
login = LoginView()
|
||||||
response = login.get(request)
|
response = login.get(request)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from .dummy import *
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
def test_logout_view():
|
def test_logout_view():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/logout')
|
request = factory.get('/logout')
|
||||||
|
@ -23,21 +24,17 @@ def test_logout_view():
|
||||||
request.session["username"] = "test"
|
request.session["username"] = "test"
|
||||||
request.session["warn"] = False
|
request.session["warn"] = False
|
||||||
|
|
||||||
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
|
||||||
dlist = [None]
|
|
||||||
models.User.delete = lambda x:dlist.pop()
|
|
||||||
|
|
||||||
logout = LogoutView()
|
logout = LogoutView()
|
||||||
response = logout.get(request)
|
response = logout.get(request)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert dlist == []
|
|
||||||
assert not request.session.get("authenticated")
|
assert not request.session.get("authenticated")
|
||||||
assert not request.session.get("username")
|
assert not request.session.get("username")
|
||||||
assert not request.session.get("warn")
|
assert not request.session.get("warn")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
def test_logout_view_url():
|
def test_logout_view_url():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/logout?url=https://www.example.com')
|
request = factory.get('/logout?url=https://www.example.com')
|
||||||
|
@ -48,16 +45,11 @@ def test_logout_view_url():
|
||||||
request.session["username"] = "test"
|
request.session["username"] = "test"
|
||||||
request.session["warn"] = False
|
request.session["warn"] = False
|
||||||
|
|
||||||
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
|
||||||
dlist = [None]
|
|
||||||
models.User.delete = lambda x:dlist.pop()
|
|
||||||
|
|
||||||
logout = LogoutView()
|
logout = LogoutView()
|
||||||
response = logout.get(request)
|
response = logout.get(request)
|
||||||
|
|
||||||
assert response.status_code == 302
|
assert response.status_code == 302
|
||||||
assert response['Location'] == 'https://www.example.com'
|
assert response['Location'] == 'https://www.example.com'
|
||||||
assert dlist == []
|
|
||||||
assert not request.session.get("authenticated")
|
assert not request.session.get("authenticated")
|
||||||
assert not request.session.get("username")
|
assert not request.session.get("username")
|
||||||
assert not request.session.get("warn")
|
assert not request.session.get("warn")
|
||||||
|
@ -65,6 +57,7 @@ def test_logout_view_url():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_user(username="test", session_key="test_session")
|
||||||
def test_logout_view_service():
|
def test_logout_view_service():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/logout?service=https://www.example.com')
|
request = factory.get('/logout?service=https://www.example.com')
|
||||||
|
@ -75,16 +68,11 @@ def test_logout_view_service():
|
||||||
request.session["username"] = "test"
|
request.session["username"] = "test"
|
||||||
request.session["warn"] = False
|
request.session["warn"] = False
|
||||||
|
|
||||||
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
|
||||||
dlist = [None]
|
|
||||||
models.User.delete = lambda x:dlist.pop()
|
|
||||||
|
|
||||||
logout = LogoutView()
|
logout = LogoutView()
|
||||||
response = logout.get(request)
|
response = logout.get(request)
|
||||||
|
|
||||||
assert response.status_code == 302
|
assert response.status_code == 302
|
||||||
assert response['Location'] == 'https://www.example.com'
|
assert response['Location'] == 'https://www.example.com'
|
||||||
assert dlist == []
|
|
||||||
assert not request.session.get("authenticated")
|
assert not request.session.get("authenticated")
|
||||||
assert not request.session.get("username")
|
assert not request.session.get("username")
|
||||||
assert not request.session.get("warn")
|
assert not request.session.get("warn")
|
||||||
|
|
|
@ -12,50 +12,47 @@ from cas_server import models
|
||||||
from .dummy import *
|
from .dummy import *
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
def test_validate_view_ok():
|
def test_validate_view_ok():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
|
request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
|
||||||
|
|
||||||
request.session = DummySession()
|
request.session = DummySession()
|
||||||
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
|
||||||
|
|
||||||
validate = Validate()
|
validate = Validate()
|
||||||
response = validate.get(request)
|
response = validate.get(request)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.content == "yes\n"
|
assert response.content == b"yes\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
def test_validate_view_badservice():
|
def test_validate_view_badservice():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
|
request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
|
||||||
|
|
||||||
request.session = DummySession()
|
request.session = DummySession()
|
||||||
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
|
||||||
|
|
||||||
validate = Validate()
|
validate = Validate()
|
||||||
response = validate.get(request)
|
response = validate.get(request)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.content == "no\n"
|
assert response.content == b"no\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1")
|
||||||
def test_validate_view_badticket():
|
def test_validate_view_badticket():
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
|
request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
|
||||||
|
|
||||||
request.session = DummySession()
|
request.session = DummySession()
|
||||||
|
|
||||||
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random1")
|
|
||||||
|
|
||||||
validate = Validate()
|
validate = Validate()
|
||||||
response = validate.get(request)
|
response = validate.get(request)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.content == "no\n"
|
assert response.content == b"no\n"
|
||||||
|
|
14
tox.ini
14
tox.ini
|
@ -2,6 +2,8 @@
|
||||||
envlist=
|
envlist=
|
||||||
py27-django17,
|
py27-django17,
|
||||||
py27-django18,
|
py27-django18,
|
||||||
|
py34-django17,
|
||||||
|
py34-django18,
|
||||||
flake8,
|
flake8,
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
|
@ -27,6 +29,18 @@ deps =
|
||||||
Django>=1.8,<1.9
|
Django>=1.8,<1.9
|
||||||
{[base]deps}
|
{[base]deps}
|
||||||
|
|
||||||
|
[testenv:py34-django17]
|
||||||
|
basepython=python3.4
|
||||||
|
deps =
|
||||||
|
Django>=1.7,<1.8
|
||||||
|
{[base]deps}
|
||||||
|
|
||||||
|
[testenv:py34-django18]
|
||||||
|
basepython=python3.4
|
||||||
|
deps =
|
||||||
|
Django>=1.8,<1.9
|
||||||
|
{[base]deps}
|
||||||
|
|
||||||
[testenv:flake8]
|
[testenv:flake8]
|
||||||
basepython=python
|
basepython=python
|
||||||
deps=flake8
|
deps=flake8
|
||||||
|
|
Loading…
Reference in a new issue