From e63a3d6357bf11881ca1553d57c7f2b04dd3a821 Mon Sep 17 00:00:00 2001 From: Dan Guido Date: Mon, 4 Aug 2025 22:13:48 -0700 Subject: [PATCH] Fix linting issues across the codebase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Python Code Quality (ruff) - Fixed import organization and removed unused imports in test files - Replaced `== True` comparisons with direct boolean checks - Added noqa comments for intentional imports in test modules ## YAML Formatting (yamllint) - Removed trailing spaces in openssl.yml comments - All YAML files now pass yamllint validation (except one pre-existing long regex line) ## Code Consistency - Maintained proper import ordering in test files - Ensured all code follows project linting standards - Ready for CI pipeline validation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- roles/strongswan/tasks/openssl.yml | 2 +- tests/unit/test_openssl_compatibility.py | 172 ++++++++++---------- tests/unit/test_wireguard_key_generation.py | 171 ++++++++++--------- 3 files changed, 168 insertions(+), 177 deletions(-) diff --git a/roles/strongswan/tasks/openssl.yml b/roles/strongswan/tasks/openssl.yml index 7acfbca6..255a8c23 100644 --- a/roles/strongswan/tasks/openssl.yml +++ b/roles/strongswan/tasks/openssl.yml @@ -55,7 +55,7 @@ # CA can sign both server and client certs, restricted to VPN use only extended_key_usage: - serverAuth # Allows signing server certificates - - clientAuth # Allows signing client certificates + - clientAuth # Allows signing client certificates - '1.3.6.1.5.5.7.3.17' # IPsec End Entity OID - VPN-specific usage extended_key_usage_critical: true # Name constraints from defaults/main.yml template - prevents CA from issuing certs for public domains diff --git a/tests/unit/test_openssl_compatibility.py b/tests/unit/test_openssl_compatibility.py index a073d52d..6ccded79 100644 --- a/tests/unit/test_openssl_compatibility.py +++ b/tests/unit/test_openssl_compatibility.py @@ -5,18 +5,14 @@ Hybrid approach: validates actual certificates when available, else tests templa Based on issues #14755, #14718 - Apple device compatibility Issues #75, #153 - Security enhancements (name constraints, EKU restrictions) """ -import os import glob +import os import re import subprocess import sys -import yaml -import tempfile -import ipaddress -from pathlib import Path + from cryptography import x509 -from cryptography.hazmat.primitives import serialization -from cryptography.x509.oid import NameOID, ExtensionOID +from cryptography.x509.oid import ExtensionOID, NameOID def find_generated_certificates(): @@ -27,7 +23,7 @@ def find_generated_certificates(): "../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory "../../configs/*/ipsec/.pki/cacert.pem" # Alternative path ] - + for pattern in config_patterns: ca_certs = glob.glob(pattern) if ca_certs: @@ -38,7 +34,7 @@ def find_generated_certificates(): 'server_certs': glob.glob(f"{base_path}/certs/*.crt"), 'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12") } - + return None def test_openssl_version_detection(): @@ -67,49 +63,49 @@ def validate_ca_certificate_real(cert_files): # Read the actual CA certificate generated by Ansible with open(cert_files['ca_cert'], 'rb') as f: cert_data = f.read() - + certificate = x509.load_pem_x509_certificate(cert_data) - + # Check Basic Constraints basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value assert basic_constraints.ca is True, "CA certificate should have CA:TRUE" assert basic_constraints.path_length == 0, "CA should have pathlen:0 constraint" - + # Check Key Usage key_usage = certificate.extensions.get_extension_for_oid(ExtensionOID.KEY_USAGE).value assert key_usage.key_cert_sign is True, "CA should have keyCertSign usage" assert key_usage.crl_sign is True, "CA should have cRLSign usage" - + # Check Extended Key Usage (Issue #75) eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "CA should allow signing server certificates" assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "CA should allow signing client certificates" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "CA should have IPsec End Entity EKU" - + # Check Name Constraints (Issue #75) - defense against certificate misuse name_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.NAME_CONSTRAINTS).value assert name_constraints.permitted_subtrees is not None, "CA should have permitted name constraints" assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints" - + # Verify public domains are excluded - excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees + excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.DNSName)] public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] for domain in public_domains: assert domain in excluded_dns, f"CA should exclude public domain {domain}" - + # Verify private IP ranges are excluded (Issue #75) - excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees + excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.IPAddress)] assert len(excluded_ips) > 0, "CA should exclude private IP ranges" - + # Verify email domains are also excluded (Issue #153) - excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees + excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.RFC822Name)] email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] for domain in email_domains: assert domain in excluded_emails, f"CA should exclude email domain {domain}" - + print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}") def validate_ca_certificate_config(): @@ -119,10 +115,10 @@ def validate_ca_certificate_config(): if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return - - with open(openssl_task_file, 'r') as f: + + with open(openssl_task_file) as f: content = f.read() - + # Verify key security configurations are present security_checks = [ ('name_constraints_permitted', 'Name constraints should be configured'), @@ -135,28 +131,28 @@ def validate_ca_certificate_config(): ('CA:TRUE', 'CA certificate should be marked as CA'), ('pathlen:0', 'Path length constraint should be set') ] - + for check, message in security_checks: assert check in content, f"Missing security configuration: {message}" - + # Verify public domains are excluded public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] for domain in public_domains: assert f'"DNS:{domain}"' in content, f"Public domain {domain} should be excluded" - + # Verify private IP ranges are excluded private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"] for ip_range in private_ranges: assert ip_range in content, f"Private IP range {ip_range} should be excluded" - + # Verify email domains are excluded (Issue #153) email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] for domain in email_domains: assert f'"email:{domain}"' in content, f"Email domain {domain} should be excluded" - + # Verify IPv6 constraints are present (Issue #153) assert "IP:0:0:0:0:0:0:0:0/0:0:0:0:0:0:0:0" in content, "IPv6 all-zeros should be excluded" - + print("✓ CA certificate configuration has proper security constraints") def test_ca_certificate(): @@ -174,31 +170,31 @@ def validate_server_certificates_real(cert_files): if not server_certs: print("⚠ No server certificates found") return - + for server_cert_path in server_certs: with open(server_cert_path, 'rb') as f: cert_data = f.read() - + certificate = x509.load_pem_x509_certificate(cert_data) - + # Check it's not a CA certificate basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value assert basic_constraints.ca is False, "Server certificate should not be a CA" - + # Check Extended Key Usage (Issue #75) eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU" # Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153) assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation" - + # Check SAN extension exists (required for Apple devices) try: san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value assert len(san) > 0, "Server certificate must have SAN extension for Apple device compatibility" except x509.ExtensionNotFound: assert False, "Server certificate missing SAN extension - required for modern clients" - + print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}") def validate_server_certificates_config(): @@ -207,18 +203,18 @@ def validate_server_certificates_config(): if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return - - with open(openssl_task_file, 'r') as f: + + with open(openssl_task_file) as f: content = f.read() - + # Look for server certificate CSR section server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL) if not server_csr_section: print("⚠ Could not find server certificate CSR section") return - + server_section = server_csr_section.group(0) - + # Check server certificate CSR configuration server_checks = [ ('subject_alt_name', 'Server certificates should have SAN extension'), @@ -227,16 +223,16 @@ def validate_server_certificates_config(): ('digitalSignature', 'Server certificates should have digital signature usage'), ('keyEncipherment', 'Server certificates should have key encipherment usage') ] - + for check, message in server_checks: assert check in server_section, f"Missing server certificate configuration: {message}" - + # Security check: Server certificates should NOT have clientAuth (Issue #153) assert 'clientAuth' not in server_section, "Server certificates should NOT have clientAuth EKU for role separation" - + # Verify SAN extension is configured for Apple compatibility assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility" - + print("✓ Server certificate configuration has proper EKU and SAN settings") def test_server_certificates(): @@ -255,34 +251,34 @@ def validate_client_certificates_real(cert_files): for cert_path in cert_files['server_certs']: if 'cacert.pem' in cert_path: continue - + with open(cert_path, 'rb') as f: cert_data = f.read() certificate = x509.load_pem_x509_certificate(cert_data) - + # Check if this looks like a client cert vs server cert cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value # Server certs typically have IP addresses or domain names as CN if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4): client_certs.append((cert_path, certificate)) - + if not client_certs: print("⚠ No client certificates found") return - + for cert_path, certificate in client_certs: # Check it's not a CA certificate basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value assert basic_constraints.ca is False, "Client certificate should not be a CA" - + # Check Extended Key Usage restrictions (Issue #75) eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "Client cert must have clientAuth EKU" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU" - + # Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153) assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation" - + # Check SAN extension for email try: san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value @@ -290,7 +286,7 @@ def validate_client_certificates_real(cert_files): assert len(email_sans) > 0, "Client certificate should have email SAN" except x509.ExtensionNotFound: print(f"⚠ Client certificate missing SAN extension: {os.path.basename(cert_path)}") - + print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}") def validate_client_certificates_config(): @@ -299,18 +295,18 @@ def validate_client_certificates_config(): if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return - - with open(openssl_task_file, 'r') as f: + + with open(openssl_task_file) as f: content = f.read() - + # Look for client certificate CSR section client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL) if not client_csr_section: print("⚠ Could not find client certificate CSR section") return - + client_section = client_csr_section.group(0) - + # Check client certificate configuration client_checks = [ ('clientAuth', 'Client certificates should have clientAuth EKU'), @@ -319,16 +315,16 @@ def validate_client_certificates_config(): ('keyEncipherment', 'Client certificates should have key encipherment usage'), ('email:', 'Client certificates should have email SAN') ] - + for check, message in client_checks: assert check in client_section, f"Missing client certificate configuration: {message}" - + # Security check: Client certificates should NOT have serverAuth (Issue #153) assert 'serverAuth' not in client_section, "Client certificates must NOT have serverAuth EKU to prevent server impersonation" - + # Verify client certificates use unique email domains (Issue #153) assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment" - + print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)") def test_client_certificates(): @@ -345,28 +341,28 @@ def validate_pkcs12_files_real(cert_files): if not cert_files.get('p12_files'): print("⚠ No PKCS#12 files found") return - + major, minor = test_openssl_version_detection() - + for p12_file in cert_files['p12_files']: assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}" - + # Test that PKCS#12 file can be read (validates format) legacy_flag = ['-legacy'] if major >= 3 else [] - + result = subprocess.run([ 'openssl', 'pkcs12', '-info', '-in', p12_file, '-passin', 'pass:', # Try empty password first '-noout' ] + legacy_flag, capture_output=True, text=True) - + # PKCS#12 files should be readable (even if password-protected) # We're just testing format validity, not trying to extract contents if result.returncode != 0: # Try with common password patterns if empty password fails print(f"⚠ PKCS#12 file may require password: {os.path.basename(p12_file)}") - + print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}") def validate_pkcs12_files_config(): @@ -375,10 +371,10 @@ def validate_pkcs12_files_config(): if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return - - with open(openssl_task_file, 'r') as f: + + with open(openssl_task_file) as f: content = f.read() - + # Check PKCS#12 generation configuration p12_checks = [ ('openssl_pkcs12', 'PKCS#12 generation should be configured'), @@ -389,10 +385,10 @@ def validate_pkcs12_files_config(): ('passphrase', 'PKCS#12 files should be password protected'), ('mode: "0600"', 'PKCS#12 files should have secure permissions') ] - + for check, message in p12_checks: assert check in content, f"Missing PKCS#12 configuration: {message}" - + print("✓ PKCS#12 configuration has proper Apple device compatibility settings") def test_pkcs12_files(): @@ -410,31 +406,31 @@ def validate_certificate_chain_real(cert_files): with open(cert_files['ca_cert'], 'rb') as f: ca_cert_data = f.read() ca_certificate = x509.load_pem_x509_certificate(ca_cert_data) - + # Test that all other certificates are signed by the CA other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']] - + if not other_certs: print("⚠ No client/server certificates found to validate") return - + for cert_path in other_certs: with open(cert_path, 'rb') as f: cert_data = f.read() certificate = x509.load_pem_x509_certificate(cert_data) - + # Verify the certificate was signed by our CA assert certificate.issuer == ca_certificate.subject, f"Certificate {cert_path} not signed by CA" - + # Verify certificate is currently valid (not expired) from datetime import datetime, timezone now = datetime.now(timezone.utc) assert certificate.not_valid_before <= now, f"Certificate {cert_path} not yet valid" assert certificate.not_valid_after >= now, f"Certificate {cert_path} has expired" - + print(f"✓ Real certificate chain valid: {os.path.basename(cert_path)}") - - print(f"✓ All real certificates properly signed by CA") + + print("✓ All real certificates properly signed by CA") def validate_certificate_chain_config(): """Validate certificate chain configuration in Ansible files (CI mode)""" @@ -442,10 +438,10 @@ def validate_certificate_chain_config(): if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return - - with open(openssl_task_file, 'r') as f: + + with open(openssl_task_file) as f: content = f.read() - + # Check certificate signing configuration chain_checks = [ ('provider: ownca', 'Certificates should be signed by own CA'), @@ -457,10 +453,10 @@ def validate_certificate_chain_config(): ('curve: secp384r1', 'Should use strong elliptic curve cryptography'), ('type: ECC', 'Should use elliptic curve keys for better security') ] - + for check, message in chain_checks: assert check in content, f"Missing certificate chain configuration: {message}" - + print("✓ Certificate chain configuration properly set up for CA signing") def test_certificate_chain(): @@ -481,12 +477,12 @@ def find_ansible_file(relative_path): "../..", # Grandparent (from tests/unit to project root) "../../..", # Alternative deep path ] - + for base in possible_bases: full_path = os.path.join(base, relative_path) if os.path.exists(full_path): return full_path - + return None if __name__ == "__main__": diff --git a/tests/unit/test_wireguard_key_generation.py b/tests/unit/test_wireguard_key_generation.py index 6c5f9a13..f75e3b10 100644 --- a/tests/unit/test_wireguard_key_generation.py +++ b/tests/unit/test_wireguard_key_generation.py @@ -8,7 +8,6 @@ import os import subprocess import sys import tempfile -import shutil # Add library directory to path to import our custom module sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library')) @@ -29,7 +28,7 @@ def test_wireguard_tools_available(): def test_x25519_module_import(): """Test that our custom x25519_pubkey module can be imported and used""" try: - from x25519_pubkey import run_module + import x25519_pubkey # noqa: F401 print("✓ x25519_pubkey module imports successfully") return True except ImportError as e: @@ -40,24 +39,24 @@ def generate_test_private_key(): """Generate a test private key using the same method as Algo""" with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file: raw_key_path = temp_file.name - + try: # Generate 32 random bytes for X25519 private key (same as community.crypto does) import secrets raw_data = secrets.token_bytes(32) - + # Write raw key to file (like community.crypto openssl_privatekey with format: raw) with open(raw_key_path, 'wb') as f: f.write(raw_data) - + assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}" - + b64_key = base64.b64encode(raw_data).decode() - + print(f"✓ Generated private key (base64): {b64_key[:12]}...") - + return raw_key_path, b64_key - + except Exception: # Clean up on error if os.path.exists(raw_key_path): @@ -68,33 +67,32 @@ def generate_test_private_key(): def test_x25519_pubkey_from_raw_file(): """Test our x25519_pubkey module with raw private key file""" raw_key_path, b64_key = generate_test_private_key() - + try: # Import here so we can mock the module_utils if needed - from unittest.mock import Mock - + # Mock the AnsibleModule for testing class MockModule: def __init__(self, params): self.params = params self.result = {} - + def fail_json(self, **kwargs): raise Exception(f"Module failed: {kwargs}") - + def exit_json(self, **kwargs): self.result = kwargs - + with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub: public_key_path = temp_pub.name - + try: # Test the module logic directly - from x25519_pubkey import run_module import x25519_pubkey - + from x25519_pubkey import run_module + original_AnsibleModule = x25519_pubkey.AnsibleModule - + try: # Mock the module call mock_module = MockModule({ @@ -102,38 +100,38 @@ def test_x25519_pubkey_from_raw_file(): 'public_key_path': public_key_path, 'private_key_b64': None }) - + x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module - + # Run the module run_module() - + # Check the result assert 'public_key' in mock_module.result - assert mock_module.result['changed'] == True + assert mock_module.result['changed'] assert os.path.exists(public_key_path) - - with open(public_key_path, 'r') as f: + + with open(public_key_path) as f: derived_pubkey = f.read().strip() - + # Validate base64 format try: decoded = base64.b64decode(derived_pubkey, validate=True) assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}" except Exception as e: assert False, f"Invalid base64 public key: {e}" - + print(f"✓ Derived public key from raw file: {derived_pubkey[:12]}...") - + return derived_pubkey - + finally: x25519_pubkey.AnsibleModule = original_AnsibleModule - + finally: if os.path.exists(public_key_path): os.unlink(public_key_path) - + finally: if os.path.exists(raw_key_path): os.unlink(raw_key_path) @@ -142,56 +140,55 @@ def test_x25519_pubkey_from_raw_file(): def test_x25519_pubkey_from_b64_string(): """Test our x25519_pubkey module with base64 private key string""" raw_key_path, b64_key = generate_test_private_key() - + try: - from unittest.mock import Mock - + class MockModule: def __init__(self, params): self.params = params self.result = {} - + def fail_json(self, **kwargs): raise Exception(f"Module failed: {kwargs}") - + def exit_json(self, **kwargs): self.result = kwargs - - from x25519_pubkey import run_module + import x25519_pubkey - + from x25519_pubkey import run_module + original_AnsibleModule = x25519_pubkey.AnsibleModule - + try: mock_module = MockModule({ 'private_key_b64': b64_key, 'private_key_path': None, 'public_key_path': None }) - + x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module - + # Run the module run_module() - + # Check the result assert 'public_key' in mock_module.result derived_pubkey = mock_module.result['public_key'] - + # Validate base64 format try: decoded = base64.b64decode(derived_pubkey, validate=True) assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}" except Exception as e: assert False, f"Invalid base64 public key: {e}" - + print(f"✓ Derived public key from base64 string: {derived_pubkey[:12]}...") - + return derived_pubkey - + finally: x25519_pubkey.AnsibleModule = original_AnsibleModule - + finally: if os.path.exists(raw_key_path): os.unlink(raw_key_path) @@ -201,45 +198,44 @@ def test_wireguard_validation(): """Test that our derived keys work with actual WireGuard tools""" if not test_wireguard_tools_available(): return - + # Generate keys using our method raw_key_path, b64_key = generate_test_private_key() - + try: # Derive public key using our module - from unittest.mock import Mock - + class MockModule: def __init__(self, params): self.params = params self.result = {} - + def fail_json(self, **kwargs): raise Exception(f"Module failed: {kwargs}") - + def exit_json(self, **kwargs): self.result = kwargs - - from x25519_pubkey import run_module + import x25519_pubkey - + from x25519_pubkey import run_module + original_AnsibleModule = x25519_pubkey.AnsibleModule - + try: mock_module = MockModule({ 'private_key_b64': b64_key, 'private_key_path': None, 'public_key_path': None }) - + x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module run_module() - + derived_pubkey = mock_module.result['public_key'] - + finally: x25519_pubkey.AnsibleModule = original_AnsibleModule - + with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config: # Create a WireGuard config using our keys wg_config = f"""[Interface] @@ -252,33 +248,33 @@ AllowedIPs = 10.19.49.2/32 """ temp_config.write(wg_config) config_path = temp_config.name - + try: # Test that WireGuard can parse our config result = subprocess.run([ 'wg-quick', 'strip', config_path ], capture_output=True, text=True) - + assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}" - + # Test key derivation with wg pubkey command wg_result = subprocess.run([ 'wg', 'pubkey' ], input=b64_key, capture_output=True, text=True) - + if wg_result.returncode == 0: wg_derived = wg_result.stdout.strip() assert wg_derived == derived_pubkey, f"Key mismatch: wg={wg_derived} vs ours={derived_pubkey}" - print(f"✓ WireGuard validation: keys match wg pubkey output") + print("✓ WireGuard validation: keys match wg pubkey output") else: print(f"⚠ Could not validate with wg pubkey: {wg_result.stderr}") - + print("✓ WireGuard accepts our generated configuration") - + finally: if os.path.exists(config_path): os.unlink(config_path) - + finally: if os.path.exists(raw_key_path): os.unlink(raw_key_path) @@ -288,49 +284,48 @@ def test_key_consistency(): """Test that the same private key always produces the same public key""" # Generate one private key to reuse raw_key_path, b64_key = generate_test_private_key() - + try: def derive_pubkey_from_same_key(): - from unittest.mock import Mock - + class MockModule: def __init__(self, params): self.params = params self.result = {} - + def fail_json(self, **kwargs): raise Exception(f"Module failed: {kwargs}") - + def exit_json(self, **kwargs): self.result = kwargs - - from x25519_pubkey import run_module + import x25519_pubkey - + from x25519_pubkey import run_module + original_AnsibleModule = x25519_pubkey.AnsibleModule - + try: mock_module = MockModule({ 'private_key_b64': b64_key, # SAME key each time 'private_key_path': None, 'public_key_path': None }) - + x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module run_module() - + return mock_module.result['public_key'] - + finally: x25519_pubkey.AnsibleModule = original_AnsibleModule - + # Derive public key multiple times from same private key pubkey1 = derive_pubkey_from_same_key() pubkey2 = derive_pubkey_from_same_key() - + assert pubkey1 == pubkey2, f"Key derivation not consistent: {pubkey1} vs {pubkey2}" print("✓ Key derivation is consistent") - + finally: if os.path.exists(raw_key_path): os.unlink(raw_key_path) @@ -344,7 +339,7 @@ if __name__ == "__main__": test_key_consistency, test_wireguard_validation, ] - + failed = 0 for test in tests: try: @@ -355,9 +350,9 @@ if __name__ == "__main__": except Exception as e: print(f"✗ {test.__name__} error: {e}") failed += 1 - + if failed > 0: print(f"\n{failed} tests failed") sys.exit(1) else: - print(f"\nAll {len(tests)} tests passed!") \ No newline at end of file + print(f"\nAll {len(tests)} tests passed!")