Add some tests using tox
This commit is contained in:
parent
39557d1942
commit
c0d8550120
15 changed files with 724 additions and 51 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1,6 +1,9 @@
|
||||||
*.pyc
|
*.pyc
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
bootstrap3
|
bootstrap3
|
||||||
cas/
|
cas/
|
||||||
db.sqlite3
|
db.sqlite3
|
||||||
manage.py
|
manage.py
|
||||||
|
|
||||||
|
.tox
|
||||||
|
|
21
.travis.yml
Normal file
21
.travis.yml
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
language: python
|
||||||
|
python:
|
||||||
|
- "2.7"
|
||||||
|
env:
|
||||||
|
global:
|
||||||
|
- PIP_DOWNLOAD_CACHE=$HOME/.pip_cache
|
||||||
|
matrix:
|
||||||
|
- TOX_ENV=py27-django17
|
||||||
|
- TOX_ENV=py27-django18
|
||||||
|
- TOX_ENV=flake8
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- $HOME/.pip-cache/
|
||||||
|
install:
|
||||||
|
- "travis_retry pip install setuptools --upgrade"
|
||||||
|
- "pip install tox"
|
||||||
|
script:
|
||||||
|
- tox -e $TOX_ENV
|
||||||
|
after_script:
|
||||||
|
- cat .tox/$TOX_ENV/log/*.log
|
||||||
|
|
|
@ -27,26 +27,14 @@ class UserCredential(forms.Form):
|
||||||
method = forms.CharField(widget=forms.HiddenInput(), required=False)
|
method = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||||
warn = forms.BooleanField(label=_('warn'), required=False)
|
warn = forms.BooleanField(label=_('warn'), required=False)
|
||||||
|
|
||||||
def __init__(self, request, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.request = request
|
|
||||||
super(UserCredential, self).__init__(*args, **kwargs)
|
super(UserCredential, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
cleaned_data = super(UserCredential, self).clean()
|
cleaned_data = super(UserCredential, self).clean()
|
||||||
auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username"))
|
auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username"))
|
||||||
if auth.test_password(cleaned_data.get("password")):
|
if auth.test_password(cleaned_data.get("password")):
|
||||||
try:
|
cleaned_data["username"] = auth.username
|
||||||
user = models.User.objects.get(
|
|
||||||
username=auth.username,
|
|
||||||
session_key=self.request.session.session_key
|
|
||||||
)
|
|
||||||
user.save()
|
|
||||||
except models.User.DoesNotExist:
|
|
||||||
user = models.User.objects.create(
|
|
||||||
username=auth.username,
|
|
||||||
session_key=self.request.session.session_key
|
|
||||||
)
|
|
||||||
user.save()
|
|
||||||
else:
|
else:
|
||||||
raise forms.ValidationError(_(u"Bad user"))
|
raise forms.ValidationError(_(u"Bad user"))
|
||||||
|
|
||||||
|
|
|
@ -89,11 +89,14 @@ class LogoutView(View, LogoutMixin):
|
||||||
request = None
|
request = None
|
||||||
service = None
|
service = None
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def init_get(self, request):
|
||||||
"""methode called on GET request on this view"""
|
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.GET.get('service')
|
self.service = request.GET.get('service')
|
||||||
self.url = request.GET.get('url')
|
self.url = request.GET.get('url')
|
||||||
|
|
||||||
|
def get(self, request, *args, **kwargs):
|
||||||
|
"""methode called on GET request on this view"""
|
||||||
|
self.init_get(request)
|
||||||
self.logout()
|
self.logout()
|
||||||
# if service is set, redirect to service after logout
|
# if service is set, redirect to service after logout
|
||||||
if self.service:
|
if self.service:
|
||||||
|
@ -105,6 +108,7 @@ class LogoutView(View, LogoutMixin):
|
||||||
# else redirect to login page
|
# else redirect to login page
|
||||||
else:
|
else:
|
||||||
if settings.CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT:
|
if settings.CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT:
|
||||||
|
|
||||||
messages.add_message(request, messages.SUCCESS, _(u'Successfully logout'))
|
messages.add_message(request, messages.SUCCESS, _(u'Successfully logout'))
|
||||||
return redirect("cas_server:login")
|
return redirect("cas_server:login")
|
||||||
else:
|
else:
|
||||||
|
@ -129,67 +133,110 @@ class LoginView(View, LogoutMixin):
|
||||||
renewed = False
|
renewed = False
|
||||||
warned = False
|
warned = False
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
INVALID_LOGIN_TICKET = 1
|
||||||
"""methode called on POST request on this view"""
|
USER_LOGIN_OK = 2
|
||||||
|
USER_LOGIN_FAILURE = 3
|
||||||
|
USER_ALREADY_LOGGED = 4
|
||||||
|
USER_AUTHENTICATED = 5
|
||||||
|
USER_NOT_AUTHENTICATED = 6
|
||||||
|
|
||||||
|
def init_post(self, request):
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.POST.get('service')
|
self.service = request.POST.get('service')
|
||||||
self.renew = True if request.POST.get('renew') else False
|
self.renew = True if request.POST.get('renew') else False
|
||||||
self.gateway = request.POST.get('gateway')
|
self.gateway = request.POST.get('gateway')
|
||||||
self.method = request.POST.get('method')
|
self.method = request.POST.get('method')
|
||||||
|
|
||||||
|
def check_lt(self):
|
||||||
# save LT for later check
|
# save LT for later check
|
||||||
lt_valid = request.session.get('lt')
|
lt_valid = self.request.session.get('lt')
|
||||||
lt_send = request.POST.get('lt')
|
lt_send = self.request.POST.get('lt')
|
||||||
# generate a new LT (by posting the LT has been consumed)
|
# generate a new LT (by posting the LT has been consumed)
|
||||||
request.session['lt'] = utils.gen_lt()
|
self.request.session['lt'] = utils.gen_lt()
|
||||||
|
|
||||||
# check if send LT is valid
|
# check if send LT is valid
|
||||||
if lt_valid is None or lt_valid != lt_send:
|
if lt_valid is None or lt_valid != lt_send:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def post(self, request, *args, **kwargs):
|
||||||
|
"""methode called on POST request on this view"""
|
||||||
|
self.init_post(request)
|
||||||
|
ret = self.process_post()
|
||||||
|
if ret == self.INVALID_LOGIN_TICKET:
|
||||||
messages.add_message(
|
messages.add_message(
|
||||||
self.request,
|
self.request,
|
||||||
messages.ERROR,
|
messages.ERROR,
|
||||||
_(u"Invalid login ticket")
|
_(u"Invalid login ticket")
|
||||||
)
|
)
|
||||||
values = request.POST.copy()
|
elif ret == self.USER_LOGIN_OK:
|
||||||
# if not set a new LT and fail
|
try:
|
||||||
values['lt'] = request.session['lt']
|
|
||||||
self.init_form(values)
|
|
||||||
|
|
||||||
elif not request.session.get("authenticated") or self.renew:
|
|
||||||
self.init_form(request.POST)
|
|
||||||
if self.form.is_valid():
|
|
||||||
self.user = models.User.objects.get(
|
self.user = models.User.objects.get(
|
||||||
username=self.form.cleaned_data['username'],
|
username=self.request.session['username'],
|
||||||
session_key=self.request.session.session_key
|
session_key=self.request.session.session_key
|
||||||
)
|
)
|
||||||
request.session.set_expiry(0)
|
self.user.save()
|
||||||
request.session["username"] = self.form.cleaned_data['username']
|
except models.User.DoesNotExist:
|
||||||
request.session["warn"] = True if self.form.cleaned_data.get("warn") else False
|
self.user = models.User.objects.create(
|
||||||
request.session["authenticated"] = True
|
username=self.request.session['username'],
|
||||||
self.renewed = True
|
session_key=self.request.session.session_key
|
||||||
self.warned = True
|
)
|
||||||
else:
|
self.user.save()
|
||||||
self.logout()
|
elif ret == self.USER_LOGIN_FAILURE: # bad user login
|
||||||
|
self.logout()
|
||||||
|
elif ret == self.USER_ALREADY_LOGGED:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise EnvironmentError("invalid output for LoginView.process_post")
|
||||||
return self.common()
|
return self.common()
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def process_post(self, pytest=False):
|
||||||
"""methode called on GET request on this view"""
|
if not self.check_lt():
|
||||||
|
values = self.request.POST.copy()
|
||||||
|
# if not set a new LT and fail
|
||||||
|
values['lt'] = self.request.session['lt']
|
||||||
|
self.init_form(values)
|
||||||
|
return self.INVALID_LOGIN_TICKET
|
||||||
|
elif not self.request.session.get("authenticated") or self.renew:
|
||||||
|
self.init_form(self.request.POST)
|
||||||
|
if self.form.is_valid():
|
||||||
|
self.request.session.set_expiry(0)
|
||||||
|
self.request.session["username"] = self.form.cleaned_data['username']
|
||||||
|
self.request.session["warn"] = True if self.form.cleaned_data.get("warn") else False
|
||||||
|
self.request.session["authenticated"] = True
|
||||||
|
self.renewed = True
|
||||||
|
self.warned = True
|
||||||
|
return self.USER_LOGIN_OK
|
||||||
|
else:
|
||||||
|
return self.USER_LOGIN_FAILURE
|
||||||
|
else:
|
||||||
|
return self.USER_ALREADY_LOGGED
|
||||||
|
|
||||||
|
def init_get(self, request):
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.GET.get('service')
|
self.service = request.GET.get('service')
|
||||||
self.renew = True if request.GET.get('renew') else False
|
self.renew = True if request.GET.get('renew') else False
|
||||||
self.gateway = request.GET.get('gateway')
|
self.gateway = request.GET.get('gateway')
|
||||||
self.method = request.GET.get('method')
|
self.method = request.GET.get('method')
|
||||||
|
|
||||||
# generate a new LT if none is present
|
def get(self, request, *args, **kwargs):
|
||||||
request.session['lt'] = request.session.get('lt', utils.gen_lt())
|
"""methode called on GET request on this view"""
|
||||||
|
self.init_get(request)
|
||||||
if not request.session.get("authenticated") or self.renew:
|
self.process_get()
|
||||||
self.init_form()
|
|
||||||
return self.common()
|
return self.common()
|
||||||
|
|
||||||
|
def process_get(self):
|
||||||
|
# generate a new LT if none is present
|
||||||
|
self.request.session['lt'] = self.request.session.get('lt', utils.gen_lt())
|
||||||
|
|
||||||
|
if not self.request.session.get("authenticated") or self.renew:
|
||||||
|
self.init_form()
|
||||||
|
return self.USER_NOT_AUTHENTICATED
|
||||||
|
return self.USER_AUTHENTICATED
|
||||||
|
|
||||||
def init_form(self, values=None):
|
def init_form(self, values=None):
|
||||||
self.form = forms.UserCredential(
|
self.form = forms.UserCredential(
|
||||||
self.request,
|
|
||||||
values,
|
values,
|
||||||
initial={
|
initial={
|
||||||
'service': self.service,
|
'service': self.service,
|
||||||
|
@ -345,7 +392,6 @@ class Auth(View):
|
||||||
if not username or not password or not service:
|
if not username or not password or not service:
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse("no\n", content_type="text/plain")
|
||||||
form = forms.UserCredential(
|
form = forms.UserCredential(
|
||||||
request,
|
|
||||||
request.POST,
|
request.POST,
|
||||||
initial={
|
initial={
|
||||||
'service': service,
|
'service': service,
|
||||||
|
@ -355,10 +401,17 @@ class Auth(View):
|
||||||
)
|
)
|
||||||
if form.is_valid():
|
if form.is_valid():
|
||||||
try:
|
try:
|
||||||
user = models.User.objects.get(
|
try:
|
||||||
username=form.cleaned_data['username'],
|
user = models.User.objects.get(
|
||||||
session_key=request.session.session_key
|
username=form.cleaned_data['username'],
|
||||||
)
|
session_key=request.session.session_key
|
||||||
|
)
|
||||||
|
except models.User.DoesNotExist:
|
||||||
|
user = models.User.objects.create(
|
||||||
|
username=form.cleaned_data['username'],
|
||||||
|
session_key=request.session.session_key
|
||||||
|
)
|
||||||
|
user.save()
|
||||||
# is the service allowed
|
# is the service allowed
|
||||||
service_pattern = ServicePattern.validate(service)
|
service_pattern = ServicePattern.validate(service)
|
||||||
# is the current user allowed on this service
|
# is the current user allowed on this service
|
||||||
|
|
9
requirements-dev.txt
Normal file
9
requirements-dev.txt
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
tox==1.8.1
|
||||||
|
pytest==2.6.4
|
||||||
|
pytest-django==2.7.0
|
||||||
|
pytest-pythonpath==0.3
|
||||||
|
requests>=2.4
|
||||||
|
django-picklefield>=0.3.1
|
||||||
|
requests_futures>=0.9.5
|
||||||
|
django-bootstrap3>=5.4
|
||||||
|
lxml>=3.4
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
setuptools>=5.5
|
||||||
|
requests>=2.4
|
||||||
|
requests_futures>=0.9.5
|
||||||
|
django-picklefield>=0.3.1
|
||||||
|
django-bootstrap3>=5.4
|
||||||
|
lxml>=3.4
|
||||||
|
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
61
tests/dummy.py
Normal file
61
tests/dummy.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
from cas_server import models
|
||||||
|
|
||||||
|
class DummyUserManager(object):
|
||||||
|
def __init__(self, username, session_key):
|
||||||
|
self.username = username
|
||||||
|
self.session_key = session_key
|
||||||
|
def get(self, username=None, session_key=None):
|
||||||
|
if username == self.username and session_key == self.session_key:
|
||||||
|
return models.User(username=username, session_key=session_key)
|
||||||
|
else:
|
||||||
|
raise models.User.DoesNotExist()
|
||||||
|
|
||||||
|
class DummyTicketManager(object):
|
||||||
|
def __init__(self, ticket_class, service, ticket):
|
||||||
|
self.ticket_class = ticket_class
|
||||||
|
self.service = service
|
||||||
|
self.ticket = ticket
|
||||||
|
|
||||||
|
def create(self, **kwargs):
|
||||||
|
for field in models.ServiceTicket._meta.fields:
|
||||||
|
field.allow_unsaved_instance_assignment = True
|
||||||
|
return self.ticket_class(**kwargs)
|
||||||
|
|
||||||
|
def filter(self, *args, **kwargs):
|
||||||
|
return DummyQuerySet()
|
||||||
|
|
||||||
|
def get(self, **kwargs):
|
||||||
|
if 'value' in kwargs:
|
||||||
|
if kwargs['value'] != self.ticket:
|
||||||
|
raise self.ticket_class.DoesNotExist()
|
||||||
|
else:
|
||||||
|
kwargs['value'] = self.ticket
|
||||||
|
|
||||||
|
if 'service' in kwargs:
|
||||||
|
if kwargs['service'] != self.service:
|
||||||
|
raise self.ticket_class.DoesNotExist()
|
||||||
|
else:
|
||||||
|
kwargs['service'] = self.service
|
||||||
|
if not 'user' in kwargs:
|
||||||
|
kwargs['user'] = models.User(username="test")
|
||||||
|
|
||||||
|
for field in models.ServiceTicket._meta.fields:
|
||||||
|
field.allow_unsaved_instance_assignment = True
|
||||||
|
for key in kwargs.keys():
|
||||||
|
if '__' in key:
|
||||||
|
del kwargs[key]
|
||||||
|
kwargs['attributs'] = {'mail': 'test@example.com'}
|
||||||
|
kwargs['service_pattern'] = models.ServicePattern()
|
||||||
|
return self.ticket_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DummySession(dict):
|
||||||
|
session_key = "test_session"
|
||||||
|
|
||||||
|
def set_expiry(self, int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DummyQuerySet(set):
|
||||||
|
pass
|
32
tests/init.py
Normal file
32
tests/init.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import django
|
||||||
|
from django.conf import settings
|
||||||
|
from django.contrib import messages
|
||||||
|
|
||||||
|
settings.configure()
|
||||||
|
settings.STATIC_URL = "/static/"
|
||||||
|
settings.DATABASES = {
|
||||||
|
'default': {
|
||||||
|
'ENGINE': 'django.db.backends.sqlite3',
|
||||||
|
'NAME': '/dev/null',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
settings.INSTALLED_APPS = (
|
||||||
|
'django.contrib.admin',
|
||||||
|
'django.contrib.auth',
|
||||||
|
'django.contrib.contenttypes',
|
||||||
|
'django.contrib.sessions',
|
||||||
|
'django.contrib.messages',
|
||||||
|
'django.contrib.staticfiles',
|
||||||
|
'bootstrap3',
|
||||||
|
'cas_server',
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.ROOT_URLCONF = "/"
|
||||||
|
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||||
|
|
||||||
|
try:
|
||||||
|
django.setup()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
messages.add_message = lambda x,y,z:None
|
||||||
|
|
93
tests/test_validate_service.py
Normal file
93
tests/test_validate_service.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from .init import *
|
||||||
|
|
||||||
|
from django.test import RequestFactory
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from lxml import etree
|
||||||
|
from cas_server.views import ValidateService
|
||||||
|
from cas_server import models
|
||||||
|
|
||||||
|
from .dummy import *
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_validate_service_view_ok():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
models.ServiceTicket.save = lambda x:None
|
||||||
|
|
||||||
|
validate = ValidateService()
|
||||||
|
validate.allow_proxy_ticket = False
|
||||||
|
response = validate.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
root = etree.fromstring(response.content)
|
||||||
|
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
|
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users[0].text == "test"
|
||||||
|
|
||||||
|
attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
|
|
||||||
|
assert len(attributes) == 1
|
||||||
|
|
||||||
|
attrs = {}
|
||||||
|
for attr in attributes[0]:
|
||||||
|
attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text
|
||||||
|
|
||||||
|
assert 'mail' in attrs
|
||||||
|
assert attrs['mail'] == 'test@example.com'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_validate_service_view_badservice():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example2.com', "ST-random")
|
||||||
|
models.ServiceTicket.save = lambda x:None
|
||||||
|
|
||||||
|
validate = ValidateService()
|
||||||
|
validate.allow_proxy_ticket = False
|
||||||
|
response = validate.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
root = etree.fromstring(response.content)
|
||||||
|
|
||||||
|
error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
|
|
||||||
|
assert len(error) == 1
|
||||||
|
assert error[0].attrib['code'] == 'INVALID_SERVICE'
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_validate_service_view_badticket():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random2")
|
||||||
|
models.ServiceTicket.save = lambda x:None
|
||||||
|
|
||||||
|
validate = ValidateService()
|
||||||
|
validate.allow_proxy_ticket = False
|
||||||
|
response = validate.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
root = etree.fromstring(response.content)
|
||||||
|
|
||||||
|
error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
|
|
||||||
|
assert len(error) == 1
|
||||||
|
assert error[0].attrib['code'] == 'INVALID_TICKET'
|
49
tests/test_views_auth.py
Normal file
49
tests/test_views_auth.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from .init import *
|
||||||
|
|
||||||
|
from django.test import RequestFactory
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cas_server.views import Auth
|
||||||
|
from cas_server import models
|
||||||
|
|
||||||
|
from .dummy import *
|
||||||
|
|
||||||
|
settings.CAS_AUTH_SHARED_SECRET = "test"
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_auth_view_goodpass():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
|
||||||
|
|
||||||
|
auth = Auth()
|
||||||
|
response = auth.post(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content == "yes\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_view_badpass():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
|
||||||
|
|
||||||
|
auth = Auth()
|
||||||
|
response = auth.post(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content == "no\n"
|
||||||
|
|
170
tests/test_views_login.py
Normal file
170
tests/test_views_login.py
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from .init import *
|
||||||
|
|
||||||
|
from django.test import RequestFactory
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cas_server.views import LoginView
|
||||||
|
from cas_server import models
|
||||||
|
|
||||||
|
from .dummy import *
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_view_post_goodpass_goodlt():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'})
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session['lt'] = 'LT-random'
|
||||||
|
|
||||||
|
request.session["username"] = os.urandom(20)
|
||||||
|
request.session["warn"] = os.urandom(20)
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
login.init_post(request)
|
||||||
|
|
||||||
|
ret = login.process_post(pytest=True)
|
||||||
|
|
||||||
|
assert ret == LoginView.USER_LOGIN_OK
|
||||||
|
assert request.session.get("authenticated") == True
|
||||||
|
assert request.session.get("username") == "test"
|
||||||
|
assert request.session.get("warn") == False
|
||||||
|
|
||||||
|
def test_login_view_post_badlt():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'})
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session['lt'] = 'LT-random2'
|
||||||
|
|
||||||
|
authenticated = os.urandom(20)
|
||||||
|
username = os.urandom(20)
|
||||||
|
warn = os.urandom(20)
|
||||||
|
|
||||||
|
request.session["authenticated"] = authenticated
|
||||||
|
request.session["username"] = username
|
||||||
|
request.session["warn"] = warn
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
login.init_post(request)
|
||||||
|
|
||||||
|
ret = login.process_post(pytest=True)
|
||||||
|
|
||||||
|
assert ret == LoginView.INVALID_LOGIN_TICKET
|
||||||
|
assert request.session.get("authenticated") == authenticated
|
||||||
|
assert request.session.get("username") == username
|
||||||
|
assert request.session.get("warn") == warn
|
||||||
|
|
||||||
|
def test_login_view_post_badpass_good_lt():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'})
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session['lt'] = 'LT-random'
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
login.init_post(request)
|
||||||
|
ret = login.process_post()
|
||||||
|
|
||||||
|
assert ret == LoginView.USER_LOGIN_FAILURE
|
||||||
|
assert not request.session.get("authenticated")
|
||||||
|
assert not request.session.get("username")
|
||||||
|
assert not request.session.get("warn")
|
||||||
|
|
||||||
|
|
||||||
|
def test_view_login_get_unauth():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/login')
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
login.init_get(request)
|
||||||
|
ret = login.process_get()
|
||||||
|
|
||||||
|
assert ret == LoginView.USER_NOT_AUTHENTICATED
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
response = login.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_view_login_get_auth():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/login')
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session["authenticated"] = True
|
||||||
|
request.session["username"] = "test"
|
||||||
|
request.session["warn"] = False
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
login.init_get(request)
|
||||||
|
ret = login.process_get()
|
||||||
|
|
||||||
|
assert ret == LoginView.USER_AUTHENTICATED
|
||||||
|
|
||||||
|
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
response = login.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_view_login_get_auth_service():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/login?service=https://www.example.com')
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session["authenticated"] = True
|
||||||
|
request.session["username"] = "test"
|
||||||
|
request.session["warn"] = False
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
login.init_get(request)
|
||||||
|
ret = login.process_get()
|
||||||
|
|
||||||
|
assert ret == LoginView.USER_AUTHENTICATED
|
||||||
|
|
||||||
|
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
||||||
|
models.User.save = lambda x:None
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
|
||||||
|
models.ServiceTicket.save = lambda x:None
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
response = login.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response['Location'].startswith('https://www.example.com?ticket=ST-')
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_view_login_get_auth_service_warn():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.post('/login?service=https://www.example.com')
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session["authenticated"] = True
|
||||||
|
request.session["username"] = "test"
|
||||||
|
request.session["warn"] = True
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
login.init_get(request)
|
||||||
|
ret = login.process_get()
|
||||||
|
|
||||||
|
assert ret == LoginView.USER_AUTHENTICATED
|
||||||
|
|
||||||
|
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
||||||
|
models.User.save = lambda x:None
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
|
||||||
|
models.ServiceTicket.save = lambda x:None
|
||||||
|
|
||||||
|
login = LoginView()
|
||||||
|
response = login.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
92
tests/test_views_logout.py
Normal file
92
tests/test_views_logout.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from .init import *
|
||||||
|
|
||||||
|
from django.test import RequestFactory
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cas_server.views import LogoutView
|
||||||
|
from cas_server import models
|
||||||
|
|
||||||
|
from .dummy import *
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_logout_view():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/logout')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session["authenticated"] = True
|
||||||
|
request.session["username"] = "test"
|
||||||
|
request.session["warn"] = False
|
||||||
|
|
||||||
|
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
||||||
|
dlist = [None]
|
||||||
|
models.User.delete = lambda x:dlist.pop()
|
||||||
|
|
||||||
|
logout = LogoutView()
|
||||||
|
response = logout.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert dlist == []
|
||||||
|
assert not request.session.get("authenticated")
|
||||||
|
assert not request.session.get("username")
|
||||||
|
assert not request.session.get("warn")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_logout_view_url():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/logout?url=https://www.example.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session["authenticated"] = True
|
||||||
|
request.session["username"] = "test"
|
||||||
|
request.session["warn"] = False
|
||||||
|
|
||||||
|
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
||||||
|
dlist = [None]
|
||||||
|
models.User.delete = lambda x:dlist.pop()
|
||||||
|
|
||||||
|
logout = LogoutView()
|
||||||
|
response = logout.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response['Location'] == 'https://www.example.com'
|
||||||
|
assert dlist == []
|
||||||
|
assert not request.session.get("authenticated")
|
||||||
|
assert not request.session.get("username")
|
||||||
|
assert not request.session.get("warn")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_logout_view_service():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/logout?service=https://www.example.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
request.session["authenticated"] = True
|
||||||
|
request.session["username"] = "test"
|
||||||
|
request.session["warn"] = False
|
||||||
|
|
||||||
|
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
|
||||||
|
dlist = [None]
|
||||||
|
models.User.delete = lambda x:dlist.pop()
|
||||||
|
|
||||||
|
logout = LogoutView()
|
||||||
|
response = logout.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response['Location'] == 'https://www.example.com'
|
||||||
|
assert dlist == []
|
||||||
|
assert not request.session.get("authenticated")
|
||||||
|
assert not request.session.get("username")
|
||||||
|
assert not request.session.get("warn")
|
||||||
|
|
||||||
|
|
61
tests/test_views_validate.py
Normal file
61
tests/test_views_validate.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from .init import *
|
||||||
|
|
||||||
|
from django.test import RequestFactory
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cas_server.views import Validate
|
||||||
|
from cas_server import models
|
||||||
|
|
||||||
|
from .dummy import *
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_validate_view_ok():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
|
||||||
|
validate = Validate()
|
||||||
|
response = validate.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content == "yes\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_validate_view_badservice():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||||
|
|
||||||
|
validate = Validate()
|
||||||
|
response = validate.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content == "no\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_validate_view_badticket():
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
|
||||||
|
|
||||||
|
request.session = DummySession()
|
||||||
|
|
||||||
|
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random1")
|
||||||
|
|
||||||
|
validate = Validate()
|
||||||
|
response = validate.get(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content == "no\n"
|
34
tox.ini
Normal file
34
tox.ini
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
[tox]
|
||||||
|
envlist=
|
||||||
|
py27-django17,
|
||||||
|
py27-django18,
|
||||||
|
flake8,
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
max-line-length=100
|
||||||
|
exclude=migrations
|
||||||
|
|
||||||
|
[base]
|
||||||
|
deps =
|
||||||
|
-r{toxinidir}/requirements-dev.txt
|
||||||
|
|
||||||
|
[testenv]
|
||||||
|
commands=py.test --tb native {posargs:tests}
|
||||||
|
|
||||||
|
[testenv:py27-django17]
|
||||||
|
basepython=python2.7
|
||||||
|
deps =
|
||||||
|
Django>=1.7,<1.8
|
||||||
|
{[base]deps}
|
||||||
|
|
||||||
|
[testenv:py27-django18]
|
||||||
|
basepython=python2.7
|
||||||
|
deps =
|
||||||
|
Django>=1.8,<1.9
|
||||||
|
{[base]deps}
|
||||||
|
|
||||||
|
[testenv:flake8]
|
||||||
|
basepython=python
|
||||||
|
deps=flake8
|
||||||
|
commands=flake8 {toxinidir}/cas_server
|
||||||
|
|
Loading…
Reference in a new issue