refactor: lazy-load provider libraries

Defer the import of cloud provider libraries (boto3, hcloud,
digitalocean, azure) until they are needed within the respective
server creation functions. This prevents unnecessary imports
when only a subset of providers is used, reducing resource
consumption and avoiding import errors for unused libraries.
This commit is contained in:
Kumi 2024-07-18 18:09:18 +02:00
parent 844d420439
commit 6fd8d7cc1a
Signed by: kumi
GPG key ID: ECBCC9082395383F

View file

@ -7,18 +7,6 @@ import random
import string
import argparse
import configparser
import boto3
from digitalocean import Manager, Droplet
from hcloud import Client
from hcloud.server_types.client import ServerType
from hcloud.images.client import Image
from hcloud.locations.client import LocationsClient
from hcloud.ssh_keys.client import SSHKeysClient
from hcloud.servers.domain import ServerCreatePublicNetwork
from azure.identity import ClientSecretCredential
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.resource import ResourceManagementClient
# Set up logging
logging.basicConfig(
@ -31,6 +19,18 @@ random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k
# Function to create a server on Hetzner Cloud
def create_hetzner_server(hetzner_config, location, server_type):
try:
from hcloud import Client
from hcloud.server_types.client import ServerType
from hcloud.images.client import Image
from hcloud.locations.client import LocationsClient
from hcloud.ssh_keys.client import SSHKeysClient
from hcloud.servers.domain import ServerCreatePublicNetwork
except ImportError:
raise ImportError(
"Please install the hcloud library by running `pip install hcloud`"
)
api_token = hetzner_config["api_token"]
client = Client(token=api_token)
server_type = ServerType(name=server_type)
@ -76,6 +76,13 @@ def get_ami_id_by_name(ec2_client, ami_name):
# Function to create a server on AWS
def create_aws_server(aws_config, location, server_type):
try:
import boto3
except ImportError:
raise ImportError(
"Please install the boto3 library by running `pip install boto3`"
)
access_key = aws_config["access_key"]
secret_key = aws_config["secret_key"]
region = aws_config["region"]
@ -113,6 +120,13 @@ def create_aws_server(aws_config, location, server_type):
# Function to create a server on DigitalOcean
def create_digitalocean_server(digitalocean_config, location, server_type):
try:
from digitalocean import Manager, Droplet
except ImportError:
raise ImportError(
"Please install the python-digitalocean library by running `pip install python-digitalocean`"
)
api_token = digitalocean_config["api_token"]
manager = Manager(token=api_token)
@ -133,6 +147,16 @@ def create_digitalocean_server(digitalocean_config, location, server_type):
# Function to create a server on Azure
def create_azure_server(azure_config, location, server_type):
try:
from azure.identity import ClientSecretCredential
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.resource import ResourceManagementClient
except ImportError:
raise ImportError(
"Please install the azure libraries by running `pip install azure-mgmt-compute azure-mgmt-network azure-mgmt-resource`"
)
subscription_id = azure_config["subscription_id"]
client_id = azure_config["client_id"]
client_secret = azure_config["client_secret"]
@ -270,7 +294,9 @@ def generate_wireguard_keys():
# Function to configure Wireguard on the VPS
def configure_wireguard(wireguard_config, server_ip, private_key, public_key, preshared_key):
def configure_wireguard(
wireguard_config, server_ip, private_key, public_key, preshared_key
):
address = wireguard_config["address"]
listen_port = wireguard_config["listen_port"]
@ -325,7 +351,9 @@ def run(config, provider, location, server_type, endpoint_only):
elif provider == "digitalocean":
location = location or config["digitalocean"]["region"]
server_type = server_type or config["digitalocean"]["server_type"]
server = create_digitalocean_server(config["digitalocean"], location, server_type)
server = create_digitalocean_server(
config["digitalocean"], location, server_type
)
server.load()
server_ip = server.ip_address
elif provider == "azure":
@ -366,7 +394,9 @@ def run(config, provider, location, server_type, endpoint_only):
public_key = private_to_public_key(private_key)
configure_wireguard(config["wireguard"], server_ip, private_key, public_key, preshared_key)
configure_wireguard(
config["wireguard"], server_ip, private_key, public_key, preshared_key
)
# Generate client configuration for Chimpman
wireguard_address = config["wireguard"]["address"]
@ -408,7 +438,11 @@ def main():
)
parser.add_argument("--location", type=str, help="Server location")
parser.add_argument("--server_type", type=str, help="Server type")
parser.add_argument("--endpoint-only", action="store_true", help="Return Wireguard endpoint (host:port) instead of full configuration")
parser.add_argument(
"--endpoint-only",
action="store_true",
help="Return Wireguard endpoint (host:port) instead of full configuration",
)
args = parser.parse_args()
# Check if the configuration has all the required sections