From 6fd8d7cc1ac7b252d25e7dc402927eb9aa8f7531 Mon Sep 17 00:00:00 2001 From: Kumi Date: Thu, 18 Jul 2024 18:09:18 +0200 Subject: [PATCH] refactor: lazy-load provider libraries Defer the import of cloud provider libraries (boto3, hcloud, digitalocean, azure) until they are needed within the respective server creation functions. This prevents unnecessary imports when only a subset of providers is used, reducing resource consumption and avoiding import errors for unused libraries. --- worker.py | 66 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/worker.py b/worker.py index 14fd853..e255a40 100644 --- a/worker.py +++ b/worker.py @@ -7,18 +7,6 @@ import random import string import argparse import configparser -import boto3 -from digitalocean import Manager, Droplet -from hcloud import Client -from hcloud.server_types.client import ServerType -from hcloud.images.client import Image -from hcloud.locations.client import LocationsClient -from hcloud.ssh_keys.client import SSHKeysClient -from hcloud.servers.domain import ServerCreatePublicNetwork -from azure.identity import ClientSecretCredential -from azure.mgmt.compute import ComputeManagementClient -from azure.mgmt.network import NetworkManagementClient -from azure.mgmt.resource import ResourceManagementClient # Set up logging logging.basicConfig( @@ -31,6 +19,18 @@ random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k # Function to create a server on Hetzner Cloud def create_hetzner_server(hetzner_config, location, server_type): + try: + from hcloud import Client + from hcloud.server_types.client import ServerType + from hcloud.images.client import Image + from hcloud.locations.client import LocationsClient + from hcloud.ssh_keys.client import SSHKeysClient + from hcloud.servers.domain import ServerCreatePublicNetwork + except ImportError: + raise ImportError( + "Please install the hcloud library by running `pip install hcloud`" + ) + api_token = hetzner_config["api_token"] client = Client(token=api_token) server_type = ServerType(name=server_type) @@ -76,6 +76,13 @@ def get_ami_id_by_name(ec2_client, ami_name): # Function to create a server on AWS def create_aws_server(aws_config, location, server_type): + try: + import boto3 + except ImportError: + raise ImportError( + "Please install the boto3 library by running `pip install boto3`" + ) + access_key = aws_config["access_key"] secret_key = aws_config["secret_key"] region = aws_config["region"] @@ -113,6 +120,13 @@ def create_aws_server(aws_config, location, server_type): # Function to create a server on DigitalOcean def create_digitalocean_server(digitalocean_config, location, server_type): + try: + from digitalocean import Manager, Droplet + except ImportError: + raise ImportError( + "Please install the python-digitalocean library by running `pip install python-digitalocean`" + ) + api_token = digitalocean_config["api_token"] manager = Manager(token=api_token) @@ -133,6 +147,16 @@ def create_digitalocean_server(digitalocean_config, location, server_type): # Function to create a server on Azure def create_azure_server(azure_config, location, server_type): + try: + from azure.identity import ClientSecretCredential + from azure.mgmt.compute import ComputeManagementClient + from azure.mgmt.network import NetworkManagementClient + from azure.mgmt.resource import ResourceManagementClient + except ImportError: + raise ImportError( + "Please install the azure libraries by running `pip install azure-mgmt-compute azure-mgmt-network azure-mgmt-resource`" + ) + subscription_id = azure_config["subscription_id"] client_id = azure_config["client_id"] client_secret = azure_config["client_secret"] @@ -270,7 +294,9 @@ def generate_wireguard_keys(): # Function to configure Wireguard on the VPS -def configure_wireguard(wireguard_config, server_ip, private_key, public_key, preshared_key): +def configure_wireguard( + wireguard_config, server_ip, private_key, public_key, preshared_key +): address = wireguard_config["address"] listen_port = wireguard_config["listen_port"] @@ -325,7 +351,9 @@ def run(config, provider, location, server_type, endpoint_only): elif provider == "digitalocean": location = location or config["digitalocean"]["region"] server_type = server_type or config["digitalocean"]["server_type"] - server = create_digitalocean_server(config["digitalocean"], location, server_type) + server = create_digitalocean_server( + config["digitalocean"], location, server_type + ) server.load() server_ip = server.ip_address elif provider == "azure": @@ -366,7 +394,9 @@ def run(config, provider, location, server_type, endpoint_only): public_key = private_to_public_key(private_key) - configure_wireguard(config["wireguard"], server_ip, private_key, public_key, preshared_key) + configure_wireguard( + config["wireguard"], server_ip, private_key, public_key, preshared_key + ) # Generate client configuration for Chimpman wireguard_address = config["wireguard"]["address"] @@ -408,7 +438,11 @@ def main(): ) parser.add_argument("--location", type=str, help="Server location") parser.add_argument("--server_type", type=str, help="Server type") - parser.add_argument("--endpoint-only", action="store_true", help="Return Wireguard endpoint (host:port) instead of full configuration") + parser.add_argument( + "--endpoint-only", + action="store_true", + help="Return Wireguard endpoint (host:port) instead of full configuration", + ) args = parser.parse_args() # Check if the configuration has all the required sections