refactor: improve configuration handling in server setup

Refactored the code to pass configuration sections as parameters to server creation functions instead of accessing them globally. This enhances modularity and clarity by making function dependencies explicit. Additionally, consolidated configuration reading and argument parsing in the main function, and added a validation step to check for required configuration sections.

This change makes the codebase more maintainable and prepares it for potential future extensions.
This commit is contained in:
Kumi 2024-07-18 17:59:33 +02:00
parent 1aa3932555
commit eb575e8f49
Signed by: kumi
GPG key ID: ECBCC9082395383F

114
worker.py
View file

@ -26,32 +26,15 @@ logging.basicConfig(
level=logging.INFO, level=logging.INFO,
) )
# Read configuration
config = configparser.ConfigParser()
config.read("config.ini")
# Argument parser
parser = argparse.ArgumentParser(description="Create and configure a Wireguard server.")
parser.add_argument(
"--provider",
type=str,
choices=["hetzner", "aws", "digitalocean", "azure"],
default="hetzner",
help="Cloud provider (default: hetzner)",
)
parser.add_argument("--location", type=str, help="Server location")
parser.add_argument("--server_type", type=str, help="Server type")
args = parser.parse_args()
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=8)) random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
# Function to create a server on Hetzner Cloud # Function to create a server on Hetzner Cloud
def create_hetzner_server(location, server_type): def create_hetzner_server(hetzner_config, location, server_type):
api_token = config["hetzner"]["api_token"] api_token = hetzner_config["api_token"]
client = Client(token=api_token) client = Client(token=api_token)
server_type = ServerType(name=server_type) server_type = ServerType(name=server_type)
image_name = config["hetzner"]["image"] image_name = hetzner_config["image"]
image = Image(name=image_name) image = Image(name=image_name)
location = LocationsClient(client).get_by_name(location) location = LocationsClient(client).get_by_name(location)
ssh_key = SSHKeysClient(client).get_all()[0] ssh_key = SSHKeysClient(client).get_all()[0]
@ -92,12 +75,12 @@ def get_ami_id_by_name(ec2_client, ami_name):
# Function to create a server on AWS # Function to create a server on AWS
def create_aws_server(location, server_type): def create_aws_server(aws_config, location, server_type):
access_key = config["aws"]["access_key"] access_key = aws_config["access_key"]
secret_key = config["aws"]["secret_key"] secret_key = aws_config["secret_key"]
region = config["aws"]["region"] region = aws_config["region"]
ami_name = config["aws"]["ami_name"] ami_name = aws_config["ami_name"]
key_pair = config["aws"]["key_pair"] key_pair = aws_config["key_pair"]
ec2_client = boto3.client( ec2_client = boto3.client(
"ec2", "ec2",
@ -129,8 +112,8 @@ def create_aws_server(location, server_type):
# Function to create a server on DigitalOcean # Function to create a server on DigitalOcean
def create_digitalocean_server(location, server_type): def create_digitalocean_server(digitalocean_config, location, server_type):
api_token = config["digitalocean"]["api_token"] api_token = digitalocean_config["api_token"]
manager = Manager(token=api_token) manager = Manager(token=api_token)
logging.info("Creating DigitalOcean server...") logging.info("Creating DigitalOcean server...")
@ -149,15 +132,15 @@ def create_digitalocean_server(location, server_type):
# Function to create a server on Azure # Function to create a server on Azure
def create_azure_server(location, server_type): def create_azure_server(azure_config, location, server_type):
subscription_id = config["azure"]["subscription_id"] subscription_id = azure_config["subscription_id"]
client_id = config["azure"]["client_id"] client_id = azure_config["client_id"]
client_secret = config["azure"]["client_secret"] client_secret = azure_config["client_secret"]
tenant_id = config["azure"]["tenant_id"] tenant_id = azure_config["tenant_id"]
vm_size = server_type vm_size = server_type
image_publisher = config["azure"]["image_publisher"] image_publisher = azure_config["image_publisher"]
image_offer = config["azure"]["image_offer"] image_offer = azure_config["image_offer"]
image_sku = config["azure"]["image_sku"] image_sku = azure_config["image_sku"]
credential = ClientSecretCredential( credential = ClientSecretCredential(
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
@ -287,14 +270,14 @@ def generate_wireguard_keys():
# Function to configure Wireguard on the VPS # Function to configure Wireguard on the VPS
def configure_wireguard(server_ip, private_key, public_key, preshared_key): def configure_wireguard(wireguard_config, server_ip, private_key, public_key, preshared_key):
address = config["wireguard"]["address"] address = wireguard_config["address"]
listen_port = config["wireguard"]["listen_port"] listen_port = wireguard_config["listen_port"]
peer_public_key = config["wireguard"]["peer_public_key"] peer_public_key = wireguard_config["peer_public_key"]
allowed_ips = config["wireguard"]["peer_allowed_ips"] allowed_ips = wireguard_config["peer_allowed_ips"]
endpoint = config["wireguard"]["peer_endpoint"] endpoint = wireguard_config["peer_endpoint"]
persistent_keepalive = config["wireguard"]["peer_persistent_keepalive"] persistent_keepalive = wireguard_config["peer_persistent_keepalive"]
wg_config = f""" wg_config = f"""
[Interface] [Interface]
@ -328,27 +311,27 @@ PersistentKeepalive = {persistent_keepalive}
# Main function to create and configure the server # Main function to create and configure the server
def main(provider, location, server_type): def run(config, provider, location, server_type):
if provider == "hetzner": if provider == "hetzner":
location = location or config["hetzner"]["location"] location = location or config["hetzner"]["location"]
server_type = server_type or config["hetzner"]["server_type"] server_type = server_type or config["hetzner"]["server_type"]
server = create_hetzner_server(location, server_type) server = create_hetzner_server(config["hetzner"], location, server_type)
server_ip = server.public_net.ipv6.ip.split("/")[0] + "1" server_ip = server.public_net.ipv6.ip.split("/")[0] + "1"
elif provider == "aws": elif provider == "aws":
location = location or config["aws"]["region"] location = location or config["aws"]["region"]
server_type = server_type or config["aws"]["instance_type"] server_type = server_type or config["aws"]["instance_type"]
server = create_aws_server(location, server_type) server = create_aws_server(config["aws"], location, server_type)
server_ip = server.public_ip_address server_ip = server.public_ip_address
elif provider == "digitalocean": elif provider == "digitalocean":
location = location or config["digitalocean"]["region"] location = location or config["digitalocean"]["region"]
server_type = server_type or config["digitalocean"]["server_type"] server_type = server_type or config["digitalocean"]["server_type"]
server = create_digitalocean_server(location, server_type) server = create_digitalocean_server(config["digitalocean"], location, server_type)
server.load() server.load()
server_ip = server.ip_address server_ip = server.ip_address
elif provider == "azure": elif provider == "azure":
location = location or config["azure"]["location"] location = location or config["azure"]["location"]
server_type = server_type or config["azure"]["vm_size"] server_type = server_type or config["azure"]["vm_size"]
server, server_ip = create_azure_server(location, server_type) server, server_ip = create_azure_server(config["azure"], location, server_type)
else: else:
raise ValueError("Unsupported provider") raise ValueError("Unsupported provider")
@ -383,7 +366,7 @@ def main(provider, location, server_type):
public_key = private_to_public_key(private_key) public_key = private_to_public_key(private_key)
configure_wireguard(server_ip, private_key, public_key, preshared_key) configure_wireguard(config["wireguard"], server_ip, private_key, public_key, preshared_key)
# Generate client configuration for Chimpman # Generate client configuration for Chimpman
wireguard_address = config["wireguard"]["address"] wireguard_address = config["wireguard"]["address"]
@ -404,5 +387,36 @@ PersistentKeepalive = {persistent_keepalive}
print(peer_config) print(peer_config)
def main():
# Read configuration
config = configparser.ConfigParser()
config.read("config.ini")
# Argument parser
parser = argparse.ArgumentParser(
description="Create and configure a Wireguard server."
)
parser.add_argument(
"--provider",
type=str,
choices=["hetzner", "aws", "digitalocean", "azure"],
default="hetzner",
help="Cloud provider (default: hetzner)",
)
parser.add_argument("--location", type=str, help="Server location")
parser.add_argument("--server_type", type=str, help="Server type")
args = parser.parse_args()
# Check if the configuration has all the required sections
required = ["wireguard", args.provider]
for section in required:
if section not in config:
raise ValueError(f"Missing section {section} in config.ini")
# Run the main function with parsed arguments # Run the main function with parsed arguments
main(args.provider, args.location, args.server_type) run(config, args.provider, args.location, args.server_type)
if __name__ == "__main__":
main()