refactor: improve configuration handling in server setup
Refactored the code to pass configuration sections as parameters to server creation functions instead of accessing them globally. This enhances modularity and clarity by making function dependencies explicit. Additionally, consolidated configuration reading and argument parsing in the main function, and added a validation step to check for required configuration sections. This change makes the codebase more maintainable and prepares it for potential future extensions.
This commit is contained in:
parent
1aa3932555
commit
eb575e8f49
1 changed files with 65 additions and 51 deletions
116
worker.py
116
worker.py
|
@ -26,32 +26,15 @@ logging.basicConfig(
|
|||
level=logging.INFO,
|
||||
)
|
||||
|
||||
# Read configuration
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini")
|
||||
|
||||
# Argument parser
|
||||
parser = argparse.ArgumentParser(description="Create and configure a Wireguard server.")
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
choices=["hetzner", "aws", "digitalocean", "azure"],
|
||||
default="hetzner",
|
||||
help="Cloud provider (default: hetzner)",
|
||||
)
|
||||
parser.add_argument("--location", type=str, help="Server location")
|
||||
parser.add_argument("--server_type", type=str, help="Server type")
|
||||
args = parser.parse_args()
|
||||
|
||||
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
|
||||
|
||||
|
||||
# Function to create a server on Hetzner Cloud
|
||||
def create_hetzner_server(location, server_type):
|
||||
api_token = config["hetzner"]["api_token"]
|
||||
def create_hetzner_server(hetzner_config, location, server_type):
|
||||
api_token = hetzner_config["api_token"]
|
||||
client = Client(token=api_token)
|
||||
server_type = ServerType(name=server_type)
|
||||
image_name = config["hetzner"]["image"]
|
||||
image_name = hetzner_config["image"]
|
||||
image = Image(name=image_name)
|
||||
location = LocationsClient(client).get_by_name(location)
|
||||
ssh_key = SSHKeysClient(client).get_all()[0]
|
||||
|
@ -92,12 +75,12 @@ def get_ami_id_by_name(ec2_client, ami_name):
|
|||
|
||||
|
||||
# Function to create a server on AWS
|
||||
def create_aws_server(location, server_type):
|
||||
access_key = config["aws"]["access_key"]
|
||||
secret_key = config["aws"]["secret_key"]
|
||||
region = config["aws"]["region"]
|
||||
ami_name = config["aws"]["ami_name"]
|
||||
key_pair = config["aws"]["key_pair"]
|
||||
def create_aws_server(aws_config, location, server_type):
|
||||
access_key = aws_config["access_key"]
|
||||
secret_key = aws_config["secret_key"]
|
||||
region = aws_config["region"]
|
||||
ami_name = aws_config["ami_name"]
|
||||
key_pair = aws_config["key_pair"]
|
||||
|
||||
ec2_client = boto3.client(
|
||||
"ec2",
|
||||
|
@ -129,8 +112,8 @@ def create_aws_server(location, server_type):
|
|||
|
||||
|
||||
# Function to create a server on DigitalOcean
|
||||
def create_digitalocean_server(location, server_type):
|
||||
api_token = config["digitalocean"]["api_token"]
|
||||
def create_digitalocean_server(digitalocean_config, location, server_type):
|
||||
api_token = digitalocean_config["api_token"]
|
||||
manager = Manager(token=api_token)
|
||||
|
||||
logging.info("Creating DigitalOcean server...")
|
||||
|
@ -149,15 +132,15 @@ def create_digitalocean_server(location, server_type):
|
|||
|
||||
|
||||
# Function to create a server on Azure
|
||||
def create_azure_server(location, server_type):
|
||||
subscription_id = config["azure"]["subscription_id"]
|
||||
client_id = config["azure"]["client_id"]
|
||||
client_secret = config["azure"]["client_secret"]
|
||||
tenant_id = config["azure"]["tenant_id"]
|
||||
def create_azure_server(azure_config, location, server_type):
|
||||
subscription_id = azure_config["subscription_id"]
|
||||
client_id = azure_config["client_id"]
|
||||
client_secret = azure_config["client_secret"]
|
||||
tenant_id = azure_config["tenant_id"]
|
||||
vm_size = server_type
|
||||
image_publisher = config["azure"]["image_publisher"]
|
||||
image_offer = config["azure"]["image_offer"]
|
||||
image_sku = config["azure"]["image_sku"]
|
||||
image_publisher = azure_config["image_publisher"]
|
||||
image_offer = azure_config["image_offer"]
|
||||
image_sku = azure_config["image_sku"]
|
||||
|
||||
credential = ClientSecretCredential(
|
||||
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
|
||||
|
@ -287,14 +270,14 @@ def generate_wireguard_keys():
|
|||
|
||||
|
||||
# Function to configure Wireguard on the VPS
|
||||
def configure_wireguard(server_ip, private_key, public_key, preshared_key):
|
||||
address = config["wireguard"]["address"]
|
||||
listen_port = config["wireguard"]["listen_port"]
|
||||
def configure_wireguard(wireguard_config, server_ip, private_key, public_key, preshared_key):
|
||||
address = wireguard_config["address"]
|
||||
listen_port = wireguard_config["listen_port"]
|
||||
|
||||
peer_public_key = config["wireguard"]["peer_public_key"]
|
||||
allowed_ips = config["wireguard"]["peer_allowed_ips"]
|
||||
endpoint = config["wireguard"]["peer_endpoint"]
|
||||
persistent_keepalive = config["wireguard"]["peer_persistent_keepalive"]
|
||||
peer_public_key = wireguard_config["peer_public_key"]
|
||||
allowed_ips = wireguard_config["peer_allowed_ips"]
|
||||
endpoint = wireguard_config["peer_endpoint"]
|
||||
persistent_keepalive = wireguard_config["peer_persistent_keepalive"]
|
||||
|
||||
wg_config = f"""
|
||||
[Interface]
|
||||
|
@ -328,27 +311,27 @@ PersistentKeepalive = {persistent_keepalive}
|
|||
|
||||
|
||||
# Main function to create and configure the server
|
||||
def main(provider, location, server_type):
|
||||
def run(config, provider, location, server_type):
|
||||
if provider == "hetzner":
|
||||
location = location or config["hetzner"]["location"]
|
||||
server_type = server_type or config["hetzner"]["server_type"]
|
||||
server = create_hetzner_server(location, server_type)
|
||||
server = create_hetzner_server(config["hetzner"], location, server_type)
|
||||
server_ip = server.public_net.ipv6.ip.split("/")[0] + "1"
|
||||
elif provider == "aws":
|
||||
location = location or config["aws"]["region"]
|
||||
server_type = server_type or config["aws"]["instance_type"]
|
||||
server = create_aws_server(location, server_type)
|
||||
server = create_aws_server(config["aws"], location, server_type)
|
||||
server_ip = server.public_ip_address
|
||||
elif provider == "digitalocean":
|
||||
location = location or config["digitalocean"]["region"]
|
||||
server_type = server_type or config["digitalocean"]["server_type"]
|
||||
server = create_digitalocean_server(location, server_type)
|
||||
server = create_digitalocean_server(config["digitalocean"], location, server_type)
|
||||
server.load()
|
||||
server_ip = server.ip_address
|
||||
elif provider == "azure":
|
||||
location = location or config["azure"]["location"]
|
||||
server_type = server_type or config["azure"]["vm_size"]
|
||||
server, server_ip = create_azure_server(location, server_type)
|
||||
server, server_ip = create_azure_server(config["azure"], location, server_type)
|
||||
else:
|
||||
raise ValueError("Unsupported provider")
|
||||
|
||||
|
@ -383,7 +366,7 @@ def main(provider, location, server_type):
|
|||
|
||||
public_key = private_to_public_key(private_key)
|
||||
|
||||
configure_wireguard(server_ip, private_key, public_key, preshared_key)
|
||||
configure_wireguard(config["wireguard"], server_ip, private_key, public_key, preshared_key)
|
||||
|
||||
# Generate client configuration for Chimpman
|
||||
wireguard_address = config["wireguard"]["address"]
|
||||
|
@ -404,5 +387,36 @@ PersistentKeepalive = {persistent_keepalive}
|
|||
print(peer_config)
|
||||
|
||||
|
||||
# Run the main function with parsed arguments
|
||||
main(args.provider, args.location, args.server_type)
|
||||
def main():
|
||||
# Read configuration
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini")
|
||||
|
||||
# Argument parser
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create and configure a Wireguard server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
choices=["hetzner", "aws", "digitalocean", "azure"],
|
||||
default="hetzner",
|
||||
help="Cloud provider (default: hetzner)",
|
||||
)
|
||||
parser.add_argument("--location", type=str, help="Server location")
|
||||
parser.add_argument("--server_type", type=str, help="Server type")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check if the configuration has all the required sections
|
||||
required = ["wireguard", args.provider]
|
||||
|
||||
for section in required:
|
||||
if section not in config:
|
||||
raise ValueError(f"Missing section {section} in config.ini")
|
||||
|
||||
# Run the main function with parsed arguments
|
||||
run(config, args.provider, args.location, args.server_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Reference in a new issue