Add support for redirect_uris with query params

Some clients might add extra parameters to the redirect_uri, for
instance as extra verification if proper state parameter handling is not
supported.

This patch adds proper handling of redirect_uris with query parameters.
This commit is contained in:
Maarten van Schaik 2015-07-10 12:22:25 +02:00
parent 6ce523edaa
commit 7632054aad
3 changed files with 54 additions and 15 deletions

View file

@ -2,6 +2,11 @@ from datetime import timedelta
import logging import logging
from django.utils import timezone from django.utils import timezone
try:
from urllib import urlencode
from urlparse import urlsplit, parse_qs, urlunsplit
except ImportError:
from urllib.parse import urlsplit, parse_qs, urlunsplit, urlencode
from oidc_provider.lib.errors import * from oidc_provider.lib.errors import *
from oidc_provider.lib.utils.params import * from oidc_provider.lib.utils.params import *
@ -72,7 +77,9 @@ class AuthorizeEndpoint(object):
try: try:
self.client = Client.objects.get(client_id=self.params.client_id) self.client = Client.objects.get(client_id=self.params.client_id)
if not (self.params.redirect_uri in self.client.redirect_uris): clean_redirect_uri = urlsplit(self.params.redirect_uri)
clean_redirect_uri = urlunsplit(clean_redirect_uri._replace(query=''))
if not (clean_redirect_uri in self.client.redirect_uris):
logger.error('[Authorize] Invalid redirect uri: %s', self.params.redirect_uri) logger.error('[Authorize] Invalid redirect uri: %s', self.params.redirect_uri)
raise RedirectUriError() raise RedirectUriError()
@ -88,6 +95,10 @@ class AuthorizeEndpoint(object):
raise ClientIdError() raise ClientIdError()
def create_response_uri(self): def create_response_uri(self):
uri = urlsplit(self.params.redirect_uri)
query_params = parse_qs(uri.query)
query_fragment = parse_qs(uri.fragment)
try: try:
if self.grant_type == 'authorization_code': if self.grant_type == 'authorization_code':
code = create_code( code = create_code(
@ -97,8 +108,8 @@ class AuthorizeEndpoint(object):
code.save() code.save()
# Create the response uri. query_params['code'] = code.code
uri = self.params.redirect_uri + '?code={0}'.format(code.code) query_params['state'] = self.params.state if self.params.state else ''
elif self.grant_type == 'implicit': elif self.grant_type == 'implicit':
id_token_dic = create_id_token( id_token_dic = create_id_token(
@ -117,18 +128,17 @@ class AuthorizeEndpoint(object):
id_token = encode_id_token( id_token = encode_id_token(
id_token_dic, self.client.client_secret) id_token_dic, self.client.client_secret)
# Create the response uri. query_fragment['token_type'] = 'bearer'
uri = self.params.redirect_uri + \ query_fragment['id_token'] = id_token
'#token_type={0}&id_token={1}&expires_in={2}'.format( query_fragment['expires_in'] = 60 * 10
'bearer',
id_token,
60 * 10,
)
# Check if response_type is 'id_token token' then # Check if response_type is 'id_token token' then
# add access_token to the fragment. # add access_token to the fragment.
if self.params.response_type == 'id_token token': if self.params.response_type == 'id_token token':
uri += '&access_token={0}'.format(token.access_token) query_fragment['access_token'] = token.access_token
query_fragment['state'] = self.params.state if self.params.state else ''
except Exception as error: except Exception as error:
logger.error('[Authorize] Error when trying to create response uri: %s', error) logger.error('[Authorize] Error when trying to create response uri: %s', error)
raise AuthorizeError( raise AuthorizeError(
@ -136,10 +146,10 @@ class AuthorizeEndpoint(object):
'server_error', 'server_error',
self.grant_type) self.grant_type)
# Add state if present. uri = uri._replace(query=urlencode(query_params, doseq=True))
uri += ('&state={0}'.format(self.params.state) if self.params.state else '') uri = uri._replace(fragment=urlencode(query_fragment, doseq=True))
return uri return urlunsplit(uri)
def set_client_user_consent(self): def set_client_user_consent(self):
""" """

View file

@ -258,3 +258,26 @@ class AuthorizationCodeFlowTestCase(TestCase):
client=self.client) client=self.client)
self.assertEqual(is_code_ok, True, self.assertEqual(is_code_ok, True,
msg='Code returned is invalid or missing.') msg='Code returned is invalid or missing.')
def test_response_uri_is_properly_constructed(self):
post_data = {
'client_id': self.client.client_id,
'redirect_uri': self.client.default_redirect_uri + "?redirect_state=xyz",
'response_type': 'code',
'scope': 'openid email',
'state': self.state,
'allow': 'Accept',
}
request = self.factory.post(reverse('oidc_provider:authorize'),
data=post_data)
# Simulate that the user is logged.
request.user = self.user
response = AuthorizeView.as_view()(request)
is_code_ok = is_code_valid(url=response['Location'],
user=self.user,
client=self.client)
self.assertEqual(is_code_ok, True,
msg='Code returned is invalid.')

View file

@ -1,4 +1,8 @@
from django.contrib.auth.models import User from django.contrib.auth.models import User
try:
from urlparse import parse_qs, urlsplit
except ImportError:
from urllib.parse import parse_qs, urlsplit
from oidc_provider.models import * from oidc_provider.models import *
@ -40,7 +44,9 @@ def is_code_valid(url, user, client):
Check if the code inside the url is valid. Check if the code inside the url is valid.
""" """
try: try:
code = (url.split('code='))[1].split('&')[0] parsed = urlsplit(url)
params = parse_qs(parsed.query)
code = params['code'][0]
code = Code.objects.get(code=code) code = Code.objects.get(code=code)
is_code_ok = (code.client == client) and \ is_code_ok = (code.client == client) and \
(code.user == user) (code.user == user)