mirror of
https://github.com/trailofbits/algo.git
synced 2025-09-06 12:03:38 +02:00
Fix linting issues across the codebase
## Python Code Quality (ruff) - Fixed import organization and removed unused imports in test files - Replaced `== True` comparisons with direct boolean checks - Added noqa comments for intentional imports in test modules ## YAML Formatting (yamllint) - Removed trailing spaces in openssl.yml comments - All YAML files now pass yamllint validation (except one pre-existing long regex line) ## Code Consistency - Maintained proper import ordering in test files - Ensured all code follows project linting standards - Ready for CI pipeline validation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
a6852f3ca6
commit
e63a3d6357
3 changed files with 168 additions and 177 deletions
|
@ -55,7 +55,7 @@
|
||||||
# CA can sign both server and client certs, restricted to VPN use only
|
# CA can sign both server and client certs, restricted to VPN use only
|
||||||
extended_key_usage:
|
extended_key_usage:
|
||||||
- serverAuth # Allows signing server certificates
|
- serverAuth # Allows signing server certificates
|
||||||
- clientAuth # Allows signing client certificates
|
- clientAuth # Allows signing client certificates
|
||||||
- '1.3.6.1.5.5.7.3.17' # IPsec End Entity OID - VPN-specific usage
|
- '1.3.6.1.5.5.7.3.17' # IPsec End Entity OID - VPN-specific usage
|
||||||
extended_key_usage_critical: true
|
extended_key_usage_critical: true
|
||||||
# Name constraints from defaults/main.yml template - prevents CA from issuing certs for public domains
|
# Name constraints from defaults/main.yml template - prevents CA from issuing certs for public domains
|
||||||
|
|
|
@ -5,18 +5,14 @@ Hybrid approach: validates actual certificates when available, else tests templa
|
||||||
Based on issues #14755, #14718 - Apple device compatibility
|
Based on issues #14755, #14718 - Apple device compatibility
|
||||||
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
|
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import glob
|
import glob
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import yaml
|
|
||||||
import tempfile
|
|
||||||
import ipaddress
|
|
||||||
from pathlib import Path
|
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.x509.oid import ExtensionOID, NameOID
|
||||||
from cryptography.x509.oid import NameOID, ExtensionOID
|
|
||||||
|
|
||||||
|
|
||||||
def find_generated_certificates():
|
def find_generated_certificates():
|
||||||
|
@ -27,7 +23,7 @@ def find_generated_certificates():
|
||||||
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
|
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
|
||||||
"../../configs/*/ipsec/.pki/cacert.pem" # Alternative path
|
"../../configs/*/ipsec/.pki/cacert.pem" # Alternative path
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern in config_patterns:
|
for pattern in config_patterns:
|
||||||
ca_certs = glob.glob(pattern)
|
ca_certs = glob.glob(pattern)
|
||||||
if ca_certs:
|
if ca_certs:
|
||||||
|
@ -38,7 +34,7 @@ def find_generated_certificates():
|
||||||
'server_certs': glob.glob(f"{base_path}/certs/*.crt"),
|
'server_certs': glob.glob(f"{base_path}/certs/*.crt"),
|
||||||
'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12")
|
'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12")
|
||||||
}
|
}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def test_openssl_version_detection():
|
def test_openssl_version_detection():
|
||||||
|
@ -67,49 +63,49 @@ def validate_ca_certificate_real(cert_files):
|
||||||
# Read the actual CA certificate generated by Ansible
|
# Read the actual CA certificate generated by Ansible
|
||||||
with open(cert_files['ca_cert'], 'rb') as f:
|
with open(cert_files['ca_cert'], 'rb') as f:
|
||||||
cert_data = f.read()
|
cert_data = f.read()
|
||||||
|
|
||||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||||
|
|
||||||
# Check Basic Constraints
|
# Check Basic Constraints
|
||||||
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
|
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
|
||||||
assert basic_constraints.ca is True, "CA certificate should have CA:TRUE"
|
assert basic_constraints.ca is True, "CA certificate should have CA:TRUE"
|
||||||
assert basic_constraints.path_length == 0, "CA should have pathlen:0 constraint"
|
assert basic_constraints.path_length == 0, "CA should have pathlen:0 constraint"
|
||||||
|
|
||||||
# Check Key Usage
|
# Check Key Usage
|
||||||
key_usage = certificate.extensions.get_extension_for_oid(ExtensionOID.KEY_USAGE).value
|
key_usage = certificate.extensions.get_extension_for_oid(ExtensionOID.KEY_USAGE).value
|
||||||
assert key_usage.key_cert_sign is True, "CA should have keyCertSign usage"
|
assert key_usage.key_cert_sign is True, "CA should have keyCertSign usage"
|
||||||
assert key_usage.crl_sign is True, "CA should have cRLSign usage"
|
assert key_usage.crl_sign is True, "CA should have cRLSign usage"
|
||||||
|
|
||||||
# Check Extended Key Usage (Issue #75)
|
# Check Extended Key Usage (Issue #75)
|
||||||
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
|
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
|
||||||
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "CA should allow signing server certificates"
|
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "CA should allow signing server certificates"
|
||||||
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "CA should allow signing client certificates"
|
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "CA should allow signing client certificates"
|
||||||
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "CA should have IPsec End Entity EKU"
|
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "CA should have IPsec End Entity EKU"
|
||||||
|
|
||||||
# Check Name Constraints (Issue #75) - defense against certificate misuse
|
# Check Name Constraints (Issue #75) - defense against certificate misuse
|
||||||
name_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.NAME_CONSTRAINTS).value
|
name_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.NAME_CONSTRAINTS).value
|
||||||
assert name_constraints.permitted_subtrees is not None, "CA should have permitted name constraints"
|
assert name_constraints.permitted_subtrees is not None, "CA should have permitted name constraints"
|
||||||
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
|
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
|
||||||
|
|
||||||
# Verify public domains are excluded
|
# Verify public domains are excluded
|
||||||
excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees
|
excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees
|
||||||
if isinstance(constraint, x509.DNSName)]
|
if isinstance(constraint, x509.DNSName)]
|
||||||
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||||
for domain in public_domains:
|
for domain in public_domains:
|
||||||
assert domain in excluded_dns, f"CA should exclude public domain {domain}"
|
assert domain in excluded_dns, f"CA should exclude public domain {domain}"
|
||||||
|
|
||||||
# Verify private IP ranges are excluded (Issue #75)
|
# Verify private IP ranges are excluded (Issue #75)
|
||||||
excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees
|
excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees
|
||||||
if isinstance(constraint, x509.IPAddress)]
|
if isinstance(constraint, x509.IPAddress)]
|
||||||
assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
|
assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
|
||||||
|
|
||||||
# Verify email domains are also excluded (Issue #153)
|
# Verify email domains are also excluded (Issue #153)
|
||||||
excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees
|
excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees
|
||||||
if isinstance(constraint, x509.RFC822Name)]
|
if isinstance(constraint, x509.RFC822Name)]
|
||||||
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||||
for domain in email_domains:
|
for domain in email_domains:
|
||||||
assert domain in excluded_emails, f"CA should exclude email domain {domain}"
|
assert domain in excluded_emails, f"CA should exclude email domain {domain}"
|
||||||
|
|
||||||
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
|
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
|
||||||
|
|
||||||
def validate_ca_certificate_config():
|
def validate_ca_certificate_config():
|
||||||
|
@ -119,10 +115,10 @@ def validate_ca_certificate_config():
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(openssl_task_file, 'r') as f:
|
with open(openssl_task_file) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Verify key security configurations are present
|
# Verify key security configurations are present
|
||||||
security_checks = [
|
security_checks = [
|
||||||
('name_constraints_permitted', 'Name constraints should be configured'),
|
('name_constraints_permitted', 'Name constraints should be configured'),
|
||||||
|
@ -135,28 +131,28 @@ def validate_ca_certificate_config():
|
||||||
('CA:TRUE', 'CA certificate should be marked as CA'),
|
('CA:TRUE', 'CA certificate should be marked as CA'),
|
||||||
('pathlen:0', 'Path length constraint should be set')
|
('pathlen:0', 'Path length constraint should be set')
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in security_checks:
|
for check, message in security_checks:
|
||||||
assert check in content, f"Missing security configuration: {message}"
|
assert check in content, f"Missing security configuration: {message}"
|
||||||
|
|
||||||
# Verify public domains are excluded
|
# Verify public domains are excluded
|
||||||
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||||
for domain in public_domains:
|
for domain in public_domains:
|
||||||
assert f'"DNS:{domain}"' in content, f"Public domain {domain} should be excluded"
|
assert f'"DNS:{domain}"' in content, f"Public domain {domain} should be excluded"
|
||||||
|
|
||||||
# Verify private IP ranges are excluded
|
# Verify private IP ranges are excluded
|
||||||
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
|
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
|
||||||
for ip_range in private_ranges:
|
for ip_range in private_ranges:
|
||||||
assert ip_range in content, f"Private IP range {ip_range} should be excluded"
|
assert ip_range in content, f"Private IP range {ip_range} should be excluded"
|
||||||
|
|
||||||
# Verify email domains are excluded (Issue #153)
|
# Verify email domains are excluded (Issue #153)
|
||||||
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||||
for domain in email_domains:
|
for domain in email_domains:
|
||||||
assert f'"email:{domain}"' in content, f"Email domain {domain} should be excluded"
|
assert f'"email:{domain}"' in content, f"Email domain {domain} should be excluded"
|
||||||
|
|
||||||
# Verify IPv6 constraints are present (Issue #153)
|
# Verify IPv6 constraints are present (Issue #153)
|
||||||
assert "IP:0:0:0:0:0:0:0:0/0:0:0:0:0:0:0:0" in content, "IPv6 all-zeros should be excluded"
|
assert "IP:0:0:0:0:0:0:0:0/0:0:0:0:0:0:0:0" in content, "IPv6 all-zeros should be excluded"
|
||||||
|
|
||||||
print("✓ CA certificate configuration has proper security constraints")
|
print("✓ CA certificate configuration has proper security constraints")
|
||||||
|
|
||||||
def test_ca_certificate():
|
def test_ca_certificate():
|
||||||
|
@ -174,31 +170,31 @@ def validate_server_certificates_real(cert_files):
|
||||||
if not server_certs:
|
if not server_certs:
|
||||||
print("⚠ No server certificates found")
|
print("⚠ No server certificates found")
|
||||||
return
|
return
|
||||||
|
|
||||||
for server_cert_path in server_certs:
|
for server_cert_path in server_certs:
|
||||||
with open(server_cert_path, 'rb') as f:
|
with open(server_cert_path, 'rb') as f:
|
||||||
cert_data = f.read()
|
cert_data = f.read()
|
||||||
|
|
||||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||||
|
|
||||||
# Check it's not a CA certificate
|
# Check it's not a CA certificate
|
||||||
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
|
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
|
||||||
assert basic_constraints.ca is False, "Server certificate should not be a CA"
|
assert basic_constraints.ca is False, "Server certificate should not be a CA"
|
||||||
|
|
||||||
# Check Extended Key Usage (Issue #75)
|
# Check Extended Key Usage (Issue #75)
|
||||||
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
|
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
|
||||||
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
|
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
|
||||||
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
|
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
|
||||||
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
|
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
|
||||||
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation"
|
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation"
|
||||||
|
|
||||||
# Check SAN extension exists (required for Apple devices)
|
# Check SAN extension exists (required for Apple devices)
|
||||||
try:
|
try:
|
||||||
san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value
|
san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value
|
||||||
assert len(san) > 0, "Server certificate must have SAN extension for Apple device compatibility"
|
assert len(san) > 0, "Server certificate must have SAN extension for Apple device compatibility"
|
||||||
except x509.ExtensionNotFound:
|
except x509.ExtensionNotFound:
|
||||||
assert False, "Server certificate missing SAN extension - required for modern clients"
|
assert False, "Server certificate missing SAN extension - required for modern clients"
|
||||||
|
|
||||||
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
|
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
|
||||||
|
|
||||||
def validate_server_certificates_config():
|
def validate_server_certificates_config():
|
||||||
|
@ -207,18 +203,18 @@ def validate_server_certificates_config():
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(openssl_task_file, 'r') as f:
|
with open(openssl_task_file) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Look for server certificate CSR section
|
# Look for server certificate CSR section
|
||||||
server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL)
|
server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL)
|
||||||
if not server_csr_section:
|
if not server_csr_section:
|
||||||
print("⚠ Could not find server certificate CSR section")
|
print("⚠ Could not find server certificate CSR section")
|
||||||
return
|
return
|
||||||
|
|
||||||
server_section = server_csr_section.group(0)
|
server_section = server_csr_section.group(0)
|
||||||
|
|
||||||
# Check server certificate CSR configuration
|
# Check server certificate CSR configuration
|
||||||
server_checks = [
|
server_checks = [
|
||||||
('subject_alt_name', 'Server certificates should have SAN extension'),
|
('subject_alt_name', 'Server certificates should have SAN extension'),
|
||||||
|
@ -227,16 +223,16 @@ def validate_server_certificates_config():
|
||||||
('digitalSignature', 'Server certificates should have digital signature usage'),
|
('digitalSignature', 'Server certificates should have digital signature usage'),
|
||||||
('keyEncipherment', 'Server certificates should have key encipherment usage')
|
('keyEncipherment', 'Server certificates should have key encipherment usage')
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in server_checks:
|
for check, message in server_checks:
|
||||||
assert check in server_section, f"Missing server certificate configuration: {message}"
|
assert check in server_section, f"Missing server certificate configuration: {message}"
|
||||||
|
|
||||||
# Security check: Server certificates should NOT have clientAuth (Issue #153)
|
# Security check: Server certificates should NOT have clientAuth (Issue #153)
|
||||||
assert 'clientAuth' not in server_section, "Server certificates should NOT have clientAuth EKU for role separation"
|
assert 'clientAuth' not in server_section, "Server certificates should NOT have clientAuth EKU for role separation"
|
||||||
|
|
||||||
# Verify SAN extension is configured for Apple compatibility
|
# Verify SAN extension is configured for Apple compatibility
|
||||||
assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility"
|
assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility"
|
||||||
|
|
||||||
print("✓ Server certificate configuration has proper EKU and SAN settings")
|
print("✓ Server certificate configuration has proper EKU and SAN settings")
|
||||||
|
|
||||||
def test_server_certificates():
|
def test_server_certificates():
|
||||||
|
@ -255,34 +251,34 @@ def validate_client_certificates_real(cert_files):
|
||||||
for cert_path in cert_files['server_certs']:
|
for cert_path in cert_files['server_certs']:
|
||||||
if 'cacert.pem' in cert_path:
|
if 'cacert.pem' in cert_path:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with open(cert_path, 'rb') as f:
|
with open(cert_path, 'rb') as f:
|
||||||
cert_data = f.read()
|
cert_data = f.read()
|
||||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||||
|
|
||||||
# Check if this looks like a client cert vs server cert
|
# Check if this looks like a client cert vs server cert
|
||||||
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
||||||
# Server certs typically have IP addresses or domain names as CN
|
# Server certs typically have IP addresses or domain names as CN
|
||||||
if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4):
|
if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4):
|
||||||
client_certs.append((cert_path, certificate))
|
client_certs.append((cert_path, certificate))
|
||||||
|
|
||||||
if not client_certs:
|
if not client_certs:
|
||||||
print("⚠ No client certificates found")
|
print("⚠ No client certificates found")
|
||||||
return
|
return
|
||||||
|
|
||||||
for cert_path, certificate in client_certs:
|
for cert_path, certificate in client_certs:
|
||||||
# Check it's not a CA certificate
|
# Check it's not a CA certificate
|
||||||
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
|
basic_constraints = certificate.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS).value
|
||||||
assert basic_constraints.ca is False, "Client certificate should not be a CA"
|
assert basic_constraints.ca is False, "Client certificate should not be a CA"
|
||||||
|
|
||||||
# Check Extended Key Usage restrictions (Issue #75)
|
# Check Extended Key Usage restrictions (Issue #75)
|
||||||
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
|
eku = certificate.extensions.get_extension_for_oid(ExtensionOID.EXTENDED_KEY_USAGE).value
|
||||||
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "Client cert must have clientAuth EKU"
|
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku, "Client cert must have clientAuth EKU"
|
||||||
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
|
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
|
||||||
|
|
||||||
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
|
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
|
||||||
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation"
|
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation"
|
||||||
|
|
||||||
# Check SAN extension for email
|
# Check SAN extension for email
|
||||||
try:
|
try:
|
||||||
san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value
|
san = certificate.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value
|
||||||
|
@ -290,7 +286,7 @@ def validate_client_certificates_real(cert_files):
|
||||||
assert len(email_sans) > 0, "Client certificate should have email SAN"
|
assert len(email_sans) > 0, "Client certificate should have email SAN"
|
||||||
except x509.ExtensionNotFound:
|
except x509.ExtensionNotFound:
|
||||||
print(f"⚠ Client certificate missing SAN extension: {os.path.basename(cert_path)}")
|
print(f"⚠ Client certificate missing SAN extension: {os.path.basename(cert_path)}")
|
||||||
|
|
||||||
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
|
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
|
||||||
|
|
||||||
def validate_client_certificates_config():
|
def validate_client_certificates_config():
|
||||||
|
@ -299,18 +295,18 @@ def validate_client_certificates_config():
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(openssl_task_file, 'r') as f:
|
with open(openssl_task_file) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Look for client certificate CSR section
|
# Look for client certificate CSR section
|
||||||
client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL)
|
client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL)
|
||||||
if not client_csr_section:
|
if not client_csr_section:
|
||||||
print("⚠ Could not find client certificate CSR section")
|
print("⚠ Could not find client certificate CSR section")
|
||||||
return
|
return
|
||||||
|
|
||||||
client_section = client_csr_section.group(0)
|
client_section = client_csr_section.group(0)
|
||||||
|
|
||||||
# Check client certificate configuration
|
# Check client certificate configuration
|
||||||
client_checks = [
|
client_checks = [
|
||||||
('clientAuth', 'Client certificates should have clientAuth EKU'),
|
('clientAuth', 'Client certificates should have clientAuth EKU'),
|
||||||
|
@ -319,16 +315,16 @@ def validate_client_certificates_config():
|
||||||
('keyEncipherment', 'Client certificates should have key encipherment usage'),
|
('keyEncipherment', 'Client certificates should have key encipherment usage'),
|
||||||
('email:', 'Client certificates should have email SAN')
|
('email:', 'Client certificates should have email SAN')
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in client_checks:
|
for check, message in client_checks:
|
||||||
assert check in client_section, f"Missing client certificate configuration: {message}"
|
assert check in client_section, f"Missing client certificate configuration: {message}"
|
||||||
|
|
||||||
# Security check: Client certificates should NOT have serverAuth (Issue #153)
|
# Security check: Client certificates should NOT have serverAuth (Issue #153)
|
||||||
assert 'serverAuth' not in client_section, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
|
assert 'serverAuth' not in client_section, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
|
||||||
|
|
||||||
# Verify client certificates use unique email domains (Issue #153)
|
# Verify client certificates use unique email domains (Issue #153)
|
||||||
assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment"
|
assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment"
|
||||||
|
|
||||||
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
|
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
|
||||||
|
|
||||||
def test_client_certificates():
|
def test_client_certificates():
|
||||||
|
@ -345,28 +341,28 @@ def validate_pkcs12_files_real(cert_files):
|
||||||
if not cert_files.get('p12_files'):
|
if not cert_files.get('p12_files'):
|
||||||
print("⚠ No PKCS#12 files found")
|
print("⚠ No PKCS#12 files found")
|
||||||
return
|
return
|
||||||
|
|
||||||
major, minor = test_openssl_version_detection()
|
major, minor = test_openssl_version_detection()
|
||||||
|
|
||||||
for p12_file in cert_files['p12_files']:
|
for p12_file in cert_files['p12_files']:
|
||||||
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
|
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
|
||||||
|
|
||||||
# Test that PKCS#12 file can be read (validates format)
|
# Test that PKCS#12 file can be read (validates format)
|
||||||
legacy_flag = ['-legacy'] if major >= 3 else []
|
legacy_flag = ['-legacy'] if major >= 3 else []
|
||||||
|
|
||||||
result = subprocess.run([
|
result = subprocess.run([
|
||||||
'openssl', 'pkcs12', '-info',
|
'openssl', 'pkcs12', '-info',
|
||||||
'-in', p12_file,
|
'-in', p12_file,
|
||||||
'-passin', 'pass:', # Try empty password first
|
'-passin', 'pass:', # Try empty password first
|
||||||
'-noout'
|
'-noout'
|
||||||
] + legacy_flag, capture_output=True, text=True)
|
] + legacy_flag, capture_output=True, text=True)
|
||||||
|
|
||||||
# PKCS#12 files should be readable (even if password-protected)
|
# PKCS#12 files should be readable (even if password-protected)
|
||||||
# We're just testing format validity, not trying to extract contents
|
# We're just testing format validity, not trying to extract contents
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
# Try with common password patterns if empty password fails
|
# Try with common password patterns if empty password fails
|
||||||
print(f"⚠ PKCS#12 file may require password: {os.path.basename(p12_file)}")
|
print(f"⚠ PKCS#12 file may require password: {os.path.basename(p12_file)}")
|
||||||
|
|
||||||
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
|
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
|
||||||
|
|
||||||
def validate_pkcs12_files_config():
|
def validate_pkcs12_files_config():
|
||||||
|
@ -375,10 +371,10 @@ def validate_pkcs12_files_config():
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(openssl_task_file, 'r') as f:
|
with open(openssl_task_file) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Check PKCS#12 generation configuration
|
# Check PKCS#12 generation configuration
|
||||||
p12_checks = [
|
p12_checks = [
|
||||||
('openssl_pkcs12', 'PKCS#12 generation should be configured'),
|
('openssl_pkcs12', 'PKCS#12 generation should be configured'),
|
||||||
|
@ -389,10 +385,10 @@ def validate_pkcs12_files_config():
|
||||||
('passphrase', 'PKCS#12 files should be password protected'),
|
('passphrase', 'PKCS#12 files should be password protected'),
|
||||||
('mode: "0600"', 'PKCS#12 files should have secure permissions')
|
('mode: "0600"', 'PKCS#12 files should have secure permissions')
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in p12_checks:
|
for check, message in p12_checks:
|
||||||
assert check in content, f"Missing PKCS#12 configuration: {message}"
|
assert check in content, f"Missing PKCS#12 configuration: {message}"
|
||||||
|
|
||||||
print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
|
print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
|
||||||
|
|
||||||
def test_pkcs12_files():
|
def test_pkcs12_files():
|
||||||
|
@ -410,31 +406,31 @@ def validate_certificate_chain_real(cert_files):
|
||||||
with open(cert_files['ca_cert'], 'rb') as f:
|
with open(cert_files['ca_cert'], 'rb') as f:
|
||||||
ca_cert_data = f.read()
|
ca_cert_data = f.read()
|
||||||
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
|
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
|
||||||
|
|
||||||
# Test that all other certificates are signed by the CA
|
# Test that all other certificates are signed by the CA
|
||||||
other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']]
|
other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']]
|
||||||
|
|
||||||
if not other_certs:
|
if not other_certs:
|
||||||
print("⚠ No client/server certificates found to validate")
|
print("⚠ No client/server certificates found to validate")
|
||||||
return
|
return
|
||||||
|
|
||||||
for cert_path in other_certs:
|
for cert_path in other_certs:
|
||||||
with open(cert_path, 'rb') as f:
|
with open(cert_path, 'rb') as f:
|
||||||
cert_data = f.read()
|
cert_data = f.read()
|
||||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||||
|
|
||||||
# Verify the certificate was signed by our CA
|
# Verify the certificate was signed by our CA
|
||||||
assert certificate.issuer == ca_certificate.subject, f"Certificate {cert_path} not signed by CA"
|
assert certificate.issuer == ca_certificate.subject, f"Certificate {cert_path} not signed by CA"
|
||||||
|
|
||||||
# Verify certificate is currently valid (not expired)
|
# Verify certificate is currently valid (not expired)
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
assert certificate.not_valid_before <= now, f"Certificate {cert_path} not yet valid"
|
assert certificate.not_valid_before <= now, f"Certificate {cert_path} not yet valid"
|
||||||
assert certificate.not_valid_after >= now, f"Certificate {cert_path} has expired"
|
assert certificate.not_valid_after >= now, f"Certificate {cert_path} has expired"
|
||||||
|
|
||||||
print(f"✓ Real certificate chain valid: {os.path.basename(cert_path)}")
|
print(f"✓ Real certificate chain valid: {os.path.basename(cert_path)}")
|
||||||
|
|
||||||
print(f"✓ All real certificates properly signed by CA")
|
print("✓ All real certificates properly signed by CA")
|
||||||
|
|
||||||
def validate_certificate_chain_config():
|
def validate_certificate_chain_config():
|
||||||
"""Validate certificate chain configuration in Ansible files (CI mode)"""
|
"""Validate certificate chain configuration in Ansible files (CI mode)"""
|
||||||
|
@ -442,10 +438,10 @@ def validate_certificate_chain_config():
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(openssl_task_file, 'r') as f:
|
with open(openssl_task_file) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Check certificate signing configuration
|
# Check certificate signing configuration
|
||||||
chain_checks = [
|
chain_checks = [
|
||||||
('provider: ownca', 'Certificates should be signed by own CA'),
|
('provider: ownca', 'Certificates should be signed by own CA'),
|
||||||
|
@ -457,10 +453,10 @@ def validate_certificate_chain_config():
|
||||||
('curve: secp384r1', 'Should use strong elliptic curve cryptography'),
|
('curve: secp384r1', 'Should use strong elliptic curve cryptography'),
|
||||||
('type: ECC', 'Should use elliptic curve keys for better security')
|
('type: ECC', 'Should use elliptic curve keys for better security')
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in chain_checks:
|
for check, message in chain_checks:
|
||||||
assert check in content, f"Missing certificate chain configuration: {message}"
|
assert check in content, f"Missing certificate chain configuration: {message}"
|
||||||
|
|
||||||
print("✓ Certificate chain configuration properly set up for CA signing")
|
print("✓ Certificate chain configuration properly set up for CA signing")
|
||||||
|
|
||||||
def test_certificate_chain():
|
def test_certificate_chain():
|
||||||
|
@ -481,12 +477,12 @@ def find_ansible_file(relative_path):
|
||||||
"../..", # Grandparent (from tests/unit to project root)
|
"../..", # Grandparent (from tests/unit to project root)
|
||||||
"../../..", # Alternative deep path
|
"../../..", # Alternative deep path
|
||||||
]
|
]
|
||||||
|
|
||||||
for base in possible_bases:
|
for base in possible_bases:
|
||||||
full_path = os.path.join(base, relative_path)
|
full_path = os.path.join(base, relative_path)
|
||||||
if os.path.exists(full_path):
|
if os.path.exists(full_path):
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -8,7 +8,6 @@ import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Add library directory to path to import our custom module
|
# Add library directory to path to import our custom module
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library'))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library'))
|
||||||
|
@ -29,7 +28,7 @@ def test_wireguard_tools_available():
|
||||||
def test_x25519_module_import():
|
def test_x25519_module_import():
|
||||||
"""Test that our custom x25519_pubkey module can be imported and used"""
|
"""Test that our custom x25519_pubkey module can be imported and used"""
|
||||||
try:
|
try:
|
||||||
from x25519_pubkey import run_module
|
import x25519_pubkey # noqa: F401
|
||||||
print("✓ x25519_pubkey module imports successfully")
|
print("✓ x25519_pubkey module imports successfully")
|
||||||
return True
|
return True
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -40,24 +39,24 @@ def generate_test_private_key():
|
||||||
"""Generate a test private key using the same method as Algo"""
|
"""Generate a test private key using the same method as Algo"""
|
||||||
with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file:
|
with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file:
|
||||||
raw_key_path = temp_file.name
|
raw_key_path = temp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate 32 random bytes for X25519 private key (same as community.crypto does)
|
# Generate 32 random bytes for X25519 private key (same as community.crypto does)
|
||||||
import secrets
|
import secrets
|
||||||
raw_data = secrets.token_bytes(32)
|
raw_data = secrets.token_bytes(32)
|
||||||
|
|
||||||
# Write raw key to file (like community.crypto openssl_privatekey with format: raw)
|
# Write raw key to file (like community.crypto openssl_privatekey with format: raw)
|
||||||
with open(raw_key_path, 'wb') as f:
|
with open(raw_key_path, 'wb') as f:
|
||||||
f.write(raw_data)
|
f.write(raw_data)
|
||||||
|
|
||||||
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
|
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
|
||||||
|
|
||||||
b64_key = base64.b64encode(raw_data).decode()
|
b64_key = base64.b64encode(raw_data).decode()
|
||||||
|
|
||||||
print(f"✓ Generated private key (base64): {b64_key[:12]}...")
|
print(f"✓ Generated private key (base64): {b64_key[:12]}...")
|
||||||
|
|
||||||
return raw_key_path, b64_key
|
return raw_key_path, b64_key
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Clean up on error
|
# Clean up on error
|
||||||
if os.path.exists(raw_key_path):
|
if os.path.exists(raw_key_path):
|
||||||
|
@ -68,33 +67,32 @@ def generate_test_private_key():
|
||||||
def test_x25519_pubkey_from_raw_file():
|
def test_x25519_pubkey_from_raw_file():
|
||||||
"""Test our x25519_pubkey module with raw private key file"""
|
"""Test our x25519_pubkey module with raw private key file"""
|
||||||
raw_key_path, b64_key = generate_test_private_key()
|
raw_key_path, b64_key = generate_test_private_key()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Import here so we can mock the module_utils if needed
|
# Import here so we can mock the module_utils if needed
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
# Mock the AnsibleModule for testing
|
# Mock the AnsibleModule for testing
|
||||||
class MockModule:
|
class MockModule:
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
self.params = params
|
self.params = params
|
||||||
self.result = {}
|
self.result = {}
|
||||||
|
|
||||||
def fail_json(self, **kwargs):
|
def fail_json(self, **kwargs):
|
||||||
raise Exception(f"Module failed: {kwargs}")
|
raise Exception(f"Module failed: {kwargs}")
|
||||||
|
|
||||||
def exit_json(self, **kwargs):
|
def exit_json(self, **kwargs):
|
||||||
self.result = kwargs
|
self.result = kwargs
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub:
|
with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub:
|
||||||
public_key_path = temp_pub.name
|
public_key_path = temp_pub.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test the module logic directly
|
# Test the module logic directly
|
||||||
from x25519_pubkey import run_module
|
|
||||||
import x25519_pubkey
|
import x25519_pubkey
|
||||||
|
from x25519_pubkey import run_module
|
||||||
|
|
||||||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Mock the module call
|
# Mock the module call
|
||||||
mock_module = MockModule({
|
mock_module = MockModule({
|
||||||
|
@ -102,38 +100,38 @@ def test_x25519_pubkey_from_raw_file():
|
||||||
'public_key_path': public_key_path,
|
'public_key_path': public_key_path,
|
||||||
'private_key_b64': None
|
'private_key_b64': None
|
||||||
})
|
})
|
||||||
|
|
||||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||||
|
|
||||||
# Run the module
|
# Run the module
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
# Check the result
|
# Check the result
|
||||||
assert 'public_key' in mock_module.result
|
assert 'public_key' in mock_module.result
|
||||||
assert mock_module.result['changed'] == True
|
assert mock_module.result['changed']
|
||||||
assert os.path.exists(public_key_path)
|
assert os.path.exists(public_key_path)
|
||||||
|
|
||||||
with open(public_key_path, 'r') as f:
|
with open(public_key_path) as f:
|
||||||
derived_pubkey = f.read().strip()
|
derived_pubkey = f.read().strip()
|
||||||
|
|
||||||
# Validate base64 format
|
# Validate base64 format
|
||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(derived_pubkey, validate=True)
|
decoded = base64.b64decode(derived_pubkey, validate=True)
|
||||||
assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}"
|
assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert False, f"Invalid base64 public key: {e}"
|
assert False, f"Invalid base64 public key: {e}"
|
||||||
|
|
||||||
print(f"✓ Derived public key from raw file: {derived_pubkey[:12]}...")
|
print(f"✓ Derived public key from raw file: {derived_pubkey[:12]}...")
|
||||||
|
|
||||||
return derived_pubkey
|
return derived_pubkey
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(public_key_path):
|
if os.path.exists(public_key_path):
|
||||||
os.unlink(public_key_path)
|
os.unlink(public_key_path)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(raw_key_path):
|
if os.path.exists(raw_key_path):
|
||||||
os.unlink(raw_key_path)
|
os.unlink(raw_key_path)
|
||||||
|
@ -142,56 +140,55 @@ def test_x25519_pubkey_from_raw_file():
|
||||||
def test_x25519_pubkey_from_b64_string():
|
def test_x25519_pubkey_from_b64_string():
|
||||||
"""Test our x25519_pubkey module with base64 private key string"""
|
"""Test our x25519_pubkey module with base64 private key string"""
|
||||||
raw_key_path, b64_key = generate_test_private_key()
|
raw_key_path, b64_key = generate_test_private_key()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
class MockModule:
|
class MockModule:
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
self.params = params
|
self.params = params
|
||||||
self.result = {}
|
self.result = {}
|
||||||
|
|
||||||
def fail_json(self, **kwargs):
|
def fail_json(self, **kwargs):
|
||||||
raise Exception(f"Module failed: {kwargs}")
|
raise Exception(f"Module failed: {kwargs}")
|
||||||
|
|
||||||
def exit_json(self, **kwargs):
|
def exit_json(self, **kwargs):
|
||||||
self.result = kwargs
|
self.result = kwargs
|
||||||
|
|
||||||
from x25519_pubkey import run_module
|
|
||||||
import x25519_pubkey
|
import x25519_pubkey
|
||||||
|
from x25519_pubkey import run_module
|
||||||
|
|
||||||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mock_module = MockModule({
|
mock_module = MockModule({
|
||||||
'private_key_b64': b64_key,
|
'private_key_b64': b64_key,
|
||||||
'private_key_path': None,
|
'private_key_path': None,
|
||||||
'public_key_path': None
|
'public_key_path': None
|
||||||
})
|
})
|
||||||
|
|
||||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||||
|
|
||||||
# Run the module
|
# Run the module
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
# Check the result
|
# Check the result
|
||||||
assert 'public_key' in mock_module.result
|
assert 'public_key' in mock_module.result
|
||||||
derived_pubkey = mock_module.result['public_key']
|
derived_pubkey = mock_module.result['public_key']
|
||||||
|
|
||||||
# Validate base64 format
|
# Validate base64 format
|
||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(derived_pubkey, validate=True)
|
decoded = base64.b64decode(derived_pubkey, validate=True)
|
||||||
assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}"
|
assert len(decoded) == 32, f"Public key should be 32 bytes, got {len(decoded)}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert False, f"Invalid base64 public key: {e}"
|
assert False, f"Invalid base64 public key: {e}"
|
||||||
|
|
||||||
print(f"✓ Derived public key from base64 string: {derived_pubkey[:12]}...")
|
print(f"✓ Derived public key from base64 string: {derived_pubkey[:12]}...")
|
||||||
|
|
||||||
return derived_pubkey
|
return derived_pubkey
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(raw_key_path):
|
if os.path.exists(raw_key_path):
|
||||||
os.unlink(raw_key_path)
|
os.unlink(raw_key_path)
|
||||||
|
@ -201,45 +198,44 @@ def test_wireguard_validation():
|
||||||
"""Test that our derived keys work with actual WireGuard tools"""
|
"""Test that our derived keys work with actual WireGuard tools"""
|
||||||
if not test_wireguard_tools_available():
|
if not test_wireguard_tools_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generate keys using our method
|
# Generate keys using our method
|
||||||
raw_key_path, b64_key = generate_test_private_key()
|
raw_key_path, b64_key = generate_test_private_key()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Derive public key using our module
|
# Derive public key using our module
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
class MockModule:
|
class MockModule:
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
self.params = params
|
self.params = params
|
||||||
self.result = {}
|
self.result = {}
|
||||||
|
|
||||||
def fail_json(self, **kwargs):
|
def fail_json(self, **kwargs):
|
||||||
raise Exception(f"Module failed: {kwargs}")
|
raise Exception(f"Module failed: {kwargs}")
|
||||||
|
|
||||||
def exit_json(self, **kwargs):
|
def exit_json(self, **kwargs):
|
||||||
self.result = kwargs
|
self.result = kwargs
|
||||||
|
|
||||||
from x25519_pubkey import run_module
|
|
||||||
import x25519_pubkey
|
import x25519_pubkey
|
||||||
|
from x25519_pubkey import run_module
|
||||||
|
|
||||||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mock_module = MockModule({
|
mock_module = MockModule({
|
||||||
'private_key_b64': b64_key,
|
'private_key_b64': b64_key,
|
||||||
'private_key_path': None,
|
'private_key_path': None,
|
||||||
'public_key_path': None
|
'public_key_path': None
|
||||||
})
|
})
|
||||||
|
|
||||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
derived_pubkey = mock_module.result['public_key']
|
derived_pubkey = mock_module.result['public_key']
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config:
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config:
|
||||||
# Create a WireGuard config using our keys
|
# Create a WireGuard config using our keys
|
||||||
wg_config = f"""[Interface]
|
wg_config = f"""[Interface]
|
||||||
|
@ -252,33 +248,33 @@ AllowedIPs = 10.19.49.2/32
|
||||||
"""
|
"""
|
||||||
temp_config.write(wg_config)
|
temp_config.write(wg_config)
|
||||||
config_path = temp_config.name
|
config_path = temp_config.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test that WireGuard can parse our config
|
# Test that WireGuard can parse our config
|
||||||
result = subprocess.run([
|
result = subprocess.run([
|
||||||
'wg-quick', 'strip', config_path
|
'wg-quick', 'strip', config_path
|
||||||
], capture_output=True, text=True)
|
], capture_output=True, text=True)
|
||||||
|
|
||||||
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
|
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
|
||||||
|
|
||||||
# Test key derivation with wg pubkey command
|
# Test key derivation with wg pubkey command
|
||||||
wg_result = subprocess.run([
|
wg_result = subprocess.run([
|
||||||
'wg', 'pubkey'
|
'wg', 'pubkey'
|
||||||
], input=b64_key, capture_output=True, text=True)
|
], input=b64_key, capture_output=True, text=True)
|
||||||
|
|
||||||
if wg_result.returncode == 0:
|
if wg_result.returncode == 0:
|
||||||
wg_derived = wg_result.stdout.strip()
|
wg_derived = wg_result.stdout.strip()
|
||||||
assert wg_derived == derived_pubkey, f"Key mismatch: wg={wg_derived} vs ours={derived_pubkey}"
|
assert wg_derived == derived_pubkey, f"Key mismatch: wg={wg_derived} vs ours={derived_pubkey}"
|
||||||
print(f"✓ WireGuard validation: keys match wg pubkey output")
|
print("✓ WireGuard validation: keys match wg pubkey output")
|
||||||
else:
|
else:
|
||||||
print(f"⚠ Could not validate with wg pubkey: {wg_result.stderr}")
|
print(f"⚠ Could not validate with wg pubkey: {wg_result.stderr}")
|
||||||
|
|
||||||
print("✓ WireGuard accepts our generated configuration")
|
print("✓ WireGuard accepts our generated configuration")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(config_path):
|
if os.path.exists(config_path):
|
||||||
os.unlink(config_path)
|
os.unlink(config_path)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(raw_key_path):
|
if os.path.exists(raw_key_path):
|
||||||
os.unlink(raw_key_path)
|
os.unlink(raw_key_path)
|
||||||
|
@ -288,49 +284,48 @@ def test_key_consistency():
|
||||||
"""Test that the same private key always produces the same public key"""
|
"""Test that the same private key always produces the same public key"""
|
||||||
# Generate one private key to reuse
|
# Generate one private key to reuse
|
||||||
raw_key_path, b64_key = generate_test_private_key()
|
raw_key_path, b64_key = generate_test_private_key()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def derive_pubkey_from_same_key():
|
def derive_pubkey_from_same_key():
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
class MockModule:
|
class MockModule:
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
self.params = params
|
self.params = params
|
||||||
self.result = {}
|
self.result = {}
|
||||||
|
|
||||||
def fail_json(self, **kwargs):
|
def fail_json(self, **kwargs):
|
||||||
raise Exception(f"Module failed: {kwargs}")
|
raise Exception(f"Module failed: {kwargs}")
|
||||||
|
|
||||||
def exit_json(self, **kwargs):
|
def exit_json(self, **kwargs):
|
||||||
self.result = kwargs
|
self.result = kwargs
|
||||||
|
|
||||||
from x25519_pubkey import run_module
|
|
||||||
import x25519_pubkey
|
import x25519_pubkey
|
||||||
|
from x25519_pubkey import run_module
|
||||||
|
|
||||||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mock_module = MockModule({
|
mock_module = MockModule({
|
||||||
'private_key_b64': b64_key, # SAME key each time
|
'private_key_b64': b64_key, # SAME key each time
|
||||||
'private_key_path': None,
|
'private_key_path': None,
|
||||||
'public_key_path': None
|
'public_key_path': None
|
||||||
})
|
})
|
||||||
|
|
||||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
return mock_module.result['public_key']
|
return mock_module.result['public_key']
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
||||||
|
|
||||||
# Derive public key multiple times from same private key
|
# Derive public key multiple times from same private key
|
||||||
pubkey1 = derive_pubkey_from_same_key()
|
pubkey1 = derive_pubkey_from_same_key()
|
||||||
pubkey2 = derive_pubkey_from_same_key()
|
pubkey2 = derive_pubkey_from_same_key()
|
||||||
|
|
||||||
assert pubkey1 == pubkey2, f"Key derivation not consistent: {pubkey1} vs {pubkey2}"
|
assert pubkey1 == pubkey2, f"Key derivation not consistent: {pubkey1} vs {pubkey2}"
|
||||||
print("✓ Key derivation is consistent")
|
print("✓ Key derivation is consistent")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(raw_key_path):
|
if os.path.exists(raw_key_path):
|
||||||
os.unlink(raw_key_path)
|
os.unlink(raw_key_path)
|
||||||
|
@ -344,7 +339,7 @@ if __name__ == "__main__":
|
||||||
test_key_consistency,
|
test_key_consistency,
|
||||||
test_wireguard_validation,
|
test_wireguard_validation,
|
||||||
]
|
]
|
||||||
|
|
||||||
failed = 0
|
failed = 0
|
||||||
for test in tests:
|
for test in tests:
|
||||||
try:
|
try:
|
||||||
|
@ -355,9 +350,9 @@ if __name__ == "__main__":
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ {test.__name__} error: {e}")
|
print(f"✗ {test.__name__} error: {e}")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
print(f"\n{failed} tests failed")
|
print(f"\n{failed} tests failed")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
print(f"\nAll {len(tests)} tests passed!")
|
print(f"\nAll {len(tests)} tests passed!")
|
||||||
|
|
Loading…
Add table
Reference in a new issue