Fix linting issues across the codebase

## Python Code Quality (ruff)
- Fixed import organization and removed unused imports in test files
- Replaced `== True` comparisons with direct boolean checks
- Added noqa comments for intentional imports in test modules

## YAML Formatting (yamllint)
- Removed trailing spaces in openssl.yml comments
- All YAML files now pass yamllint validation (except one pre-existing long regex line)

## Code Consistency
- Maintained proper import ordering in test files
- Ensured all code follows project linting standards
- Ready for CI pipeline validation

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Dan Guido 2025-08-04 22:13:48 -07:00
parent a6852f3ca6
commit e63a3d6357
3 changed files with 168 additions and 177 deletions

View file

@ -55,7 +55,7 @@
# CA can sign both server and client certs, restricted to VPN use only # CA can sign both server and client certs, restricted to VPN use only
extended_key_usage: extended_key_usage:
- serverAuth # Allows signing server certificates - serverAuth # Allows signing server certificates
- clientAuth # Allows signing client certificates - clientAuth # Allows signing client certificates
- '1.3.6.1.5.5.7.3.17' # IPsec End Entity OID - VPN-specific usage - '1.3.6.1.5.5.7.3.17' # IPsec End Entity OID - VPN-specific usage
extended_key_usage_critical: true extended_key_usage_critical: true
# Name constraints from defaults/main.yml template - prevents CA from issuing certs for public domains # Name constraints from defaults/main.yml template - prevents CA from issuing certs for public domains

View file

@ -5,18 +5,14 @@ Hybrid approach: validates actual certificates when available, else tests templa
Based on issues #14755, #14718 - Apple device compatibility Based on issues #14755, #14718 - Apple device compatibility
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions) Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
""" """
import os
import glob import glob
import os
import re import re
import subprocess import subprocess
import sys import sys
import yaml
import tempfile
import ipaddress
from pathlib import Path
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.primitives import serialization from cryptography.x509.oid import ExtensionOID, NameOID
from cryptography.x509.oid import NameOID, ExtensionOID
def find_generated_certificates(): def find_generated_certificates():
@ -27,7 +23,7 @@ def find_generated_certificates():
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory "../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
"../../configs/*/ipsec/.pki/cacert.pem" # Alternative path "../../configs/*/ipsec/.pki/cacert.pem" # Alternative path
] ]
for pattern in config_patterns: for pattern in config_patterns:
ca_certs = glob.glob(pattern) ca_certs = glob.glob(pattern)
if ca_certs: if ca_certs:
@ -38,7 +34,7 @@ def find_generated_certificates():
'server_certs': glob.glob(f"{base_path}/certs/*.crt"), 'server_certs': glob.glob(f"{base_path}/certs/*.crt"),
'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12") 'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12")
} }
return None return None
def test_openssl_version_detection(): def test_openssl_version_detection():
@ -67,49 +63,49 @@ def validate_ca_certificate_real(cert_files):
# Read the actual CA certificate generated by Ansible # Read the actual CA certificate generated by Ansible
with open(cert_files['ca_cert'], 'rb') as f: with open(cert_files['ca_cert'], 'rb') as f:
cert_data = f.read() cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data) certificate = x509.load_pem_x509_certificate(cert_data)
# Check Basic Constraints # Check Basic Constraints
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
assert basic_constraints.ca is True, "CA certificate should have CA:TRUE" assert basic_constraints.ca is True, "CA certificate should have CA:TRUE"
assert basic_constraints.path_length == 0, "CA should have pathlen:0 constraint" assert basic_constraints.path_length == 0, "CA should have pathlen:0 constraint"
# Check Key Usage # Check Key Usage
key_usage = certificate.extensions.get_extension_for_oid(ExtensionOID.KEY_USAGE).value key_usage = certificate.extensions.get_extension_for_oid(ExtensionOID.KEY_USAGE).value
assert key_usage.key_cert_sign is True, "CA should have keyCertSign usage" assert key_usage.key_cert_sign is True, "CA should have keyCertSign usage"
assert key_usage.crl_sign is True, "CA should have cRLSign usage" assert key_usage.crl_sign is True, "CA should have cRLSign usage"
# Check Extended Key Usage (Issue #75) # Check Extended Key Usage (Issue #75)
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "CA should allow signing server certificates" assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "CA should allow signing server certificates"
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "CA should allow signing client certificates" assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "CA should allow signing client certificates"
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "CA should have IPsec End Entity EKU" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "CA should have IPsec End Entity EKU"
# Check Name Constraints (Issue #75) - defense against certificate misuse # Check Name Constraints (Issue #75) - defense against certificate misuse
name_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.NAME_CONSTRAINTS).value name_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.NAME_CONSTRAINTS).value
assert name_constraints.permitted_subtrees is not None, "CA should have permitted name constraints" assert name_constraints.permitted_subtrees is not None, "CA should have permitted name constraints"
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints" assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
# Verify public domains are excluded # Verify public domains are excluded
excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.DNSName)] if isinstance(constraint, x509.DNSName)]
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in public_domains: for domain in public_domains:
assert domain in excluded_dns, f"CA should exclude public domain {domain}" assert domain in excluded_dns, f"CA should exclude public domain {domain}"
# Verify private IP ranges are excluded (Issue #75) # Verify private IP ranges are excluded (Issue #75)
excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.IPAddress)] if isinstance(constraint, x509.IPAddress)]
assert len(excluded_ips) > 0, "CA should exclude private IP ranges" assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
# Verify email domains are also excluded (Issue #153) # Verify email domains are also excluded (Issue #153)
excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.RFC822Name)] if isinstance(constraint, x509.RFC822Name)]
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in email_domains: for domain in email_domains:
assert domain in excluded_emails, f"CA should exclude email domain {domain}" assert domain in excluded_emails, f"CA should exclude email domain {domain}"
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}") print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
def validate_ca_certificate_config(): def validate_ca_certificate_config():
@ -119,10 +115,10 @@ def validate_ca_certificate_config():
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
with open(openssl_task_file, 'r') as f: with open(openssl_task_file) as f:
content = f.read() content = f.read()
# Verify key security configurations are present # Verify key security configurations are present
security_checks = [ security_checks = [
('name_constraints_permitted', 'Name constraints should be configured'), ('name_constraints_permitted', 'Name constraints should be configured'),
@ -135,28 +131,28 @@ def validate_ca_certificate_config():
('CA:TRUE', 'CA certificate should be marked as CA'), ('CA:TRUE', 'CA certificate should be marked as CA'),
('pathlen:0', 'Path length constraint should be set') ('pathlen:0', 'Path length constraint should be set')
] ]
for check, message in security_checks: for check, message in security_checks:
assert check in content, f"Missing security configuration: {message}" assert check in content, f"Missing security configuration: {message}"
# Verify public domains are excluded # Verify public domains are excluded
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in public_domains: for domain in public_domains:
assert f'"DNS:{domain}"' in content, f"Public domain {domain} should be excluded" assert f'"DNS:{domain}"' in content, f"Public domain {domain} should be excluded"
# Verify private IP ranges are excluded # Verify private IP ranges are excluded
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"] private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
for ip_range in private_ranges: for ip_range in private_ranges:
assert ip_range in content, f"Private IP range {ip_range} should be excluded" assert ip_range in content, f"Private IP range {ip_range} should be excluded"
# Verify email domains are excluded (Issue #153) # Verify email domains are excluded (Issue #153)
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in email_domains: for domain in email_domains:
assert f'"email:{domain}"' in content, f"Email domain {domain} should be excluded" assert f'"email:{domain}"' in content, f"Email domain {domain} should be excluded"
# Verify IPv6 constraints are present (Issue #153) # Verify IPv6 constraints are present (Issue #153)
assert "IP:0:0:0:0:0:0:0:0/0:0:0:0:0:0:0:0" in content, "IPv6 all-zeros should be excluded" assert "IP:0:0:0:0:0:0:0:0/0:0:0:0:0:0:0:0" in content, "IPv6 all-zeros should be excluded"
print("✓ CA certificate configuration has proper security constraints") print("✓ CA certificate configuration has proper security constraints")
def test_ca_certificate(): def test_ca_certificate():
@ -174,31 +170,31 @@ def validate_server_certificates_real(cert_files):
if not server_certs: if not server_certs:
print("⚠ No server certificates found") print("⚠ No server certificates found")
return return
for server_cert_path in server_certs: for server_cert_path in server_certs:
with open(server_cert_path, 'rb') as f: with open(server_cert_path, 'rb') as f:
cert_data = f.read() cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data) certificate = x509.load_pem_x509_certificate(cert_data)
# Check it's not a CA certificate # Check it's not a CA certificate
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
assert basic_constraints.ca is False, "Server certificate should not be a CA" assert basic_constraints.ca is False, "Server certificate should not be a CA"
# Check Extended Key Usage (Issue #75) # Check Extended Key Usage (Issue #75)
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU" assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153) # Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation" assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation"
# Check SAN extension exists (required for Apple devices) # Check SAN extension exists (required for Apple devices)
try: try:
san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value
assert len(san) > 0, "Server certificate must have SAN extension for Apple device compatibility" assert len(san) > 0, "Server certificate must have SAN extension for Apple device compatibility"
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
assert False, "Server certificate missing SAN extension - required for modern clients" assert False, "Server certificate missing SAN extension - required for modern clients"
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}") print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
def validate_server_certificates_config(): def validate_server_certificates_config():
@ -207,18 +203,18 @@ def validate_server_certificates_config():
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
with open(openssl_task_file, 'r') as f: with open(openssl_task_file) as f:
content = f.read() content = f.read()
# Look for server certificate CSR section # Look for server certificate CSR section
server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL) server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL)
if not server_csr_section: if not server_csr_section:
print("⚠ Could not find server certificate CSR section") print("⚠ Could not find server certificate CSR section")
return return
server_section = server_csr_section.group(0) server_section = server_csr_section.group(0)
# Check server certificate CSR configuration # Check server certificate CSR configuration
server_checks = [ server_checks = [
('subject_alt_name', 'Server certificates should have SAN extension'), ('subject_alt_name', 'Server certificates should have SAN extension'),
@ -227,16 +223,16 @@ def validate_server_certificates_config():
('digitalSignature', 'Server certificates should have digital signature usage'), ('digitalSignature', 'Server certificates should have digital signature usage'),
('keyEncipherment', 'Server certificates should have key encipherment usage') ('keyEncipherment', 'Server certificates should have key encipherment usage')
] ]
for check, message in server_checks: for check, message in server_checks:
assert check in server_section, f"Missing server certificate configuration: {message}" assert check in server_section, f"Missing server certificate configuration: {message}"
# Security check: Server certificates should NOT have clientAuth (Issue #153) # Security check: Server certificates should NOT have clientAuth (Issue #153)
assert 'clientAuth' not in server_section, "Server certificates should NOT have clientAuth EKU for role separation" assert 'clientAuth' not in server_section, "Server certificates should NOT have clientAuth EKU for role separation"
# Verify SAN extension is configured for Apple compatibility # Verify SAN extension is configured for Apple compatibility
assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility" assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility"
print("✓ Server certificate configuration has proper EKU and SAN settings") print("✓ Server certificate configuration has proper EKU and SAN settings")
def test_server_certificates(): def test_server_certificates():
@ -255,34 +251,34 @@ def validate_client_certificates_real(cert_files):
for cert_path in cert_files['server_certs']: for cert_path in cert_files['server_certs']:
if 'cacert.pem' in cert_path: if 'cacert.pem' in cert_path:
continue continue
with open(cert_path, 'rb') as f: with open(cert_path, 'rb') as f:
cert_data = f.read() cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data) certificate = x509.load_pem_x509_certificate(cert_data)
# Check if this looks like a client cert vs server cert # Check if this looks like a client cert vs server cert
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
# Server certs typically have IP addresses or domain names as CN # Server certs typically have IP addresses or domain names as CN
if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4): if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4):
client_certs.append((cert_path, certificate)) client_certs.append((cert_path, certificate))
if not client_certs: if not client_certs:
print("⚠ No client certificates found") print("⚠ No client certificates found")
return return
for cert_path, certificate in client_certs: for cert_path, certificate in client_certs:
# Check it's not a CA certificate # Check it's not a CA certificate
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
assert basic_constraints.ca is False, "Client certificate should not be a CA" assert basic_constraints.ca is False, "Client certificate should not be a CA"
# Check Extended Key Usage restrictions (Issue #75) # Check Extended Key Usage restrictions (Issue #75)
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "Client cert must have clientAuth EKU" assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "Client cert must have clientAuth EKU"
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153) # Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation" assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation"
# Check SAN extension for email # Check SAN extension for email
try: try:
san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value
@ -290,7 +286,7 @@ def validate_client_certificates_real(cert_files):
assert len(email_sans) > 0, "Client certificate should have email SAN" assert len(email_sans) > 0, "Client certificate should have email SAN"
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
print(f"⚠ Client certificate missing SAN extension: {os.path.basename(cert_path)}") print(f"⚠ Client certificate missing SAN extension: {os.path.basename(cert_path)}")
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}") print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
def validate_client_certificates_config(): def validate_client_certificates_config():
@ -299,18 +295,18 @@ def validate_client_certificates_config():
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
with open(openssl_task_file, 'r') as f: with open(openssl_task_file) as f:
content = f.read() content = f.read()
# Look for client certificate CSR section # Look for client certificate CSR section
client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL) client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL)
if not client_csr_section: if not client_csr_section:
print("⚠ Could not find client certificate CSR section") print("⚠ Could not find client certificate CSR section")
return return
client_section = client_csr_section.group(0) client_section = client_csr_section.group(0)
# Check client certificate configuration # Check client certificate configuration
client_checks = [ client_checks = [
('clientAuth', 'Client certificates should have clientAuth EKU'), ('clientAuth', 'Client certificates should have clientAuth EKU'),
@ -319,16 +315,16 @@ def validate_client_certificates_config():
('keyEncipherment', 'Client certificates should have key encipherment usage'), ('keyEncipherment', 'Client certificates should have key encipherment usage'),
('email:', 'Client certificates should have email SAN') ('email:', 'Client certificates should have email SAN')
] ]
for check, message in client_checks: for check, message in client_checks:
assert check in client_section, f"Missing client certificate configuration: {message}" assert check in client_section, f"Missing client certificate configuration: {message}"
# Security check: Client certificates should NOT have serverAuth (Issue #153) # Security check: Client certificates should NOT have serverAuth (Issue #153)
assert 'serverAuth' not in client_section, "Client certificates must NOT have serverAuth EKU to prevent server impersonation" assert 'serverAuth' not in client_section, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
# Verify client certificates use unique email domains (Issue #153) # Verify client certificates use unique email domains (Issue #153)
assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment" assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment"
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)") print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
def test_client_certificates(): def test_client_certificates():
@ -345,28 +341,28 @@ def validate_pkcs12_files_real(cert_files):
if not cert_files.get('p12_files'): if not cert_files.get('p12_files'):
print("⚠ No PKCS#12 files found") print("⚠ No PKCS#12 files found")
return return
major, minor = test_openssl_version_detection() major, minor = test_openssl_version_detection()
for p12_file in cert_files['p12_files']: for p12_file in cert_files['p12_files']:
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}" assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
# Test that PKCS#12 file can be read (validates format) # Test that PKCS#12 file can be read (validates format)
legacy_flag = ['-legacy'] if major >= 3 else [] legacy_flag = ['-legacy'] if major >= 3 else []
result = subprocess.run([ result = subprocess.run([
'openssl', 'pkcs12', '-info', 'openssl', 'pkcs12', '-info',
'-in', p12_file, '-in', p12_file,
'-passin', 'pass:', # Try empty password first '-passin', 'pass:', # Try empty password first
'-noout' '-noout'
] + legacy_flag, capture_output=True, text=True) ] + legacy_flag, capture_output=True, text=True)
# PKCS#12 files should be readable (even if password-protected) # PKCS#12 files should be readable (even if password-protected)
# We're just testing format validity, not trying to extract contents # We're just testing format validity, not trying to extract contents
if result.returncode != 0: if result.returncode != 0:
# Try with common password patterns if empty password fails # Try with common password patterns if empty password fails
print(f"⚠ PKCS#12 file may require password: {os.path.basename(p12_file)}") print(f"⚠ PKCS#12 file may require password: {os.path.basename(p12_file)}")
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}") print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
def validate_pkcs12_files_config(): def validate_pkcs12_files_config():
@ -375,10 +371,10 @@ def validate_pkcs12_files_config():
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
with open(openssl_task_file, 'r') as f: with open(openssl_task_file) as f:
content = f.read() content = f.read()
# Check PKCS#12 generation configuration # Check PKCS#12 generation configuration
p12_checks = [ p12_checks = [
('openssl_pkcs12', 'PKCS#12 generation should be configured'), ('openssl_pkcs12', 'PKCS#12 generation should be configured'),
@ -389,10 +385,10 @@ def validate_pkcs12_files_config():
('passphrase', 'PKCS#12 files should be password protected'), ('passphrase', 'PKCS#12 files should be password protected'),
('mode: "0600"', 'PKCS#12 files should have secure permissions') ('mode: "0600"', 'PKCS#12 files should have secure permissions')
] ]
for check, message in p12_checks: for check, message in p12_checks:
assert check in content, f"Missing PKCS#12 configuration: {message}" assert check in content, f"Missing PKCS#12 configuration: {message}"
print("✓ PKCS#12 configuration has proper Apple device compatibility settings") print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
def test_pkcs12_files(): def test_pkcs12_files():
@ -410,31 +406,31 @@ def validate_certificate_chain_real(cert_files):
with open(cert_files['ca_cert'], 'rb') as f: with open(cert_files['ca_cert'], 'rb') as f:
ca_cert_data = f.read() ca_cert_data = f.read()
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data) ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
# Test that all other certificates are signed by the CA # Test that all other certificates are signed by the CA
other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']] other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']]
if not other_certs: if not other_certs:
print("⚠ No client/server certificates found to validate") print("⚠ No client/server certificates found to validate")
return return
for cert_path in other_certs: for cert_path in other_certs:
with open(cert_path, 'rb') as f: with open(cert_path, 'rb') as f:
cert_data = f.read() cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data) certificate = x509.load_pem_x509_certificate(cert_data)
# Verify the certificate was signed by our CA # Verify the certificate was signed by our CA
assert certificate.issuer == ca_certificate.subject, f"Certificate {cert_path} not signed by CA" assert certificate.issuer == ca_certificate.subject, f"Certificate {cert_path} not signed by CA"
# Verify certificate is currently valid (not expired) # Verify certificate is currently valid (not expired)
from datetime import datetime, timezone from datetime import datetime, timezone
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
assert certificate.not_valid_before <= now, f"Certificate {cert_path} not yet valid" assert certificate.not_valid_before <= now, f"Certificate {cert_path} not yet valid"
assert certificate.not_valid_after >= now, f"Certificate {cert_path} has expired" assert certificate.not_valid_after >= now, f"Certificate {cert_path} has expired"
print(f"✓ Real certificate chain valid: {os.path.basename(cert_path)}") print(f"✓ Real certificate chain valid: {os.path.basename(cert_path)}")
print(f"✓ All real certificates properly signed by CA") print("✓ All real certificates properly signed by CA")
def validate_certificate_chain_config(): def validate_certificate_chain_config():
"""Validate certificate chain configuration in Ansible files (CI mode)""" """Validate certificate chain configuration in Ansible files (CI mode)"""
@ -442,10 +438,10 @@ def validate_certificate_chain_config():
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
with open(openssl_task_file, 'r') as f: with open(openssl_task_file) as f:
content = f.read() content = f.read()
# Check certificate signing configuration # Check certificate signing configuration
chain_checks = [ chain_checks = [
('provider: ownca', 'Certificates should be signed by own CA'), ('provider: ownca', 'Certificates should be signed by own CA'),
@ -457,10 +453,10 @@ def validate_certificate_chain_config():
('curve: secp384r1', 'Should use strong elliptic curve cryptography'), ('curve: secp384r1', 'Should use strong elliptic curve cryptography'),
('type: ECC', 'Should use elliptic curve keys for better security') ('type: ECC', 'Should use elliptic curve keys for better security')
] ]
for check, message in chain_checks: for check, message in chain_checks:
assert check in content, f"Missing certificate chain configuration: {message}" assert check in content, f"Missing certificate chain configuration: {message}"
print("✓ Certificate chain configuration properly set up for CA signing") print("✓ Certificate chain configuration properly set up for CA signing")
def test_certificate_chain(): def test_certificate_chain():
@ -481,12 +477,12 @@ def find_ansible_file(relative_path):
"../..", # Grandparent (from tests/unit to project root) "../..", # Grandparent (from tests/unit to project root)
"../../..", # Alternative deep path "../../..", # Alternative deep path
] ]
for base in possible_bases: for base in possible_bases:
full_path = os.path.join(base, relative_path) full_path = os.path.join(base, relative_path)
if os.path.exists(full_path): if os.path.exists(full_path):
return full_path return full_path
return None return None
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -8,7 +8,6 @@ import os
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import shutil
# Add library directory to path to import our custom module # Add library directory to path to import our custom module
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library'))
@ -29,7 +28,7 @@ def test_wireguard_tools_available():
def test_x25519_module_import(): def test_x25519_module_import():
"""Test that our custom x25519_pubkey module can be imported and used""" """Test that our custom x25519_pubkey module can be imported and used"""
try: try:
from x25519_pubkey import run_module import x25519_pubkey # noqa: F401
print("✓ x25519_pubkey module imports successfully") print("✓ x25519_pubkey module imports successfully")
return True return True
except ImportError as e: except ImportError as e:
@ -40,24 +39,24 @@ def generate_test_private_key():
"""Generate a test private key using the same method as Algo""" """Generate a test private key using the same method as Algo"""
with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file:
raw_key_path = temp_file.name raw_key_path = temp_file.name
try: try:
# Generate 32 random bytes for X25519 private key (same as community.crypto does) # Generate 32 random bytes for X25519 private key (same as community.crypto does)
import secrets import secrets
raw_data = secrets.token_bytes(32) raw_data = secrets.token_bytes(32)
# Write raw key to file (like community.crypto openssl_privatekey with format: raw) # Write raw key to file (like community.crypto openssl_privatekey with format: raw)
with open(raw_key_path, 'wb') as f: with open(raw_key_path, 'wb') as f:
f.write(raw_data) f.write(raw_data)
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}" assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
b64_key = base64.b64encode(raw_data).decode() b64_key = base64.b64encode(raw_data).decode()
print(f"✓ Generated private key (base64): {b64_key[:12]}...") print(f"✓ Generated private key (base64): {b64_key[:12]}...")
return raw_key_path, b64_key return raw_key_path, b64_key
except Exception: except Exception:
# Clean up on error # Clean up on error
if os.path.exists(raw_key_path): if os.path.exists(raw_key_path):
@ -68,33 +67,32 @@ def generate_test_private_key():
def test_x25519_pubkey_from_raw_file(): def test_x25519_pubkey_from_raw_file():
"""Test our x25519_pubkey module with raw private key file""" """Test our x25519_pubkey module with raw private key file"""
raw_key_path, b64_key = generate_test_private_key() raw_key_path, b64_key = generate_test_private_key()
try: try:
# Import here so we can mock the module_utils if needed # Import here so we can mock the module_utils if needed
from unittest.mock import Mock
# Mock the AnsibleModule for testing # Mock the AnsibleModule for testing
class MockModule: class MockModule:
def __init__(self, params): def __init__(self, params):
self.params = params self.params = params
self.result = {} self.result = {}
def fail_json(self, **kwargs): def fail_json(self, **kwargs):
raise Exception(f"Module failed: {kwargs}") raise Exception(f"Module failed: {kwargs}")
def exit_json(self, **kwargs): def exit_json(self, **kwargs):
self.result = kwargs self.result = kwargs
with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub: with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub:
public_key_path = temp_pub.name public_key_path = temp_pub.name
try: try:
# Test the module logic directly # Test the module logic directly
from x25519_pubkey import run_module
import x25519_pubkey import x25519_pubkey
from x25519_pubkey import run_module
original_AnsibleModule = x25519_pubkey.AnsibleModule original_AnsibleModule = x25519_pubkey.AnsibleModule
try: try:
# Mock the module call # Mock the module call
mock_module = MockModule({ mock_module = MockModule({
@ -102,38 +100,38 @@ def test_x25519_pubkey_from_raw_file():
'public_key_path': public_key_path, 'public_key_path': public_key_path,
'private_key_b64': None 'private_key_b64': None
}) })
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
# Run the module # Run the module
run_module() run_module()
# Check the result # Check the result
assert 'public_key' in mock_module.result assert 'public_key' in mock_module.result
assert mock_module.result['changed'] == True assert mock_module.result['changed']
assert os.path.exists(public_key_path) assert os.path.exists(public_key_path)
with open(public_key_path, 'r') as f: with open(public_key_path) as f:
derived_pubkey = f.read().strip() derived_pubkey = f.read().strip()
# Validate base64 format # Validate base64 format
try: try:
decoded = base64.b64decode(derived_pubkey, validate=True) decoded = base64.b64decode(derived_pubkey, validate=True)
assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}" assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}"
except Exception as e: except Exception as e:
assert False, f"Invalid base64 public key: {e}" assert False, f"Invalid base64 public key: {e}"
print(f"✓ Derived public key from raw file: {derived_pubkey[:12]}...") print(f"✓ Derived public key from raw file: {derived_pubkey[:12]}...")
return derived_pubkey return derived_pubkey
finally: finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule x25519_pubkey.AnsibleModule = original_AnsibleModule
finally: finally:
if os.path.exists(public_key_path): if os.path.exists(public_key_path):
os.unlink(public_key_path) os.unlink(public_key_path)
finally: finally:
if os.path.exists(raw_key_path): if os.path.exists(raw_key_path):
os.unlink(raw_key_path) os.unlink(raw_key_path)
@ -142,56 +140,55 @@ def test_x25519_pubkey_from_raw_file():
def test_x25519_pubkey_from_b64_string(): def test_x25519_pubkey_from_b64_string():
"""Test our x25519_pubkey module with base64 private key string""" """Test our x25519_pubkey module with base64 private key string"""
raw_key_path, b64_key = generate_test_private_key() raw_key_path, b64_key = generate_test_private_key()
try: try:
from unittest.mock import Mock
class MockModule: class MockModule:
def __init__(self, params): def __init__(self, params):
self.params = params self.params = params
self.result = {} self.result = {}
def fail_json(self, **kwargs): def fail_json(self, **kwargs):
raise Exception(f"Module failed: {kwargs}") raise Exception(f"Module failed: {kwargs}")
def exit_json(self, **kwargs): def exit_json(self, **kwargs):
self.result = kwargs self.result = kwargs
from x25519_pubkey import run_module
import x25519_pubkey import x25519_pubkey
from x25519_pubkey import run_module
original_AnsibleModule = x25519_pubkey.AnsibleModule original_AnsibleModule = x25519_pubkey.AnsibleModule
try: try:
mock_module = MockModule({ mock_module = MockModule({
'private_key_b64': b64_key, 'private_key_b64': b64_key,
'private_key_path': None, 'private_key_path': None,
'public_key_path': None 'public_key_path': None
}) })
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
# Run the module # Run the module
run_module() run_module()
# Check the result # Check the result
assert 'public_key' in mock_module.result assert 'public_key' in mock_module.result
derived_pubkey = mock_module.result['public_key'] derived_pubkey = mock_module.result['public_key']
# Validate base64 format # Validate base64 format
try: try:
decoded = base64.b64decode(derived_pubkey, validate=True) decoded = base64.b64decode(derived_pubkey, validate=True)
assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}" assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}"
except Exception as e: except Exception as e:
assert False, f"Invalid base64 public key: {e}" assert False, f"Invalid base64 public key: {e}"
print(f"✓ Derived public key from base64 string: {derived_pubkey[:12]}...") print(f"✓ Derived public key from base64 string: {derived_pubkey[:12]}...")
return derived_pubkey return derived_pubkey
finally: finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule x25519_pubkey.AnsibleModule = original_AnsibleModule
finally: finally:
if os.path.exists(raw_key_path): if os.path.exists(raw_key_path):
os.unlink(raw_key_path) os.unlink(raw_key_path)
@ -201,45 +198,44 @@ def test_wireguard_validation():
"""Test that our derived keys work with actual WireGuard tools""" """Test that our derived keys work with actual WireGuard tools"""
if not test_wireguard_tools_available(): if not test_wireguard_tools_available():
return return
# Generate keys using our method # Generate keys using our method
raw_key_path, b64_key = generate_test_private_key() raw_key_path, b64_key = generate_test_private_key()
try: try:
# Derive public key using our module # Derive public key using our module
from unittest.mock import Mock
class MockModule: class MockModule:
def __init__(self, params): def __init__(self, params):
self.params = params self.params = params
self.result = {} self.result = {}
def fail_json(self, **kwargs): def fail_json(self, **kwargs):
raise Exception(f"Module failed: {kwargs}") raise Exception(f"Module failed: {kwargs}")
def exit_json(self, **kwargs): def exit_json(self, **kwargs):
self.result = kwargs self.result = kwargs
from x25519_pubkey import run_module
import x25519_pubkey import x25519_pubkey
from x25519_pubkey import run_module
original_AnsibleModule = x25519_pubkey.AnsibleModule original_AnsibleModule = x25519_pubkey.AnsibleModule
try: try:
mock_module = MockModule({ mock_module = MockModule({
'private_key_b64': b64_key, 'private_key_b64': b64_key,
'private_key_path': None, 'private_key_path': None,
'public_key_path': None 'public_key_path': None
}) })
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
run_module() run_module()
derived_pubkey = mock_module.result['public_key'] derived_pubkey = mock_module.result['public_key']
finally: finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule x25519_pubkey.AnsibleModule = original_AnsibleModule
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config: with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config:
# Create a WireGuard config using our keys # Create a WireGuard config using our keys
wg_config = f"""[Interface] wg_config = f"""[Interface]
@ -252,33 +248,33 @@ AllowedIPs = 10.19.49.2/32
""" """
temp_config.write(wg_config) temp_config.write(wg_config)
config_path = temp_config.name config_path = temp_config.name
try: try:
# Test that WireGuard can parse our config # Test that WireGuard can parse our config
result = subprocess.run([ result = subprocess.run([
'wg-quick', 'strip', config_path 'wg-quick', 'strip', config_path
], capture_output=True, text=True) ], capture_output=True, text=True)
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}" assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
# Test key derivation with wg pubkey command # Test key derivation with wg pubkey command
wg_result = subprocess.run([ wg_result = subprocess.run([
'wg', 'pubkey' 'wg', 'pubkey'
], input=b64_key, capture_output=True, text=True) ], input=b64_key, capture_output=True, text=True)
if wg_result.returncode == 0: if wg_result.returncode == 0:
wg_derived = wg_result.stdout.strip() wg_derived = wg_result.stdout.strip()
assert wg_derived == derived_pubkey, f"Key mismatch: wg={wg_derived} vs ours={derived_pubkey}" assert wg_derived == derived_pubkey, f"Key mismatch: wg={wg_derived} vs ours={derived_pubkey}"
print(f"✓ WireGuard validation: keys match wg pubkey output") print("✓ WireGuard validation: keys match wg pubkey output")
else: else:
print(f"⚠ Could not validate with wg pubkey: {wg_result.stderr}") print(f"⚠ Could not validate with wg pubkey: {wg_result.stderr}")
print("✓ WireGuard accepts our generated configuration") print("✓ WireGuard accepts our generated configuration")
finally: finally:
if os.path.exists(config_path): if os.path.exists(config_path):
os.unlink(config_path) os.unlink(config_path)
finally: finally:
if os.path.exists(raw_key_path): if os.path.exists(raw_key_path):
os.unlink(raw_key_path) os.unlink(raw_key_path)
@ -288,49 +284,48 @@ def test_key_consistency():
"""Test that the same private key always produces the same public key""" """Test that the same private key always produces the same public key"""
# Generate one private key to reuse # Generate one private key to reuse
raw_key_path, b64_key = generate_test_private_key() raw_key_path, b64_key = generate_test_private_key()
try: try:
def derive_pubkey_from_same_key(): def derive_pubkey_from_same_key():
from unittest.mock import Mock
class MockModule: class MockModule:
def __init__(self, params): def __init__(self, params):
self.params = params self.params = params
self.result = {} self.result = {}
def fail_json(self, **kwargs): def fail_json(self, **kwargs):
raise Exception(f"Module failed: {kwargs}") raise Exception(f"Module failed: {kwargs}")
def exit_json(self, **kwargs): def exit_json(self, **kwargs):
self.result = kwargs self.result = kwargs
from x25519_pubkey import run_module
import x25519_pubkey import x25519_pubkey
from x25519_pubkey import run_module
original_AnsibleModule = x25519_pubkey.AnsibleModule original_AnsibleModule = x25519_pubkey.AnsibleModule
try: try:
mock_module = MockModule({ mock_module = MockModule({
'private_key_b64': b64_key, # SAME key each time 'private_key_b64': b64_key, # SAME key each time
'private_key_path': None, 'private_key_path': None,
'public_key_path': None 'public_key_path': None
}) })
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
run_module() run_module()
return mock_module.result['public_key'] return mock_module.result['public_key']
finally: finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule x25519_pubkey.AnsibleModule = original_AnsibleModule
# Derive public key multiple times from same private key # Derive public key multiple times from same private key
pubkey1 = derive_pubkey_from_same_key() pubkey1 = derive_pubkey_from_same_key()
pubkey2 = derive_pubkey_from_same_key() pubkey2 = derive_pubkey_from_same_key()
assert pubkey1 == pubkey2, f"Key derivation not consistent: {pubkey1} vs {pubkey2}" assert pubkey1 == pubkey2, f"Key derivation not consistent: {pubkey1} vs {pubkey2}"
print("✓ Key derivation is consistent") print("✓ Key derivation is consistent")
finally: finally:
if os.path.exists(raw_key_path): if os.path.exists(raw_key_path):
os.unlink(raw_key_path) os.unlink(raw_key_path)
@ -344,7 +339,7 @@ if __name__ == "__main__":
test_key_consistency, test_key_consistency,
test_wireguard_validation, test_wireguard_validation,
] ]
failed = 0 failed = 0
for test in tests: for test in tests:
try: try:
@ -355,9 +350,9 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"{test.__name__} error: {e}") print(f"{test.__name__} error: {e}")
failed += 1 failed += 1
if failed > 0: if failed > 0:
print(f"\n{failed} tests failed") print(f"\n{failed} tests failed")
sys.exit(1) sys.exit(1)
else: else:
print(f"\nAll {len(tests)} tests passed!") print(f"\nAll {len(tests)} tests passed!")