feat: add type annotations and improve docstrings in Git classes

Enhanced the InMemoryRepo and Git classes with type annotations and detailed docstrings for improved code readability and maintainability. Type hints were added to function signatures for clarity, and docstrings were expanded to include argument and return type descriptions. Addressed the ability to fetch file content and directory structure using specific commit hashes, defaulting to HEAD when not provided.
This commit is contained in:
Kumi 2024-06-19 09:50:32 +02:00
parent d5fc940e60
commit ee9a84d92e
Signed by: kumi
GPG key ID: ECBCC9082395383F

View file

@ -2,11 +2,16 @@ import requests
import logging import logging
from dulwich.objects import Tree, Blob, ShaFile, Tree from typing import Generator, Dict, List, Optional
from dulwich.objects import Tree, Blob
from dulwich.client import HttpGitClient, get_transport_and_path from dulwich.client import HttpGitClient, get_transport_and_path
from dulwich.repo import MemoryRepo from dulwich.repo import MemoryRepo
class InMemoryRepo(MemoryRepo): class InMemoryRepo(MemoryRepo):
"""A subclass of MemoryRepo that provides additional methods to interact with the repository."""
def get_tree(self, commit_sha: bytes) -> Tree: def get_tree(self, commit_sha: bytes) -> Tree:
"""Return the tree object for the given commit. """Return the tree object for the given commit.
@ -19,7 +24,9 @@ class InMemoryRepo(MemoryRepo):
commit = self.get_object(commit_sha) commit = self.get_object(commit_sha)
return self.get_object(commit.tree) return self.get_object(commit.tree)
def list_tree(self, tree, path="", prefix=""): def list_tree(
self, tree: Tree, path: str = "", prefix: str = ""
) -> Generator[str, None, None]:
"""List the directory structure of the tree object. """List the directory structure of the tree object.
Args: Args:
@ -47,7 +54,9 @@ class InMemoryRepo(MemoryRepo):
if isinstance(self.get_object(entry.sha), Tree): if isinstance(self.get_object(entry.sha), Tree):
if path: if path:
for _ in self.list_tree( for _ in self.list_tree(
self.get_object(entry.sha), path="/".join(path_parts[1:]), prefix="/".join(path_parts[1:]) self.get_object(entry.sha),
path="/".join(path_parts[1:]),
prefix="/".join(path_parts[1:]),
): ):
yield (_) yield (_)
else: else:
@ -58,17 +67,30 @@ class InMemoryRepo(MemoryRepo):
else: else:
yield (entry_path) yield (entry_path)
def get_file_content(self, tree, file_path): def get_file_content(self, tree: Tree, file_path: str) -> bytes:
"""Get the content of a file in the tree object.
Args:
tree (Tree): The tree object.
file_path (str): The path of the file.
Returns:
bytes: The content of the file.
"""
parts = file_path.split("/") parts = file_path.split("/")
for entry in tree.items(): for entry in tree.items():
entry_name = entry.path.decode("utf-8") entry_name = entry.path.decode("utf-8")
if entry_name == parts[0]: if entry_name == parts[0]:
# If we are already in the last part of the path, return the file content
if len(parts) == 1: if len(parts) == 1:
file_obj = self.get_object(entry.sha) file_obj = self.get_object(entry.sha)
if isinstance(file_obj, Blob): if isinstance(file_obj, Blob):
return file_obj.data return file_obj.data
else: else:
raise ValueError(f"Path {file_path} is not a file.") raise ValueError(f"Path {file_path} is not a file.")
# If there are more parts in the path, try the next depth level
else: else:
if isinstance(self.get_object(entry.sha), Tree): if isinstance(self.get_object(entry.sha), Tree):
return self.get_file_content( return self.get_file_content(
@ -80,51 +102,107 @@ class InMemoryRepo(MemoryRepo):
class Git: class Git:
"""A class to interact with a remote Git repository.
This class is not specific to any hosting service and should work with any Git repository.
"""
def __init__(self, repo_url): def __init__(self, repo_url):
"""Initialize the Git class.
Args:
repo_url (str): The URL of the remote Git repository.
"""
self.repo_url = repo_url.rstrip("/") self.repo_url = repo_url.rstrip("/")
self.client = HttpGitClient(self.repo_url) self.client = HttpGitClient(self.repo_url)
def get_remote_refs(self): def get_remote_refs(self) -> Dict[bytes, bytes]:
"""Get the remote references of the repository.
Returns:
Dict[bytes, bytes]: A dictionary of references.
"""
client, path = get_transport_and_path(self.repo_url) client, path = get_transport_and_path(self.repo_url)
refs = client.fetch(path, self.repo) refs = client.fetch(path, self.repo)
return refs return refs
def get_head_commit(self, refs): def get_head_commit(self, refs) -> bytes:
"""Get the commit hash of the HEAD reference.
Args:
refs (Dict[bytes, bytes]): A dictionary of references.
Returns:
bytes: The commit hash of the HEAD reference.
"""
return refs[b"HEAD"] return refs[b"HEAD"]
def get_pack_data(self, commit_sha): def get_pack_data(self, commit_sha: str) -> bytes:
"""Get the pack data for the given commit.
Args:
commit_sha (str): The commit hash.
Returns:
bytes: The pack data.
"""
url = f"{self.repo_url}/git-upload-pack" url = f"{self.repo_url}/git-upload-pack"
request_body = f"0032want {commit_sha} multi_ack_detailed side-band-64k thin-pack ofs-delta agent=git/2.28.0\n00000009done\n" request_body = f"0032want {commit_sha} multi_ack_detailed side-band-64k thin-pack ofs-delta agent=git/2.28.0\n00000009done\n"
response = requests.post(url, data=request_body.encode("utf-8")) response = requests.post(url, data=request_body.encode("utf-8"))
response.raise_for_status() response.raise_for_status()
return response.content return response.content
def get_directory_structure(self, path=""): def get_directory_structure(
self, path: str = "", commit_sha: Optional[bytes] = None
) -> List[str]:
"""Get the directory structure of the repository.
Args:
path (str): The path within the repository.
commit_sha (bytes): The commit hash. If not provided, the HEAD commit will be used.
Returns:
List[str]: A list of file and directory paths under the given path.
"""
# Initialize an in-memory repository # Initialize an in-memory repository
self.repo = InMemoryRepo() self.repo = InMemoryRepo()
# Fetch the remote references and objects into the in-memory repository if not commit_sha:
refs = self.get_remote_refs() # Fetch the remote references and objects into the in-memory repository
head_commit_hash = self.get_head_commit(refs) refs = self.get_remote_refs()
commit_sha = self.get_head_commit(refs)
# Get the tree object for the HEAD commit # Get the tree object for the given commit
tree = self.repo.get_tree(head_commit_hash) tree = self.repo.get_tree(commit_sha)
# List the directory structure # List the directory structure
return list(self.repo.list_tree(tree, path=path)) return list(self.repo.list_tree(tree, path=path))
def get_file_content(self, file_path): def get_file_content(
self, file_path: str, commit_sha: Optional[bytes] = None
) -> bytes:
"""Get the content of a file in the repository.
Args:
file_path (str): The path of the file.
commit_sha (bytes): The commit hash. If not provided, the HEAD commit will be used.
Returns:
bytes: The content of the file.
"""
file_path = file_path.lstrip("/") file_path = file_path.lstrip("/")
# Initialize an in-memory repository # Initialize an in-memory repository
self.repo = InMemoryRepo() self.repo = InMemoryRepo()
# Fetch the remote references and objects into the in-memory repository if not commit_sha:
refs = self.get_remote_refs() # Fetch the remote references and objects into the in-memory repository
head_commit_hash = self.get_head_commit(refs) refs = self.get_remote_refs()
commit_sha = self.get_head_commit(refs)
# Get the tree object for the HEAD commit # Get the tree object for the given commit
tree = self.repo.get_tree(head_commit_hash) tree = self.repo.get_tree(commit_sha)
# Get the file content # Get the file content
return self.repo.get_file_content(tree, file_path) return self.repo.get_file_content(tree, file_path)