From 352d63fdb52452f6e99d5603757c54c3f5c186d7 Mon Sep 17 00:00:00 2001 From: pukkandan Date: Wed, 21 Jul 2021 11:17:27 +0530 Subject: [PATCH] [utils] Improve `traverse_obj` --- yt_dlp/extractor/youtube.py | 10 +++++----- yt_dlp/utils.py | 21 +++++++++++++++------ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/yt_dlp/extractor/youtube.py b/yt_dlp/extractor/youtube.py index aa0421a72..afe31a12d 100644 --- a/yt_dlp/extractor/youtube.py +++ b/yt_dlp/extractor/youtube.py @@ -1929,10 +1929,11 @@ class YoutubeIE(YoutubeBaseInfoExtractor): return sts def _mark_watched(self, video_id, player_responses): - playback_url = url_or_none((traverse_obj( - player_responses, ('playbackTracking', 'videostatsPlaybackUrl', 'baseUrl'), - expected_type=str) or [None])[0]) + playback_url = traverse_obj( + player_responses, (..., 'playbackTracking', 'videostatsPlaybackUrl', 'baseUrl'), + expected_type=url_or_none, get_all=False) if not playback_url: + self.report_warning('Unable to mark watched') return parsed_playback_url = compat_urlparse.urlparse(playback_url) qs = compat_urlparse.parse_qs(parsed_playback_url.query) @@ -2606,8 +2607,7 @@ class YoutubeIE(YoutubeBaseInfoExtractor): self._get_requested_clients(url, smuggled_data), video_id, webpage, master_ytcfg, player_url, identity_token)) - get_first = lambda obj, keys, **kwargs: ( - traverse_obj(obj, (..., *variadic(keys)), **kwargs) or [None])[0] + get_first = lambda obj, keys, **kwargs: traverse_obj(obj, (..., *variadic(keys)), **kwargs, get_all=False) playability_statuses = traverse_obj( player_responses, (..., 'playabilityStatus'), expected_type=dict, default=[]) diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 4d3cbc7b4..4d12c0a8e 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -6225,7 +6225,7 @@ def load_plugins(name, suffix, namespace): def traverse_obj( - obj, *path_list, default=None, expected_type=None, + obj, *path_list, default=None, expected_type=None, get_all=True, casesense=True, is_user_input=False, traverse_string=False): ''' Traverse nested list/dict/tuple @param path_list A list of paths which are checked one by one. @@ -6234,7 +6234,8 @@ def traverse_obj( all the keys given in the tuple are traversed, and "..." traverses all the keys in the object @param default Default value to return - @param expected_type Only accept final value of this type + @param expected_type Only accept final value of this type (Can also be any callable) + @param get_all Return all the values obtained from a path or only the first one @param casesense Whether to consider dictionary keys as case sensitive @param is_user_input Whether the keys are generated from user input. If True, strings are converted to int/slice if necessary @@ -6281,6 +6282,13 @@ def traverse_obj( return None return obj + if isinstance(expected_type, type): + type_test = lambda val: val if isinstance(val, expected_type) else None + elif expected_type is not None: + type_test = expected_type + else: + type_test = lambda val: val + for path in path_list: depth = 0 val = _traverse_obj(obj, path) @@ -6288,12 +6296,13 @@ def traverse_obj( if depth: for _ in range(depth - 1): val = itertools.chain.from_iterable(v for v in val if v is not None) - val = ([v for v in val if v is not None] if expected_type is None - else [v for v in val if isinstance(v, expected_type)]) + val = [v for v in map(type_test, val) if v is not None] if val: + return val if get_all else val[0] + else: + val = type_test(val) + if val is not None: return val - elif expected_type is None or isinstance(val, expected_type): - return val return default