From d4b9d6605101616a598d1e10cf40d25cf209bf1f Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Wed, 29 Jun 2016 20:51:30 +0200 Subject: [PATCH] Cleaner BaseHTTPRequestHandler --- cas_server/tests/test_view.py | 8 ++++---- cas_server/tests/utils.py | 6 +++--- cas_server/utils.py | 31 +++++++++++++++++++++++++------ 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py index 8326d77..ce81db9 100644 --- a/cas_server/tests/test_view.py +++ b/cas_server/tests/test_view.py @@ -863,7 +863,7 @@ class ValidateServiceTestCase(TestCase, XmlContent): def test_validate_service_view_ok_pgturl(self): """test the retrieval of a ProxyGrantingTicket""" - (host, port) = utils.PGTUrlHandler.run()[1:3] + (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] service = "http://%s:%s" % (host, port) ticket = get_user_ticket_request(service)[1] @@ -873,7 +873,7 @@ class ValidateServiceTestCase(TestCase, XmlContent): '/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service} ) - pgt_params = utils.PGTUrlHandler.PARAMS.copy() + pgt_params = httpd.PARAMS self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) @@ -887,7 +887,7 @@ class ValidateServiceTestCase(TestCase, XmlContent): def test_validate_service_pgturl_sslerror(self): """test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl""" - (host, port) = utils.PGTUrlHandler.run()[1:3] + (host, port) = utils.HttpParamsHandler.run()[1:3] service = "https://%s:%s" % (host, port) ticket = get_user_ticket_request(service)[1] @@ -907,7 +907,7 @@ class ValidateServiceTestCase(TestCase, XmlContent): test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error. PGT creation should be aborted but the ticket still be valid """ - (host, port) = utils.PGTUrlHandler404.run()[1:3] + (host, port) = utils.Http404Handler.run()[1:3] service = "http://%s:%s" % (host, port) ticket = get_user_ticket_request(service)[1] diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py index db49dc9..286c477 100644 --- a/cas_server/tests/utils.py +++ b/cas_server/tests/utils.py @@ -55,14 +55,14 @@ def get_user_ticket_request(service): def get_pgt(): """return a dict contening a service, user and PGT ticket for this service""" - (host, port) = utils.PGTUrlHandler.run()[1:3] + (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] service = "http://%s:%s" % (host, port) - (user, ticket) = get_user_ticket_request(service) + (user, ticket) = get_user_ticket_request(service)[:2] client = Client() client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) - params = utils.PGTUrlHandler.PARAMS.copy() + params = httpd.PARAMS params["service"] = service params["user"] = user diff --git a/cas_server/utils.py b/cas_server/utils.py index acdbb6c..f85c25e 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -23,6 +23,7 @@ import hashlib import crypt import base64 import six +import cgi from threading import Thread from importlib import import_module from six.moves import BaseHTTPServer @@ -150,9 +151,11 @@ def gen_saml_id(): return _gen_ticket('_') -class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): - """A simple http server that return 200 on GET and store GET parameters. Used in unit tests""" - PARAMS = {} +class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler): + """ + A simple http server that return 200 on GET or POST + and store GET or POST parameters. Used in unit tests + """ def do_GET(self): """Called on a GET request on the BaseHTTPServer""" @@ -162,7 +165,19 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): self.wfile.write(b"ok") url = urlparse(self.path) params = dict(parse_qsl(url.query)) - PGTUrlHandler.PARAMS.update(params) + self.server.PARAMS = params + + def do_POST(self): + """Called on a POST request on the BaseHTTPServer""" + ctype, pdict = cgi.parse_header(self.headers.get('content-type')) + if ctype == 'multipart/form-data': + postvars = cgi.parse_multipart(self.rfile, pdict) + elif ctype == 'application/x-www-form-urlencoded': + length = int(self.headers.get('content-length')) + postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1) + else: + postvars = {} + self.server.PARAMS = postvars def log_message(self, *args): """silent any log message""" @@ -183,10 +198,10 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): httpd_thread = Thread(target=lauch) httpd_thread.daemon = True httpd_thread.start() - return (httpd_thread, host, port) + return (httpd, host, port) -class PGTUrlHandler404(PGTUrlHandler): +class Http404Handler(HttpParamsHandler): """A simple http server that always return 404 not found. Used in unit tests""" def do_GET(self): """Called on a GET request on the BaseHTTPServer""" @@ -195,6 +210,10 @@ class PGTUrlHandler404(PGTUrlHandler): self.end_headers() self.wfile.write(b"error 404 not found") + def do_POST(self): + """Called on a POST request on the BaseHTTPServer""" + return self.do_GET() + class LdapHashUserPassword(object): """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""