From eb575e8f494031f396bcef062614a220f6a1bd77 Mon Sep 17 00:00:00 2001 From: Kumi Date: Thu, 18 Jul 2024 17:59:33 +0200 Subject: [PATCH] refactor: improve configuration handling in server setup Refactored the code to pass configuration sections as parameters to server creation functions instead of accessing them globally. This enhances modularity and clarity by making function dependencies explicit. Additionally, consolidated configuration reading and argument parsing in the main function, and added a validation step to check for required configuration sections. This change makes the codebase more maintainable and prepares it for potential future extensions. --- worker.py | 116 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 65 insertions(+), 51 deletions(-) diff --git a/worker.py b/worker.py index c9e2775..f66c51d 100644 --- a/worker.py +++ b/worker.py @@ -26,32 +26,15 @@ logging.basicConfig( level=logging.INFO, ) -# Read configuration -config = configparser.ConfigParser() -config.read("config.ini") - -# Argument parser -parser = argparse.ArgumentParser(description="Create and configure a Wireguard server.") -parser.add_argument( - "--provider", - type=str, - choices=["hetzner", "aws", "digitalocean", "azure"], - default="hetzner", - help="Cloud provider (default: hetzner)", -) -parser.add_argument("--location", type=str, help="Server location") -parser.add_argument("--server_type", type=str, help="Server type") -args = parser.parse_args() - random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=8)) # Function to create a server on Hetzner Cloud -def create_hetzner_server(location, server_type): - api_token = config["hetzner"]["api_token"] +def create_hetzner_server(hetzner_config, location, server_type): + api_token = hetzner_config["api_token"] client = Client(token=api_token) server_type = ServerType(name=server_type) - image_name = config["hetzner"]["image"] + image_name = hetzner_config["image"] image = Image(name=image_name) location = LocationsClient(client).get_by_name(location) ssh_key = SSHKeysClient(client).get_all()[0] @@ -92,12 +75,12 @@ def get_ami_id_by_name(ec2_client, ami_name): # Function to create a server on AWS -def create_aws_server(location, server_type): - access_key = config["aws"]["access_key"] - secret_key = config["aws"]["secret_key"] - region = config["aws"]["region"] - ami_name = config["aws"]["ami_name"] - key_pair = config["aws"]["key_pair"] +def create_aws_server(aws_config, location, server_type): + access_key = aws_config["access_key"] + secret_key = aws_config["secret_key"] + region = aws_config["region"] + ami_name = aws_config["ami_name"] + key_pair = aws_config["key_pair"] ec2_client = boto3.client( "ec2", @@ -129,8 +112,8 @@ def create_aws_server(location, server_type): # Function to create a server on DigitalOcean -def create_digitalocean_server(location, server_type): - api_token = config["digitalocean"]["api_token"] +def create_digitalocean_server(digitalocean_config, location, server_type): + api_token = digitalocean_config["api_token"] manager = Manager(token=api_token) logging.info("Creating DigitalOcean server...") @@ -149,15 +132,15 @@ def create_digitalocean_server(location, server_type): # Function to create a server on Azure -def create_azure_server(location, server_type): - subscription_id = config["azure"]["subscription_id"] - client_id = config["azure"]["client_id"] - client_secret = config["azure"]["client_secret"] - tenant_id = config["azure"]["tenant_id"] +def create_azure_server(azure_config, location, server_type): + subscription_id = azure_config["subscription_id"] + client_id = azure_config["client_id"] + client_secret = azure_config["client_secret"] + tenant_id = azure_config["tenant_id"] vm_size = server_type - image_publisher = config["azure"]["image_publisher"] - image_offer = config["azure"]["image_offer"] - image_sku = config["azure"]["image_sku"] + image_publisher = azure_config["image_publisher"] + image_offer = azure_config["image_offer"] + image_sku = azure_config["image_sku"] credential = ClientSecretCredential( client_id=client_id, client_secret=client_secret, tenant_id=tenant_id @@ -287,14 +270,14 @@ def generate_wireguard_keys(): # Function to configure Wireguard on the VPS -def configure_wireguard(server_ip, private_key, public_key, preshared_key): - address = config["wireguard"]["address"] - listen_port = config["wireguard"]["listen_port"] +def configure_wireguard(wireguard_config, server_ip, private_key, public_key, preshared_key): + address = wireguard_config["address"] + listen_port = wireguard_config["listen_port"] - peer_public_key = config["wireguard"]["peer_public_key"] - allowed_ips = config["wireguard"]["peer_allowed_ips"] - endpoint = config["wireguard"]["peer_endpoint"] - persistent_keepalive = config["wireguard"]["peer_persistent_keepalive"] + peer_public_key = wireguard_config["peer_public_key"] + allowed_ips = wireguard_config["peer_allowed_ips"] + endpoint = wireguard_config["peer_endpoint"] + persistent_keepalive = wireguard_config["peer_persistent_keepalive"] wg_config = f""" [Interface] @@ -328,27 +311,27 @@ PersistentKeepalive = {persistent_keepalive} # Main function to create and configure the server -def main(provider, location, server_type): +def run(config, provider, location, server_type): if provider == "hetzner": location = location or config["hetzner"]["location"] server_type = server_type or config["hetzner"]["server_type"] - server = create_hetzner_server(location, server_type) + server = create_hetzner_server(config["hetzner"], location, server_type) server_ip = server.public_net.ipv6.ip.split("/")[0] + "1" elif provider == "aws": location = location or config["aws"]["region"] server_type = server_type or config["aws"]["instance_type"] - server = create_aws_server(location, server_type) + server = create_aws_server(config["aws"], location, server_type) server_ip = server.public_ip_address elif provider == "digitalocean": location = location or config["digitalocean"]["region"] server_type = server_type or config["digitalocean"]["server_type"] - server = create_digitalocean_server(location, server_type) + server = create_digitalocean_server(config["digitalocean"], location, server_type) server.load() server_ip = server.ip_address elif provider == "azure": location = location or config["azure"]["location"] server_type = server_type or config["azure"]["vm_size"] - server, server_ip = create_azure_server(location, server_type) + server, server_ip = create_azure_server(config["azure"], location, server_type) else: raise ValueError("Unsupported provider") @@ -383,7 +366,7 @@ def main(provider, location, server_type): public_key = private_to_public_key(private_key) - configure_wireguard(server_ip, private_key, public_key, preshared_key) + configure_wireguard(config["wireguard"], server_ip, private_key, public_key, preshared_key) # Generate client configuration for Chimpman wireguard_address = config["wireguard"]["address"] @@ -404,5 +387,36 @@ PersistentKeepalive = {persistent_keepalive} print(peer_config) -# Run the main function with parsed arguments -main(args.provider, args.location, args.server_type) +def main(): + # Read configuration + config = configparser.ConfigParser() + config.read("config.ini") + + # Argument parser + parser = argparse.ArgumentParser( + description="Create and configure a Wireguard server." + ) + parser.add_argument( + "--provider", + type=str, + choices=["hetzner", "aws", "digitalocean", "azure"], + default="hetzner", + help="Cloud provider (default: hetzner)", + ) + parser.add_argument("--location", type=str, help="Server location") + parser.add_argument("--server_type", type=str, help="Server type") + args = parser.parse_args() + + # Check if the configuration has all the required sections + required = ["wireguard", args.provider] + + for section in required: + if section not in config: + raise ValueError(f"Missing section {section} in config.ini") + + # Run the main function with parsed arguments + run(config, args.provider, args.location, args.server_type) + + +if __name__ == "__main__": + main()