diff --git a/worker.py b/worker.py index c9e2775..f66c51d 100644 --- a/worker.py +++ b/worker.py @@ -26,32 +26,15 @@ logging.basicConfig( level=logging.INFO, ) -# Read configuration -config = configparser.ConfigParser() -config.read("config.ini") - -# Argument parser -parser = argparse.ArgumentParser(description="Create and configure a Wireguard server.") -parser.add_argument( - "--provider", - type=str, - choices=["hetzner", "aws", "digitalocean", "azure"], - default="hetzner", - help="Cloud provider (default: hetzner)", -) -parser.add_argument("--location", type=str, help="Server location") -parser.add_argument("--server_type", type=str, help="Server type") -args = parser.parse_args() - random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=8)) # Function to create a server on Hetzner Cloud -def create_hetzner_server(location, server_type): - api_token = config["hetzner"]["api_token"] +def create_hetzner_server(hetzner_config, location, server_type): + api_token = hetzner_config["api_token"] client = Client(token=api_token) server_type = ServerType(name=server_type) - image_name = config["hetzner"]["image"] + image_name = hetzner_config["image"] image = Image(name=image_name) location = LocationsClient(client).get_by_name(location) ssh_key = SSHKeysClient(client).get_all()[0] @@ -92,12 +75,12 @@ def get_ami_id_by_name(ec2_client, ami_name): # Function to create a server on AWS -def create_aws_server(location, server_type): - access_key = config["aws"]["access_key"] - secret_key = config["aws"]["secret_key"] - region = config["aws"]["region"] - ami_name = config["aws"]["ami_name"] - key_pair = config["aws"]["key_pair"] +def create_aws_server(aws_config, location, server_type): + access_key = aws_config["access_key"] + secret_key = aws_config["secret_key"] + region = aws_config["region"] + ami_name = aws_config["ami_name"] + key_pair = aws_config["key_pair"] ec2_client = boto3.client( "ec2", @@ -129,8 +112,8 @@ def create_aws_server(location, server_type): # Function to create a server on DigitalOcean -def create_digitalocean_server(location, server_type): - api_token = config["digitalocean"]["api_token"] +def create_digitalocean_server(digitalocean_config, location, server_type): + api_token = digitalocean_config["api_token"] manager = Manager(token=api_token) logging.info("Creating DigitalOcean server...") @@ -149,15 +132,15 @@ def create_digitalocean_server(location, server_type): # Function to create a server on Azure -def create_azure_server(location, server_type): - subscription_id = config["azure"]["subscription_id"] - client_id = config["azure"]["client_id"] - client_secret = config["azure"]["client_secret"] - tenant_id = config["azure"]["tenant_id"] +def create_azure_server(azure_config, location, server_type): + subscription_id = azure_config["subscription_id"] + client_id = azure_config["client_id"] + client_secret = azure_config["client_secret"] + tenant_id = azure_config["tenant_id"] vm_size = server_type - image_publisher = config["azure"]["image_publisher"] - image_offer = config["azure"]["image_offer"] - image_sku = config["azure"]["image_sku"] + image_publisher = azure_config["image_publisher"] + image_offer = azure_config["image_offer"] + image_sku = azure_config["image_sku"] credential = ClientSecretCredential( client_id=client_id, client_secret=client_secret, tenant_id=tenant_id @@ -287,14 +270,14 @@ def generate_wireguard_keys(): # Function to configure Wireguard on the VPS -def configure_wireguard(server_ip, private_key, public_key, preshared_key): - address = config["wireguard"]["address"] - listen_port = config["wireguard"]["listen_port"] +def configure_wireguard(wireguard_config, server_ip, private_key, public_key, preshared_key): + address = wireguard_config["address"] + listen_port = wireguard_config["listen_port"] - peer_public_key = config["wireguard"]["peer_public_key"] - allowed_ips = config["wireguard"]["peer_allowed_ips"] - endpoint = config["wireguard"]["peer_endpoint"] - persistent_keepalive = config["wireguard"]["peer_persistent_keepalive"] + peer_public_key = wireguard_config["peer_public_key"] + allowed_ips = wireguard_config["peer_allowed_ips"] + endpoint = wireguard_config["peer_endpoint"] + persistent_keepalive = wireguard_config["peer_persistent_keepalive"] wg_config = f""" [Interface] @@ -328,27 +311,27 @@ PersistentKeepalive = {persistent_keepalive} # Main function to create and configure the server -def main(provider, location, server_type): +def run(config, provider, location, server_type): if provider == "hetzner": location = location or config["hetzner"]["location"] server_type = server_type or config["hetzner"]["server_type"] - server = create_hetzner_server(location, server_type) + server = create_hetzner_server(config["hetzner"], location, server_type) server_ip = server.public_net.ipv6.ip.split("/")[0] + "1" elif provider == "aws": location = location or config["aws"]["region"] server_type = server_type or config["aws"]["instance_type"] - server = create_aws_server(location, server_type) + server = create_aws_server(config["aws"], location, server_type) server_ip = server.public_ip_address elif provider == "digitalocean": location = location or config["digitalocean"]["region"] server_type = server_type or config["digitalocean"]["server_type"] - server = create_digitalocean_server(location, server_type) + server = create_digitalocean_server(config["digitalocean"], location, server_type) server.load() server_ip = server.ip_address elif provider == "azure": location = location or config["azure"]["location"] server_type = server_type or config["azure"]["vm_size"] - server, server_ip = create_azure_server(location, server_type) + server, server_ip = create_azure_server(config["azure"], location, server_type) else: raise ValueError("Unsupported provider") @@ -383,7 +366,7 @@ def main(provider, location, server_type): public_key = private_to_public_key(private_key) - configure_wireguard(server_ip, private_key, public_key, preshared_key) + configure_wireguard(config["wireguard"], server_ip, private_key, public_key, preshared_key) # Generate client configuration for Chimpman wireguard_address = config["wireguard"]["address"] @@ -404,5 +387,36 @@ PersistentKeepalive = {persistent_keepalive} print(peer_config) -# Run the main function with parsed arguments -main(args.provider, args.location, args.server_type) +def main(): + # Read configuration + config = configparser.ConfigParser() + config.read("config.ini") + + # Argument parser + parser = argparse.ArgumentParser( + description="Create and configure a Wireguard server." + ) + parser.add_argument( + "--provider", + type=str, + choices=["hetzner", "aws", "digitalocean", "azure"], + default="hetzner", + help="Cloud provider (default: hetzner)", + ) + parser.add_argument("--location", type=str, help="Server location") + parser.add_argument("--server_type", type=str, help="Server type") + args = parser.parse_args() + + # Check if the configuration has all the required sections + required = ["wireguard", args.provider] + + for section in required: + if section not in config: + raise ValueError(f"Missing section {section} in config.ini") + + # Run the main function with parsed arguments + run(config, args.provider, args.location, args.server_type) + + +if __name__ == "__main__": + main()