Kumi
6fd8d7cc1a
Defer the import of cloud provider libraries (boto3, hcloud, digitalocean, azure) until they are needed within the respective server creation functions. This prevents unnecessary imports when only a subset of providers is used, reducing resource consumption and avoiding import errors for unused libraries.
460 lines
15 KiB
Python
460 lines
15 KiB
Python
import os
|
|
import paramiko
|
|
import subprocess
|
|
import time
|
|
import logging
|
|
import random
|
|
import string
|
|
import argparse
|
|
import configparser
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
format="[%(asctime)s] %(levelname)s: %(message)s",
|
|
level=logging.INFO,
|
|
)
|
|
|
|
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
|
|
|
|
|
|
# Function to create a server on Hetzner Cloud
|
|
def create_hetzner_server(hetzner_config, location, server_type):
|
|
try:
|
|
from hcloud import Client
|
|
from hcloud.server_types.client import ServerType
|
|
from hcloud.images.client import Image
|
|
from hcloud.locations.client import LocationsClient
|
|
from hcloud.ssh_keys.client import SSHKeysClient
|
|
from hcloud.servers.domain import ServerCreatePublicNetwork
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install the hcloud library by running `pip install hcloud`"
|
|
)
|
|
|
|
api_token = hetzner_config["api_token"]
|
|
client = Client(token=api_token)
|
|
server_type = ServerType(name=server_type)
|
|
image_name = hetzner_config["image"]
|
|
image = Image(name=image_name)
|
|
location = LocationsClient(client).get_by_name(location)
|
|
ssh_key = SSHKeysClient(client).get_all()[0]
|
|
public_network = ServerCreatePublicNetwork(enable_ipv4=False, enable_ipv6=True)
|
|
|
|
logging.info("Creating Hetzner server...")
|
|
|
|
server = client.servers.create(
|
|
name="wireguard-vps-" + random_string,
|
|
server_type=server_type,
|
|
image=image,
|
|
location=location,
|
|
ssh_keys=[ssh_key],
|
|
public_net=public_network,
|
|
)
|
|
|
|
while not server.server.status == "running":
|
|
time.sleep(5)
|
|
server.server.reload()
|
|
|
|
return server.server
|
|
|
|
|
|
# Function to get the latest AMI ID by name
|
|
def get_ami_id_by_name(ec2_client, ami_name):
|
|
response = ec2_client.describe_images(
|
|
Filters=[
|
|
{"Name": "name", "Values": [ami_name]},
|
|
{"Name": "state", "Values": ["available"]},
|
|
],
|
|
Owners=["self", "amazon"],
|
|
)
|
|
images = sorted(response["Images"], key=lambda x: x["CreationDate"], reverse=True)
|
|
if images:
|
|
return images[0]["ImageId"]
|
|
else:
|
|
raise ValueError(f"No AMI found with name pattern {ami_name}")
|
|
|
|
|
|
# Function to create a server on AWS
|
|
def create_aws_server(aws_config, location, server_type):
|
|
try:
|
|
import boto3
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install the boto3 library by running `pip install boto3`"
|
|
)
|
|
|
|
access_key = aws_config["access_key"]
|
|
secret_key = aws_config["secret_key"]
|
|
region = aws_config["region"]
|
|
ami_name = aws_config["ami_name"]
|
|
key_pair = aws_config["key_pair"]
|
|
|
|
ec2_client = boto3.client(
|
|
"ec2",
|
|
aws_access_key_id=access_key,
|
|
aws_secret_access_key=secret_key,
|
|
region_name=region,
|
|
)
|
|
|
|
ami_id = get_ami_id_by_name(ec2_client, ami_name)
|
|
|
|
ec2 = boto3.resource(
|
|
"ec2",
|
|
aws_access_key_id=access_key,
|
|
aws_secret_access_key=secret_key,
|
|
region_name=region,
|
|
)
|
|
|
|
logging.info("Creating AWS server...")
|
|
|
|
instances = ec2.create_instances(
|
|
ImageId=ami_id,
|
|
MinCount=1,
|
|
MaxCount=1,
|
|
InstanceType=server_type,
|
|
KeyName=key_pair,
|
|
)
|
|
|
|
return instances[0]
|
|
|
|
|
|
# Function to create a server on DigitalOcean
|
|
def create_digitalocean_server(digitalocean_config, location, server_type):
|
|
try:
|
|
from digitalocean import Manager, Droplet
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install the python-digitalocean library by running `pip install python-digitalocean`"
|
|
)
|
|
|
|
api_token = digitalocean_config["api_token"]
|
|
manager = Manager(token=api_token)
|
|
|
|
logging.info("Creating DigitalOcean server...")
|
|
|
|
droplet = Droplet(
|
|
token=api_token,
|
|
name="wireguard-vps-" + random_string,
|
|
region=location,
|
|
image="debian-12-x64",
|
|
size_slug=server_type,
|
|
ssh_keys=manager.get_all_sshkeys(),
|
|
)
|
|
|
|
droplet.create()
|
|
return droplet
|
|
|
|
|
|
# Function to create a server on Azure
|
|
def create_azure_server(azure_config, location, server_type):
|
|
try:
|
|
from azure.identity import ClientSecretCredential
|
|
from azure.mgmt.compute import ComputeManagementClient
|
|
from azure.mgmt.network import NetworkManagementClient
|
|
from azure.mgmt.resource import ResourceManagementClient
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install the azure libraries by running `pip install azure-mgmt-compute azure-mgmt-network azure-mgmt-resource`"
|
|
)
|
|
|
|
subscription_id = azure_config["subscription_id"]
|
|
client_id = azure_config["client_id"]
|
|
client_secret = azure_config["client_secret"]
|
|
tenant_id = azure_config["tenant_id"]
|
|
vm_size = server_type
|
|
image_publisher = azure_config["image_publisher"]
|
|
image_offer = azure_config["image_offer"]
|
|
image_sku = azure_config["image_sku"]
|
|
|
|
credential = ClientSecretCredential(
|
|
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
|
|
)
|
|
resource_client = ResourceManagementClient(credential, subscription_id)
|
|
compute_client = ComputeManagementClient(credential, subscription_id)
|
|
network_client = NetworkManagementClient(credential, subscription_id)
|
|
|
|
resource_group_name = f"rg-{random_string}"
|
|
network_name = f"vnet-{random_string}"
|
|
subnet_name = f"subnet-{random_string}"
|
|
ip_name = f"ip-{random_string}"
|
|
nic_name = f"nic-{random_string}"
|
|
vm_name = f"vm-{random_string}"
|
|
|
|
logging.info("Creating Azure resource group...")
|
|
resource_client.resource_groups.create_or_update(
|
|
location=location, resource_group_name=resource_group_name
|
|
)
|
|
|
|
logging.info("Creating Azure virtual network...")
|
|
network_client.virtual_networks.begin_create_or_update(
|
|
resource_group_name,
|
|
network_name,
|
|
{"location": location, "address_space": {"address_prefixes": ["10.0.0.0/16"]}},
|
|
).result()
|
|
|
|
logging.info("Creating Azure subnet...")
|
|
network_client.subnets.begin_create_or_update(
|
|
resource_group_name,
|
|
network_name,
|
|
subnet_name,
|
|
{"address_prefix": "10.0.0.0/24"},
|
|
).result()
|
|
|
|
logging.info("Creating Azure public IP address...")
|
|
ip_address = network_client.public_ip_addresses.begin_create_or_update(
|
|
resource_group_name,
|
|
ip_name,
|
|
{"location": location, "public_ip_allocation_method": "Dynamic"},
|
|
).result()
|
|
|
|
logging.info("Creating Azure network interface...")
|
|
nic = network_client.network_interfaces.begin_create_or_update(
|
|
resource_group_name,
|
|
nic_name,
|
|
{
|
|
"location": location,
|
|
"ip_configurations": [
|
|
{
|
|
"name": "ipconfig1",
|
|
"public_ip_address": ip_address,
|
|
"subnet": {
|
|
"id": f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/virtualNetworks/{network_name}/subnets/{subnet_name}"
|
|
},
|
|
}
|
|
],
|
|
},
|
|
).result()
|
|
|
|
logging.info("Creating Azure virtual machine...")
|
|
vm = compute_client.virtual_machines.begin_create_or_update(
|
|
resource_group_name,
|
|
vm_name,
|
|
{
|
|
"location": location,
|
|
"storage_profile": {
|
|
"image_reference": {
|
|
"publisher": image_publisher,
|
|
"offer": image_offer,
|
|
"sku": image_sku,
|
|
"version": "latest",
|
|
}
|
|
},
|
|
"hardware_profile": {"vm_size": vm_size},
|
|
"os_profile": {
|
|
"computer_name": vm_name,
|
|
"admin_username": "azureuser",
|
|
"admin_password": "P@ssw0rd!",
|
|
},
|
|
"network_profile": {"network_interfaces": [{"id": nic.id}]},
|
|
},
|
|
).result()
|
|
|
|
return vm, ip_address.ip_address
|
|
|
|
|
|
# Function to execute commands on the server
|
|
def ssh_execute_command(ip, command):
|
|
logging.info(f"Executing command on {ip}: {command}")
|
|
ssh = paramiko.SSHClient()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
ssh.connect(ip, username="root", key_filename=os.path.expanduser("~/.ssh/id_rsa"))
|
|
stdin, stdout, stderr = ssh.exec_command(command)
|
|
output = stdout.read().decode()
|
|
error = stderr.read().decode()
|
|
ssh.close()
|
|
return output, error
|
|
|
|
|
|
# Function to generate a random private key
|
|
def generate_private_key():
|
|
return subprocess.check_output("wg genkey", shell=True).decode().strip()
|
|
|
|
|
|
# Function to convert private key to public key
|
|
def private_to_public_key(private_key):
|
|
return (
|
|
subprocess.check_output(f"echo {private_key} | wg pubkey", shell=True)
|
|
.decode()
|
|
.strip()
|
|
)
|
|
|
|
|
|
# Function to generate a random preshared key
|
|
def generate_preshared_key():
|
|
return subprocess.check_output("wg genpsk", shell=True).decode().strip()
|
|
|
|
|
|
# Function to generate Wireguard keys
|
|
def generate_wireguard_keys():
|
|
logging.info("Generating Wireguard keys...")
|
|
private_key = generate_private_key()
|
|
public_key = private_to_public_key(private_key)
|
|
preshared_key = generate_preshared_key()
|
|
return private_key, public_key, preshared_key
|
|
|
|
|
|
# Function to configure Wireguard on the VPS
|
|
def configure_wireguard(
|
|
wireguard_config, server_ip, private_key, public_key, preshared_key
|
|
):
|
|
address = wireguard_config["address"]
|
|
listen_port = wireguard_config["listen_port"]
|
|
|
|
peer_public_key = wireguard_config["peer_public_key"]
|
|
allowed_ips = wireguard_config["peer_allowed_ips"]
|
|
endpoint = wireguard_config["peer_endpoint"]
|
|
persistent_keepalive = wireguard_config["peer_persistent_keepalive"]
|
|
|
|
wg_config = f"""
|
|
[Interface]
|
|
Address = {address}
|
|
PrivateKey = {private_key}
|
|
ListenPort = {listen_port}
|
|
|
|
[Peer]
|
|
PublicKey = {peer_public_key}
|
|
PresharedKey = {preshared_key}
|
|
AllowedIPs = {allowed_ips}
|
|
Endpoint = {endpoint}
|
|
PersistentKeepalive = {persistent_keepalive}
|
|
"""
|
|
|
|
ssh_execute_command(server_ip, f"echo '{wg_config}' > /etc/wireguard/wg0.conf")
|
|
ssh_execute_command(server_ip, "wg-quick up wg0")
|
|
|
|
# Configure ip(6)tables
|
|
iptables_rules = [
|
|
"ip6tables -A FORWARD -i wg0 -j ACCEPT",
|
|
"ip6tables -A FORWARD -o wg0 -j ACCEPT",
|
|
"ip6tables -t nat -A POSTROUTING -o eth0 -j MASQUERADE",
|
|
"iptables -A FORWARD -i wg0 -j ACCEPT",
|
|
"iptables -A FORWARD -o wg0 -j ACCEPT",
|
|
"iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE",
|
|
]
|
|
|
|
for rule in iptables_rules:
|
|
ssh_execute_command(server_ip, rule)
|
|
|
|
|
|
# Main function to create and configure the server
|
|
def run(config, provider, location, server_type, endpoint_only):
|
|
if provider == "hetzner":
|
|
location = location or config["hetzner"]["location"]
|
|
server_type = server_type or config["hetzner"]["server_type"]
|
|
server = create_hetzner_server(config["hetzner"], location, server_type)
|
|
server_ip = server.public_net.ipv6.ip.split("/")[0] + "1"
|
|
elif provider == "aws":
|
|
location = location or config["aws"]["region"]
|
|
server_type = server_type or config["aws"]["instance_type"]
|
|
server = create_aws_server(config["aws"], location, server_type)
|
|
server_ip = server.public_ip_address
|
|
elif provider == "digitalocean":
|
|
location = location or config["digitalocean"]["region"]
|
|
server_type = server_type or config["digitalocean"]["server_type"]
|
|
server = create_digitalocean_server(
|
|
config["digitalocean"], location, server_type
|
|
)
|
|
server.load()
|
|
server_ip = server.ip_address
|
|
elif provider == "azure":
|
|
location = location or config["azure"]["location"]
|
|
server_type = server_type or config["azure"]["vm_size"]
|
|
server, server_ip = create_azure_server(config["azure"], location, server_type)
|
|
else:
|
|
raise ValueError("Unsupported provider")
|
|
|
|
logging.info(f"Server IP: {server_ip}")
|
|
|
|
# Giving server time to boot up
|
|
logging.info("Waiting for server to boot up...")
|
|
time.sleep(30)
|
|
|
|
# Install Wireguard and configure it
|
|
commands = [
|
|
"apt update",
|
|
"apt install -y wireguard",
|
|
"echo 'net.ipv6.conf.all.forwarding=1' >> /etc/sysctl.conf",
|
|
"echo 'net.ipv4.ip_forward=1' >> /etc/sysctl.conf",
|
|
"sysctl -p",
|
|
]
|
|
|
|
for command in commands:
|
|
ssh_execute_command(server_ip, command)
|
|
|
|
# Check if private key and preshared key are provided
|
|
if config.get("wireguard", "private_key", fallback=None):
|
|
private_key = config["wireguard"]["private_key"]
|
|
else:
|
|
private_key = generate_private_key()
|
|
|
|
if config.get("wireguard", "preshared_key", fallback=None):
|
|
preshared_key = config["wireguard"]["preshared_key"]
|
|
else:
|
|
preshared_key = generate_preshared_key()
|
|
|
|
public_key = private_to_public_key(private_key)
|
|
|
|
configure_wireguard(
|
|
config["wireguard"], server_ip, private_key, public_key, preshared_key
|
|
)
|
|
|
|
# Generate client configuration for Chimpman
|
|
wireguard_address = config["wireguard"]["address"]
|
|
routed_addresses = config["wireguard"]["routed_addresses"]
|
|
listen_port = config["wireguard"]["listen_port"]
|
|
persistent_keepalive = config["wireguard"]["peer_persistent_keepalive"]
|
|
|
|
if endpoint_only:
|
|
print(f"[{server_ip}]:{listen_port}")
|
|
return
|
|
|
|
peer_config = f"""
|
|
[Peer]
|
|
PublicKey = {public_key}
|
|
PresharedKey = {preshared_key}
|
|
AllowedIPs = {wireguard_address}, {routed_addresses}
|
|
Endpoint = [{server_ip}]:{listen_port}
|
|
PersistentKeepalive = {persistent_keepalive}
|
|
"""
|
|
|
|
print(peer_config)
|
|
|
|
|
|
def main():
|
|
# Read configuration
|
|
config = configparser.ConfigParser()
|
|
config.read("config.ini")
|
|
|
|
# Argument parser
|
|
parser = argparse.ArgumentParser(
|
|
description="Create and configure a Wireguard server."
|
|
)
|
|
parser.add_argument(
|
|
"--provider",
|
|
type=str,
|
|
choices=["hetzner", "aws", "digitalocean", "azure"],
|
|
default="hetzner",
|
|
help="Cloud provider (default: hetzner)",
|
|
)
|
|
parser.add_argument("--location", type=str, help="Server location")
|
|
parser.add_argument("--server_type", type=str, help="Server type")
|
|
parser.add_argument(
|
|
"--endpoint-only",
|
|
action="store_true",
|
|
help="Return Wireguard endpoint (host:port) instead of full configuration",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Check if the configuration has all the required sections
|
|
required = ["wireguard", args.provider]
|
|
|
|
for section in required:
|
|
if section not in config:
|
|
raise ValueError(f"Missing section {section} in config.ini")
|
|
|
|
# Run the main function with parsed arguments
|
|
run(config, args.provider, args.location, args.server_type, args.endpoint_only)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|