wireguard-vpn-setup/worker.py
Kumi 1aa3932555
refactor: rename and clarify peer config print statement
Renamed the 'chimpman_config' variable to 'peer_config' for clarity, and updated the print statement to more accurately reflect the content being displayed. This improves code readability and eliminates the use of ambiguous terminology.

No changes to functionality.
2024-07-18 17:50:36 +02:00

408 lines
13 KiB
Python

import os
import paramiko
import subprocess
import time
import logging
import random
import string
import argparse
import configparser
import boto3
from digitalocean import Manager, Droplet
from hcloud import Client
from hcloud.server_types.client import ServerType
from hcloud.images.client import Image
from hcloud.locations.client import LocationsClient
from hcloud.ssh_keys.client import SSHKeysClient
from hcloud.servers.domain import ServerCreatePublicNetwork
from azure.identity import ClientSecretCredential
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.resource import ResourceManagementClient
# Set up logging
logging.basicConfig(
format="[%(asctime)s] %(levelname)s: %(message)s",
level=logging.INFO,
)
# Read configuration
config = configparser.ConfigParser()
config.read("config.ini")
# Argument parser
parser = argparse.ArgumentParser(description="Create and configure a Wireguard server.")
parser.add_argument(
"--provider",
type=str,
choices=["hetzner", "aws", "digitalocean", "azure"],
default="hetzner",
help="Cloud provider (default: hetzner)",
)
parser.add_argument("--location", type=str, help="Server location")
parser.add_argument("--server_type", type=str, help="Server type")
args = parser.parse_args()
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
# Function to create a server on Hetzner Cloud
def create_hetzner_server(location, server_type):
api_token = config["hetzner"]["api_token"]
client = Client(token=api_token)
server_type = ServerType(name=server_type)
image_name = config["hetzner"]["image"]
image = Image(name=image_name)
location = LocationsClient(client).get_by_name(location)
ssh_key = SSHKeysClient(client).get_all()[0]
public_network = ServerCreatePublicNetwork(enable_ipv4=False, enable_ipv6=True)
logging.info("Creating Hetzner server...")
server = client.servers.create(
name="wireguard-vps-" + random_string,
server_type=server_type,
image=image,
location=location,
ssh_keys=[ssh_key],
public_net=public_network,
)
while not server.server.status == "running":
time.sleep(5)
server.server.reload()
return server.server
# Function to get the latest AMI ID by name
def get_ami_id_by_name(ec2_client, ami_name):
response = ec2_client.describe_images(
Filters=[
{"Name": "name", "Values": [ami_name]},
{"Name": "state", "Values": ["available"]},
],
Owners=["self", "amazon"],
)
images = sorted(response["Images"], key=lambda x: x["CreationDate"], reverse=True)
if images:
return images[0]["ImageId"]
else:
raise ValueError(f"No AMI found with name pattern {ami_name}")
# Function to create a server on AWS
def create_aws_server(location, server_type):
access_key = config["aws"]["access_key"]
secret_key = config["aws"]["secret_key"]
region = config["aws"]["region"]
ami_name = config["aws"]["ami_name"]
key_pair = config["aws"]["key_pair"]
ec2_client = boto3.client(
"ec2",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region,
)
ami_id = get_ami_id_by_name(ec2_client, ami_name)
ec2 = boto3.resource(
"ec2",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region,
)
logging.info("Creating AWS server...")
instances = ec2.create_instances(
ImageId=ami_id,
MinCount=1,
MaxCount=1,
InstanceType=server_type,
KeyName=key_pair,
)
return instances[0]
# Function to create a server on DigitalOcean
def create_digitalocean_server(location, server_type):
api_token = config["digitalocean"]["api_token"]
manager = Manager(token=api_token)
logging.info("Creating DigitalOcean server...")
droplet = Droplet(
token=api_token,
name="wireguard-vps-" + random_string,
region=location,
image="debian-12-x64",
size_slug=server_type,
ssh_keys=manager.get_all_sshkeys(),
)
droplet.create()
return droplet
# Function to create a server on Azure
def create_azure_server(location, server_type):
subscription_id = config["azure"]["subscription_id"]
client_id = config["azure"]["client_id"]
client_secret = config["azure"]["client_secret"]
tenant_id = config["azure"]["tenant_id"]
vm_size = server_type
image_publisher = config["azure"]["image_publisher"]
image_offer = config["azure"]["image_offer"]
image_sku = config["azure"]["image_sku"]
credential = ClientSecretCredential(
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
)
resource_client = ResourceManagementClient(credential, subscription_id)
compute_client = ComputeManagementClient(credential, subscription_id)
network_client = NetworkManagementClient(credential, subscription_id)
resource_group_name = f"rg-{random_string}"
network_name = f"vnet-{random_string}"
subnet_name = f"subnet-{random_string}"
ip_name = f"ip-{random_string}"
nic_name = f"nic-{random_string}"
vm_name = f"vm-{random_string}"
logging.info("Creating Azure resource group...")
resource_client.resource_groups.create_or_update(
location=location, resource_group_name=resource_group_name
)
logging.info("Creating Azure virtual network...")
network_client.virtual_networks.begin_create_or_update(
resource_group_name,
network_name,
{"location": location, "address_space": {"address_prefixes": ["10.0.0.0/16"]}},
).result()
logging.info("Creating Azure subnet...")
network_client.subnets.begin_create_or_update(
resource_group_name,
network_name,
subnet_name,
{"address_prefix": "10.0.0.0/24"},
).result()
logging.info("Creating Azure public IP address...")
ip_address = network_client.public_ip_addresses.begin_create_or_update(
resource_group_name,
ip_name,
{"location": location, "public_ip_allocation_method": "Dynamic"},
).result()
logging.info("Creating Azure network interface...")
nic = network_client.network_interfaces.begin_create_or_update(
resource_group_name,
nic_name,
{
"location": location,
"ip_configurations": [
{
"name": "ipconfig1",
"public_ip_address": ip_address,
"subnet": {
"id": f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/virtualNetworks/{network_name}/subnets/{subnet_name}"
},
}
],
},
).result()
logging.info("Creating Azure virtual machine...")
vm = compute_client.virtual_machines.begin_create_or_update(
resource_group_name,
vm_name,
{
"location": location,
"storage_profile": {
"image_reference": {
"publisher": image_publisher,
"offer": image_offer,
"sku": image_sku,
"version": "latest",
}
},
"hardware_profile": {"vm_size": vm_size},
"os_profile": {
"computer_name": vm_name,
"admin_username": "azureuser",
"admin_password": "P@ssw0rd!",
},
"network_profile": {"network_interfaces": [{"id": nic.id}]},
},
).result()
return vm, ip_address.ip_address
# Function to execute commands on the server
def ssh_execute_command(ip, command):
logging.info(f"Executing command on {ip}: {command}")
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(ip, username="root", key_filename=os.path.expanduser("~/.ssh/id_rsa"))
stdin, stdout, stderr = ssh.exec_command(command)
output = stdout.read().decode()
error = stderr.read().decode()
ssh.close()
return output, error
# Function to generate a random private key
def generate_private_key():
return subprocess.check_output("wg genkey", shell=True).decode().strip()
# Function to convert private key to public key
def private_to_public_key(private_key):
return (
subprocess.check_output(f"echo {private_key} | wg pubkey", shell=True)
.decode()
.strip()
)
# Function to generate a random preshared key
def generate_preshared_key():
return subprocess.check_output("wg genpsk", shell=True).decode().strip()
# Function to generate Wireguard keys
def generate_wireguard_keys():
logging.info("Generating Wireguard keys...")
private_key = generate_private_key()
public_key = private_to_public_key(private_key)
preshared_key = generate_preshared_key()
return private_key, public_key, preshared_key
# Function to configure Wireguard on the VPS
def configure_wireguard(server_ip, private_key, public_key, preshared_key):
address = config["wireguard"]["address"]
listen_port = config["wireguard"]["listen_port"]
peer_public_key = config["wireguard"]["peer_public_key"]
allowed_ips = config["wireguard"]["peer_allowed_ips"]
endpoint = config["wireguard"]["peer_endpoint"]
persistent_keepalive = config["wireguard"]["peer_persistent_keepalive"]
wg_config = f"""
[Interface]
Address = {address}
PrivateKey = {private_key}
ListenPort = {listen_port}
[Peer]
PublicKey = {peer_public_key}
PresharedKey = {preshared_key}
AllowedIPs = {allowed_ips}
Endpoint = {endpoint}
PersistentKeepalive = {persistent_keepalive}
"""
ssh_execute_command(server_ip, f"echo '{wg_config}' > /etc/wireguard/wg0.conf")
ssh_execute_command(server_ip, "wg-quick up wg0")
# Configure ip(6)tables
iptables_rules = [
"ip6tables -A FORWARD -i wg0 -j ACCEPT",
"ip6tables -A FORWARD -o wg0 -j ACCEPT",
"ip6tables -t nat -A POSTROUTING -o eth0 -j MASQUERADE",
"iptables -A FORWARD -i wg0 -j ACCEPT",
"iptables -A FORWARD -o wg0 -j ACCEPT",
"iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE",
]
for rule in iptables_rules:
ssh_execute_command(server_ip, rule)
# Main function to create and configure the server
def main(provider, location, server_type):
if provider == "hetzner":
location = location or config["hetzner"]["location"]
server_type = server_type or config["hetzner"]["server_type"]
server = create_hetzner_server(location, server_type)
server_ip = server.public_net.ipv6.ip.split("/")[0] + "1"
elif provider == "aws":
location = location or config["aws"]["region"]
server_type = server_type or config["aws"]["instance_type"]
server = create_aws_server(location, server_type)
server_ip = server.public_ip_address
elif provider == "digitalocean":
location = location or config["digitalocean"]["region"]
server_type = server_type or config["digitalocean"]["server_type"]
server = create_digitalocean_server(location, server_type)
server.load()
server_ip = server.ip_address
elif provider == "azure":
location = location or config["azure"]["location"]
server_type = server_type or config["azure"]["vm_size"]
server, server_ip = create_azure_server(location, server_type)
else:
raise ValueError("Unsupported provider")
logging.info(f"Server IP: {server_ip}")
# Giving server time to boot up
logging.info("Waiting for server to boot up...")
time.sleep(30)
# Install Wireguard and configure it
commands = [
"apt update",
"apt install -y wireguard",
"echo 'net.ipv6.conf.all.forwarding=1' >> /etc/sysctl.conf",
"echo 'net.ipv4.ip_forward=1' >> /etc/sysctl.conf",
"sysctl -p",
]
for command in commands:
ssh_execute_command(server_ip, command)
# Check if private key and preshared key are provided
if config["wireguard"]["private_key"]:
private_key = config["wireguard"]["private_key"]
else:
private_key = generate_private_key()
if config["wireguard"]["preshared_key"]:
preshared_key = config["wireguard"]["preshared_key"]
else:
preshared_key = generate_preshared_key()
public_key = private_to_public_key(private_key)
configure_wireguard(server_ip, private_key, public_key, preshared_key)
# Generate client configuration for Chimpman
wireguard_address = config["wireguard"]["address"]
routed_addresses = config["wireguard"]["routed_addresses"]
listen_port = config["wireguard"]["listen_port"]
persistent_keepalive = config["wireguard"]["peer_persistent_keepalive"]
peer_config = f"""
[Peer]
PublicKey = {public_key}
PresharedKey = {preshared_key}
AllowedIPs = {wireguard_address}, {routed_addresses}
Endpoint = [{server_ip}]:{listen_port}
PersistentKeepalive = {persistent_keepalive}
"""
print("Wireguard Configuration:")
print(peer_config)
# Run the main function with parsed arguments
main(args.provider, args.location, args.server_type)