diff --git a/test/test_networking.py b/test/test_networking.py index d4eba2a5d..1bd6afc88 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -804,10 +804,10 @@ class TestUrllibRequestHandler(TestRequestHandlerBase): assert not isinstance(exc_info.value, TransportError) -def run_validation(handler, fail, req, **handler_kwargs): +def run_validation(handler, error, req, **handler_kwargs): with handler(**handler_kwargs) as rh: - if fail: - with pytest.raises(UnsupportedRequest): + if error: + with pytest.raises(error): rh.validate(req) else: rh.validate(req) @@ -824,6 +824,9 @@ class TestRequestHandlerValidation: _SUPPORTED_PROXY_SCHEMES = None _SUPPORTED_URL_SCHEMES = None + def _check_extensions(self, extensions): + extensions.clear() + class HTTPSupportedRH(ValidationRH): _SUPPORTED_URL_SCHEMES = ('http',) @@ -834,26 +837,26 @@ class TestRequestHandlerValidation: ('https', False, {}), ('data', False, {}), ('ftp', False, {}), - ('file', True, {}), + ('file', UnsupportedRequest, {}), ('file', False, {'enable_file_urls': True}), ]), (NoCheckRH, [('http', False, {})]), - (ValidationRH, [('http', True, {})]) + (ValidationRH, [('http', UnsupportedRequest, {})]) ] PROXY_SCHEME_TESTS = [ # scheme, expected to fail ('Urllib', [ ('http', False), - ('https', True), + ('https', UnsupportedRequest), ('socks4', False), ('socks4a', False), ('socks5', False), ('socks5h', False), - ('socks', True), + ('socks', UnsupportedRequest), ]), (NoCheckRH, [('http', False)]), - (HTTPSupportedRH, [('http', True)]), + (HTTPSupportedRH, [('http', UnsupportedRequest)]), ] PROXY_KEY_TESTS = [ @@ -863,8 +866,22 @@ class TestRequestHandlerValidation: ('unrelated', False), ]), (NoCheckRH, [('all', False)]), - (HTTPSupportedRH, [('all', True)]), - (HTTPSupportedRH, [('no', True)]), + (HTTPSupportedRH, [('all', UnsupportedRequest)]), + (HTTPSupportedRH, [('no', UnsupportedRequest)]), + ] + + EXTENSION_TESTS = [ + ('Urllib', [ + ({'cookiejar': 'notacookiejar'}, AssertionError), + ({'cookiejar': CookieJar()}, False), + ({'timeout': 1}, False), + ({'timeout': 'notatimeout'}, AssertionError), + ({'unsupported': 'value'}, UnsupportedRequest), + ]), + (NoCheckRH, [ + ({'cookiejar': 'notacookiejar'}, False), + ({'somerandom': 'test'}, False), # but any extension is allowed through + ]), ] @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [ @@ -907,15 +924,16 @@ class TestRequestHandlerValidation: @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1']) @pytest.mark.parametrize('handler', ['Urllib'], indirect=True) def test_missing_proxy_scheme(self, handler, proxy_url): - run_validation(handler, True, Request('http://', proxies={'http': 'example.com'})) + run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': 'example.com'})) - @pytest.mark.parametrize('handler', ['Urllib'], indirect=True) - def test_cookiejar_extension(self, handler): - run_validation(handler, True, Request('http://', extensions={'cookiejar': 'notacookiejar'})) - - @pytest.mark.parametrize('handler', ['Urllib'], indirect=True) - def test_timeout_extension(self, handler): - run_validation(handler, True, Request('http://', extensions={'timeout': 'notavalidtimeout'})) + @pytest.mark.parametrize('handler,extensions,fail', [ + (handler_tests[0], extensions, fail) + for handler_tests in EXTENSION_TESTS + for extensions, fail in handler_tests[1] + ], indirect=['handler']) + def test_extension(self, handler, extensions, fail): + run_validation( + handler, fail, Request('http://', extensions=extensions)) def test_invalid_request_type(self): rh = self.ValidationRH(logger=FakeLogger()) diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py index ff3a22c8c..3fe5fa52e 100644 --- a/yt_dlp/networking/_urllib.py +++ b/yt_dlp/networking/_urllib.py @@ -385,6 +385,11 @@ class UrllibRH(RequestHandler, InstanceStoreMixin): if self.enable_file_urls: self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file') + def _check_extensions(self, extensions): + super()._check_extensions(extensions) + extensions.pop('cookiejar', None) + extensions.pop('timeout', None) + def _create_instance(self, proxies, cookiejar): opener = urllib.request.OpenerDirector() handlers = [ diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py index 7f7457978..ab26a0628 100644 --- a/yt_dlp/networking/common.py +++ b/yt_dlp/networking/common.py @@ -21,6 +21,7 @@ from .exceptions import ( TransportError, UnsupportedRequest, ) +from ..compat.types import NoneType from ..utils import ( bug_reports_message, classproperty, @@ -147,6 +148,7 @@ class RequestHandler(abc.ABC): a proxy url with an url scheme not in this list will raise an UnsupportedRequest. - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum. + The above may be set to None to disable the checks. Parameters: @@ -169,9 +171,14 @@ class RequestHandler(abc.ABC): Requests may have additional optional parameters defined as extensions. RequestHandler subclasses may choose to support custom extensions. + If an extension is supported, subclasses should extend _check_extensions(extensions) + to pop and validate the extension. + - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised. + The following extensions are defined for RequestHandler: - - `cookiejar`: Cookiejar to use for this request - - `timeout`: socket timeout to use for this request + - `cookiejar`: Cookiejar to use for this request. + - `timeout`: socket timeout to use for this request. + To enable these, add extensions.pop('', None) to _check_extensions Apart from the url protocol, proxies dict may contain the following keys: - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol. @@ -263,26 +270,19 @@ class RequestHandler(abc.ABC): if scheme not in self._SUPPORTED_PROXY_SCHEMES: raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"') - def _check_cookiejar_extension(self, extensions): - if not extensions.get('cookiejar'): - return - if not isinstance(extensions['cookiejar'], CookieJar): - raise UnsupportedRequest('cookiejar is not a CookieJar') - - def _check_timeout_extension(self, extensions): - if extensions.get('timeout') is None: - return - if not isinstance(extensions['timeout'], (float, int)): - raise UnsupportedRequest('timeout is not a float or int') - def _check_extensions(self, extensions): - self._check_cookiejar_extension(extensions) - self._check_timeout_extension(extensions) + """Check extensions for unsupported extensions. Subclasses should extend this.""" + assert isinstance(extensions.get('cookiejar'), (CookieJar, NoneType)) + assert isinstance(extensions.get('timeout'), (float, int, NoneType)) def _validate(self, request): self._check_url_scheme(request) self._check_proxies(request.proxies or self.proxies) - self._check_extensions(request.extensions) + extensions = request.extensions.copy() + self._check_extensions(extensions) + if extensions: + # TODO: add support for optional extensions + raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}') @wrap_request_errors def validate(self, request: Request):