wireguard-vpn-setup/worker.py
Kumi 6fd8d7cc1a
refactor: lazy-load provider libraries
Defer the import of cloud provider libraries (boto3, hcloud,
digitalocean, azure) until they are needed within the respective
server creation functions. This prevents unnecessary imports
when only a subset of providers is used, reducing resource
consumption and avoiding import errors for unused libraries.
2024-07-18 18:09:18 +02:00

460 lines
15 KiB
Python

import os
import paramiko
import subprocess
import time
import logging
import random
import string
import argparse
import configparser
# Set up logging
logging.basicConfig(
format="[%(asctime)s] %(levelname)s: %(message)s",
level=logging.INFO,
)
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
# Function to create a server on Hetzner Cloud
def create_hetzner_server(hetzner_config, location, server_type):
try:
from hcloud import Client
from hcloud.server_types.client import ServerType
from hcloud.images.client import Image
from hcloud.locations.client import LocationsClient
from hcloud.ssh_keys.client import SSHKeysClient
from hcloud.servers.domain import ServerCreatePublicNetwork
except ImportError:
raise ImportError(
"Please install the hcloud library by running `pip install hcloud`"
)
api_token = hetzner_config["api_token"]
client = Client(token=api_token)
server_type = ServerType(name=server_type)
image_name = hetzner_config["image"]
image = Image(name=image_name)
location = LocationsClient(client).get_by_name(location)
ssh_key = SSHKeysClient(client).get_all()[0]
public_network = ServerCreatePublicNetwork(enable_ipv4=False, enable_ipv6=True)
logging.info("Creating Hetzner server...")
server = client.servers.create(
name="wireguard-vps-" + random_string,
server_type=server_type,
image=image,
location=location,
ssh_keys=[ssh_key],
public_net=public_network,
)
while not server.server.status == "running":
time.sleep(5)
server.server.reload()
return server.server
# Function to get the latest AMI ID by name
def get_ami_id_by_name(ec2_client, ami_name):
response = ec2_client.describe_images(
Filters=[
{"Name": "name", "Values": [ami_name]},
{"Name": "state", "Values": ["available"]},
],
Owners=["self", "amazon"],
)
images = sorted(response["Images"], key=lambda x: x["CreationDate"], reverse=True)
if images:
return images[0]["ImageId"]
else:
raise ValueError(f"No AMI found with name pattern {ami_name}")
# Function to create a server on AWS
def create_aws_server(aws_config, location, server_type):
try:
import boto3
except ImportError:
raise ImportError(
"Please install the boto3 library by running `pip install boto3`"
)
access_key = aws_config["access_key"]
secret_key = aws_config["secret_key"]
region = aws_config["region"]
ami_name = aws_config["ami_name"]
key_pair = aws_config["key_pair"]
ec2_client = boto3.client(
"ec2",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region,
)
ami_id = get_ami_id_by_name(ec2_client, ami_name)
ec2 = boto3.resource(
"ec2",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region,
)
logging.info("Creating AWS server...")
instances = ec2.create_instances(
ImageId=ami_id,
MinCount=1,
MaxCount=1,
InstanceType=server_type,
KeyName=key_pair,
)
return instances[0]
# Function to create a server on DigitalOcean
def create_digitalocean_server(digitalocean_config, location, server_type):
try:
from digitalocean import Manager, Droplet
except ImportError:
raise ImportError(
"Please install the python-digitalocean library by running `pip install python-digitalocean`"
)
api_token = digitalocean_config["api_token"]
manager = Manager(token=api_token)
logging.info("Creating DigitalOcean server...")
droplet = Droplet(
token=api_token,
name="wireguard-vps-" + random_string,
region=location,
image="debian-12-x64",
size_slug=server_type,
ssh_keys=manager.get_all_sshkeys(),
)
droplet.create()
return droplet
# Function to create a server on Azure
def create_azure_server(azure_config, location, server_type):
try:
from azure.identity import ClientSecretCredential
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.resource import ResourceManagementClient
except ImportError:
raise ImportError(
"Please install the azure libraries by running `pip install azure-mgmt-compute azure-mgmt-network azure-mgmt-resource`"
)
subscription_id = azure_config["subscription_id"]
client_id = azure_config["client_id"]
client_secret = azure_config["client_secret"]
tenant_id = azure_config["tenant_id"]
vm_size = server_type
image_publisher = azure_config["image_publisher"]
image_offer = azure_config["image_offer"]
image_sku = azure_config["image_sku"]
credential = ClientSecretCredential(
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
)
resource_client = ResourceManagementClient(credential, subscription_id)
compute_client = ComputeManagementClient(credential, subscription_id)
network_client = NetworkManagementClient(credential, subscription_id)
resource_group_name = f"rg-{random_string}"
network_name = f"vnet-{random_string}"
subnet_name = f"subnet-{random_string}"
ip_name = f"ip-{random_string}"
nic_name = f"nic-{random_string}"
vm_name = f"vm-{random_string}"
logging.info("Creating Azure resource group...")
resource_client.resource_groups.create_or_update(
location=location, resource_group_name=resource_group_name
)
logging.info("Creating Azure virtual network...")
network_client.virtual_networks.begin_create_or_update(
resource_group_name,
network_name,
{"location": location, "address_space": {"address_prefixes": ["10.0.0.0/16"]}},
).result()
logging.info("Creating Azure subnet...")
network_client.subnets.begin_create_or_update(
resource_group_name,
network_name,
subnet_name,
{"address_prefix": "10.0.0.0/24"},
).result()
logging.info("Creating Azure public IP address...")
ip_address = network_client.public_ip_addresses.begin_create_or_update(
resource_group_name,
ip_name,
{"location": location, "public_ip_allocation_method": "Dynamic"},
).result()
logging.info("Creating Azure network interface...")
nic = network_client.network_interfaces.begin_create_or_update(
resource_group_name,
nic_name,
{
"location": location,
"ip_configurations": [
{
"name": "ipconfig1",
"public_ip_address": ip_address,
"subnet": {
"id": f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/virtualNetworks/{network_name}/subnets/{subnet_name}"
},
}
],
},
).result()
logging.info("Creating Azure virtual machine...")
vm = compute_client.virtual_machines.begin_create_or_update(
resource_group_name,
vm_name,
{
"location": location,
"storage_profile": {
"image_reference": {
"publisher": image_publisher,
"offer": image_offer,
"sku": image_sku,
"version": "latest",
}
},
"hardware_profile": {"vm_size": vm_size},
"os_profile": {
"computer_name": vm_name,
"admin_username": "azureuser",
"admin_password": "P@ssw0rd!",
},
"network_profile": {"network_interfaces": [{"id": nic.id}]},
},
).result()
return vm, ip_address.ip_address
# Function to execute commands on the server
def ssh_execute_command(ip, command):
logging.info(f"Executing command on {ip}: {command}")
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(ip, username="root", key_filename=os.path.expanduser("~/.ssh/id_rsa"))
stdin, stdout, stderr = ssh.exec_command(command)
output = stdout.read().decode()
error = stderr.read().decode()
ssh.close()
return output, error
# Function to generate a random private key
def generate_private_key():
return subprocess.check_output("wg genkey", shell=True).decode().strip()
# Function to convert private key to public key
def private_to_public_key(private_key):
return (
subprocess.check_output(f"echo {private_key} | wg pubkey", shell=True)
.decode()
.strip()
)
# Function to generate a random preshared key
def generate_preshared_key():
return subprocess.check_output("wg genpsk", shell=True).decode().strip()
# Function to generate Wireguard keys
def generate_wireguard_keys():
logging.info("Generating Wireguard keys...")
private_key = generate_private_key()
public_key = private_to_public_key(private_key)
preshared_key = generate_preshared_key()
return private_key, public_key, preshared_key
# Function to configure Wireguard on the VPS
def configure_wireguard(
wireguard_config, server_ip, private_key, public_key, preshared_key
):
address = wireguard_config["address"]
listen_port = wireguard_config["listen_port"]
peer_public_key = wireguard_config["peer_public_key"]
allowed_ips = wireguard_config["peer_allowed_ips"]
endpoint = wireguard_config["peer_endpoint"]
persistent_keepalive = wireguard_config["peer_persistent_keepalive"]
wg_config = f"""
[Interface]
Address = {address}
PrivateKey = {private_key}
ListenPort = {listen_port}
[Peer]
PublicKey = {peer_public_key}
PresharedKey = {preshared_key}
AllowedIPs = {allowed_ips}
Endpoint = {endpoint}
PersistentKeepalive = {persistent_keepalive}
"""
ssh_execute_command(server_ip, f"echo '{wg_config}' > /etc/wireguard/wg0.conf")
ssh_execute_command(server_ip, "wg-quick up wg0")
# Configure ip(6)tables
iptables_rules = [
"ip6tables -A FORWARD -i wg0 -j ACCEPT",
"ip6tables -A FORWARD -o wg0 -j ACCEPT",
"ip6tables -t nat -A POSTROUTING -o eth0 -j MASQUERADE",
"iptables -A FORWARD -i wg0 -j ACCEPT",
"iptables -A FORWARD -o wg0 -j ACCEPT",
"iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE",
]
for rule in iptables_rules:
ssh_execute_command(server_ip, rule)
# Main function to create and configure the server
def run(config, provider, location, server_type, endpoint_only):
if provider == "hetzner":
location = location or config["hetzner"]["location"]
server_type = server_type or config["hetzner"]["server_type"]
server = create_hetzner_server(config["hetzner"], location, server_type)
server_ip = server.public_net.ipv6.ip.split("/")[0] + "1"
elif provider == "aws":
location = location or config["aws"]["region"]
server_type = server_type or config["aws"]["instance_type"]
server = create_aws_server(config["aws"], location, server_type)
server_ip = server.public_ip_address
elif provider == "digitalocean":
location = location or config["digitalocean"]["region"]
server_type = server_type or config["digitalocean"]["server_type"]
server = create_digitalocean_server(
config["digitalocean"], location, server_type
)
server.load()
server_ip = server.ip_address
elif provider == "azure":
location = location or config["azure"]["location"]
server_type = server_type or config["azure"]["vm_size"]
server, server_ip = create_azure_server(config["azure"], location, server_type)
else:
raise ValueError("Unsupported provider")
logging.info(f"Server IP: {server_ip}")
# Giving server time to boot up
logging.info("Waiting for server to boot up...")
time.sleep(30)
# Install Wireguard and configure it
commands = [
"apt update",
"apt install -y wireguard",
"echo 'net.ipv6.conf.all.forwarding=1' >> /etc/sysctl.conf",
"echo 'net.ipv4.ip_forward=1' >> /etc/sysctl.conf",
"sysctl -p",
]
for command in commands:
ssh_execute_command(server_ip, command)
# Check if private key and preshared key are provided
if config.get("wireguard", "private_key", fallback=None):
private_key = config["wireguard"]["private_key"]
else:
private_key = generate_private_key()
if config.get("wireguard", "preshared_key", fallback=None):
preshared_key = config["wireguard"]["preshared_key"]
else:
preshared_key = generate_preshared_key()
public_key = private_to_public_key(private_key)
configure_wireguard(
config["wireguard"], server_ip, private_key, public_key, preshared_key
)
# Generate client configuration for Chimpman
wireguard_address = config["wireguard"]["address"]
routed_addresses = config["wireguard"]["routed_addresses"]
listen_port = config["wireguard"]["listen_port"]
persistent_keepalive = config["wireguard"]["peer_persistent_keepalive"]
if endpoint_only:
print(f"[{server_ip}]:{listen_port}")
return
peer_config = f"""
[Peer]
PublicKey = {public_key}
PresharedKey = {preshared_key}
AllowedIPs = {wireguard_address}, {routed_addresses}
Endpoint = [{server_ip}]:{listen_port}
PersistentKeepalive = {persistent_keepalive}
"""
print(peer_config)
def main():
# Read configuration
config = configparser.ConfigParser()
config.read("config.ini")
# Argument parser
parser = argparse.ArgumentParser(
description="Create and configure a Wireguard server."
)
parser.add_argument(
"--provider",
type=str,
choices=["hetzner", "aws", "digitalocean", "azure"],
default="hetzner",
help="Cloud provider (default: hetzner)",
)
parser.add_argument("--location", type=str, help="Server location")
parser.add_argument("--server_type", type=str, help="Server type")
parser.add_argument(
"--endpoint-only",
action="store_true",
help="Return Wireguard endpoint (host:port) instead of full configuration",
)
args = parser.parse_args()
# Check if the configuration has all the required sections
required = ["wireguard", args.provider]
for section in required:
if section not in config:
raise ValueError(f"Missing section {section} in config.ini")
# Run the main function with parsed arguments
run(config, args.provider, args.location, args.server_type, args.endpoint_only)
if __name__ == "__main__":
main()