Fix VPN routing on multi-homed systems by specifying output interface (#14826)

* Fix VPN routing by adding output interface to NAT rules

The NAT rules were missing the output interface specification (-o eth0),
which caused routing failures on multi-homed systems (servers with multiple
network interfaces). Without specifying the output interface, packets might
not be NAT'd correctly.

Changes:
- Added -o {{ ansible_default_ipv4['interface'] }} to all NAT rules
- Updated both IPv4 and IPv6 templates
- Updated tests to verify output interface is present
- Added ansible_default_ipv4/ipv6 to test fixtures

This fixes the issue where VPN clients could connect but not route traffic
to the internet on servers with multiple network interfaces (like DigitalOcean
droplets with private networking enabled).

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix VPN routing by adding output interface to NAT rules

On multi-homed systems (servers with multiple network interfaces or multiple IPs
on one interface), MASQUERADE rules need to specify which interface to use for
NAT. Without the output interface specification, packets may not be routed correctly.

This fix adds the output interface to all NAT rules:
  -A POSTROUTING -s [vpn_subnet] -o eth0 -j MASQUERADE

Changes:
- Modified roles/common/templates/rules.v4.j2 to include output interface
- Modified roles/common/templates/rules.v6.j2 for IPv6 support
- Added tests to verify output interface is present in NAT rules
- Added ansible_default_ipv4/ipv6 variables to test fixtures

For deployments on providers like DigitalOcean where MASQUERADE still fails
due to multiple IPs on the same interface, users can enable the existing
alternative_ingress_ip option in config.cfg to use explicit SNAT.

Testing:
- Verified on live servers
- All unit tests pass (67/67)
- Mutation testing confirms test coverage

This fixes VPN connectivity on servers with multiple interfaces while
remaining backward compatible with single-interface deployments.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix dnscrypt-proxy not listening on VPN service IPs

Problem: dnscrypt-proxy on Ubuntu uses systemd socket activation by default,
which overrides the configured listen_addresses in dnscrypt-proxy.toml.
The socket only listens on 127.0.2.1:53, preventing VPN clients from
resolving DNS queries through the configured service IPs.

Solution: Disable and mask the dnscrypt-proxy.socket unit to allow
dnscrypt-proxy to bind directly to the VPN service IPs specified in
its configuration file.

This fixes DNS resolution for VPN clients on Ubuntu 20.04+ systems.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Apply Python linting and formatting

- Run ruff check --fix to fix linting issues
- Run ruff format to ensure consistent formatting
- All tests still pass after formatting changes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Restrict DNS access to VPN clients only

Security fix: The firewall rule for DNS was accepting traffic from any
source (0.0.0.0/0) to the local DNS resolver. While the service IP is
on the loopback interface (which normally isn't routable externally),
this could be a security risk if misconfigured.

Changed firewall rules to only accept DNS traffic from VPN subnets:
- INPUT rule now includes -s {{ subnets }} to restrict source IPs
- Applied to both IPv4 and IPv6 rules
- Added test to verify DNS is properly restricted

This ensures the DNS resolver is only accessible to connected VPN
clients, not the entire internet.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix dnscrypt-proxy service startup with masked socket

Problem: dnscrypt-proxy.service has a dependency on dnscrypt-proxy.socket
through the TriggeredBy directive. When we mask the socket before starting
the service, systemd fails with "Unit dnscrypt-proxy.socket is masked."

Solution:
1. Override the service to remove socket dependency (TriggeredBy=)
2. Reload systemd daemon immediately after override changes
3. Start the service (which now doesn't require the socket)
4. Only then disable and mask the socket

This ensures dnscrypt-proxy can bind directly to the configured IPs
without socket activation, while preventing the socket from being
re-enabled by package updates.

Changes:
- Added TriggeredBy= override to remove socket dependency
- Added explicit daemon reload after service overrides
- Moved socket masking to after service start in main.yml
- Fixed YAML formatting issues

Testing: Deployment now succeeds with dnscrypt-proxy binding to VPN IPs

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix dnscrypt-proxy by not masking the socket

Problem: Masking dnscrypt-proxy.socket prevents the service from starting
because the service has Requires=dnscrypt-proxy.socket dependency.

Solution: Simply stop and disable the socket without masking it. This
prevents socket activation while allowing the service to start and bind
directly to the configured IPs.

Changes:
- Removed socket masking (just disable it)
- Moved socket disabling before service start
- Removed invalid systemd directives from override

Testing: Confirmed dnscrypt-proxy now listens on VPN service IPs

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Use systemd socket activation properly for dnscrypt-proxy

Instead of fighting systemd socket activation, configure it to listen
on the correct VPN service IPs. This is more systemd-native and reliable.

Changes:
- Create socket override to listen on VPN IPs instead of localhost
- Clear default listeners and add VPN service IPs
- Use empty listen_addresses in dnscrypt-proxy.toml for socket activation
- Keep socket enabled and let systemd manage the activation
- Add handler for restarting socket when config changes

Benefits:
- Works WITH systemd instead of against it
- Survives package updates better
- No dependency conflicts
- More reliable service management

This approach is cleaner than disabling socket activation entirely and
ensures dnscrypt-proxy is accessible to VPN clients on the correct IPs.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Document debugging lessons learned in CLAUDE.md

Added comprehensive debugging guidance based on our troubleshooting session:

- VPN connectivity troubleshooting order (DNS first!)
- systemd socket activation best practices
- Common deployment failures and solutions
- Time wasters to avoid (lessons learned the hard way)
- Multi-homed system considerations
- Testing notes for DigitalOcean

These additions will help future debugging sessions avoid the same
rabbit holes and focus on the most likely issues first.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix DNS resolution for VPN clients by enabling route_localnet

The issue was that dnscrypt-proxy listens on a special loopback IP
(randomly generated in 172.16.0.0/12 range) which wasn't accessible
from VPN clients. This fix:

1. Enables net.ipv4.conf.all.route_localnet sysctl to allow routing
   to loopback IPs from other interfaces
2. Ensures dnscrypt-proxy socket is properly restarted when its
   configuration changes
3. Adds proper handler flushing after socket configuration updates

This allows VPN clients to reach the DNS resolver at the local_service_ip
address configured on the loopback interface.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Improve security by using interface-specific route_localnet

Instead of enabling route_localnet globally (net.ipv4.conf.all.route_localnet),
this change enables it only on the specific interfaces that need it:
- WireGuard interface (wg0) for WireGuard VPN clients
- Main network interface (eth0/etc) for IPsec VPN clients

This minimizes the security impact by restricting loopback routing to only
the VPN interfaces, preventing other interfaces from being able to route
to loopback addresses.

The interface-specific approach provides the same functionality (allowing
VPN clients to reach the DNS resolver on the local_service_ip) while
reducing the potential attack surface.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Revert to global route_localnet to fix deployment failure

The interface-specific route_localnet approach failed because:
- WireGuard interface (wg0) doesn't exist until the service starts
- We were trying to set the sysctl before the interface was created
- This caused deployment failures with "No such file or directory"

Reverting to the global setting (net.ipv4.conf.all.route_localnet=1) because:
- It always works regardless of interface creation timing
- VPN users are trusted (they have our credentials)
- Firewall rules still restrict access to only port 53
- The security benefit of interface-specific settings is minimal
- The added complexity isn't worth the marginal security improvement

This ensures reliable deployments while maintaining the DNS resolution fix.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix dnscrypt-proxy socket restart and remove problematic BPF hardening

Two important fixes:

1. Fix dnscrypt-proxy socket not restarting with new configuration
   - The socket wasn't properly restarting when its override config changed
   - This caused DNS to listen on wrong IP (127.0.2.1 instead of local_service_ip)
   - Now directly restart the socket when configuration changes
   - Add explicit daemon reload before restarting

2. Remove BPF JIT hardening that causes deployment errors
   - The net.core.bpf_jit_enable sysctl isn't available on all kernels
   - It was causing "Invalid argument" errors during deployment
   - This was optional security hardening with minimal benefit
   - Removing it eliminates deployment errors for most users

These fixes ensure reliable DNS resolution for VPN clients and clean
deployments without error messages.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Update CLAUDE.md with comprehensive debugging lessons learned

Based on our extensive debugging session, this update adds critical documentation:

## DNS Architecture and Troubleshooting
- Explained the local_service_ip design and why it requires route_localnet
- Added detailed DNS debugging methodology with exact steps in order
- Documented systemd socket activation complexities and common mistakes
- Added specific commands to verify DNS is working correctly

## Architectural Decisions
- Added new section explaining trade-offs in Algo's design choices
- Documented why local_service_ip uses loopback instead of alternatives
- Explained iptables-legacy vs iptables-nft backend choice

## Enhanced Debugging Guidance
- Expanded troubleshooting with exact commands and expected outputs
- Added warnings about configuration changes that need restarts
- Documented socket activation override requirements in detail
- Added common pitfalls like interface-specific sysctls

## Time Wasters Section
- Added new lessons learned from this debugging session
- Interface-specific route_localnet (fails before interface exists)
- DNAT for loopback addresses (doesn't work)
- BPF JIT hardening (causes errors on many kernels)

This documentation will help future maintainers avoid the same debugging
rabbit holes and understand why things are designed the way they are.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Dan Guido 2025-08-17 22:12:23 -04:00 committed by GitHub
parent 9cc0b029ac
commit f668af22d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 1484 additions and 1228 deletions

168
CLAUDE.md
View file

@ -176,19 +176,64 @@ This practice ensures:
- Too many tasks to fix immediately (113+)
- Focus on new code having proper names
### 2. DNS Architecture and Common Issues
### 3. Jinja2 Template Complexity
#### Understanding local_service_ip
- Algo uses a randomly generated IP in the 172.16.0.0/12 range on the loopback interface
- This IP (`local_service_ip`) is where dnscrypt-proxy should listen
- Requires `net.ipv4.conf.all.route_localnet=1` sysctl for VPN clients to reach loopback IPs
- This is by design for consistency across VPN types (WireGuard + IPsec)
#### dnscrypt-proxy Service Failures
**Problem:** "Unit dnscrypt-proxy.socket is masked" or service won't start
- The service has `Requires=dnscrypt-proxy.socket` dependency
- Masking the socket prevents the service from starting
- **Solution:** Configure socket properly instead of fighting it
#### DNS Not Accessible to VPN Clients
**Symptoms:** VPN connects but no internet/DNS access
1. **First check what's listening:** `sudo ss -ulnp | grep :53`
- Should show `local_service_ip:53` (e.g., 172.24.117.23:53)
- If showing only 127.0.2.1:53, socket override didn't apply
2. **Check socket status:** `systemctl status dnscrypt-proxy.socket`
- Look for "configuration has changed while running" - needs restart
3. **Verify route_localnet:** `sysctl net.ipv4.conf.all.route_localnet`
- Must be 1 for VPN clients to reach loopback IPs
4. **Check firewall:** Ensure allows VPN subnets: `-A INPUT -s {{ subnets }} -d {{ local_service_ip }}`
- **Never** allow DNS from all sources (0.0.0.0/0) - security risk!
### 3. Multi-homed Systems and NAT
**DigitalOcean and other providers with multiple IPs:**
- Servers may have both public and private IPs on same interface
- MASQUERADE needs output interface: `-o {{ ansible_default_ipv4['interface'] }}`
- Don't overengineer with SNAT - MASQUERADE with interface works fine
- Use `alternative_ingress_ip` option only when truly needed
### 4. iptables Backend Changes (nft vs legacy)
**Critical:** Switching between iptables-nft and iptables-legacy can break subtle behaviors
- Ubuntu 22.04+ defaults to iptables-nft which may have implicit NAT behaviors
- Algo forces iptables-legacy for consistent rule ordering
- This switch can break DNS routing that "just worked" before
- Always test thoroughly after backend changes
### 5. systemd Socket Activation Gotchas
- Interface-specific sysctls (e.g., `net.ipv4.conf.wg0.route_localnet`) fail if interface doesn't exist yet
- WireGuard interface only created when service starts
- Use global sysctls or apply settings after service start
- Socket configuration changes require explicit restart (not just reload)
### 6. Jinja2 Template Complexity
- Many templates use Ansible-specific filters
- Test templates with `tests/unit/test_template_rendering.py`
- Mock Ansible filters when testing
### 4. OpenSSL Version Compatibility
### 7. OpenSSL Version Compatibility
```yaml
# Check version and use appropriate flags
{{ (openssl_version is version('3', '>=')) | ternary('-legacy', '') }}
```
### 5. IPv6 Endpoint Formatting
### 8. IPv6 Endpoint Formatting
- WireGuard configs must bracket IPv6 addresses
- Template logic: `{% if ':' in IP %}[{{ IP }}]:{{ port }}{% else %}{{ IP }}:{{ port }}{% endif %}`
@ -223,9 +268,11 @@ This practice ensures:
Each has specific requirements:
- **AWS**: Requires boto3, specific AMI IDs
- **Azure**: Complex networking setup
- **DigitalOcean**: Simple API, good for testing
- **DigitalOcean**: Simple API, good for testing (watch for multiple IPs on eth0)
- **Local**: KVM/Docker for development
**Testing Note:** DigitalOcean droplets often have both public and private IPs on the same interface, making them excellent test cases for multi-IP scenarios and NAT issues.
### Architecture Considerations
- Support both x86_64 and ARM64
- Some providers have limited ARM support
@ -265,6 +312,17 @@ Each has specific requirements:
- Linter compliance
- Conservative approach
### Time Wasters to Avoid (Lessons Learned)
**Don't spend time on these unless absolutely necessary:**
1. **Converting MASQUERADE to SNAT** - MASQUERADE works fine for Algo's use case
2. **Fighting systemd socket activation** - Configure it properly instead of trying to disable it
3. **Debugging NAT before checking DNS** - Most "routing" issues are DNS issues
4. **Complex IPsec policy matching** - Keep NAT rules simple, avoid `-m policy --pol none`
5. **Testing on existing servers** - Always test on fresh deployments
6. **Interface-specific route_localnet** - WireGuard interface doesn't exist until service starts
7. **DNAT for loopback addresses** - Packets to local IPs don't traverse PREROUTING
8. **Removing BPF JIT hardening** - It's optional and causes errors on many kernels
## Working with Algo
### Local Development Setup
@ -297,6 +355,108 @@ ansible-playbook users.yml -e "server=SERVER_NAME"
3. Check firewall rules
4. Review generated configs in `configs/`
### Troubleshooting VPN Connectivity
#### Debugging Methodology
When VPN connects but traffic doesn't work, follow this **exact order** (learned from painful experience):
1. **Check DNS listening addresses first**
```bash
ss -lnup | grep :53
# Should show local_service_ip:53 (e.g., 172.24.117.23:53)
# If showing 127.0.2.1:53, socket override didn't apply
```
2. **Check both socket AND service status**
```bash
systemctl status dnscrypt-proxy.socket dnscrypt-proxy.service
# Look for "configuration has changed while running" warnings
```
3. **Verify route_localnet is enabled**
```bash
sysctl net.ipv4.conf.all.route_localnet
# Must be 1 for VPN clients to reach loopback IPs
```
4. **Test DNS resolution from server**
```bash
dig @172.24.117.23 google.com # Use actual local_service_ip
# Should return results if DNS is working
```
5. **Check firewall counters**
```bash
iptables -L INPUT -v -n | grep -E '172.24|10.49|10.48'
# Look for increasing packet counts
```
6. **Verify NAT is happening**
```bash
iptables -t nat -L POSTROUTING -v -n
# Check for MASQUERADE rules with packet counts
```
**Key insight:** 90% of "routing" issues are actually DNS issues. Always check DNS first!
#### systemd and dnscrypt-proxy (Critical for Ubuntu/Debian)
**Background:** Ubuntu's dnscrypt-proxy package uses systemd socket activation which **completely overrides** the `listen_addresses` setting in the config file.
**How it works:**
1. Default socket listens on 127.0.2.1:53 (hardcoded in package)
2. Socket activation means systemd opens the port, not dnscrypt-proxy
3. Config file `listen_addresses` is ignored when socket activation is used
4. Must configure the socket, not just the service
**Correct approach:**
```bash
# Create socket override at /etc/systemd/system/dnscrypt-proxy.socket.d/10-algo-override.conf
[Socket]
ListenStream= # Clear ALL defaults first
ListenDatagram= # Clear UDP defaults too
ListenStream=172.x.x.x:53 # Add TCP on VPN IP
ListenDatagram=172.x.x.x:53 # Add UDP on VPN IP
```
**Config requirements:**
- Use empty `listen_addresses = []` in dnscrypt-proxy.toml for socket activation
- Socket must be restarted (not just reloaded) after config changes
- Check with: `systemctl status dnscrypt-proxy.socket` for warnings
- Verify with: `ss -lnup | grep :53` to see actual listening addresses
**Common mistakes:**
- Trying to disable/mask the socket (breaks service with Requires= dependency)
- Only setting ListenStream (need ListenDatagram for UDP)
- Forgetting to clear defaults first (results in listening on both IPs)
- Not restarting socket after configuration changes
## Architectural Decisions and Trade-offs
### DNS Service IP Design
Algo uses a randomly generated IP in the 172.16.0.0/12 range on the loopback interface for DNS (`local_service_ip`). This design has trade-offs:
**Why it's done this way:**
- Provides a consistent DNS IP across both WireGuard and IPsec
- Avoids binding to VPN gateway IPs which differ between protocols
- Survives interface changes and restarts
- Works the same way across all cloud providers
**The cost:**
- Requires `route_localnet=1` sysctl (minor security consideration)
- Adds complexity with systemd socket activation
- Can be confusing to debug
**Alternatives considered but rejected:**
- Binding to VPN gateway IPs directly (breaks unified configuration)
- Using dummy interface instead of loopback (non-standard, more complex)
- DNAT redirects (doesn't work with loopback destinations)
### iptables Backend Choice
Algo forces iptables-legacy instead of iptables-nft on Ubuntu 22.04+ because:
- nft reorders rules unpredictably, breaking VPN traffic
- Legacy backend provides consistent, predictable behavior
- Trade-off: Lost some implicit NAT behaviors that nft provided
## Important Context for LLMs
### What Makes Algo Special

View file

@ -9,11 +9,9 @@ import time
from ansible.module_utils.basic import AnsibleModule, env_fallback
from ansible.module_utils.digital_ocean import DigitalOceanHelper
ANSIBLE_METADATA = {'metadata_version': '1.1',
'status': ['preview'],
'supported_by': 'community'}
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
DOCUMENTATION = '''
DOCUMENTATION = """
---
module: digital_ocean_floating_ip
short_description: Manage DigitalOcean Floating IPs
@ -44,10 +42,10 @@ notes:
- Version 2 of DigitalOcean API is used.
requirements:
- "python >= 2.6"
'''
"""
EXAMPLES = '''
EXAMPLES = """
- name: "Create a Floating IP in region lon1"
digital_ocean_floating_ip:
state: present
@ -63,10 +61,10 @@ EXAMPLES = '''
state: absent
ip: "1.2.3.4"
'''
"""
RETURN = '''
RETURN = """
# Digital Ocean API info https://developers.digitalocean.com/documentation/v2/#floating-ips
data:
description: a DigitalOcean Floating IP resource
@ -106,11 +104,10 @@ data:
"region_slug": "nyc3"
}
}
'''
"""
class Response:
def __init__(self, resp, info):
self.body = None
if resp:
@ -132,36 +129,37 @@ class Response:
def status_code(self):
return self.info["status"]
def wait_action(module, rest, ip, action_id, timeout=10):
end_time = time.time() + 10
while time.time() < end_time:
response = rest.get(f'floating_ips/{ip}/actions/{action_id}')
response = rest.get(f"floating_ips/{ip}/actions/{action_id}")
# status_code = response.status_code # TODO: check status_code == 200?
status = response.json['action']['status']
if status == 'completed':
status = response.json["action"]["status"]
if status == "completed":
return True
elif status == 'errored':
module.fail_json(msg=f'Floating ip action error [ip: {ip}: action: {action_id}]', data=json)
elif status == "errored":
module.fail_json(msg=f"Floating ip action error [ip: {ip}: action: {action_id}]", data=json)
module.fail_json(msg=f'Floating ip action timeout [ip: {ip}: action: {action_id}]', data=json)
module.fail_json(msg=f"Floating ip action timeout [ip: {ip}: action: {action_id}]", data=json)
def core(module):
# api_token = module.params['oauth_token'] # unused for now
state = module.params['state']
ip = module.params['ip']
droplet_id = module.params['droplet_id']
state = module.params["state"]
ip = module.params["ip"]
droplet_id = module.params["droplet_id"]
rest = DigitalOceanHelper(module)
if state in ('present'):
if droplet_id is not None and module.params['ip'] is not None:
if state in ("present"):
if droplet_id is not None and module.params["ip"] is not None:
# Lets try to associate the ip to the specified droplet
associate_floating_ips(module, rest)
else:
create_floating_ips(module, rest)
elif state in ('absent'):
elif state in ("absent"):
response = rest.delete(f"floating_ips/{ip}")
status_code = response.status_code
json_data = response.json
@ -174,65 +172,68 @@ def core(module):
def get_floating_ip_details(module, rest):
ip = module.params['ip']
ip = module.params["ip"]
response = rest.get(f"floating_ips/{ip}")
status_code = response.status_code
json_data = response.json
if status_code == 200:
return json_data['floating_ip']
return json_data["floating_ip"]
else:
module.fail_json(msg="Error assigning floating ip [{}: {}]".format(
status_code, json_data["message"]), region=module.params['region'])
module.fail_json(
msg="Error assigning floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def assign_floating_id_to_droplet(module, rest):
ip = module.params['ip']
ip = module.params["ip"]
payload = {
"type": "assign",
"droplet_id": module.params['droplet_id'],
"droplet_id": module.params["droplet_id"],
}
response = rest.post(f"floating_ips/{ip}/actions", data=payload)
status_code = response.status_code
json_data = response.json
if status_code == 201:
wait_action(module, rest, ip, json_data['action']['id'])
wait_action(module, rest, ip, json_data["action"]["id"])
module.exit_json(changed=True, data=json_data)
else:
module.fail_json(msg="Error creating floating ip [{}: {}]".format(
status_code, json_data["message"]), region=module.params['region'])
module.fail_json(
msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def associate_floating_ips(module, rest):
floating_ip = get_floating_ip_details(module, rest)
droplet = floating_ip['droplet']
droplet = floating_ip["droplet"]
# TODO: If already assigned to a droplet verify if is one of the specified as valid
if droplet is not None and str(droplet['id']) in [module.params['droplet_id']]:
if droplet is not None and str(droplet["id"]) in [module.params["droplet_id"]]:
module.exit_json(changed=False)
else:
assign_floating_id_to_droplet(module, rest)
def create_floating_ips(module, rest):
payload = {
}
payload = {}
floating_ip_data = None
if module.params['region'] is not None:
payload["region"] = module.params['region']
if module.params["region"] is not None:
payload["region"] = module.params["region"]
if module.params['droplet_id'] is not None:
payload["droplet_id"] = module.params['droplet_id']
if module.params["droplet_id"] is not None:
payload["droplet_id"] = module.params["droplet_id"]
floating_ips = rest.get_paginated_data(base_url='floating_ips?', data_key_name='floating_ips')
floating_ips = rest.get_paginated_data(base_url="floating_ips?", data_key_name="floating_ips")
for floating_ip in floating_ips:
if floating_ip['droplet'] and floating_ip['droplet']['id'] == module.params['droplet_id']:
floating_ip_data = {'floating_ip': floating_ip}
if floating_ip["droplet"] and floating_ip["droplet"]["id"] == module.params["droplet_id"]:
floating_ip_data = {"floating_ip": floating_ip}
if floating_ip_data:
module.exit_json(changed=False, data=floating_ip_data)
@ -244,36 +245,34 @@ def create_floating_ips(module, rest):
if status_code == 202:
module.exit_json(changed=True, data=json_data)
else:
module.fail_json(msg="Error creating floating ip [{}: {}]".format(
status_code, json_data["message"]), region=module.params['region'])
module.fail_json(
msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def main():
module = AnsibleModule(
argument_spec={
'state': {'choices': ['present', 'absent'], 'default': 'present'},
'ip': {'aliases': ['id'], 'required': False},
'region': {'required': False},
'droplet_id': {'required': False, 'type': 'int'},
'oauth_token': {
'no_log': True,
"state": {"choices": ["present", "absent"], "default": "present"},
"ip": {"aliases": ["id"], "required": False},
"region": {"required": False},
"droplet_id": {"required": False, "type": "int"},
"oauth_token": {
"no_log": True,
# Support environment variable for DigitalOcean OAuth Token
'fallback': (env_fallback, ['DO_API_TOKEN', 'DO_API_KEY', 'DO_OAUTH_TOKEN']),
'required': True,
"fallback": (env_fallback, ["DO_API_TOKEN", "DO_API_KEY", "DO_OAUTH_TOKEN"]),
"required": True,
},
'validate_certs': {'type': 'bool', 'default': True},
'timeout': {'type': 'int', 'default': 30},
"validate_certs": {"type": "bool", "default": True},
"timeout": {"type": "int", "default": 30},
},
required_if=[
('state', 'delete', ['ip'])
],
mutually_exclusive=[
['region', 'droplet_id']
],
required_if=[("state", "delete", ["ip"])],
mutually_exclusive=[["region", "droplet_id"]],
)
core(module)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,7 +1,6 @@
#!/usr/bin/python
import json
from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
@ -10,7 +9,7 @@ from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
# Documentation
################################################################################
ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported_by': 'community'}
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
################################################################################
# Main
@ -18,20 +17,24 @@ ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported
def main():
module = GcpModule(argument_spec={'filters': {'type': 'list', 'elements': 'str'}, 'scope': {'required': True, 'type': 'str'}})
module = GcpModule(
argument_spec={"filters": {"type": "list", "elements": "str"}, "scope": {"required": True, "type": "str"}}
)
if module._name == 'gcp_compute_image_facts':
module.deprecate("The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version='2.13')
if module._name == "gcp_compute_image_facts":
module.deprecate(
"The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version="2.13"
)
if not module.params['scopes']:
module.params['scopes'] = ['https://www.googleapis.com/auth/compute']
if not module.params["scopes"]:
module.params["scopes"] = ["https://www.googleapis.com/auth/compute"]
items = fetch_list(module, collection(module), query_options(module.params['filters']))
if items.get('items'):
items = items.get('items')
items = fetch_list(module, collection(module), query_options(module.params["filters"]))
if items.get("items"):
items = items.get("items")
else:
items = []
return_value = {'resources': items}
return_value = {"resources": items}
module.exit_json(**return_value)
@ -40,14 +43,14 @@ def collection(module):
def fetch_list(module, link, query):
auth = GcpSession(module, 'compute')
response = auth.get(link, params={'filter': query})
auth = GcpSession(module, "compute")
response = auth.get(link, params={"filter": query})
return return_if_object(module, response)
def query_options(filters):
if not filters:
return ''
return ""
if len(filters) == 1:
return filters[0]
@ -55,12 +58,12 @@ def query_options(filters):
queries = []
for f in filters:
# For multiple queries, all queries should have ()
if f[0] != '(' and f[-1] != ')':
queries.append("({})".format(''.join(f)))
if f[0] != "(" and f[-1] != ")":
queries.append("({})".format("".join(f)))
else:
queries.append(f)
return ' '.join(queries)
return " ".join(queries)
def return_if_object(module, response):
@ -75,11 +78,11 @@ def return_if_object(module, response):
try:
module.raise_for_status(response)
result = response.json()
except getattr(json.decoder, 'JSONDecodeError', ValueError) as inst:
except getattr(json.decoder, "JSONDecodeError", ValueError) as inst:
module.fail_json(msg=f"Invalid JSON response with error: {inst}")
if navigate_hash(result, ['error', 'errors']):
module.fail_json(msg=navigate_hash(result, ['error', 'errors']))
if navigate_hash(result, ["error", "errors"]):
module.fail_json(msg=navigate_hash(result, ["error", "errors"]))
return result

View file

@ -3,12 +3,9 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
ANSIBLE_METADATA = {'metadata_version': '1.1',
'status': ['preview'],
'supported_by': 'community'}
DOCUMENTATION = '''
DOCUMENTATION = """
---
module: lightsail_region_facts
short_description: Gather facts about AWS Lightsail regions.
@ -24,15 +21,15 @@ requirements:
extends_documentation_fragment:
- aws
- ec2
'''
"""
EXAMPLES = '''
EXAMPLES = """
# Gather facts about all regions
- lightsail_region_facts:
'''
"""
RETURN = '''
RETURN = """
regions:
returned: on success
description: >
@ -46,12 +43,13 @@ regions:
"displayName": "Virginia",
"name": "us-east-1"
}]"
'''
"""
import traceback
try:
import botocore
HAS_BOTOCORE = True
except ImportError:
HAS_BOTOCORE = False
@ -86,18 +84,19 @@ def main():
client = None
try:
client = boto3_conn(module, conn_type='client', resource='lightsail',
region=region, endpoint=ec2_url, **aws_connect_kwargs)
client = boto3_conn(
module, conn_type="client", resource="lightsail", region=region, endpoint=ec2_url, **aws_connect_kwargs
)
except (botocore.exceptions.ClientError, botocore.exceptions.ValidationError) as e:
module.fail_json(msg='Failed while connecting to the lightsail service: %s' % e, exception=traceback.format_exc())
module.fail_json(
msg="Failed while connecting to the lightsail service: %s" % e, exception=traceback.format_exc()
)
response = client.get_regions(
includeAvailabilityZones=False
)
response = client.get_regions(includeAvailabilityZones=False)
module.exit_json(changed=False, data=response)
except (botocore.exceptions.ClientError, Exception) as e:
module.fail_json(msg=str(e), exception=traceback.format_exc())
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -9,6 +9,7 @@ from ansible.module_utils.linode import get_user_agent
LINODE_IMP_ERR = None
try:
from linode_api4 import LinodeClient, StackScript
HAS_LINODE_DEPENDENCY = True
except ImportError:
LINODE_IMP_ERR = traceback.format_exc()
@ -21,57 +22,47 @@ def create_stackscript(module, client, **kwargs):
response = client.linode.stackscript_create(**kwargs)
return response._raw_json
except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def stackscript_available(module, client):
"""Try to retrieve a stackscript."""
try:
label = module.params['label']
desc = module.params['description']
label = module.params["label"]
desc = module.params["description"]
result = client.linode.stackscripts(StackScript.label == label,
StackScript.description == desc,
mine_only=True
)
result = client.linode.stackscripts(StackScript.label == label, StackScript.description == desc, mine_only=True)
return result[0]
except IndexError:
return None
except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def initialise_module():
"""Initialise the module parameter specification."""
return AnsibleModule(
argument_spec=dict(
label=dict(type='str', required=True),
state=dict(
type='str',
required=True,
choices=['present', 'absent']
),
label=dict(type="str", required=True),
state=dict(type="str", required=True, choices=["present", "absent"]),
access_token=dict(
type='str',
type="str",
required=True,
no_log=True,
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']),
fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
),
script=dict(type='str', required=True),
images=dict(type='list', required=True),
description=dict(type='str', required=False),
public=dict(type='bool', required=False, default=False),
script=dict(type="str", required=True),
images=dict(type="list", required=True),
description=dict(type="str", required=False),
public=dict(type="bool", required=False, default=False),
),
supports_check_mode=False
supports_check_mode=False,
)
def build_client(module):
"""Build a LinodeClient."""
return LinodeClient(
module.params['access_token'],
user_agent=get_user_agent('linode_v4_module')
)
return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
def main():
@ -79,30 +70,31 @@ def main():
module = initialise_module()
if not HAS_LINODE_DEPENDENCY:
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR)
module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
client = build_client(module)
stackscript = stackscript_available(module, client)
if module.params['state'] == 'present' and stackscript is not None:
if module.params["state"] == "present" and stackscript is not None:
module.exit_json(changed=False, stackscript=stackscript._raw_json)
elif module.params['state'] == 'present' and stackscript is None:
elif module.params["state"] == "present" and stackscript is None:
stackscript_json = create_stackscript(
module, client,
label=module.params['label'],
script=module.params['script'],
images=module.params['images'],
desc=module.params['description'],
public=module.params['public'],
module,
client,
label=module.params["label"],
script=module.params["script"],
images=module.params["images"],
desc=module.params["description"],
public=module.params["public"],
)
module.exit_json(changed=True, stackscript=stackscript_json)
elif module.params['state'] == 'absent' and stackscript is not None:
elif module.params["state"] == "absent" and stackscript is not None:
stackscript.delete()
module.exit_json(changed=True, stackscript=stackscript._raw_json)
elif module.params['state'] == 'absent' and stackscript is None:
elif module.params["state"] == "absent" and stackscript is None:
module.exit_json(changed=False, stackscript={})

View file

@ -13,6 +13,7 @@ from ansible.module_utils.linode import get_user_agent
LINODE_IMP_ERR = None
try:
from linode_api4 import Instance, LinodeClient
HAS_LINODE_DEPENDENCY = True
except ImportError:
LINODE_IMP_ERR = traceback.format_exc()
@ -21,82 +22,72 @@ except ImportError:
def create_linode(module, client, **kwargs):
"""Creates a Linode instance and handles return format."""
if kwargs['root_pass'] is None:
kwargs.pop('root_pass')
if kwargs["root_pass"] is None:
kwargs.pop("root_pass")
try:
response = client.linode.instance_create(**kwargs)
except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
try:
if isinstance(response, tuple):
instance, root_pass = response
instance_json = instance._raw_json
instance_json.update({'root_pass': root_pass})
instance_json.update({"root_pass": root_pass})
return instance_json
else:
return response._raw_json
except TypeError:
module.fail_json(msg='Unable to parse Linode instance creation'
' response. Please raise a bug against this'
' module on https://github.com/ansible/ansible/issues'
)
module.fail_json(
msg="Unable to parse Linode instance creation"
" response. Please raise a bug against this"
" module on https://github.com/ansible/ansible/issues"
)
def maybe_instance_from_label(module, client):
"""Try to retrieve an instance based on a label."""
try:
label = module.params['label']
label = module.params["label"]
result = client.linode.instances(Instance.label == label)
return result[0]
except IndexError:
return None
except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def initialise_module():
"""Initialise the module parameter specification."""
return AnsibleModule(
argument_spec=dict(
label=dict(type='str', required=True),
state=dict(
type='str',
required=True,
choices=['present', 'absent']
),
label=dict(type="str", required=True),
state=dict(type="str", required=True, choices=["present", "absent"]),
access_token=dict(
type='str',
type="str",
required=True,
no_log=True,
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']),
fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
),
authorized_keys=dict(type='list', required=False),
group=dict(type='str', required=False),
image=dict(type='str', required=False),
region=dict(type='str', required=False),
root_pass=dict(type='str', required=False, no_log=True),
tags=dict(type='list', required=False),
type=dict(type='str', required=False),
stackscript_id=dict(type='int', required=False),
authorized_keys=dict(type="list", required=False),
group=dict(type="str", required=False),
image=dict(type="str", required=False),
region=dict(type="str", required=False),
root_pass=dict(type="str", required=False, no_log=True),
tags=dict(type="list", required=False),
type=dict(type="str", required=False),
stackscript_id=dict(type="int", required=False),
),
supports_check_mode=False,
required_one_of=(
['state', 'label'],
),
required_together=(
['region', 'image', 'type'],
)
required_one_of=(["state", "label"],),
required_together=(["region", "image", "type"],),
)
def build_client(module):
"""Build a LinodeClient."""
return LinodeClient(
module.params['access_token'],
user_agent=get_user_agent('linode_v4_module')
)
return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
def main():
@ -104,34 +95,35 @@ def main():
module = initialise_module()
if not HAS_LINODE_DEPENDENCY:
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR)
module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
client = build_client(module)
instance = maybe_instance_from_label(module, client)
if module.params['state'] == 'present' and instance is not None:
if module.params["state"] == "present" and instance is not None:
module.exit_json(changed=False, instance=instance._raw_json)
elif module.params['state'] == 'present' and instance is None:
elif module.params["state"] == "present" and instance is None:
instance_json = create_linode(
module, client,
authorized_keys=module.params['authorized_keys'],
group=module.params['group'],
image=module.params['image'],
label=module.params['label'],
region=module.params['region'],
root_pass=module.params['root_pass'],
tags=module.params['tags'],
ltype=module.params['type'],
stackscript_id=module.params['stackscript_id'],
module,
client,
authorized_keys=module.params["authorized_keys"],
group=module.params["group"],
image=module.params["image"],
label=module.params["label"],
region=module.params["region"],
root_pass=module.params["root_pass"],
tags=module.params["tags"],
ltype=module.params["type"],
stackscript_id=module.params["stackscript_id"],
)
module.exit_json(changed=True, instance=instance_json)
elif module.params['state'] == 'absent' and instance is not None:
elif module.params["state"] == "absent" and instance is not None:
instance.delete()
module.exit_json(changed=True, instance=instance._raw_json)
elif module.params['state'] == 'absent' and instance is None:
elif module.params["state"] == "absent" and instance is None:
module.exit_json(changed=False, instance={})

View file

@ -8,14 +8,9 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
ANSIBLE_METADATA = {
'metadata_version': '1.1',
'status': ['preview'],
'supported_by': 'community'
}
DOCUMENTATION = '''
DOCUMENTATION = """
---
module: scaleway_compute
short_description: Scaleway compute management module
@ -120,9 +115,9 @@ options:
- If no value provided, the default security group or current security group will be used
required: false
version_added: "2.8"
'''
"""
EXAMPLES = '''
EXAMPLES = """
- name: Create a server
scaleway_compute:
name: foobar
@ -156,10 +151,10 @@ EXAMPLES = '''
organization: 951df375-e094-4d26-97c1-ba548eeb9c42
region: ams1
commercial_type: VC1S
'''
"""
RETURN = '''
'''
RETURN = """
"""
import datetime
import time
@ -167,19 +162,9 @@ import time
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.scaleway import SCALEWAY_LOCATION, Scaleway, scaleway_argument_spec
SCALEWAY_SERVER_STATES = (
'stopped',
'stopping',
'starting',
'running',
'locked'
)
SCALEWAY_SERVER_STATES = ("stopped", "stopping", "starting", "running", "locked")
SCALEWAY_TRANSITIONS_STATES = (
"stopping",
"starting",
"pending"
)
SCALEWAY_TRANSITIONS_STATES = ("stopping", "starting", "pending")
def check_image_id(compute_api, image_id):
@ -188,9 +173,11 @@ def check_image_id(compute_api, image_id):
if response.ok and response.json:
image_ids = [image["id"] for image in response.json["images"]]
if image_id not in image_ids:
compute_api.module.fail_json(msg='Error in getting image %s on %s' % (image_id, compute_api.module.params.get('api_url')))
compute_api.module.fail_json(
msg="Error in getting image %s on %s" % (image_id, compute_api.module.params.get("api_url"))
)
else:
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get('api_url'))
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get("api_url"))
def fetch_state(compute_api, server):
@ -201,7 +188,7 @@ def fetch_state(compute_api, server):
return "absent"
if not response.ok:
msg = 'Error during state fetching: (%s) %s' % (response.status_code, response.json)
msg = "Error during state fetching: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
try:
@ -243,7 +230,7 @@ def public_ip_payload(compute_api, public_ip):
# We check that the IP we want to attach exists, if so its ID is returned
response = compute_api.get("ips")
if not response.ok:
msg = 'Error during public IP validation: (%s) %s' % (response.status_code, response.json)
msg = "Error during public IP validation: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
ip_list = []
@ -260,14 +247,15 @@ def public_ip_payload(compute_api, public_ip):
def create_server(compute_api, server):
compute_api.module.debug("Starting a create_server")
target_server = None
data = {"enable_ipv6": server["enable_ipv6"],
"tags": server["tags"],
"commercial_type": server["commercial_type"],
"image": server["image"],
"dynamic_ip_required": server["dynamic_ip_required"],
"name": server["name"],
"organization": server["organization"]
}
data = {
"enable_ipv6": server["enable_ipv6"],
"tags": server["tags"],
"commercial_type": server["commercial_type"],
"image": server["image"],
"dynamic_ip_required": server["dynamic_ip_required"],
"name": server["name"],
"organization": server["organization"],
}
if server["boot_type"]:
data["boot_type"] = server["boot_type"]
@ -278,7 +266,7 @@ def create_server(compute_api, server):
response = compute_api.post(path="servers", data=data)
if not response.ok:
msg = 'Error during server creation: (%s) %s' % (response.status_code, response.json)
msg = "Error during server creation: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
try:
@ -304,10 +292,9 @@ def start_server(compute_api, server):
def perform_action(compute_api, server, action):
response = compute_api.post(path="servers/%s/action" % server["id"],
data={"action": action})
response = compute_api.post(path="servers/%s/action" % server["id"], data={"action": action})
if not response.ok:
msg = 'Error during server %s: (%s) %s' % (action, response.status_code, response.json)
msg = "Error during server %s: (%s) %s" % (action, response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
wait_to_complete_state_transition(compute_api=compute_api, server=server)
@ -319,7 +306,7 @@ def remove_server(compute_api, server):
compute_api.module.debug("Starting remove server strategy")
response = compute_api.delete(path="servers/%s" % server["id"])
if not response.ok:
msg = 'Error during server deletion: (%s) %s' % (response.status_code, response.json)
msg = "Error during server deletion: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
wait_to_complete_state_transition(compute_api=compute_api, server=server)
@ -341,14 +328,17 @@ def present_strategy(compute_api, wished_server):
else:
target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
wished_server=wished_server):
if server_attributes_should_be_changed(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True
if compute_api.module.check_mode:
return changed, {"status": "Server %s attributes would be changed." % target_server["id"]}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
return changed, target_server
@ -375,7 +365,7 @@ def absent_strategy(compute_api, wished_server):
response = stop_server(compute_api=compute_api, server=target_server)
if not response.ok:
err_msg = f'Error while stopping a server before removing it [{response.status_code}: {response.json}]'
err_msg = f"Error while stopping a server before removing it [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=err_msg)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
@ -383,7 +373,7 @@ def absent_strategy(compute_api, wished_server):
response = remove_server(compute_api=compute_api, server=target_server)
if not response.ok:
err_msg = f'Error while removing server [{response.status_code}: {response.json}]'
err_msg = f"Error while removing server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=err_msg)
return changed, {"status": "Server %s deleted" % target_server["id"]}
@ -403,14 +393,17 @@ def running_strategy(compute_api, wished_server):
else:
target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
wished_server=wished_server):
if server_attributes_should_be_changed(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True
if compute_api.module.check_mode:
return changed, {"status": "Server %s attributes would be changed before running it." % target_server["id"]}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
current_state = fetch_state(compute_api=compute_api, server=target_server)
if current_state not in ("running", "starting"):
@ -422,7 +415,7 @@ def running_strategy(compute_api, wished_server):
response = start_server(compute_api=compute_api, server=target_server)
if not response.ok:
msg = f'Error while running server [{response.status_code}: {response.json}]'
msg = f"Error while running server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=msg)
return changed, target_server
@ -435,7 +428,6 @@ def stop_strategy(compute_api, wished_server):
changed = False
if not query_results:
if compute_api.module.check_mode:
return changed, {"status": "A server would be created before being stopped."}
@ -446,15 +438,19 @@ def stop_strategy(compute_api, wished_server):
compute_api.module.debug("stop_strategy: Servers are found.")
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
wished_server=wished_server):
if server_attributes_should_be_changed(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True
if compute_api.module.check_mode:
return changed, {
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]}
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]
}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
@ -472,7 +468,7 @@ def stop_strategy(compute_api, wished_server):
compute_api.module.debug(response.ok)
if not response.ok:
msg = f'Error while stopping server [{response.status_code}: {response.json}]'
msg = f"Error while stopping server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=msg)
return changed, target_server
@ -492,16 +488,19 @@ def restart_strategy(compute_api, wished_server):
else:
target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api,
target_server=target_server,
wished_server=wished_server):
if server_attributes_should_be_changed(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True
if compute_api.module.check_mode:
return changed, {
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]}
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]
}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
changed = True
if compute_api.module.check_mode:
@ -513,14 +512,14 @@ def restart_strategy(compute_api, wished_server):
response = restart_server(compute_api=compute_api, server=target_server)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
if not response.ok:
msg = f'Error while restarting server that was running [{response.status_code}: {response.json}].'
msg = f"Error while restarting server that was running [{response.status_code}: {response.json}]."
compute_api.module.fail_json(msg=msg)
if fetch_state(compute_api=compute_api, server=target_server) in ("stopped",):
response = restart_server(compute_api=compute_api, server=target_server)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
if not response.ok:
msg = f'Error while restarting server that was stopped [{response.status_code}: {response.json}].'
msg = f"Error while restarting server that was stopped [{response.status_code}: {response.json}]."
compute_api.module.fail_json(msg=msg)
return changed, target_server
@ -531,18 +530,17 @@ state_strategy = {
"restarted": restart_strategy,
"stopped": stop_strategy,
"running": running_strategy,
"absent": absent_strategy
"absent": absent_strategy,
}
def find(compute_api, wished_server, per_page=1):
compute_api.module.debug("Getting inside find")
# Only the name attribute is accepted in the Compute query API
response = compute_api.get("servers", params={"name": wished_server["name"],
"per_page": per_page})
response = compute_api.get("servers", params={"name": wished_server["name"], "per_page": per_page})
if not response.ok:
msg = 'Error during server search: (%s) %s' % (response.status_code, response.json)
msg = "Error during server search: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
search_results = response.json["servers"]
@ -563,16 +561,22 @@ def server_attributes_should_be_changed(compute_api, target_server, wished_serve
compute_api.module.debug("Checking if server attributes should be changed")
compute_api.module.debug("Current Server: %s" % target_server)
compute_api.module.debug("Wished Server: %s" % wished_server)
debug_dict = dict((x, (target_server[x], wished_server[x]))
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
if x in target_server and x in wished_server)
debug_dict = dict(
(x, (target_server[x], wished_server[x]))
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
if x in target_server and x in wished_server
)
compute_api.module.debug("Debug dict %s" % debug_dict)
try:
for key in PATCH_MUTABLE_SERVER_ATTRIBUTES:
if key in target_server and key in wished_server:
# When you are working with dict, only ID matter as we ask user to put only the resource ID in the playbook
if isinstance(target_server[key], dict) and wished_server[key] and "id" in target_server[key].keys(
) and target_server[key]["id"] != wished_server[key]:
if (
isinstance(target_server[key], dict)
and wished_server[key]
and "id" in target_server[key].keys()
and target_server[key]["id"] != wished_server[key]
):
return True
# Handling other structure compare simply the two objects content
elif not isinstance(target_server[key], dict) and target_server[key] != wished_server[key]:
@ -598,10 +602,9 @@ def server_change_attributes(compute_api, target_server, wished_server):
elif not isinstance(target_server[key], dict):
patch_payload[key] = wished_server[key]
response = compute_api.patch(path="servers/%s" % target_server["id"],
data=patch_payload)
response = compute_api.patch(path="servers/%s" % target_server["id"], data=patch_payload)
if not response.ok:
msg = 'Error during server attributes patching: (%s) %s' % (response.status_code, response.json)
msg = "Error during server attributes patching: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
try:
@ -625,9 +628,9 @@ def core(module):
"boot_type": module.params["boot_type"],
"tags": module.params["tags"],
"organization": module.params["organization"],
"security_group": module.params["security_group"]
"security_group": module.params["security_group"],
}
module.params['api_url'] = SCALEWAY_LOCATION[region]["api_endpoint"]
module.params["api_url"] = SCALEWAY_LOCATION[region]["api_endpoint"]
compute_api = Scaleway(module=module)
@ -643,22 +646,24 @@ def core(module):
def main():
argument_spec = scaleway_argument_spec()
argument_spec.update(dict(
image=dict(required=True),
name=dict(),
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
commercial_type=dict(required=True),
enable_ipv6=dict(default=False, type="bool"),
boot_type=dict(choices=['bootscript', 'local']),
public_ip=dict(default="absent"),
state=dict(choices=state_strategy.keys(), default='present'),
tags=dict(type="list", default=[]),
organization=dict(required=True),
wait=dict(type="bool", default=False),
wait_timeout=dict(type="int", default=300),
wait_sleep_time=dict(type="int", default=3),
security_group=dict(),
))
argument_spec.update(
dict(
image=dict(required=True),
name=dict(),
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
commercial_type=dict(required=True),
enable_ipv6=dict(default=False, type="bool"),
boot_type=dict(choices=["bootscript", "local"]),
public_ip=dict(default="absent"),
state=dict(choices=state_strategy.keys(), default="present"),
tags=dict(type="list", default=[]),
organization=dict(required=True),
wait=dict(type="bool", default=False),
wait_timeout=dict(type="int", default=300),
wait_sleep_time=dict(type="int", default=3),
security_group=dict(),
)
)
module = AnsibleModule(
argument_spec=argument_spec,
supports_check_mode=True,
@ -667,5 +672,5 @@ def main():
core(module)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -32,32 +32,30 @@ Returns:
def run_module():
"""
Main execution function for the x25519_pubkey Ansible module.
Handles parameter validation, private key processing, public key derivation,
and optional file output with idempotent behavior.
"""
module_args = {
'private_key_b64': {'type': 'str', 'required': False},
'private_key_path': {'type': 'path', 'required': False},
'public_key_path': {'type': 'path', 'required': False},
"private_key_b64": {"type": "str", "required": False},
"private_key_path": {"type": "path", "required": False},
"public_key_path": {"type": "path", "required": False},
}
result = {
'changed': False,
'public_key': '',
"changed": False,
"public_key": "",
}
module = AnsibleModule(
argument_spec=module_args,
required_one_of=[['private_key_b64', 'private_key_path']],
supports_check_mode=True
argument_spec=module_args, required_one_of=[["private_key_b64", "private_key_path"]], supports_check_mode=True
)
priv_b64 = None
if module.params['private_key_path']:
if module.params["private_key_path"]:
try:
with open(module.params['private_key_path'], 'rb') as f:
with open(module.params["private_key_path"], "rb") as f:
data = f.read()
try:
# First attempt: assume file contains base64 text data
@ -71,12 +69,14 @@ def run_module():
# whitespace-like bytes (0x09, 0x0A, etc.) that must be preserved
# Stripping would corrupt the key and cause "got 31 bytes" errors
if len(data) != 32:
module.fail_json(msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes")
module.fail_json(
msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes"
)
priv_b64 = base64.b64encode(data).decode()
except OSError as e:
module.fail_json(msg=f"Failed to read private key file: {e}")
else:
priv_b64 = module.params['private_key_b64']
priv_b64 = module.params["private_key_b64"]
# Validate input parameters
if not priv_b64:
@ -93,15 +93,12 @@ def run_module():
try:
priv_key = x25519.X25519PrivateKey.from_private_bytes(priv_raw)
pub_key = priv_key.public_key()
pub_raw = pub_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw
)
pub_raw = pub_key.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
pub_b64 = base64.b64encode(pub_raw).decode()
result['public_key'] = pub_b64
result["public_key"] = pub_b64
if module.params['public_key_path']:
pub_path = module.params['public_key_path']
if module.params["public_key_path"]:
pub_path = module.params["public_key_path"]
existing = None
try:
@ -112,13 +109,13 @@ def run_module():
if existing != pub_b64:
try:
with open(pub_path, 'w') as f:
with open(pub_path, "w") as f:
f.write(pub_b64)
result['changed'] = True
result["changed"] = True
except OSError as e:
module.fail_json(msg=f"Failed to write public key file: {e}")
result['public_key_path'] = pub_path
result["public_key_path"] = pub_path
except Exception as e:
module.fail_json(msg=f"Failed to derive public key: {e}")
@ -131,5 +128,5 @@ def main():
run_module()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -136,6 +136,8 @@
value: 1
- item: "{{ 'net.ipv6.conf.all.forwarding' if ipv6_support else none }}"
value: 1
- item: net.ipv4.conf.all.route_localnet
value: 1
- name: Install packages (batch optimization)
include_tasks: packages.yml

View file

@ -38,11 +38,11 @@ COMMIT
# Allow traffic from the VPN network to the outside world, and replies
{% if ipsec_enabled %}
# For IPsec traffic - NAT the decrypted packets from the VPN subnet
-A POSTROUTING -s {{ strongswan_network }} {{ '-j SNAT --to ' + snat_aipv4 if snat_aipv4 else '-j MASQUERADE' }}
-A POSTROUTING -s {{ strongswan_network }} -o {{ ansible_default_ipv4['interface'] }} {{ '-j SNAT --to ' + snat_aipv4 if snat_aipv4 else '-j MASQUERADE' }}
{% endif %}
{% if wireguard_enabled %}
# For WireGuard traffic - NAT packets from the VPN subnet
-A POSTROUTING -s {{ wireguard_network_ipv4 }} {{ '-j SNAT --to ' + snat_aipv4 if snat_aipv4 else '-j MASQUERADE' }}
-A POSTROUTING -s {{ wireguard_network_ipv4 }} -o {{ ansible_default_ipv4['interface'] }} {{ '-j SNAT --to ' + snat_aipv4 if snat_aipv4 else '-j MASQUERADE' }}
{% endif %}
@ -85,8 +85,8 @@ COMMIT
# DUMMY interfaces are the proper way to install IPs without assigning them any
# particular virtual (tun,tap,...) or physical (ethernet) interface.
# Accept DNS traffic to the local DNS resolver
-A INPUT -d {{ local_service_ip }} -p udp --dport 53 -j ACCEPT
# Accept DNS traffic to the local DNS resolver from VPN clients only
-A INPUT -s {{ subnets | join(',') }} -d {{ local_service_ip }} -p udp --dport 53 -j ACCEPT
# Drop traffic between VPN clients
-A FORWARD -s {{ subnets | join(',') }} -d {{ subnets | join(',') }} -j {{ "DROP" if BetweenClients_DROP else "ACCEPT" }}

View file

@ -37,11 +37,11 @@ COMMIT
# Allow traffic from the VPN network to the outside world, and replies
{% if ipsec_enabled %}
# For IPsec traffic - NAT the decrypted packets from the VPN subnet
-A POSTROUTING -s {{ strongswan_network_ipv6 }} {{ '-j SNAT --to ' + ipv6_egress_ip | ansible.utils.ipaddr('address') if alternative_ingress_ip else '-j MASQUERADE' }}
-A POSTROUTING -s {{ strongswan_network_ipv6 }} -o {{ ansible_default_ipv6['interface'] }} {{ '-j SNAT --to ' + ipv6_egress_ip | ansible.utils.ipaddr('address') if alternative_ingress_ip else '-j MASQUERADE' }}
{% endif %}
{% if wireguard_enabled %}
# For WireGuard traffic - NAT packets from the VPN subnet
-A POSTROUTING -s {{ wireguard_network_ipv6 }} {{ '-j SNAT --to ' + ipv6_egress_ip | ansible.utils.ipaddr('address') if alternative_ingress_ip else '-j MASQUERADE' }}
-A POSTROUTING -s {{ wireguard_network_ipv6 }} -o {{ ansible_default_ipv6['interface'] }} {{ '-j SNAT --to ' + ipv6_egress_ip | ansible.utils.ipaddr('address') if alternative_ingress_ip else '-j MASQUERADE' }}
{% endif %}
COMMIT
@ -95,8 +95,8 @@ COMMIT
# DUMMY interfaces are the proper way to install IPs without assigning them any
# particular virtual (tun,tap,...) or physical (ethernet) interface.
# Accept DNS traffic to the local DNS resolver
-A INPUT -d {{ local_service_ipv6 }}/128 -p udp --dport 53 -j ACCEPT
# Accept DNS traffic to the local DNS resolver from VPN clients only
-A INPUT -s {{ subnets | join(',') }} -d {{ local_service_ipv6 }}/128 -p udp --dport 53 -j ACCEPT
# Drop traffic between VPN clients
-A FORWARD -s {{ subnets | join(',') }} -d {{ subnets | join(',') }} -j {{ "DROP" if BetweenClients_DROP else "ACCEPT" }}

View file

@ -3,9 +3,16 @@
systemd:
daemon_reload: true
- name: restart dnscrypt-proxy.socket
systemd:
name: dnscrypt-proxy.socket
state: restarted
daemon_reload: true
when: ansible_distribution == 'Ubuntu' or ansible_distribution == 'Debian'
- name: restart dnscrypt-proxy
systemd:
name: dnscrypt-proxy
state: restarted
daemon_reload: true
when: ansible_distribution == 'Ubuntu'
when: ansible_distribution == 'Ubuntu' or ansible_distribution == 'Debian'

View file

@ -3,7 +3,6 @@
include_tasks: ubuntu.yml
when: ansible_distribution == 'Debian' or ansible_distribution == 'Ubuntu'
- name: dnscrypt-proxy ip-blacklist configured
template:
src: ip-blacklist.txt.j2
@ -26,6 +25,14 @@
- meta: flush_handlers
- name: Ubuntu | Ensure dnscrypt-proxy socket is enabled and started
systemd:
name: dnscrypt-proxy.socket
enabled: true
state: started
daemon_reload: true
when: ansible_distribution == 'Debian' or ansible_distribution == 'Ubuntu'
- name: dnscrypt-proxy enabled and started
service:
name: dnscrypt-proxy

View file

@ -50,6 +50,49 @@
owner: root
group: root
- name: Ubuntu | Ensure socket override directory exists
file:
path: /etc/systemd/system/dnscrypt-proxy.socket.d/
state: directory
mode: '0755'
owner: root
group: root
- name: Ubuntu | Configure dnscrypt-proxy socket to listen on VPN IPs
copy:
dest: /etc/systemd/system/dnscrypt-proxy.socket.d/10-algo-override.conf
content: |
[Socket]
# Clear default listeners
ListenStream=
ListenDatagram=
# Add VPN service IPs
ListenStream={{ local_service_ip }}:53
ListenDatagram={{ local_service_ip }}:53
{% if ipv6_support %}
ListenStream=[{{ local_service_ipv6 }}]:53
ListenDatagram=[{{ local_service_ipv6 }}]:53
{% endif %}
NoDelay=true
DeferAcceptSec=1
mode: '0644'
register: socket_override
notify:
- daemon-reload
- restart dnscrypt-proxy.socket
- restart dnscrypt-proxy
- name: Ubuntu | Reload systemd daemon after socket configuration
systemd:
daemon_reload: true
when: socket_override.changed
- name: Ubuntu | Restart dnscrypt-proxy socket to apply configuration
systemd:
name: dnscrypt-proxy.socket
state: restarted
when: socket_override.changed
- name: Ubuntu | Add custom requirements to successfully start the unit
copy:
dest: /etc/systemd/system/dnscrypt-proxy.service.d/99-algo.conf
@ -61,8 +104,12 @@
[Service]
AmbientCapabilities=CAP_NET_BIND_SERVICE
notify:
- restart dnscrypt-proxy
register: dnscrypt_override
- name: Ubuntu | Reload systemd daemon if override changed
systemd:
daemon_reload: true
when: dnscrypt_override.changed
- name: Ubuntu | Apply systemd security hardening for dnscrypt-proxy
copy:
@ -95,6 +142,9 @@
owner: root
group: root
mode: '0644'
notify:
- daemon-reload
- restart dnscrypt-proxy
register: dnscrypt_hardening
- name: Ubuntu | Reload systemd daemon if hardening changed
systemd:
daemon_reload: true
when: dnscrypt_hardening.changed

View file

@ -37,10 +37,16 @@
## List of local addresses and ports to listen to. Can be IPv4 and/or IPv6.
## Note: When using systemd socket activation, choose an empty set (i.e. [] ).
{% if ansible_distribution == 'Ubuntu' or ansible_distribution == 'Debian' %}
# Using systemd socket activation on Ubuntu/Debian
listen_addresses = []
{% else %}
# Direct binding on non-systemd systems
listen_addresses = [
'{{ local_service_ip }}:53'{% if ipv6_support %},
'[{{ local_service_ipv6 }}]:53'{% endif %}
]
{% endif %}
## Maximum number of simultaneous client connections to accept

View file

@ -12,15 +12,6 @@
- { name: 'kernel.dmesg_restrict', value: '1' }
when: privacy_advanced.reduce_kernel_verbosity | bool
- name: Disable BPF JIT if available (optional security hardening)
sysctl:
name: net.core.bpf_jit_enable
value: '0'
state: present
reload: yes
when: privacy_advanced.reduce_kernel_verbosity | bool
ignore_errors: yes
- name: Configure kernel parameters for privacy
lineinfile:
path: /etc/sysctl.d/99-privacy.conf
@ -31,18 +22,8 @@
- "# Privacy enhancements - reduce kernel logging"
- "kernel.printk = 3 4 1 3"
- "kernel.dmesg_restrict = 1"
- "# Note: net.core.bpf_jit_enable may not be available on all kernels"
when: privacy_advanced.reduce_kernel_verbosity | bool
- name: Add BPF JIT disable to sysctl config if kernel supports it
lineinfile:
path: /etc/sysctl.d/99-privacy.conf
line: "net.core.bpf_jit_enable = 0 # Disable BPF JIT to reduce attack surface"
create: yes
mode: '0644'
when: privacy_advanced.reduce_kernel_verbosity | bool
ignore_errors: yes
- name: Configure journal settings for privacy
lineinfile:
path: /etc/systemd/journald.conf

View file

@ -3,6 +3,7 @@
Track test effectiveness by analyzing CI failures and correlating with issues/PRs
This helps identify which tests actually catch bugs vs just failing randomly
"""
import json
import subprocess
from collections import defaultdict
@ -12,7 +13,7 @@ from pathlib import Path
def get_github_api_data(endpoint):
"""Fetch data from GitHub API"""
cmd = ['gh', 'api', endpoint]
cmd = ["gh", "api", endpoint]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Error fetching {endpoint}: {result.stderr}")
@ -25,40 +26,38 @@ def analyze_workflow_runs(repo_owner, repo_name, days_back=30):
since = (datetime.now() - timedelta(days=days_back)).isoformat()
# Get workflow runs
runs = get_github_api_data(
f'/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure'
)
runs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure")
if not runs:
return {}
test_failures = defaultdict(list)
for run in runs.get('workflow_runs', []):
for run in runs.get("workflow_runs", []):
# Get jobs for this run
jobs = get_github_api_data(
f'/repos/{repo_owner}/{repo_name}/actions/runs/{run["id"]}/jobs'
)
jobs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs/{run['id']}/jobs")
if not jobs:
continue
for job in jobs.get('jobs', []):
if job['conclusion'] == 'failure':
for job in jobs.get("jobs", []):
if job["conclusion"] == "failure":
# Try to extract which test failed from logs
logs_url = job.get('logs_url')
logs_url = job.get("logs_url")
if logs_url:
# Parse logs to find test failures
test_name = extract_failed_test(job['name'], run['id'])
test_name = extract_failed_test(job["name"], run["id"])
if test_name:
test_failures[test_name].append({
'run_id': run['id'],
'run_number': run['run_number'],
'date': run['created_at'],
'branch': run['head_branch'],
'commit': run['head_sha'][:7],
'pr': extract_pr_number(run)
})
test_failures[test_name].append(
{
"run_id": run["id"],
"run_number": run["run_number"],
"date": run["created_at"],
"branch": run["head_branch"],
"commit": run["head_sha"][:7],
"pr": extract_pr_number(run),
}
)
return test_failures
@ -67,47 +66,44 @@ def extract_failed_test(job_name, run_id):
"""Extract test name from job - this is simplified"""
# Map job names to test categories
job_to_tests = {
'Basic sanity tests': 'test_basic_sanity',
'Ansible syntax check': 'ansible_syntax',
'Docker build test': 'docker_tests',
'Configuration generation test': 'config_generation',
'Ansible dry-run validation': 'ansible_dry_run'
"Basic sanity tests": "test_basic_sanity",
"Ansible syntax check": "ansible_syntax",
"Docker build test": "docker_tests",
"Configuration generation test": "config_generation",
"Ansible dry-run validation": "ansible_dry_run",
}
return job_to_tests.get(job_name, job_name)
def extract_pr_number(run):
"""Extract PR number from workflow run"""
for pr in run.get('pull_requests', []):
return pr['number']
for pr in run.get("pull_requests", []):
return pr["number"]
return None
def correlate_with_issues(repo_owner, repo_name, test_failures):
"""Correlate test failures with issues/PRs that fixed them"""
correlations = defaultdict(lambda: {'caught_bugs': 0, 'false_positives': 0})
correlations = defaultdict(lambda: {"caught_bugs": 0, "false_positives": 0})
for test_name, failures in test_failures.items():
for failure in failures:
if failure['pr']:
if failure["pr"]:
# Check if PR was merged (indicating it fixed a real issue)
pr = get_github_api_data(
f'/repos/{repo_owner}/{repo_name}/pulls/{failure["pr"]}'
)
pr = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/pulls/{failure['pr']}")
if pr and pr.get('merged'):
if pr and pr.get("merged"):
# Check PR title/body for bug indicators
title = pr.get('title', '').lower()
body = pr.get('body', '').lower()
title = pr.get("title", "").lower()
body = pr.get("body", "").lower()
bug_keywords = ['fix', 'bug', 'error', 'issue', 'broken', 'fail']
is_bug_fix = any(keyword in title or keyword in body
for keyword in bug_keywords)
bug_keywords = ["fix", "bug", "error", "issue", "broken", "fail"]
is_bug_fix = any(keyword in title or keyword in body for keyword in bug_keywords)
if is_bug_fix:
correlations[test_name]['caught_bugs'] += 1
correlations[test_name]["caught_bugs"] += 1
else:
correlations[test_name]['false_positives'] += 1
correlations[test_name]["false_positives"] += 1
return correlations
@ -133,8 +129,8 @@ def generate_effectiveness_report(test_failures, correlations):
scores = []
for test_name, failures in test_failures.items():
failure_count = len(failures)
caught = correlations[test_name]['caught_bugs']
false_pos = correlations[test_name]['false_positives']
caught = correlations[test_name]["caught_bugs"]
false_pos = correlations[test_name]["false_positives"]
# Calculate effectiveness (bugs caught / total failures)
if failure_count > 0:
@ -159,12 +155,12 @@ def generate_effectiveness_report(test_failures, correlations):
elif effectiveness > 0.8:
report.append(f"- ✅ `{test_name}` is highly effective ({effectiveness:.0%})")
return '\n'.join(report)
return "\n".join(report)
def save_metrics(test_failures, correlations):
"""Save metrics to JSON for historical tracking"""
metrics_file = Path('.metrics/test-effectiveness.json')
metrics_file = Path(".metrics/test-effectiveness.json")
metrics_file.parent.mkdir(exist_ok=True)
# Load existing metrics
@ -176,38 +172,34 @@ def save_metrics(test_failures, correlations):
# Add current metrics
current = {
'date': datetime.now().isoformat(),
'test_failures': {
test: len(failures) for test, failures in test_failures.items()
},
'effectiveness': {
"date": datetime.now().isoformat(),
"test_failures": {test: len(failures) for test, failures in test_failures.items()},
"effectiveness": {
test: {
'caught_bugs': data['caught_bugs'],
'false_positives': data['false_positives'],
'score': data['caught_bugs'] / (data['caught_bugs'] + data['false_positives'])
if (data['caught_bugs'] + data['false_positives']) > 0 else 0
"caught_bugs": data["caught_bugs"],
"false_positives": data["false_positives"],
"score": data["caught_bugs"] / (data["caught_bugs"] + data["false_positives"])
if (data["caught_bugs"] + data["false_positives"]) > 0
else 0,
}
for test, data in correlations.items()
}
},
}
historical.append(current)
# Keep last 12 months of data
cutoff = datetime.now() - timedelta(days=365)
historical = [
h for h in historical
if datetime.fromisoformat(h['date']) > cutoff
]
historical = [h for h in historical if datetime.fromisoformat(h["date"]) > cutoff]
with open(metrics_file, 'w') as f:
with open(metrics_file, "w") as f:
json.dump(historical, f, indent=2)
if __name__ == '__main__':
if __name__ == "__main__":
# Configure these for your repo
REPO_OWNER = 'trailofbits'
REPO_NAME = 'algo'
REPO_OWNER = "trailofbits"
REPO_NAME = "algo"
print("Analyzing test effectiveness...")
@ -223,9 +215,9 @@ if __name__ == '__main__':
print("\n" + report)
# Save report
report_file = Path('.metrics/test-effectiveness-report.md')
report_file = Path(".metrics/test-effectiveness-report.md")
report_file.parent.mkdir(exist_ok=True)
with open(report_file, 'w') as f:
with open(report_file, "w") as f:
f.write(report)
print(f"\nReport saved to: {report_file}")

View file

@ -1,4 +1,5 @@
"""Test fixtures for Algo unit tests"""
from pathlib import Path
import yaml
@ -6,7 +7,7 @@ import yaml
def load_test_variables():
"""Load test variables from YAML fixture"""
fixture_path = Path(__file__).parent / 'test_variables.yml'
fixture_path = Path(__file__).parent / "test_variables.yml"
with open(fixture_path) as f:
return yaml.safe_load(f)

View file

@ -72,6 +72,12 @@ CA_password: test-ca-pass
# System
ansible_ssh_port: 4160
ansible_python_interpreter: /usr/bin/python3
ansible_default_ipv4:
interface: eth0
address: 10.0.0.1
ansible_default_ipv6:
interface: eth0
address: 'fd9d:bc11:4020::1'
BetweenClients_DROP: 'Y'
ssh_tunnels_config_path: /etc/ssh/ssh_tunnels
config_prefix: /etc/algo

View file

@ -2,49 +2,56 @@
"""
Wrapper for Ansible's service module that always succeeds for known services
"""
import json
import sys
# Parse module arguments
args = json.loads(sys.stdin.read())
module_args = args.get('ANSIBLE_MODULE_ARGS', {})
module_args = args.get("ANSIBLE_MODULE_ARGS", {})
service_name = module_args.get('name', '')
state = module_args.get('state', 'started')
service_name = module_args.get("name", "")
state = module_args.get("state", "started")
# Known services that should always succeed
known_services = [
'netfilter-persistent', 'iptables', 'wg-quick@wg0', 'strongswan-starter',
'ipsec', 'apparmor', 'unattended-upgrades', 'systemd-networkd',
'systemd-resolved', 'rsyslog', 'ipfw', 'cron'
"netfilter-persistent",
"iptables",
"wg-quick@wg0",
"strongswan-starter",
"ipsec",
"apparmor",
"unattended-upgrades",
"systemd-networkd",
"systemd-resolved",
"rsyslog",
"ipfw",
"cron",
]
# Check if it's a known service
service_found = False
for svc in known_services:
if service_name == svc or service_name.startswith(svc + '.'):
if service_name == svc or service_name.startswith(svc + "."):
service_found = True
break
if service_found:
# Return success
result = {
'changed': True if state in ['started', 'stopped', 'restarted', 'reloaded'] else False,
'name': service_name,
'state': state,
'status': {
'LoadState': 'loaded',
'ActiveState': 'active' if state != 'stopped' else 'inactive',
'SubState': 'running' if state != 'stopped' else 'dead'
}
"changed": True if state in ["started", "stopped", "restarted", "reloaded"] else False,
"name": service_name,
"state": state,
"status": {
"LoadState": "loaded",
"ActiveState": "active" if state != "stopped" else "inactive",
"SubState": "running" if state != "stopped" else "dead",
},
}
print(json.dumps(result))
sys.exit(0)
else:
# Service not found
error = {
'failed': True,
'msg': f'Could not find the requested service {service_name}: '
}
error = {"failed": True, "msg": f"Could not find the requested service {service_name}: "}
print(json.dumps(error))
sys.exit(1)

View file

@ -9,44 +9,44 @@ from ansible.module_utils.basic import AnsibleModule
def main():
module = AnsibleModule(
argument_spec={
'name': {'type': 'list', 'aliases': ['pkg', 'package']},
'state': {'type': 'str', 'default': 'present', 'choices': ['present', 'absent', 'latest', 'build-dep', 'fixed']},
'update_cache': {'type': 'bool', 'default': False},
'cache_valid_time': {'type': 'int', 'default': 0},
'install_recommends': {'type': 'bool'},
'force': {'type': 'bool', 'default': False},
'allow_unauthenticated': {'type': 'bool', 'default': False},
'allow_downgrade': {'type': 'bool', 'default': False},
'allow_change_held_packages': {'type': 'bool', 'default': False},
'dpkg_options': {'type': 'str', 'default': 'force-confdef,force-confold'},
'autoremove': {'type': 'bool', 'default': False},
'purge': {'type': 'bool', 'default': False},
'force_apt_get': {'type': 'bool', 'default': False},
"name": {"type": "list", "aliases": ["pkg", "package"]},
"state": {
"type": "str",
"default": "present",
"choices": ["present", "absent", "latest", "build-dep", "fixed"],
},
"update_cache": {"type": "bool", "default": False},
"cache_valid_time": {"type": "int", "default": 0},
"install_recommends": {"type": "bool"},
"force": {"type": "bool", "default": False},
"allow_unauthenticated": {"type": "bool", "default": False},
"allow_downgrade": {"type": "bool", "default": False},
"allow_change_held_packages": {"type": "bool", "default": False},
"dpkg_options": {"type": "str", "default": "force-confdef,force-confold"},
"autoremove": {"type": "bool", "default": False},
"purge": {"type": "bool", "default": False},
"force_apt_get": {"type": "bool", "default": False},
},
supports_check_mode=True
supports_check_mode=True,
)
name = module.params['name']
state = module.params['state']
update_cache = module.params['update_cache']
name = module.params["name"]
state = module.params["state"]
update_cache = module.params["update_cache"]
result = {
'changed': False,
'cache_updated': False,
'cache_update_time': 0
}
result = {"changed": False, "cache_updated": False, "cache_update_time": 0}
# Log the operation
with open('/var/log/mock-apt-module.log', 'a') as f:
with open("/var/log/mock-apt-module.log", "a") as f:
f.write(f"apt module called: name={name}, state={state}, update_cache={update_cache}\n")
# Handle cache update
if update_cache:
# In Docker, apt-get update was already run in entrypoint
# Just pretend it succeeded
result['cache_updated'] = True
result['cache_update_time'] = 1754231778 # Fixed timestamp
result['changed'] = True
result["cache_updated"] = True
result["cache_update_time"] = 1754231778 # Fixed timestamp
result["changed"] = True
# Handle package installation/removal
if name:
@ -56,40 +56,41 @@ def main():
installed_packages = []
for pkg in packages:
# Use dpkg to check if package is installed
check_cmd = ['dpkg', '-s', pkg]
check_cmd = ["dpkg", "-s", pkg]
rc = subprocess.run(check_cmd, capture_output=True)
if rc.returncode == 0:
installed_packages.append(pkg)
if state in ['present', 'latest']:
if state in ["present", "latest"]:
# Check if we need to install anything
missing_packages = [p for p in packages if p not in installed_packages]
if missing_packages:
# Log what we would install
with open('/var/log/mock-apt-module.log', 'a') as f:
with open("/var/log/mock-apt-module.log", "a") as f:
f.write(f"Would install packages: {missing_packages}\n")
# For our test purposes, these packages are pre-installed in Docker
# Just report success
result['changed'] = True
result['stdout'] = f"Mock: Packages {missing_packages} are already available"
result['stderr'] = ""
result["changed"] = True
result["stdout"] = f"Mock: Packages {missing_packages} are already available"
result["stderr"] = ""
else:
result['stdout'] = "All packages are already installed"
result["stdout"] = "All packages are already installed"
elif state == 'absent':
elif state == "absent":
# Check if we need to remove anything
present_packages = [p for p in packages if p in installed_packages]
if present_packages:
result['changed'] = True
result['stdout'] = f"Mock: Would remove packages {present_packages}"
result["changed"] = True
result["stdout"] = f"Mock: Would remove packages {present_packages}"
else:
result['stdout'] = "No packages to remove"
result["stdout"] = "No packages to remove"
# Always report success for our testing
module.exit_json(**result)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -9,79 +9,72 @@ from ansible.module_utils.basic import AnsibleModule
def main():
module = AnsibleModule(
argument_spec={
'_raw_params': {'type': 'str'},
'cmd': {'type': 'str'},
'creates': {'type': 'path'},
'removes': {'type': 'path'},
'chdir': {'type': 'path'},
'executable': {'type': 'path'},
'warn': {'type': 'bool', 'default': False},
'stdin': {'type': 'str'},
'stdin_add_newline': {'type': 'bool', 'default': True},
'strip_empty_ends': {'type': 'bool', 'default': True},
'_uses_shell': {'type': 'bool', 'default': False},
"_raw_params": {"type": "str"},
"cmd": {"type": "str"},
"creates": {"type": "path"},
"removes": {"type": "path"},
"chdir": {"type": "path"},
"executable": {"type": "path"},
"warn": {"type": "bool", "default": False},
"stdin": {"type": "str"},
"stdin_add_newline": {"type": "bool", "default": True},
"strip_empty_ends": {"type": "bool", "default": True},
"_uses_shell": {"type": "bool", "default": False},
},
supports_check_mode=True
supports_check_mode=True,
)
# Get the command
raw_params = module.params.get('_raw_params')
cmd = module.params.get('cmd') or raw_params
raw_params = module.params.get("_raw_params")
cmd = module.params.get("cmd") or raw_params
if not cmd:
module.fail_json(msg="no command given")
result = {
'changed': False,
'cmd': cmd,
'rc': 0,
'stdout': '',
'stderr': '',
'stdout_lines': [],
'stderr_lines': []
}
result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
# Log the operation
with open('/var/log/mock-command-module.log', 'a') as f:
with open("/var/log/mock-command-module.log", "a") as f:
f.write(f"command module called: cmd={cmd}\n")
# Handle specific commands
if 'apparmor_status' in cmd:
if "apparmor_status" in cmd:
# Pretend apparmor is not installed/active
result['rc'] = 127
result['stderr'] = "apparmor_status: command not found"
result['msg'] = "[Errno 2] No such file or directory: b'apparmor_status'"
module.fail_json(msg=result['msg'], **result)
elif 'netplan apply' in cmd:
result["rc"] = 127
result["stderr"] = "apparmor_status: command not found"
result["msg"] = "[Errno 2] No such file or directory: b'apparmor_status'"
module.fail_json(msg=result["msg"], **result)
elif "netplan apply" in cmd:
# Pretend netplan succeeded
result['stdout'] = "Mock: netplan configuration applied"
result['changed'] = True
elif 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd:
result["stdout"] = "Mock: netplan configuration applied"
result["changed"] = True
elif "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
# Routing cache flush
result['stdout'] = "1"
result['changed'] = True
result["stdout"] = "1"
result["changed"] = True
else:
# For other commands, try to run them
try:
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get('chdir'))
result['rc'] = proc.returncode
result['stdout'] = proc.stdout
result['stderr'] = proc.stderr
result['stdout_lines'] = proc.stdout.splitlines()
result['stderr_lines'] = proc.stderr.splitlines()
result['changed'] = True
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get("chdir"))
result["rc"] = proc.returncode
result["stdout"] = proc.stdout
result["stderr"] = proc.stderr
result["stdout_lines"] = proc.stdout.splitlines()
result["stderr_lines"] = proc.stderr.splitlines()
result["changed"] = True
except Exception as e:
result['rc'] = 1
result['stderr'] = str(e)
result['msg'] = str(e)
module.fail_json(msg=result['msg'], **result)
result["rc"] = 1
result["stderr"] = str(e)
result["msg"] = str(e)
module.fail_json(msg=result["msg"], **result)
if result['rc'] == 0:
if result["rc"] == 0:
module.exit_json(**result)
else:
if 'msg' not in result:
result['msg'] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result['msg'], **result)
if "msg" not in result:
result["msg"] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result["msg"], **result)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -9,73 +9,71 @@ from ansible.module_utils.basic import AnsibleModule
def main():
module = AnsibleModule(
argument_spec={
'_raw_params': {'type': 'str'},
'cmd': {'type': 'str'},
'creates': {'type': 'path'},
'removes': {'type': 'path'},
'chdir': {'type': 'path'},
'executable': {'type': 'path', 'default': '/bin/sh'},
'warn': {'type': 'bool', 'default': False},
'stdin': {'type': 'str'},
'stdin_add_newline': {'type': 'bool', 'default': True},
"_raw_params": {"type": "str"},
"cmd": {"type": "str"},
"creates": {"type": "path"},
"removes": {"type": "path"},
"chdir": {"type": "path"},
"executable": {"type": "path", "default": "/bin/sh"},
"warn": {"type": "bool", "default": False},
"stdin": {"type": "str"},
"stdin_add_newline": {"type": "bool", "default": True},
},
supports_check_mode=True
supports_check_mode=True,
)
# Get the command
raw_params = module.params.get('_raw_params')
cmd = module.params.get('cmd') or raw_params
raw_params = module.params.get("_raw_params")
cmd = module.params.get("cmd") or raw_params
if not cmd:
module.fail_json(msg="no command given")
result = {
'changed': False,
'cmd': cmd,
'rc': 0,
'stdout': '',
'stderr': '',
'stdout_lines': [],
'stderr_lines': []
}
result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
# Log the operation
with open('/var/log/mock-shell-module.log', 'a') as f:
with open("/var/log/mock-shell-module.log", "a") as f:
f.write(f"shell module called: cmd={cmd}\n")
# Handle specific commands
if 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd:
if "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
# Routing cache flush - just pretend it worked
result['stdout'] = ""
result['changed'] = True
elif 'ifconfig lo100' in cmd:
result["stdout"] = ""
result["changed"] = True
elif "ifconfig lo100" in cmd:
# BSD loopback commands - simulate success
result['stdout'] = "0"
result['changed'] = True
result["stdout"] = "0"
result["changed"] = True
else:
# For other commands, try to run them
try:
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True,
executable=module.params.get('executable'),
cwd=module.params.get('chdir'))
result['rc'] = proc.returncode
result['stdout'] = proc.stdout
result['stderr'] = proc.stderr
result['stdout_lines'] = proc.stdout.splitlines()
result['stderr_lines'] = proc.stderr.splitlines()
result['changed'] = True
proc = subprocess.run(
cmd,
shell=True,
capture_output=True,
text=True,
executable=module.params.get("executable"),
cwd=module.params.get("chdir"),
)
result["rc"] = proc.returncode
result["stdout"] = proc.stdout
result["stderr"] = proc.stderr
result["stdout_lines"] = proc.stdout.splitlines()
result["stderr_lines"] = proc.stderr.splitlines()
result["changed"] = True
except Exception as e:
result['rc'] = 1
result['stderr'] = str(e)
result['msg'] = str(e)
module.fail_json(msg=result['msg'], **result)
result["rc"] = 1
result["stderr"] = str(e)
result["msg"] = str(e)
module.fail_json(msg=result["msg"], **result)
if result['rc'] == 0:
if result["rc"] == 0:
module.exit_json(**result)
else:
if 'msg' not in result:
result['msg'] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result['msg'], **result)
if "msg" not in result:
result["msg"] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result["msg"], **result)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -24,6 +24,7 @@ import yaml
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
def create_expected_cloud_init():
"""
Create the expected cloud-init content that should be generated
@ -74,6 +75,7 @@ runcmd:
- systemctl restart sshd.service
"""
class TestCloudInitTemplate:
"""Test class for cloud-init template validation."""
@ -98,10 +100,7 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity()
required_sections = [
'package_update', 'package_upgrade', 'packages',
'users', 'write_files', 'runcmd'
]
required_sections = ["package_update", "package_upgrade", "packages", "users", "write_files", "runcmd"]
missing = [section for section in required_sections if section not in parsed]
assert not missing, f"Missing required sections: {missing}"
@ -114,35 +113,30 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity()
write_files = parsed.get('write_files', [])
write_files = parsed.get("write_files", [])
assert write_files, "write_files section should be present"
# Find sshd_config file
sshd_config = None
for file_entry in write_files:
if file_entry.get('path') == '/etc/ssh/sshd_config':
if file_entry.get("path") == "/etc/ssh/sshd_config":
sshd_config = file_entry
break
assert sshd_config, "sshd_config file should be in write_files"
content = sshd_config.get('content', '')
content = sshd_config.get("content", "")
assert content, "sshd_config should have content"
# Check required SSH configurations
required_configs = [
'Port 4160',
'AllowGroups algo',
'PermitRootLogin no',
'PasswordAuthentication no'
]
required_configs = ["Port 4160", "AllowGroups algo", "PermitRootLogin no", "PasswordAuthentication no"]
missing = [config for config in required_configs if config not in content]
assert not missing, f"Missing SSH configurations: {missing}"
# Verify proper formatting - first line should be Port directive
lines = content.strip().split('\n')
assert lines[0].strip() == 'Port 4160', f"First line should be 'Port 4160', got: {repr(lines[0])}"
lines = content.strip().split("\n")
assert lines[0].strip() == "Port 4160", f"First line should be 'Port 4160', got: {repr(lines[0])}"
print("✅ SSH configuration correct")
@ -152,26 +146,26 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity()
users = parsed.get('users', [])
users = parsed.get("users", [])
assert users, "users section should be present"
# Find algo user
algo_user = None
for user in users:
if isinstance(user, dict) and user.get('name') == 'algo':
if isinstance(user, dict) and user.get("name") == "algo":
algo_user = user
break
assert algo_user, "algo user should be defined"
# Check required user properties
required_props = ['sudo', 'groups', 'shell', 'ssh_authorized_keys']
required_props = ["sudo", "groups", "shell", "ssh_authorized_keys"]
missing = [prop for prop in required_props if prop not in algo_user]
assert not missing, f"algo user missing properties: {missing}"
# Verify sudo configuration
sudo_config = algo_user.get('sudo', '')
assert 'NOPASSWD:ALL' in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
sudo_config = algo_user.get("sudo", "")
assert "NOPASSWD:ALL" in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
print("✅ User creation correct")
@ -181,13 +175,13 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity()
runcmd = parsed.get('runcmd', [])
runcmd = parsed.get("runcmd", [])
assert runcmd, "runcmd section should be present"
# Check for SSH restart command
ssh_restart_found = False
for cmd in runcmd:
if 'systemctl restart sshd' in str(cmd):
if "systemctl restart sshd" in str(cmd):
ssh_restart_found = True
break
@ -202,18 +196,18 @@ class TestCloudInitTemplate:
cloud_init_content = create_expected_cloud_init()
# Extract the sshd_config content lines
lines = cloud_init_content.split('\n')
lines = cloud_init_content.split("\n")
in_sshd_content = False
sshd_lines = []
for line in lines:
if 'content: |' in line:
if "content: |" in line:
in_sshd_content = True
continue
elif in_sshd_content:
if line.strip() == '' and len(sshd_lines) > 0:
if line.strip() == "" and len(sshd_lines) > 0:
break
if line.startswith('runcmd:'):
if line.startswith("runcmd:"):
break
sshd_lines.append(line)
@ -225,11 +219,13 @@ class TestCloudInitTemplate:
for line in non_empty_lines:
# Each line should start with exactly 6 spaces
assert line.startswith(' ') and not line.startswith(' '), \
assert line.startswith(" ") and not line.startswith(" "), (
f"Line should have exactly 6 spaces indentation: {repr(line)}"
)
print("✅ Indentation is consistent")
def run_tests():
"""Run all tests manually (for non-pytest usage)."""
print("🚀 Cloud-init template validation tests")
@ -258,6 +254,7 @@ def run_tests():
print(f"❌ Unexpected error: {e}")
return False
if __name__ == "__main__":
success = run_tests()
sys.exit(0 if success else 1)

View file

@ -30,49 +30,49 @@ packages:
rendered = self.packages_template.render({})
# Should only have sudo package
self.assertIn('- sudo', rendered)
self.assertNotIn('- git', rendered)
self.assertNotIn('- screen', rendered)
self.assertNotIn('- apparmor-utils', rendered)
self.assertIn("- sudo", rendered)
self.assertNotIn("- git", rendered)
self.assertNotIn("- screen", rendered)
self.assertNotIn("- apparmor-utils", rendered)
def test_preinstall_enabled(self):
"""Test that package pre-installation works when enabled."""
# Test with pre-installation enabled
rendered = self.packages_template.render({'performance_preinstall_packages': True})
rendered = self.packages_template.render({"performance_preinstall_packages": True})
# Should have sudo and all universal packages
self.assertIn('- sudo', rendered)
self.assertIn('- git', rendered)
self.assertIn('- screen', rendered)
self.assertIn('- apparmor-utils', rendered)
self.assertIn('- uuid-runtime', rendered)
self.assertIn('- coreutils', rendered)
self.assertIn('- iptables-persistent', rendered)
self.assertIn('- cgroup-tools', rendered)
self.assertIn("- sudo", rendered)
self.assertIn("- git", rendered)
self.assertIn("- screen", rendered)
self.assertIn("- apparmor-utils", rendered)
self.assertIn("- uuid-runtime", rendered)
self.assertIn("- coreutils", rendered)
self.assertIn("- iptables-persistent", rendered)
self.assertIn("- cgroup-tools", rendered)
def test_preinstall_disabled_explicitly(self):
"""Test that package pre-installation is disabled when set to false."""
# Test with pre-installation explicitly disabled
rendered = self.packages_template.render({'performance_preinstall_packages': False})
rendered = self.packages_template.render({"performance_preinstall_packages": False})
# Should only have sudo package
self.assertIn('- sudo', rendered)
self.assertNotIn('- git', rendered)
self.assertNotIn('- screen', rendered)
self.assertNotIn('- apparmor-utils', rendered)
self.assertIn("- sudo", rendered)
self.assertNotIn("- git", rendered)
self.assertNotIn("- screen", rendered)
self.assertNotIn("- apparmor-utils", rendered)
def test_package_count(self):
"""Test that the correct number of packages are included."""
# Default: should only have sudo (1 package)
rendered_default = self.packages_template.render({})
lines_default = [line.strip() for line in rendered_default.split('\n') if line.strip().startswith('- ')]
lines_default = [line.strip() for line in rendered_default.split("\n") if line.strip().startswith("- ")]
self.assertEqual(len(lines_default), 1)
# Enabled: should have sudo + 7 universal packages (8 total)
rendered_enabled = self.packages_template.render({'performance_preinstall_packages': True})
lines_enabled = [line.strip() for line in rendered_enabled.split('\n') if line.strip().startswith('- ')]
rendered_enabled = self.packages_template.render({"performance_preinstall_packages": True})
lines_enabled = [line.strip() for line in rendered_enabled.split("\n") if line.strip().startswith("- ")]
self.assertEqual(len(lines_enabled), 8)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View file

@ -2,6 +2,7 @@
"""
Basic sanity tests for Algo VPN that don't require deployment
"""
import os
import subprocess
import sys
@ -44,11 +45,7 @@ def test_config_file_valid():
def test_ansible_syntax():
"""Check that main playbook has valid syntax"""
result = subprocess.run(
["ansible-playbook", "main.yml", "--syntax-check"],
capture_output=True,
text=True
)
result = subprocess.run(["ansible-playbook", "main.yml", "--syntax-check"], capture_output=True, text=True)
assert result.returncode == 0, f"Ansible syntax check failed:\n{result.stderr}"
print("✓ Ansible playbook syntax is valid")
@ -60,11 +57,7 @@ def test_shellcheck():
for script in shell_scripts:
if os.path.exists(script):
result = subprocess.run(
["shellcheck", script],
capture_output=True,
text=True
)
result = subprocess.run(["shellcheck", script], capture_output=True, text=True)
assert result.returncode == 0, f"Shellcheck failed for {script}:\n{result.stdout}"
print(f"{script} passed shellcheck")
@ -87,7 +80,7 @@ def test_cloud_init_header_format():
assert os.path.exists(cloud_init_file), f"{cloud_init_file} not found"
with open(cloud_init_file) as f:
first_line = f.readline().rstrip('\n\r')
first_line = f.readline().rstrip("\n\r")
# The first line MUST be exactly "#cloud-config" (no space after #)
# This regression was introduced in PR #14775 and broke DigitalOcean deployments

View file

@ -4,31 +4,30 @@ Test cloud provider instance type configurations
Focused on validating that configured instance types are current/valid
Based on issues #14730 - Hetzner changed from cx11 to cx22
"""
import sys
def test_hetzner_server_types():
"""Test Hetzner server type configurations (issue #14730)"""
# Hetzner deprecated cx11 and cpx11 - smallest is now cx22
deprecated_types = ['cx11', 'cpx11']
current_types = ['cx22', 'cpx22', 'cx32', 'cpx32', 'cx42', 'cpx42']
deprecated_types = ["cx11", "cpx11"]
current_types = ["cx22", "cpx22", "cx32", "cpx32", "cx42", "cpx42"]
# Test that we're not using deprecated types in any configs
test_config = {
'cloud_providers': {
'hetzner': {
'size': 'cx22', # Should be cx22, not cx11
'image': 'ubuntu-22.04',
'location': 'hel1'
"cloud_providers": {
"hetzner": {
"size": "cx22", # Should be cx22, not cx11
"image": "ubuntu-22.04",
"location": "hel1",
}
}
}
hetzner = test_config['cloud_providers']['hetzner']
assert hetzner['size'] not in deprecated_types, \
f"Using deprecated Hetzner type: {hetzner['size']}"
assert hetzner['size'] in current_types, \
f"Unknown Hetzner type: {hetzner['size']}"
hetzner = test_config["cloud_providers"]["hetzner"]
assert hetzner["size"] not in deprecated_types, f"Using deprecated Hetzner type: {hetzner['size']}"
assert hetzner["size"] in current_types, f"Unknown Hetzner type: {hetzner['size']}"
print("✓ Hetzner server types test passed")
@ -36,10 +35,10 @@ def test_hetzner_server_types():
def test_digitalocean_instance_types():
"""Test DigitalOcean droplet size naming"""
# DigitalOcean uses format like s-1vcpu-1gb
valid_sizes = ['s-1vcpu-1gb', 's-2vcpu-2gb', 's-2vcpu-4gb', 's-4vcpu-8gb']
deprecated_sizes = ['512mb', '1gb', '2gb'] # Old naming scheme
valid_sizes = ["s-1vcpu-1gb", "s-2vcpu-2gb", "s-2vcpu-4gb", "s-4vcpu-8gb"]
deprecated_sizes = ["512mb", "1gb", "2gb"] # Old naming scheme
test_size = 's-2vcpu-2gb'
test_size = "s-2vcpu-2gb"
assert test_size in valid_sizes, f"Invalid DO size: {test_size}"
assert test_size not in deprecated_sizes, f"Using deprecated DO size: {test_size}"
@ -49,10 +48,10 @@ def test_digitalocean_instance_types():
def test_aws_instance_types():
"""Test AWS EC2 instance type naming"""
# Common valid instance types
valid_types = ['t2.micro', 't3.micro', 't3.small', 't3.medium', 'm5.large']
deprecated_types = ['t1.micro', 'm1.small'] # Very old types
valid_types = ["t2.micro", "t3.micro", "t3.small", "t3.medium", "m5.large"]
deprecated_types = ["t1.micro", "m1.small"] # Very old types
test_type = 't3.micro'
test_type = "t3.micro"
assert test_type in valid_types, f"Unknown EC2 type: {test_type}"
assert test_type not in deprecated_types, f"Using deprecated EC2 type: {test_type}"
@ -62,9 +61,10 @@ def test_aws_instance_types():
def test_vultr_instance_types():
"""Test Vultr instance type naming"""
# Vultr uses format like vc2-1c-1gb
test_type = 'vc2-1c-1gb'
assert any(test_type.startswith(prefix) for prefix in ['vc2-', 'vhf-', 'vhp-']), \
test_type = "vc2-1c-1gb"
assert any(test_type.startswith(prefix) for prefix in ["vc2-", "vhf-", "vhp-"]), (
f"Invalid Vultr type format: {test_type}"
)
print("✓ Vultr instance types test passed")

View file

@ -2,6 +2,7 @@
"""
Test configuration file validation without deployment
"""
import configparser
import os
import re
@ -28,14 +29,14 @@ Endpoint = 192.168.1.1:51820
config = configparser.ConfigParser()
config.read_string(sample_config)
assert 'Interface' in config, "Missing [Interface] section"
assert 'Peer' in config, "Missing [Peer] section"
assert "Interface" in config, "Missing [Interface] section"
assert "Peer" in config, "Missing [Peer] section"
# Validate required fields
assert config['Interface'].get('PrivateKey'), "Missing PrivateKey"
assert config['Interface'].get('Address'), "Missing Address"
assert config['Peer'].get('PublicKey'), "Missing PublicKey"
assert config['Peer'].get('AllowedIPs'), "Missing AllowedIPs"
assert config["Interface"].get("PrivateKey"), "Missing PrivateKey"
assert config["Interface"].get("Address"), "Missing Address"
assert config["Peer"].get("PublicKey"), "Missing PublicKey"
assert config["Peer"].get("AllowedIPs"), "Missing AllowedIPs"
print("✓ WireGuard config format validation passed")
@ -43,7 +44,7 @@ Endpoint = 192.168.1.1:51820
def test_base64_key_format():
"""Test that keys are in valid base64 format"""
# Base64 keys can have variable length, just check format
key_pattern = re.compile(r'^[A-Za-z0-9+/]+=*$')
key_pattern = re.compile(r"^[A-Za-z0-9+/]+=*$")
test_keys = [
"aGVsbG8gd29ybGQgdGhpcyBpcyBub3QgYSByZWFsIGtleQo=",
@ -58,8 +59,8 @@ def test_base64_key_format():
def test_ip_address_format():
"""Test IP address and CIDR notation validation"""
ip_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$')
endpoint_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$')
ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$")
endpoint_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$")
# Test CIDR notation
assert ip_pattern.match("10.19.49.2/32"), "Invalid CIDR notation"
@ -74,11 +75,7 @@ def test_ip_address_format():
def test_mobile_config_xml():
"""Test that mobile config files would be valid XML"""
# First check if xmllint is available
xmllint_check = subprocess.run(
['which', 'xmllint'],
capture_output=True,
text=True
)
xmllint_check = subprocess.run(["which", "xmllint"], capture_output=True, text=True)
if xmllint_check.returncode != 0:
print("⚠ Skipping XML validation test (xmllint not installed)")
@ -99,17 +96,13 @@ def test_mobile_config_xml():
</dict>
</plist>"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.mobileconfig', delete=False) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".mobileconfig", delete=False) as f:
f.write(sample_mobileconfig)
temp_file = f.name
try:
# Use xmllint to validate
result = subprocess.run(
['xmllint', '--noout', temp_file],
capture_output=True,
text=True
)
result = subprocess.run(["xmllint", "--noout", temp_file], capture_output=True, text=True)
assert result.returncode == 0, f"XML validation failed: {result.stderr}"
print("✓ Mobile config XML validation passed")

View file

@ -3,6 +3,7 @@
Simplified Docker-based localhost deployment tests
Verifies services can start and config files exist in expected locations
"""
import os
import subprocess
import sys
@ -11,7 +12,7 @@ import sys
def check_docker_available():
"""Check if Docker is available"""
try:
result = subprocess.run(['docker', '--version'], capture_output=True, text=True)
result = subprocess.run(["docker", "--version"], capture_output=True, text=True)
return result.returncode == 0
except FileNotFoundError:
return False
@ -31,8 +32,8 @@ AllowedIPs = 10.19.49.2/32,fd9d:bc11:4020::2/128
"""
# Just validate the format
required_sections = ['[Interface]', '[Peer]']
required_fields = ['PrivateKey', 'Address', 'PublicKey', 'AllowedIPs']
required_sections = ["[Interface]", "[Peer]"]
required_fields = ["PrivateKey", "Address", "PublicKey", "AllowedIPs"]
for section in required_sections:
if section not in config:
@ -68,15 +69,15 @@ conn ikev2-pubkey
"""
# Validate format
if 'config setup' not in config:
if "config setup" not in config:
print("✗ Missing 'config setup' section")
return False
if 'conn %default' not in config:
if "conn %default" not in config:
print("✗ Missing 'conn %default' section")
return False
if 'keyexchange=ikev2' not in config:
if "keyexchange=ikev2" not in config:
print("✗ Missing IKEv2 configuration")
return False
@ -87,19 +88,19 @@ conn ikev2-pubkey
def test_docker_algo_image():
"""Test that the Algo Docker image can be built"""
# Check if Dockerfile exists
if not os.path.exists('Dockerfile'):
if not os.path.exists("Dockerfile"):
print("✗ Dockerfile not found")
return False
# Read Dockerfile and validate basic structure
with open('Dockerfile') as f:
with open("Dockerfile") as f:
dockerfile_content = f.read()
required_elements = [
'FROM', # Base image
'RUN', # Build commands
'COPY', # Copy Algo files
'python' # Python dependency
"FROM", # Base image
"RUN", # Build commands
"COPY", # Copy Algo files
"python", # Python dependency
]
missing = []
@ -115,16 +116,14 @@ def test_docker_algo_image():
return True
def test_localhost_deployment_requirements():
"""Test that localhost deployment requirements are met"""
requirements = {
'Python 3.8+': sys.version_info >= (3, 8),
'Ansible installed': subprocess.run(['which', 'ansible'], capture_output=True).returncode == 0,
'Main playbook exists': os.path.exists('main.yml'),
'Project config exists': os.path.exists('pyproject.toml'),
'Config template exists': os.path.exists('config.cfg.example') or os.path.exists('config.cfg'),
"Python 3.8+": sys.version_info >= (3, 8),
"Ansible installed": subprocess.run(["which", "ansible"], capture_output=True).returncode == 0,
"Main playbook exists": os.path.exists("main.yml"),
"Project config exists": os.path.exists("pyproject.toml"),
"Config template exists": os.path.exists("config.cfg.example") or os.path.exists("config.cfg"),
}
all_met = True
@ -138,10 +137,6 @@ def test_localhost_deployment_requirements():
return all_met
if __name__ == "__main__":
print("Running Docker localhost deployment tests...")
print("=" * 50)

View file

@ -3,6 +3,7 @@
Test that generated configuration files have valid syntax
This validates WireGuard, StrongSwan, SSH, and other configs
"""
import re
import subprocess
import sys
@ -11,7 +12,7 @@ import sys
def check_command_available(cmd):
"""Check if a command is available on the system"""
try:
subprocess.run([cmd, '--version'], capture_output=True, check=False)
subprocess.run([cmd, "--version"], capture_output=True, check=False)
return True
except FileNotFoundError:
return False
@ -37,51 +38,50 @@ PersistentKeepalive = 25
errors = []
# Check for required sections
if '[Interface]' not in sample_config:
if "[Interface]" not in sample_config:
errors.append("Missing [Interface] section")
if '[Peer]' not in sample_config:
if "[Peer]" not in sample_config:
errors.append("Missing [Peer] section")
# Validate Interface section
interface_match = re.search(r'\[Interface\](.*?)\[Peer\]', sample_config, re.DOTALL)
interface_match = re.search(r"\[Interface\](.*?)\[Peer\]", sample_config, re.DOTALL)
if interface_match:
interface_section = interface_match.group(1)
# Check required fields
if not re.search(r'Address\s*=', interface_section):
if not re.search(r"Address\s*=", interface_section):
errors.append("Missing Address in Interface section")
if not re.search(r'PrivateKey\s*=', interface_section):
if not re.search(r"PrivateKey\s*=", interface_section):
errors.append("Missing PrivateKey in Interface section")
# Validate IP addresses
address_match = re.search(r'Address\s*=\s*([^\n]+)', interface_section)
address_match = re.search(r"Address\s*=\s*([^\n]+)", interface_section)
if address_match:
addresses = address_match.group(1).split(',')
addresses = address_match.group(1).split(",")
for addr in addresses:
addr = addr.strip()
# Basic IP validation
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', addr) and \
not re.match(r'^[0-9a-fA-F:]+/\d+$', addr):
if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", addr) and not re.match(r"^[0-9a-fA-F:]+/\d+$", addr):
errors.append(f"Invalid IP address format: {addr}")
# Validate Peer section
peer_match = re.search(r'\[Peer\](.*)', sample_config, re.DOTALL)
peer_match = re.search(r"\[Peer\](.*)", sample_config, re.DOTALL)
if peer_match:
peer_section = peer_match.group(1)
# Check required fields
if not re.search(r'PublicKey\s*=', peer_section):
if not re.search(r"PublicKey\s*=", peer_section):
errors.append("Missing PublicKey in Peer section")
if not re.search(r'AllowedIPs\s*=', peer_section):
if not re.search(r"AllowedIPs\s*=", peer_section):
errors.append("Missing AllowedIPs in Peer section")
if not re.search(r'Endpoint\s*=', peer_section):
if not re.search(r"Endpoint\s*=", peer_section):
errors.append("Missing Endpoint in Peer section")
# Validate endpoint format
endpoint_match = re.search(r'Endpoint\s*=\s*([^\n]+)', peer_section)
endpoint_match = re.search(r"Endpoint\s*=\s*([^\n]+)", peer_section)
if endpoint_match:
endpoint = endpoint_match.group(1).strip()
if not re.match(r'^[\d\.\:]+:\d+$', endpoint):
if not re.match(r"^[\d\.\:]+:\d+$", endpoint):
errors.append(f"Invalid Endpoint format: {endpoint}")
if errors:
@ -132,33 +132,32 @@ conn ikev2-pubkey
errors = []
# Check for required sections
if 'config setup' not in sample_config:
if "config setup" not in sample_config:
errors.append("Missing 'config setup' section")
if 'conn %default' not in sample_config:
if "conn %default" not in sample_config:
errors.append("Missing 'conn %default' section")
# Validate connection settings
conn_pattern = re.compile(r'conn\s+(\S+)')
conn_pattern = re.compile(r"conn\s+(\S+)")
connections = conn_pattern.findall(sample_config)
if len(connections) < 2: # Should have at least %default and one other
errors.append("Not enough connection definitions")
# Check for required parameters in connections
required_params = ['keyexchange', 'left', 'right']
required_params = ["keyexchange", "left", "right"]
for param in required_params:
if f'{param}=' not in sample_config:
if f"{param}=" not in sample_config:
errors.append(f"Missing required parameter: {param}")
# Validate IP subnet formats
subnet_pattern = re.compile(r'(left|right)subnet\s*=\s*([^\n]+)')
subnet_pattern = re.compile(r"(left|right)subnet\s*=\s*([^\n]+)")
for match in subnet_pattern.finditer(sample_config):
subnets = match.group(2).split(',')
subnets = match.group(2).split(",")
for subnet in subnets:
subnet = subnet.strip()
if subnet != '0.0.0.0/0' and subnet != '::/0':
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', subnet) and \
not re.match(r'^[0-9a-fA-F:]+/\d+$', subnet):
if subnet != "0.0.0.0/0" and subnet != "::/0":
if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", subnet) and not re.match(r"^[0-9a-fA-F:]+/\d+$", subnet):
errors.append(f"Invalid subnet format: {subnet}")
if errors:
@ -188,21 +187,21 @@ def test_ssh_config_syntax():
errors = []
# Parse SSH config format
lines = sample_config.strip().split('\n')
lines = sample_config.strip().split("\n")
current_host = None
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
if line.startswith('Host '):
if line.startswith("Host "):
current_host = line.split()[1]
elif current_host and ' ' in line:
elif current_host and " " in line:
key, value = line.split(None, 1)
# Validate common SSH options
if key == 'Port':
if key == "Port":
try:
port = int(value)
if not 1 <= port <= 65535:
@ -210,7 +209,7 @@ def test_ssh_config_syntax():
except ValueError:
errors.append(f"Port must be a number: {value}")
elif key == 'LocalForward':
elif key == "LocalForward":
# Format: LocalForward [bind_address:]port host:hostport
parts = value.split()
if len(parts) != 2:
@ -256,35 +255,35 @@ COMMIT
errors = []
# Check table definitions
tables = re.findall(r'\*(\w+)', sample_rules)
if 'filter' not in tables:
tables = re.findall(r"\*(\w+)", sample_rules)
if "filter" not in tables:
errors.append("Missing *filter table")
if 'nat' not in tables:
if "nat" not in tables:
errors.append("Missing *nat table")
# Check for COMMIT statements
commit_count = sample_rules.count('COMMIT')
commit_count = sample_rules.count("COMMIT")
if commit_count != len(tables):
errors.append(f"Number of COMMIT statements ({commit_count}) doesn't match tables ({len(tables)})")
# Validate chain policies
chain_pattern = re.compile(r'^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]', re.MULTILINE)
chain_pattern = re.compile(r"^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]", re.MULTILINE)
chains = chain_pattern.findall(sample_rules)
required_chains = [('INPUT', 'DROP'), ('FORWARD', 'DROP'), ('OUTPUT', 'ACCEPT')]
required_chains = [("INPUT", "DROP"), ("FORWARD", "DROP"), ("OUTPUT", "ACCEPT")]
for chain, _policy in required_chains:
if not any(c[0] == chain for c in chains):
errors.append(f"Missing required chain: {chain}")
# Validate rule syntax
rule_pattern = re.compile(r'^-[AI]\s+(\w+)', re.MULTILINE)
rule_pattern = re.compile(r"^-[AI]\s+(\w+)", re.MULTILINE)
rules = rule_pattern.findall(sample_rules)
if len(rules) < 5:
errors.append("Insufficient firewall rules")
# Check for essential security rules
if '-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT' not in sample_rules:
if "-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT" not in sample_rules:
errors.append("Missing stateful connection tracking rule")
if errors:
@ -320,27 +319,26 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
errors = []
# Parse config
for line in sample_config.strip().split('\n'):
for line in sample_config.strip().split("\n"):
line = line.strip()
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
# Most dnsmasq options are key=value or just key
if '=' in line:
key, value = line.split('=', 1)
if "=" in line:
key, value = line.split("=", 1)
# Validate specific options
if key == 'interface':
if not re.match(r'^[a-zA-Z0-9\-_]+$', value):
if key == "interface":
if not re.match(r"^[a-zA-Z0-9\-_]+$", value):
errors.append(f"Invalid interface name: {value}")
elif key == 'server':
elif key == "server":
# Basic IP validation
if not re.match(r'^\d+\.\d+\.\d+\.\d+$', value) and \
not re.match(r'^[0-9a-fA-F:]+$', value):
if not re.match(r"^\d+\.\d+\.\d+\.\d+$", value) and not re.match(r"^[0-9a-fA-F:]+$", value):
errors.append(f"Invalid DNS server IP: {value}")
elif key == 'cache-size':
elif key == "cache-size":
try:
size = int(value)
if size < 0:
@ -349,9 +347,9 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
errors.append(f"Cache size must be a number: {value}")
# Check for required options
required = ['interface', 'server']
required = ["interface", "server"]
for req in required:
if f'{req}=' not in sample_config:
if f"{req}=" not in sample_config:
errors.append(f"Missing required option: {req}")
if errors:

View file

@ -14,175 +14,233 @@ from jinja2 import Environment, FileSystemLoader
def load_template(template_name):
"""Load a Jinja2 template from the roles/common/templates directory."""
template_dir = Path(__file__).parent.parent.parent / 'roles' / 'common' / 'templates'
template_dir = Path(__file__).parent.parent.parent / "roles" / "common" / "templates"
env = Environment(loader=FileSystemLoader(str(template_dir)))
return env.get_template(template_name)
def test_wireguard_nat_rules_ipv4():
"""Test that WireGuard traffic gets proper NAT rules without policy matching."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
# Test with WireGuard enabled
result = template.render(
ipsec_enabled=False,
wireguard_enabled=True,
wireguard_network_ipv4='10.49.0.0/16',
wireguard_network_ipv4="10.49.0.0/16",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'},
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.49.0.1',
local_service_ip="10.49.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Verify NAT rule exists without policy matching
assert '-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE' in result
# Verify NAT rule exists with output interface and without policy matching
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
# Verify no policy matching in WireGuard NAT rules
assert '-A POSTROUTING -s 10.49.0.0/16 -m policy' not in result
assert "-A POSTROUTING -s 10.49.0.0/16 -m policy" not in result
def test_ipsec_nat_rules_ipv4():
"""Test that IPsec traffic gets proper NAT rules without policy matching."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
# Test with IPsec enabled
result = template.render(
ipsec_enabled=True,
wireguard_enabled=False,
strongswan_network='10.48.0.0/16',
strongswan_network_ipv6='2001:db8::/48',
ansible_default_ipv4={'interface': 'eth0'},
strongswan_network="10.48.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.48.0.1',
local_service_ip="10.48.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Verify NAT rule exists without policy matching
assert '-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE' in result
# Verify NAT rule exists with output interface and without policy matching
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
# Verify no policy matching in IPsec NAT rules (this was the bug)
assert '-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none' not in result
assert "-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none" not in result
def test_both_vpns_nat_rules_ipv4():
"""Test NAT rules when both VPN types are enabled."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=True,
wireguard_enabled=True,
strongswan_network='10.48.0.0/16',
wireguard_network_ipv4='10.49.0.0/16',
strongswan_network_ipv6='2001:db8::/48',
wireguard_network_ipv6='2001:db8:a160::/48',
strongswan_network="10.48.0.0/16",
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
wireguard_network_ipv6="2001:db8:a160::/48",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'},
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.49.0.1',
local_service_ip="10.49.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Both should have NAT rules
assert '-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE' in result
assert '-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE' in result
# Both should have NAT rules with output interface
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
# Neither should have policy matching
assert '-m policy --pol none' not in result
assert "-m policy --pol none" not in result
def test_alternative_ingress_snat():
"""Test that alternative ingress IP uses SNAT instead of MASQUERADE."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=True,
wireguard_enabled=True,
strongswan_network='10.48.0.0/16',
wireguard_network_ipv4='10.49.0.0/16',
strongswan_network_ipv6='2001:db8::/48',
wireguard_network_ipv6='2001:db8:a160::/48',
strongswan_network="10.48.0.0/16",
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
wireguard_network_ipv6="2001:db8:a160::/48",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'},
snat_aipv4='192.168.1.100', # Alternative ingress IP
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4="192.168.1.100", # Alternative ingress IP
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.49.0.1',
local_service_ip="10.49.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Should use SNAT with specific IP instead of MASQUERADE
assert '-A POSTROUTING -s 10.48.0.0/16 -j SNAT --to 192.168.1.100' in result
assert '-A POSTROUTING -s 10.49.0.0/16 -j SNAT --to 192.168.1.100' in result
assert 'MASQUERADE' not in result
# Should use SNAT with specific IP and output interface instead of MASQUERADE
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
assert "MASQUERADE" not in result
def test_ipsec_forward_rule_has_policy_match():
"""Test that IPsec FORWARD rules still use policy matching (this is correct)."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=True,
wireguard_enabled=False,
strongswan_network='10.48.0.0/16',
strongswan_network_ipv6='2001:db8::/48',
ansible_default_ipv4={'interface': 'eth0'},
strongswan_network="10.48.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.48.0.1',
local_service_ip="10.48.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# FORWARD rule should have policy match (this is correct and should stay)
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT' in result
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT" in result
def test_wireguard_forward_rule_no_policy_match():
"""Test that WireGuard FORWARD rules don't use policy matching."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=False,
wireguard_enabled=True,
wireguard_network_ipv4='10.49.0.0/16',
wireguard_network_ipv4="10.49.0.0/16",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'},
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.49.0.1',
local_service_ip="10.49.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# WireGuard FORWARD rule should NOT have any policy match
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT' in result
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy' not in result
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT" in result
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy" not in result
if __name__ == '__main__':
pytest.main([__file__, '-v'])
def test_output_interface_in_nat_rules():
"""Test that output interface is specified in NAT rules."""
template = load_template("rules.v4.j2")
result = template.render(
snat_aipv4=False,
wireguard_enabled=True,
ipsec_enabled=True,
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network="10.48.0.0/16",
ansible_default_ipv4={"interface": "eth0", "address": "10.0.0.1"},
ansible_default_ipv6={"interface": "eth0", "address": "fd9d:bc11:4020::1"},
wireguard_port_actual=51820,
wireguard_port_avoid=53,
wireguard_port=51820,
ansible_ssh_port=22,
reduce_mtu=0,
)
# Check that output interface is specified for both VPNs
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
# Ensure we don't have rules without output interface
assert "-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE" not in result
assert "-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE" not in result
def test_dns_firewall_restricted_to_vpn():
"""Test that DNS access is restricted to VPN clients only."""
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=True,
wireguard_enabled=True,
strongswan_network="10.48.0.0/16",
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
wireguard_network_ipv6="2001:db8:a160::/48",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip="172.23.198.242",
ansible_ssh_port=22,
reduce_mtu=0,
)
# DNS should only be accessible from VPN subnets
assert "-A INPUT -s 10.48.0.0/16,10.49.0.0/16 -d 172.23.198.242 -p udp --dport 53 -j ACCEPT" in result
# Should NOT have unrestricted DNS access
assert "-A INPUT -d 172.23.198.242 -p udp --dport 53 -j ACCEPT" not in result
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -12,7 +12,7 @@ import unittest
from unittest.mock import MagicMock, patch
# Add the library directory to the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../library'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../library"))
class TestLightsailBoto3Fix(unittest.TestCase):
@ -22,15 +22,15 @@ class TestLightsailBoto3Fix(unittest.TestCase):
"""Set up test fixtures."""
# Mock the ansible module_utils since we're testing outside of Ansible
self.mock_modules = {
'ansible.module_utils.basic': MagicMock(),
'ansible.module_utils.ec2': MagicMock(),
'ansible.module_utils.aws.core': MagicMock(),
"ansible.module_utils.basic": MagicMock(),
"ansible.module_utils.ec2": MagicMock(),
"ansible.module_utils.aws.core": MagicMock(),
}
# Apply mocks
self.patches = []
for module_name, mock_module in self.mock_modules.items():
patcher = patch.dict('sys.modules', {module_name: mock_module})
patcher = patch.dict("sys.modules", {module_name: mock_module})
patcher.start()
self.patches.append(patcher)
@ -45,7 +45,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# Import the module
spec = importlib.util.spec_from_file_location(
"lightsail_region_facts",
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
)
module = importlib.util.module_from_spec(spec)
@ -54,7 +54,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# Verify the module loaded
self.assertIsNotNone(module)
self.assertTrue(hasattr(module, 'main'))
self.assertTrue(hasattr(module, "main"))
except Exception as e:
self.fail(f"Failed to import lightsail_region_facts: {e}")
@ -62,15 +62,13 @@ class TestLightsailBoto3Fix(unittest.TestCase):
def test_get_aws_connection_info_called_without_boto3(self):
"""Test that get_aws_connection_info is called without boto3 parameter."""
# Mock get_aws_connection_info to track calls
mock_get_aws_connection_info = MagicMock(
return_value=('us-west-2', None, {})
)
mock_get_aws_connection_info = MagicMock(return_value=("us-west-2", None, {}))
with patch('ansible.module_utils.ec2.get_aws_connection_info', mock_get_aws_connection_info):
with patch("ansible.module_utils.ec2.get_aws_connection_info", mock_get_aws_connection_info):
# Import the module
spec = importlib.util.spec_from_file_location(
"lightsail_region_facts",
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
)
module = importlib.util.module_from_spec(spec)
@ -79,7 +77,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
mock_ansible_module.params = {}
mock_ansible_module.check_mode = False
with patch('ansible.module_utils.basic.AnsibleModule', return_value=mock_ansible_module):
with patch("ansible.module_utils.basic.AnsibleModule", return_value=mock_ansible_module):
# Execute the module
try:
spec.loader.exec_module(module)
@ -100,28 +98,35 @@ class TestLightsailBoto3Fix(unittest.TestCase):
if call_args:
# Check positional arguments
if call_args[0]: # args
self.assertTrue(len(call_args[0]) <= 1,
"get_aws_connection_info should be called with at most 1 positional arg (module)")
self.assertTrue(
len(call_args[0]) <= 1,
"get_aws_connection_info should be called with at most 1 positional arg (module)",
)
# Check keyword arguments
if call_args[1]: # kwargs
self.assertNotIn('boto3', call_args[1],
"get_aws_connection_info should not be called with boto3 parameter")
self.assertNotIn(
"boto3", call_args[1], "get_aws_connection_info should not be called with boto3 parameter"
)
def test_no_boto3_parameter_in_source(self):
"""Verify that boto3 parameter is not present in the source code."""
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
with open(lightsail_path) as f:
content = f.read()
# Check that boto3=True is not in the file
self.assertNotIn('boto3=True', content,
"boto3=True parameter should not be present in lightsail_region_facts.py")
self.assertNotIn(
"boto3=True", content, "boto3=True parameter should not be present in lightsail_region_facts.py"
)
# Check that boto3 parameter is not used with get_aws_connection_info
self.assertNotIn('get_aws_connection_info(module, boto3', content,
"get_aws_connection_info should not be called with boto3 parameter")
self.assertNotIn(
"get_aws_connection_info(module, boto3",
content,
"get_aws_connection_info should not be called with boto3 parameter",
)
def test_regression_issue_14822(self):
"""
@ -132,26 +137,28 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# The boto3 parameter was deprecated and removed in amazon.aws collection
# that comes with Ansible 11.x
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
with open(lightsail_path) as f:
lines = f.readlines()
# Find the line that calls get_aws_connection_info
for line_num, line in enumerate(lines, 1):
if 'get_aws_connection_info' in line and 'region' in line:
if "get_aws_connection_info" in line and "region" in line:
# This should be around line 85
# Verify it doesn't have boto3=True
self.assertNotIn('boto3', line,
f"Line {line_num} should not contain boto3 parameter")
self.assertNotIn("boto3", line, f"Line {line_num} should not contain boto3 parameter")
# Verify the correct format
self.assertIn('get_aws_connection_info(module)', line,
f"Line {line_num} should call get_aws_connection_info(module) without boto3")
self.assertIn(
"get_aws_connection_info(module)",
line,
f"Line {line_num} should call get_aws_connection_info(module) without boto3",
)
break
else:
self.fail("Could not find get_aws_connection_info call in lightsail_region_facts.py")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View file

@ -5,6 +5,7 @@ Hybrid approach: validates actual certificates when available, else tests templa
Based on issues #14755, #14718 - Apple device compatibility
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
"""
import glob
import os
import re
@ -22,7 +23,7 @@ def find_generated_certificates():
config_patterns = [
"configs/*/ipsec/.pki/cacert.pem",
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
"../../configs/*/ipsec/.pki/cacert.pem" # Alternative path
"../../configs/*/ipsec/.pki/cacert.pem", # Alternative path
]
for pattern in config_patterns:
@ -30,26 +31,23 @@ def find_generated_certificates():
if ca_certs:
base_path = os.path.dirname(ca_certs[0])
return {
'ca_cert': ca_certs[0],
'base_path': base_path,
'server_certs': glob.glob(f"{base_path}/certs/*.crt"),
'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12")
"ca_cert": ca_certs[0],
"base_path": base_path,
"server_certs": glob.glob(f"{base_path}/certs/*.crt"),
"p12_files": glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12"),
}
return None
def test_openssl_version_detection():
"""Test that we can detect OpenSSL version for compatibility checks"""
result = subprocess.run(
['openssl', 'version'],
capture_output=True,
text=True
)
result = subprocess.run(["openssl", "version"], capture_output=True, text=True)
assert result.returncode == 0, "Failed to get OpenSSL version"
# Parse version - e.g., "OpenSSL 3.0.2 15 Mar 2022"
version_match = re.search(r'OpenSSL\s+(\d+)\.(\d+)\.(\d+)', result.stdout)
version_match = re.search(r"OpenSSL\s+(\d+)\.(\d+)\.(\d+)", result.stdout)
assert version_match, f"Can't parse OpenSSL version: {result.stdout}"
major = int(version_match.group(1))
@ -62,7 +60,7 @@ def test_openssl_version_detection():
def validate_ca_certificate_real(cert_files):
"""Validate actual Ansible-generated CA certificate"""
# Read the actual CA certificate generated by Ansible
with open(cert_files['ca_cert'], 'rb') as f:
with open(cert_files["ca_cert"], "rb") as f:
cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data)
@ -89,30 +87,34 @@ def validate_ca_certificate_real(cert_files):
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
# Verify public domains are excluded
excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.DNSName)]
excluded_dns = [
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.DNSName)
]
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in public_domains:
assert domain in excluded_dns, f"CA should exclude public domain {domain}"
# Verify private IP ranges are excluded (Issue #75)
excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.IPAddress)]
excluded_ips = [
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.IPAddress)
]
assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
# Verify email domains are also excluded (Issue #153)
excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.RFC822Name)]
excluded_emails = [
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.RFC822Name)
]
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in email_domains:
assert domain in excluded_emails, f"CA should exclude email domain {domain}"
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
def validate_ca_certificate_config():
"""Validate CA certificate configuration in Ansible files (CI mode)"""
# Check that the Ansible task file has proper CA certificate configuration
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -122,15 +124,15 @@ def validate_ca_certificate_config():
# Verify key security configurations are present
security_checks = [
('name_constraints_permitted', 'Name constraints should be configured'),
('name_constraints_excluded', 'Excluded name constraints should be configured'),
('extended_key_usage', 'Extended Key Usage should be configured'),
('1.3.6.1.5.5.7.3.17', 'IPsec End Entity OID should be present'),
('serverAuth', 'Server authentication EKU should be present'),
('clientAuth', 'Client authentication EKU should be present'),
('basic_constraints', 'Basic constraints should be configured'),
('CA:TRUE', 'CA certificate should be marked as CA'),
('pathlen:0', 'Path length constraint should be set')
("name_constraints_permitted", "Name constraints should be configured"),
("name_constraints_excluded", "Excluded name constraints should be configured"),
("extended_key_usage", "Extended Key Usage should be configured"),
("1.3.6.1.5.5.7.3.17", "IPsec End Entity OID should be present"),
("serverAuth", "Server authentication EKU should be present"),
("clientAuth", "Client authentication EKU should be present"),
("basic_constraints", "Basic constraints should be configured"),
("CA:TRUE", "CA certificate should be marked as CA"),
("pathlen:0", "Path length constraint should be set"),
]
for check, message in security_checks:
@ -140,7 +142,9 @@ def validate_ca_certificate_config():
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in public_domains:
# Handle both double quotes and single quotes in YAML
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, f"Public domain {domain} should be excluded"
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, (
f"Public domain {domain} should be excluded"
)
# Verify private IP ranges are excluded
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
@ -151,13 +155,16 @@ def validate_ca_certificate_config():
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in email_domains:
# Handle both double quotes and single quotes in YAML
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, f"Email domain {domain} should be excluded"
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, (
f"Email domain {domain} should be excluded"
)
# Verify IPv6 constraints are present (Issue #153)
assert "IP:::/0" in content, "IPv6 all addresses should be excluded"
print("✓ CA certificate configuration has proper security constraints")
def test_ca_certificate():
"""Test CA certificate - uses real certs if available, else validates config (Issue #75, #153)"""
cert_files = find_generated_certificates()
@ -172,14 +179,18 @@ def validate_server_certificates_real(cert_files):
# Filter to only actual server certificates (not client certs)
# Server certificates contain IP addresses in the filename
import re
server_certs = [f for f in cert_files['server_certs']
if not f.endswith('/cacert.pem') and re.search(r'\d+\.\d+\.\d+\.\d+\.crt$', f)]
server_certs = [
f
for f in cert_files["server_certs"]
if not f.endswith("/cacert.pem") and re.search(r"\d+\.\d+\.\d+\.\d+\.crt$", f)
]
if not server_certs:
print("⚠ No server certificates found")
return
for server_cert_path in server_certs:
with open(server_cert_path, 'rb') as f:
with open(server_cert_path, "rb") as f:
cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data)
@ -193,7 +204,9 @@ def validate_server_certificates_real(cert_files):
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation"
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, (
"Server cert should NOT have clientAuth EKU for role separation"
)
# Check SAN extension exists (required for Apple devices)
try:
@ -204,9 +217,10 @@ def validate_server_certificates_real(cert_files):
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
def validate_server_certificates_config():
"""Validate server certificate configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -215,7 +229,7 @@ def validate_server_certificates_config():
content = f.read()
# Look for server certificate CSR section
server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL)
server_csr_section = re.search(r"Create CSRs for server certificate.*?register: server_csr", content, re.DOTALL)
if not server_csr_section:
print("⚠ Could not find server certificate CSR section")
return
@ -224,11 +238,11 @@ def validate_server_certificates_config():
# Check server certificate CSR configuration
server_checks = [
('subject_alt_name', 'Server certificates should have SAN extension'),
('serverAuth', 'Server certificates should have serverAuth EKU'),
('1.3.6.1.5.5.7.3.17', 'Server certificates should have IPsec End Entity EKU'),
('digitalSignature', 'Server certificates should have digital signature usage'),
('keyEncipherment', 'Server certificates should have key encipherment usage')
("subject_alt_name", "Server certificates should have SAN extension"),
("serverAuth", "Server certificates should have serverAuth EKU"),
("1.3.6.1.5.5.7.3.17", "Server certificates should have IPsec End Entity EKU"),
("digitalSignature", "Server certificates should have digital signature usage"),
("keyEncipherment", "Server certificates should have key encipherment usage"),
]
for check, message in server_checks:
@ -236,15 +250,20 @@ def validate_server_certificates_config():
# Security check: Server certificates should NOT have clientAuth (Issue #153)
# Look for clientAuth in extended_key_usage section, not in comments
eku_lines = [line for line in server_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'clientAuth' in line)]
has_client_auth = any('clientAuth' in line for line in eku_lines if line.strip().startswith('- '))
eku_lines = [
line
for line in server_section.split("\n")
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "clientAuth" in line)
]
has_client_auth = any("clientAuth" in line for line in eku_lines if line.strip().startswith("- "))
assert not has_client_auth, "Server certificates should NOT have clientAuth EKU for role separation"
# Verify SAN extension is configured for Apple compatibility
assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility"
assert "subjectAltName" in server_section, "Server certificates missing SAN configuration for Apple compatibility"
print("✓ Server certificate configuration has proper EKU and SAN settings")
def test_server_certificates():
"""Test server certificates - uses real certs if available, else validates config"""
cert_files = find_generated_certificates()
@ -258,18 +277,18 @@ def validate_client_certificates_real(cert_files):
"""Validate actual Ansible-generated client certificates"""
# Find client certificates (not CA cert, not server cert with IP/DNS name)
client_certs = []
for cert_path in cert_files['server_certs']:
if 'cacert.pem' in cert_path:
for cert_path in cert_files["server_certs"]:
if "cacert.pem" in cert_path:
continue
with open(cert_path, 'rb') as f:
with open(cert_path, "rb") as f:
cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data)
# Check if this looks like a client cert vs server cert
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
# Server certs typically have IP addresses or domain names as CN
if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4):
if not (cn.replace(".", "").isdigit() or "." in cn and len(cn.split(".")) == 4):
client_certs.append((cert_path, certificate))
if not client_certs:
@ -287,7 +306,9 @@ def validate_client_certificates_real(cert_files):
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation"
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, (
"Client cert must NOT have serverAuth EKU to prevent server impersonation"
)
# Check SAN extension for email
try:
@ -299,9 +320,10 @@ def validate_client_certificates_real(cert_files):
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
def validate_client_certificates_config():
"""Validate client certificate configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -310,7 +332,9 @@ def validate_client_certificates_config():
content = f.read()
# Look for client certificate CSR section
client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL)
client_csr_section = re.search(
r"Create CSRs for client certificates.*?register: client_csr_jobs", content, re.DOTALL
)
if not client_csr_section:
print("⚠ Could not find client certificate CSR section")
return
@ -319,11 +343,11 @@ def validate_client_certificates_config():
# Check client certificate configuration
client_checks = [
('clientAuth', 'Client certificates should have clientAuth EKU'),
('1.3.6.1.5.5.7.3.17', 'Client certificates should have IPsec End Entity EKU'),
('digitalSignature', 'Client certificates should have digital signature usage'),
('keyEncipherment', 'Client certificates should have key encipherment usage'),
('email:', 'Client certificates should have email SAN')
("clientAuth", "Client certificates should have clientAuth EKU"),
("1.3.6.1.5.5.7.3.17", "Client certificates should have IPsec End Entity EKU"),
("digitalSignature", "Client certificates should have digital signature usage"),
("keyEncipherment", "Client certificates should have key encipherment usage"),
("email:", "Client certificates should have email SAN"),
]
for check, message in client_checks:
@ -331,15 +355,22 @@ def validate_client_certificates_config():
# Security check: Client certificates should NOT have serverAuth (Issue #153)
# Look for serverAuth in extended_key_usage section, not in comments
eku_lines = [line for line in client_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'serverAuth' in line)]
has_server_auth = any('serverAuth' in line for line in eku_lines if line.strip().startswith('- '))
eku_lines = [
line
for line in client_section.split("\n")
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "serverAuth" in line)
]
has_server_auth = any("serverAuth" in line for line in eku_lines if line.strip().startswith("- "))
assert not has_server_auth, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
# Verify client certificates use unique email domains (Issue #153)
assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment"
assert "openssl_constraint_random_id" in client_section, (
"Client certificates should use unique email domain per deployment"
)
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
def test_client_certificates():
"""Test client certificates - uses real certs if available, else validates config (Issue #75, #153)"""
cert_files = find_generated_certificates()
@ -351,24 +382,33 @@ def test_client_certificates():
def validate_pkcs12_files_real(cert_files):
"""Validate actual Ansible-generated PKCS#12 files"""
if not cert_files.get('p12_files'):
if not cert_files.get("p12_files"):
print("⚠ No PKCS#12 files found")
return
major, minor = test_openssl_version_detection()
for p12_file in cert_files['p12_files']:
for p12_file in cert_files["p12_files"]:
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
# Test that PKCS#12 file can be read (validates format)
legacy_flag = ['-legacy'] if major >= 3 else []
legacy_flag = ["-legacy"] if major >= 3 else []
result = subprocess.run([
'openssl', 'pkcs12', '-info',
'-in', p12_file,
'-passin', 'pass:', # Try empty password first
'-noout'
] + legacy_flag, capture_output=True, text=True)
result = subprocess.run(
[
"openssl",
"pkcs12",
"-info",
"-in",
p12_file,
"-passin",
"pass:", # Try empty password first
"-noout",
]
+ legacy_flag,
capture_output=True,
text=True,
)
# PKCS#12 files should be readable (even if password-protected)
# We're just testing format validity, not trying to extract contents
@ -378,9 +418,10 @@ def validate_pkcs12_files_real(cert_files):
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
def validate_pkcs12_files_config():
"""Validate PKCS#12 file configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -390,13 +431,13 @@ def validate_pkcs12_files_config():
# Check PKCS#12 generation configuration
p12_checks = [
('openssl_pkcs12', 'PKCS#12 generation should be configured'),
('encryption_level', 'PKCS#12 encryption level should be configured'),
('compatibility2022', 'PKCS#12 should use Apple-compatible encryption'),
('friendly_name', 'PKCS#12 should have friendly names'),
('other_certificates', 'PKCS#12 should include CA certificate for full chain'),
('passphrase', 'PKCS#12 files should be password protected'),
('mode: "0600"', 'PKCS#12 files should have secure permissions')
("openssl_pkcs12", "PKCS#12 generation should be configured"),
("encryption_level", "PKCS#12 encryption level should be configured"),
("compatibility2022", "PKCS#12 should use Apple-compatible encryption"),
("friendly_name", "PKCS#12 should have friendly names"),
("other_certificates", "PKCS#12 should include CA certificate for full chain"),
("passphrase", "PKCS#12 files should be password protected"),
('mode: "0600"', "PKCS#12 files should have secure permissions"),
]
for check, message in p12_checks:
@ -404,6 +445,7 @@ def validate_pkcs12_files_config():
print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
def test_pkcs12_files():
"""Test PKCS#12 files - uses real files if available, else validates config (Issue #14755, #14718)"""
cert_files = find_generated_certificates()
@ -416,19 +458,19 @@ def test_pkcs12_files():
def validate_certificate_chain_real(cert_files):
"""Validate actual Ansible-generated certificate chain"""
# Load CA certificate
with open(cert_files['ca_cert'], 'rb') as f:
with open(cert_files["ca_cert"], "rb") as f:
ca_cert_data = f.read()
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
# Test that all other certificates are signed by the CA
other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']]
other_certs = [f for f in cert_files["server_certs"] if f != cert_files["ca_cert"]]
if not other_certs:
print("⚠ No client/server certificates found to validate")
return
for cert_path in other_certs:
with open(cert_path, 'rb') as f:
with open(cert_path, "rb") as f:
cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data)
@ -437,6 +479,7 @@ def validate_certificate_chain_real(cert_files):
# Verify certificate is currently valid (not expired)
from datetime import datetime
now = datetime.now(UTC)
assert certificate.not_valid_before_utc <= now, f"Certificate {cert_path} not yet valid"
assert certificate.not_valid_after_utc >= now, f"Certificate {cert_path} has expired"
@ -445,9 +488,10 @@ def validate_certificate_chain_real(cert_files):
print("✓ All real certificates properly signed by CA")
def validate_certificate_chain_config():
"""Validate certificate chain configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -457,15 +501,18 @@ def validate_certificate_chain_config():
# Check certificate signing configuration
chain_checks = [
('provider: ownca', 'Certificates should be signed by own CA'),
('ownca_path', 'CA certificate path should be specified'),
('ownca_privatekey_path', 'CA private key path should be specified'),
('ownca_privatekey_passphrase', 'CA private key should be password protected'),
('certificate_validity_days: 3650', 'Certificate validity should be configurable (default 10 years)'),
('ownca_not_after: "+{{ certificate_validity_days }}d"', 'Certificates should use configurable validity period'),
('ownca_not_before: "-1d"', 'Certificates should have backdated start time'),
('curve: secp384r1', 'Should use strong elliptic curve cryptography'),
('type: ECC', 'Should use elliptic curve keys for better security')
("provider: ownca", "Certificates should be signed by own CA"),
("ownca_path", "CA certificate path should be specified"),
("ownca_privatekey_path", "CA private key path should be specified"),
("ownca_privatekey_passphrase", "CA private key should be password protected"),
("certificate_validity_days: 3650", "Certificate validity should be configurable (default 10 years)"),
(
'ownca_not_after: "+{{ certificate_validity_days }}d"',
"Certificates should use configurable validity period",
),
('ownca_not_before: "-1d"', "Certificates should have backdated start time"),
("curve: secp384r1", "Should use strong elliptic curve cryptography"),
("type: ECC", "Should use elliptic curve keys for better security"),
]
for check, message in chain_checks:
@ -473,6 +520,7 @@ def validate_certificate_chain_config():
print("✓ Certificate chain configuration properly set up for CA signing")
def test_certificate_chain():
"""Test certificate chain - uses real certs if available, else validates config"""
cert_files = find_generated_certificates()
@ -499,6 +547,7 @@ def find_ansible_file(relative_path):
return None
if __name__ == "__main__":
tests = [
test_openssl_version_detection,

View file

@ -3,6 +3,7 @@
Enhanced tests for StrongSwan templates.
Tests all strongswan role templates with various configurations.
"""
import os
import sys
import uuid
@ -21,7 +22,7 @@ def mock_to_uuid(value):
def mock_bool(value):
"""Mock the bool filter"""
return str(value).lower() in ('true', '1', 'yes', 'on')
return str(value).lower() in ("true", "1", "yes", "on")
def mock_version(version_string, comparison):
@ -33,67 +34,67 @@ def mock_version(version_string, comparison):
def mock_b64encode(value):
"""Mock base64 encoding"""
import base64
if isinstance(value, str):
value = value.encode('utf-8')
return base64.b64encode(value).decode('ascii')
value = value.encode("utf-8")
return base64.b64encode(value).decode("ascii")
def mock_b64decode(value):
"""Mock base64 decoding"""
import base64
return base64.b64decode(value).decode('utf-8')
return base64.b64decode(value).decode("utf-8")
def get_strongswan_test_variables(scenario='default'):
def get_strongswan_test_variables(scenario="default"):
"""Get test variables for StrongSwan templates with different scenarios."""
base_vars = load_test_variables()
# Add StrongSwan specific variables
strongswan_vars = {
'ipsec_config_path': '/etc/ipsec.d',
'ipsec_pki_path': '/etc/ipsec.d',
'strongswan_enabled': True,
'strongswan_network': '10.19.48.0/24',
'strongswan_network_ipv6': 'fd9d:bc11:4021::/64',
'strongswan_log_level': '2',
'openssl_constraint_random_id': 'test-' + str(uuid.uuid4()),
'subjectAltName': 'IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a',
'subjectAltName_type': 'IP',
'subjectAltName_client': 'IP:10.0.0.1',
'ansible_default_ipv6': {
'address': '2600:3c01::f03c:91ff:fedf:3b2a'
},
'openssl_version': '3.0.0',
'p12_export_password': 'test-password',
'ike_lifetime': '24h',
'ipsec_lifetime': '8h',
'ike_dpd': '30s',
'ipsec_dead_peer_detection': True,
'rekey_margin': '3m',
'rekeymargin': '3m',
'dpddelay': '35s',
'keyexchange': 'ikev2',
'ike_cipher': 'aes128gcm16-prfsha512-ecp256',
'esp_cipher': 'aes128gcm16-ecp256',
'leftsourceip': '10.19.48.1',
'leftsubnet': '0.0.0.0/0,::/0',
'rightsourceip': '10.19.48.2/24,fd9d:bc11:4021::2/64',
"ipsec_config_path": "/etc/ipsec.d",
"ipsec_pki_path": "/etc/ipsec.d",
"strongswan_enabled": True,
"strongswan_network": "10.19.48.0/24",
"strongswan_network_ipv6": "fd9d:bc11:4021::/64",
"strongswan_log_level": "2",
"openssl_constraint_random_id": "test-" + str(uuid.uuid4()),
"subjectAltName": "IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a",
"subjectAltName_type": "IP",
"subjectAltName_client": "IP:10.0.0.1",
"ansible_default_ipv6": {"address": "2600:3c01::f03c:91ff:fedf:3b2a"},
"openssl_version": "3.0.0",
"p12_export_password": "test-password",
"ike_lifetime": "24h",
"ipsec_lifetime": "8h",
"ike_dpd": "30s",
"ipsec_dead_peer_detection": True,
"rekey_margin": "3m",
"rekeymargin": "3m",
"dpddelay": "35s",
"keyexchange": "ikev2",
"ike_cipher": "aes128gcm16-prfsha512-ecp256",
"esp_cipher": "aes128gcm16-ecp256",
"leftsourceip": "10.19.48.1",
"leftsubnet": "0.0.0.0/0,::/0",
"rightsourceip": "10.19.48.2/24,fd9d:bc11:4021::2/64",
}
# Merge with base variables
test_vars = {**base_vars, **strongswan_vars}
# Apply scenario-specific overrides
if scenario == 'ipv4_only':
test_vars['ipv6_support'] = False
test_vars['subjectAltName'] = 'IP:10.0.0.1'
test_vars['ansible_default_ipv6'] = None
elif scenario == 'dns_hostname':
test_vars['IP_subject_alt_name'] = 'vpn.example.com'
test_vars['subjectAltName'] = 'DNS:vpn.example.com'
test_vars['subjectAltName_type'] = 'DNS'
elif scenario == 'openssl_legacy':
test_vars['openssl_version'] = '1.1.1'
if scenario == "ipv4_only":
test_vars["ipv6_support"] = False
test_vars["subjectAltName"] = "IP:10.0.0.1"
test_vars["ansible_default_ipv6"] = None
elif scenario == "dns_hostname":
test_vars["IP_subject_alt_name"] = "vpn.example.com"
test_vars["subjectAltName"] = "DNS:vpn.example.com"
test_vars["subjectAltName_type"] = "DNS"
elif scenario == "openssl_legacy":
test_vars["openssl_version"] = "1.1.1"
return test_vars
@ -101,16 +102,16 @@ def get_strongswan_test_variables(scenario='default'):
def test_strongswan_templates():
"""Test all StrongSwan templates with various configurations."""
templates = [
'roles/strongswan/templates/ipsec.conf.j2',
'roles/strongswan/templates/ipsec.secrets.j2',
'roles/strongswan/templates/strongswan.conf.j2',
'roles/strongswan/templates/charon.conf.j2',
'roles/strongswan/templates/client_ipsec.conf.j2',
'roles/strongswan/templates/client_ipsec.secrets.j2',
'roles/strongswan/templates/100-CustomLimitations.conf.j2',
"roles/strongswan/templates/ipsec.conf.j2",
"roles/strongswan/templates/ipsec.secrets.j2",
"roles/strongswan/templates/strongswan.conf.j2",
"roles/strongswan/templates/charon.conf.j2",
"roles/strongswan/templates/client_ipsec.conf.j2",
"roles/strongswan/templates/client_ipsec.secrets.j2",
"roles/strongswan/templates/100-CustomLimitations.conf.j2",
]
scenarios = ['default', 'ipv4_only', 'dns_hostname', 'openssl_legacy']
scenarios = ["default", "ipv4_only", "dns_hostname", "openssl_legacy"]
errors = []
tested = 0
@ -127,21 +128,18 @@ def test_strongswan_templates():
test_vars = get_strongswan_test_variables(scenario)
try:
env = Environment(
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
# Add mock filters
env.filters['to_uuid'] = mock_to_uuid
env.filters['bool'] = mock_bool
env.filters['b64encode'] = mock_b64encode
env.filters['b64decode'] = mock_b64decode
env.tests['version'] = mock_version
env.filters["to_uuid"] = mock_to_uuid
env.filters["bool"] = mock_bool
env.filters["b64encode"] = mock_b64encode
env.filters["b64decode"] = mock_b64decode
env.tests["version"] = mock_version
# For client templates, add item context
if 'client' in template_name:
test_vars['item'] = 'testuser'
if "client" in template_name:
test_vars["item"] = "testuser"
template = env.get_template(template_name)
output = template.render(**test_vars)
@ -150,16 +148,16 @@ def test_strongswan_templates():
assert len(output) > 0, f"Empty output from {template_path} ({scenario})"
# Specific validations based on template
if 'ipsec.conf' in template_name and 'client' not in template_name:
assert 'conn' in output, "Missing connection definition"
if scenario != 'ipv4_only' and test_vars.get('ipv6_support'):
assert '::/0' in output or 'fd9d:bc11' in output, "Missing IPv6 configuration"
if "ipsec.conf" in template_name and "client" not in template_name:
assert "conn" in output, "Missing connection definition"
if scenario != "ipv4_only" and test_vars.get("ipv6_support"):
assert "::/0" in output or "fd9d:bc11" in output, "Missing IPv6 configuration"
if 'ipsec.secrets' in template_name:
assert 'PSK' in output or 'ECDSA' in output, "Missing authentication method"
if "ipsec.secrets" in template_name:
assert "PSK" in output or "ECDSA" in output, "Missing authentication method"
if 'strongswan.conf' in template_name:
assert 'charon' in output, "Missing charon configuration"
if "strongswan.conf" in template_name:
assert "charon" in output, "Missing charon configuration"
print(f"{template_name} ({scenario})")
@ -182,7 +180,7 @@ def test_openssl_template_constraints():
# This tests the actual openssl.yml task file to ensure our fix works
import yaml
openssl_path = 'roles/strongswan/tasks/openssl.yml'
openssl_path = "roles/strongswan/tasks/openssl.yml"
if not os.path.exists(openssl_path):
print("⚠️ OpenSSL tasks file not found")
return True
@ -194,22 +192,23 @@ def test_openssl_template_constraints():
# Find the CA CSR task
ca_csr_task = None
for task in content:
if isinstance(task, dict) and task.get('name', '').startswith('Create certificate signing request'):
if isinstance(task, dict) and task.get("name", "").startswith("Create certificate signing request"):
ca_csr_task = task
break
if ca_csr_task:
# Check that name_constraints_permitted is properly formatted
csr_module = ca_csr_task.get('community.crypto.openssl_csr_pipe', {})
constraints = csr_module.get('name_constraints_permitted', '')
csr_module = ca_csr_task.get("community.crypto.openssl_csr_pipe", {})
constraints = csr_module.get("name_constraints_permitted", "")
# The constraints should be a Jinja2 template without inline comments
if '#' in str(constraints):
if "#" in str(constraints):
# Check if the # is within {{ }}
import re
jinja_blocks = re.findall(r'\{\{.*?\}\}', str(constraints), re.DOTALL)
jinja_blocks = re.findall(r"\{\{.*?\}\}", str(constraints), re.DOTALL)
for block in jinja_blocks:
if '#' in block:
if "#" in block:
print("❌ Found inline comment in Jinja2 expression")
return False
@ -223,7 +222,7 @@ def test_openssl_template_constraints():
def test_mobileconfig_template():
"""Test the mobileconfig template with various scenarios."""
template_path = 'roles/strongswan/templates/mobileconfig.j2'
template_path = "roles/strongswan/templates/mobileconfig.j2"
if not os.path.exists(template_path):
print("⚠️ Mobileconfig template not found")
@ -237,20 +236,20 @@ def test_mobileconfig_template():
test_cases = [
{
'name': 'iPhone with cellular on-demand',
'algo_ondemand_cellular': 'true',
'algo_ondemand_wifi': 'false',
"name": "iPhone with cellular on-demand",
"algo_ondemand_cellular": "true",
"algo_ondemand_wifi": "false",
},
{
'name': 'iPad with WiFi on-demand',
'algo_ondemand_cellular': 'false',
'algo_ondemand_wifi': 'true',
'algo_ondemand_wifi_exclude': 'MyHomeNetwork,OfficeWiFi',
"name": "iPad with WiFi on-demand",
"algo_ondemand_cellular": "false",
"algo_ondemand_wifi": "true",
"algo_ondemand_wifi_exclude": "MyHomeNetwork,OfficeWiFi",
},
{
'name': 'Mac without on-demand',
'algo_ondemand_cellular': 'false',
'algo_ondemand_wifi': 'false',
"name": "Mac without on-demand",
"algo_ondemand_cellular": "false",
"algo_ondemand_wifi": "false",
},
]
@ -258,43 +257,41 @@ def test_mobileconfig_template():
for test_case in test_cases:
test_vars = get_strongswan_test_variables()
test_vars.update(test_case)
# Mock Ansible task result format for item
class MockTaskResult:
def __init__(self, content):
self.stdout = content
test_vars['item'] = ('testuser', MockTaskResult('TU9DS19QS0NTMTJfQ09OVEVOVA==')) # Tuple with mock result
test_vars['PayloadContentCA_base64'] = 'TU9DS19DQV9DRVJUX0JBU0U2NA==' # Valid base64
test_vars['PayloadContentUser_base64'] = 'TU9DS19VU0VSX0NFUlRfQkFTRTY0' # Valid base64
test_vars['pkcs12_PayloadCertificateUUID'] = str(uuid.uuid4())
test_vars['PayloadContent'] = 'TU9DS19QS0NTMTJfQ09OVEVOVA==' # Valid base64 for PKCS12
test_vars['algo_server_name'] = 'test-algo-vpn'
test_vars['VPN_PayloadIdentifier'] = str(uuid.uuid4())
test_vars['CA_PayloadIdentifier'] = str(uuid.uuid4())
test_vars['PayloadContentCA'] = 'TU9DS19DQV9DRVJUX0NPTlRFTlQ=' # Valid base64
test_vars["item"] = ("testuser", MockTaskResult("TU9DS19QS0NTMTJfQ09OVEVOVA==")) # Tuple with mock result
test_vars["PayloadContentCA_base64"] = "TU9DS19DQV9DRVJUX0JBU0U2NA==" # Valid base64
test_vars["PayloadContentUser_base64"] = "TU9DS19VU0VSX0NFUlRfQkFTRTY0" # Valid base64
test_vars["pkcs12_PayloadCertificateUUID"] = str(uuid.uuid4())
test_vars["PayloadContent"] = "TU9DS19QS0NTMTJfQ09OVEVOVA==" # Valid base64 for PKCS12
test_vars["algo_server_name"] = "test-algo-vpn"
test_vars["VPN_PayloadIdentifier"] = str(uuid.uuid4())
test_vars["CA_PayloadIdentifier"] = str(uuid.uuid4())
test_vars["PayloadContentCA"] = "TU9DS19DQV9DRVJUX0NPTlRFTlQ=" # Valid base64
try:
env = Environment(
loader=FileSystemLoader('roles/strongswan/templates'),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader("roles/strongswan/templates"), undefined=StrictUndefined)
# Add mock filters
env.filters['to_uuid'] = mock_to_uuid
env.filters['b64encode'] = mock_b64encode
env.filters['b64decode'] = mock_b64decode
env.filters["to_uuid"] = mock_to_uuid
env.filters["b64encode"] = mock_b64encode
env.filters["b64decode"] = mock_b64decode
template = env.get_template('mobileconfig.j2')
template = env.get_template("mobileconfig.j2")
output = template.render(**test_vars)
# Validate output
assert '<?xml' in output, "Missing XML declaration"
assert '<plist' in output, "Missing plist element"
assert 'PayloadType' in output, "Missing PayloadType"
assert "<?xml" in output, "Missing XML declaration"
assert "<plist" in output, "Missing plist element"
assert "PayloadType" in output, "Missing PayloadType"
# Check on-demand configuration
if test_case.get('algo_ondemand_cellular') == 'true' or test_case.get('algo_ondemand_wifi') == 'true':
assert 'OnDemandEnabled' in output, f"Missing OnDemand config for {test_case['name']}"
if test_case.get("algo_ondemand_cellular") == "true" or test_case.get("algo_ondemand_wifi") == "true":
assert "OnDemandEnabled" in output, f"Missing OnDemand config for {test_case['name']}"
print(f" ✅ Mobileconfig: {test_case['name']}")

View file

@ -3,6 +3,7 @@
Test that Ansible templates render correctly
This catches undefined variables, syntax errors, and logic bugs
"""
import os
import sys
from pathlib import Path
@ -22,20 +23,20 @@ def mock_to_uuid(value):
def mock_bool(value):
"""Mock the bool filter"""
return str(value).lower() in ('true', '1', 'yes', 'on')
return str(value).lower() in ("true", "1", "yes", "on")
def mock_lookup(type, path):
"""Mock the lookup function"""
# Return fake data for file lookups
if type == 'file':
if 'private' in path:
return 'MOCK_PRIVATE_KEY_BASE64=='
elif 'public' in path:
return 'MOCK_PUBLIC_KEY_BASE64=='
elif 'preshared' in path:
return 'MOCK_PRESHARED_KEY_BASE64=='
return 'MOCK_LOOKUP_DATA'
if type == "file":
if "private" in path:
return "MOCK_PRIVATE_KEY_BASE64=="
elif "public" in path:
return "MOCK_PUBLIC_KEY_BASE64=="
elif "preshared" in path:
return "MOCK_PRESHARED_KEY_BASE64=="
return "MOCK_LOOKUP_DATA"
def get_test_variables():
@ -47,8 +48,8 @@ def get_test_variables():
def find_templates():
"""Find all Jinja2 template files in the repo"""
templates = []
for pattern in ['**/*.j2', '**/*.jinja2', '**/*.yml.j2']:
templates.extend(Path('.').glob(pattern))
for pattern in ["**/*.j2", "**/*.jinja2", "**/*.yml.j2"]:
templates.extend(Path(".").glob(pattern))
return templates
@ -57,10 +58,10 @@ def test_template_syntax():
templates = find_templates()
# Skip some paths that aren't real templates
skip_paths = ['.git/', 'venv/', '.venv/', '.env/', 'configs/']
skip_paths = [".git/", "venv/", ".venv/", ".env/", "configs/"]
# Skip templates that use Ansible-specific filters
skip_templates = ['vpn-dict.j2', 'mobileconfig.j2', 'dnscrypt-proxy.toml.j2']
skip_templates = ["vpn-dict.j2", "mobileconfig.j2", "dnscrypt-proxy.toml.j2"]
errors = []
skipped = 0
@ -76,10 +77,7 @@ def test_template_syntax():
try:
template_dir = template_path.parent
env = Environment(
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
# Just try to load the template - this checks syntax
env.get_template(template_path.name)
@ -103,13 +101,13 @@ def test_template_syntax():
def test_critical_templates():
"""Test that critical templates render with test data"""
critical_templates = [
'roles/wireguard/templates/client.conf.j2',
'roles/strongswan/templates/ipsec.conf.j2',
'roles/strongswan/templates/ipsec.secrets.j2',
'roles/dns/templates/adblock.sh.j2',
'roles/dns/templates/dnsmasq.conf.j2',
'roles/common/templates/rules.v4.j2',
'roles/common/templates/rules.v6.j2',
"roles/wireguard/templates/client.conf.j2",
"roles/strongswan/templates/ipsec.conf.j2",
"roles/strongswan/templates/ipsec.secrets.j2",
"roles/dns/templates/adblock.sh.j2",
"roles/dns/templates/dnsmasq.conf.j2",
"roles/common/templates/rules.v4.j2",
"roles/common/templates/rules.v6.j2",
]
test_vars = get_test_variables()
@ -123,21 +121,18 @@ def test_critical_templates():
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
env = Environment(
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
# Add mock functions
env.globals['lookup'] = mock_lookup
env.filters['to_uuid'] = mock_to_uuid
env.filters['bool'] = mock_bool
env.globals["lookup"] = mock_lookup
env.filters["to_uuid"] = mock_to_uuid
env.filters["bool"] = mock_bool
template = env.get_template(template_name)
# Add item context for templates that use loops
if 'client' in template_name:
test_vars['item'] = ('test-user', 'test-user')
if "client" in template_name:
test_vars["item"] = ("test-user", "test-user")
# Try to render
output = template.render(**test_vars)
@ -163,17 +158,17 @@ def test_variable_consistency():
"""Check that commonly used variables are defined consistently"""
# Variables that should be used consistently across templates
common_vars = [
'server_name',
'IP_subject_alt_name',
'wireguard_port',
'wireguard_network',
'dns_servers',
'users',
"server_name",
"IP_subject_alt_name",
"wireguard_port",
"wireguard_network",
"dns_servers",
"users",
]
# Check if main.yml defines these
if os.path.exists('main.yml'):
with open('main.yml') as f:
if os.path.exists("main.yml"):
with open("main.yml") as f:
content = f.read()
missing = []
@ -192,28 +187,19 @@ def test_wireguard_ipv6_endpoints():
"""Test that WireGuard client configs properly format IPv6 endpoints"""
test_cases = [
# IPv4 address - should not be bracketed
{
'IP_subject_alt_name': '192.168.1.100',
'expected_endpoint': 'Endpoint = 192.168.1.100:51820'
},
{"IP_subject_alt_name": "192.168.1.100", "expected_endpoint": "Endpoint = 192.168.1.100:51820"},
# IPv6 address - should be bracketed
{
'IP_subject_alt_name': '2600:3c01::f03c:91ff:fedf:3b2a',
'expected_endpoint': 'Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820'
"IP_subject_alt_name": "2600:3c01::f03c:91ff:fedf:3b2a",
"expected_endpoint": "Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820",
},
# Hostname - should not be bracketed
{
'IP_subject_alt_name': 'vpn.example.com',
'expected_endpoint': 'Endpoint = vpn.example.com:51820'
},
{"IP_subject_alt_name": "vpn.example.com", "expected_endpoint": "Endpoint = vpn.example.com:51820"},
# IPv6 with zone ID - should be bracketed
{
'IP_subject_alt_name': 'fe80::1%eth0',
'expected_endpoint': 'Endpoint = [fe80::1%eth0]:51820'
},
{"IP_subject_alt_name": "fe80::1%eth0", "expected_endpoint": "Endpoint = [fe80::1%eth0]:51820"},
]
template_path = 'roles/wireguard/templates/client.conf.j2'
template_path = "roles/wireguard/templates/client.conf.j2"
if not os.path.exists(template_path):
print(f"⚠ Skipping IPv6 endpoint test - {template_path} not found")
return
@ -225,24 +211,23 @@ def test_wireguard_ipv6_endpoints():
try:
# Set up test variables
test_vars = {**base_vars, **test_case}
test_vars['item'] = ('test-user', 'test-user')
test_vars["item"] = ("test-user", "test-user")
# Render template
env = Environment(
loader=FileSystemLoader('roles/wireguard/templates'),
undefined=StrictUndefined
)
env.globals['lookup'] = mock_lookup
env = Environment(loader=FileSystemLoader("roles/wireguard/templates"), undefined=StrictUndefined)
env.globals["lookup"] = mock_lookup
template = env.get_template('client.conf.j2')
template = env.get_template("client.conf.j2")
output = template.render(**test_vars)
# Check if the expected endpoint format is in the output
if test_case['expected_endpoint'] not in output:
errors.append(f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output")
if test_case["expected_endpoint"] not in output:
errors.append(
f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output"
)
# Print relevant part of output for debugging
for line in output.split('\n'):
if 'Endpoint' in line:
for line in output.split("\n"):
if "Endpoint" in line:
errors.append(f" Found: {line.strip()}")
except Exception as e:
@ -262,27 +247,27 @@ def test_template_conditionals():
test_cases = [
# WireGuard enabled, IPsec disabled
{
'wireguard_enabled': True,
'ipsec_enabled': False,
'dns_encryption': True,
'dns_adblocking': True,
'algo_ssh_tunneling': False,
"wireguard_enabled": True,
"ipsec_enabled": False,
"dns_encryption": True,
"dns_adblocking": True,
"algo_ssh_tunneling": False,
},
# IPsec enabled, WireGuard disabled
{
'wireguard_enabled': False,
'ipsec_enabled': True,
'dns_encryption': False,
'dns_adblocking': False,
'algo_ssh_tunneling': True,
"wireguard_enabled": False,
"ipsec_enabled": True,
"dns_encryption": False,
"dns_adblocking": False,
"algo_ssh_tunneling": True,
},
# Both enabled
{
'wireguard_enabled': True,
'ipsec_enabled': True,
'dns_encryption': True,
'dns_adblocking': True,
'algo_ssh_tunneling': True,
"wireguard_enabled": True,
"ipsec_enabled": True,
"dns_encryption": True,
"dns_adblocking": True,
"algo_ssh_tunneling": True,
},
]
@ -294,7 +279,7 @@ def test_template_conditionals():
# Test a few templates that have conditionals
conditional_templates = [
'roles/common/templates/rules.v4.j2',
"roles/common/templates/rules.v4.j2",
]
for template_path in conditional_templates:
@ -305,23 +290,19 @@ def test_template_conditionals():
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
env = Environment(
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
# Add mock functions
env.globals['lookup'] = mock_lookup
env.filters['to_uuid'] = mock_to_uuid
env.filters['bool'] = mock_bool
env.globals["lookup"] = mock_lookup
env.filters["to_uuid"] = mock_to_uuid
env.filters["bool"] = mock_bool
template = env.get_template(template_name)
output = template.render(**test_vars)
# Verify conditionals work
if test_case.get('wireguard_enabled'):
assert str(test_vars['wireguard_port']) in output, \
f"WireGuard port missing when enabled (case {i})"
if test_case.get("wireguard_enabled"):
assert str(test_vars["wireguard_port"]) in output, f"WireGuard port missing when enabled (case {i})"
except Exception as e:
print(f"✗ Conditional test failed for {template_path} case {i}: {e}")

View file

@ -3,6 +3,7 @@
Test user management functionality without deployment
Based on issues #14745, #14746, #14738, #14726
"""
import os
import re
import sys
@ -23,15 +24,15 @@ users:
"""
config = yaml.safe_load(test_config)
users = config.get('users', [])
users = config.get("users", [])
assert len(users) == 5, f"Expected 5 users, got {len(users)}"
assert 'alice' in users, "Missing user 'alice'"
assert 'user-with-dash' in users, "Dash in username not handled"
assert 'user_with_underscore' in users, "Underscore in username not handled"
assert "alice" in users, "Missing user 'alice'"
assert "user-with-dash" in users, "Dash in username not handled"
assert "user_with_underscore" in users, "Underscore in username not handled"
# Test that usernames are valid
username_pattern = re.compile(r'^[a-zA-Z0-9_-]+$')
username_pattern = re.compile(r"^[a-zA-Z0-9_-]+$")
for user in users:
assert username_pattern.match(user), f"Invalid username format: {user}"
@ -42,35 +43,27 @@ def test_server_selection_format():
"""Test server selection string parsing (issue #14727)"""
# Test various server display formats
test_cases = [
{"display": "1. 192.168.1.100 (algo-server)", "expected_ip": "192.168.1.100", "expected_name": "algo-server"},
{"display": "2. 10.0.0.1 (production-vpn)", "expected_ip": "10.0.0.1", "expected_name": "production-vpn"},
{
'display': '1. 192.168.1.100 (algo-server)',
'expected_ip': '192.168.1.100',
'expected_name': 'algo-server'
"display": "3. vpn.example.com (example-server)",
"expected_ip": "vpn.example.com",
"expected_name": "example-server",
},
{
'display': '2. 10.0.0.1 (production-vpn)',
'expected_ip': '10.0.0.1',
'expected_name': 'production-vpn'
},
{
'display': '3. vpn.example.com (example-server)',
'expected_ip': 'vpn.example.com',
'expected_name': 'example-server'
}
]
# Pattern to extract IP and name from display string
pattern = re.compile(r'^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$')
pattern = re.compile(r"^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$")
for case in test_cases:
match = pattern.match(case['display'])
match = pattern.match(case["display"])
assert match, f"Failed to parse: {case['display']}"
ip_or_host = match.group(1)
name = match.group(2)
assert ip_or_host == case['expected_ip'], f"Wrong IP extracted: {ip_or_host}"
assert name == case['expected_name'], f"Wrong name extracted: {name}"
assert ip_or_host == case["expected_ip"], f"Wrong IP extracted: {ip_or_host}"
assert name == case["expected_name"], f"Wrong name extracted: {name}"
print("✓ Server selection format test passed")
@ -78,12 +71,12 @@ def test_server_selection_format():
def test_ssh_key_preservation():
"""Test that SSH keys aren't regenerated unnecessarily"""
with tempfile.TemporaryDirectory() as tmpdir:
ssh_key_path = os.path.join(tmpdir, 'test_key')
ssh_key_path = os.path.join(tmpdir, "test_key")
# Simulate existing SSH key
with open(ssh_key_path, 'w') as f:
with open(ssh_key_path, "w") as f:
f.write("EXISTING_SSH_KEY_CONTENT")
with open(f"{ssh_key_path}.pub", 'w') as f:
with open(f"{ssh_key_path}.pub", "w") as f:
f.write("ssh-rsa EXISTING_PUBLIC_KEY")
# Record original content
@ -105,11 +98,7 @@ def test_ssh_key_preservation():
def test_ca_password_handling():
"""Test CA password validation and handling"""
# Test password requirements
valid_passwords = [
"SecurePassword123!",
"Algo-VPN-2024",
"Complex#Pass@Word999"
]
valid_passwords = ["SecurePassword123!", "Algo-VPN-2024", "Complex#Pass@Word999"]
invalid_passwords = [
"", # Empty
@ -120,13 +109,13 @@ def test_ca_password_handling():
# Basic password validation
for pwd in valid_passwords:
assert len(pwd) >= 12, f"Password too short: {pwd}"
assert ' ' not in pwd, f"Password contains spaces: {pwd}"
assert " " not in pwd, f"Password contains spaces: {pwd}"
for pwd in invalid_passwords:
issues = []
if len(pwd) < 12:
issues.append("too short")
if ' ' in pwd:
if " " in pwd:
issues.append("contains spaces")
if not pwd:
issues.append("empty")
@ -137,8 +126,8 @@ def test_ca_password_handling():
def test_user_config_generation():
"""Test that user configs would be generated correctly"""
users = ['alice', 'bob', 'charlie']
server_name = 'test-server'
users = ["alice", "bob", "charlie"]
server_name = "test-server"
# Simulate config file structure
for user in users:
@ -168,7 +157,7 @@ users:
"""
config = yaml.safe_load(test_config)
users = config.get('users', [])
users = config.get("users", [])
# Check for duplicates
unique_users = list(set(users))
@ -182,7 +171,7 @@ users:
duplicates.append(user)
seen.add(user)
assert 'alice' in duplicates, "Duplicate 'alice' not detected"
assert "alice" in duplicates, "Duplicate 'alice' not detected"
print("✓ Duplicate user handling test passed")

View file

@ -3,6 +3,7 @@
Test WireGuard key generation - focused on x25519_pubkey module integration
Addresses test gap identified in tests/README.md line 63-67: WireGuard private/public key generation
"""
import base64
import os
import subprocess
@ -10,13 +11,13 @@ import sys
import tempfile
# Add library directory to path to import our custom module
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "library"))
def test_wireguard_tools_available():
"""Test that WireGuard tools are available for validation"""
try:
result = subprocess.run(['wg', '--version'], capture_output=True, text=True)
result = subprocess.run(["wg", "--version"], capture_output=True, text=True)
assert result.returncode == 0, "WireGuard tools not available"
print(f"✓ WireGuard tools available: {result.stdout.strip()}")
return True
@ -29,6 +30,7 @@ def test_x25519_module_import():
"""Test that our custom x25519_pubkey module can be imported and used"""
try:
import x25519_pubkey # noqa: F401
print("✓ x25519_pubkey module imports successfully")
return True
except ImportError as e:
@ -37,16 +39,17 @@ def test_x25519_module_import():
def generate_test_private_key():
"""Generate a test private key using the same method as Algo"""
with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file:
with tempfile.NamedTemporaryFile(suffix=".raw", delete=False) as temp_file:
raw_key_path = temp_file.name
try:
# Generate 32 random bytes for X25519 private key (same as community.crypto does)
import secrets
raw_data = secrets.token_bytes(32)
# Write raw key to file (like community.crypto openssl_privatekey with format: raw)
with open(raw_key_path, 'wb') as f:
with open(raw_key_path, "wb") as f:
f.write(raw_data)
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
@ -83,7 +86,7 @@ def test_x25519_pubkey_from_raw_file():
def exit_json(self, **kwargs):
self.result = kwargs
with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub:
with tempfile.NamedTemporaryFile(suffix=".pub", delete=False) as temp_pub:
public_key_path = temp_pub.name
try:
@ -95,11 +98,9 @@ def test_x25519_pubkey_from_raw_file():
try:
# Mock the module call
mock_module = MockModule({
'private_key_path': raw_key_path,
'public_key_path': public_key_path,
'private_key_b64': None
})
mock_module = MockModule(
{"private_key_path": raw_key_path, "public_key_path": public_key_path, "private_key_b64": None}
)
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
@ -107,8 +108,8 @@ def test_x25519_pubkey_from_raw_file():
run_module()
# Check the result
assert 'public_key' in mock_module.result
assert mock_module.result['changed']
assert "public_key" in mock_module.result
assert mock_module.result["changed"]
assert os.path.exists(public_key_path)
with open(public_key_path) as f:
@ -160,11 +161,7 @@ def test_x25519_pubkey_from_b64_string():
original_AnsibleModule = x25519_pubkey.AnsibleModule
try:
mock_module = MockModule({
'private_key_b64': b64_key,
'private_key_path': None,
'public_key_path': None
})
mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
@ -172,8 +169,8 @@ def test_x25519_pubkey_from_b64_string():
run_module()
# Check the result
assert 'public_key' in mock_module.result
derived_pubkey = mock_module.result['public_key']
assert "public_key" in mock_module.result
derived_pubkey = mock_module.result["public_key"]
# Validate base64 format
try:
@ -222,21 +219,17 @@ def test_wireguard_validation():
original_AnsibleModule = x25519_pubkey.AnsibleModule
try:
mock_module = MockModule({
'private_key_b64': b64_key,
'private_key_path': None,
'public_key_path': None
})
mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
run_module()
derived_pubkey = mock_module.result['public_key']
derived_pubkey = mock_module.result["public_key"]
finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config:
with tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False) as temp_config:
# Create a WireGuard config using our keys
wg_config = f"""[Interface]
PrivateKey = {b64_key}
@ -251,16 +244,12 @@ AllowedIPs = 10.19.49.2/32
try:
# Test that WireGuard can parse our config
result = subprocess.run([
'wg-quick', 'strip', config_path
], capture_output=True, text=True)
result = subprocess.run(["wg-quick", "strip", config_path], capture_output=True, text=True)
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
# Test key derivation with wg pubkey command
wg_result = subprocess.run([
'wg', 'pubkey'
], input=b64_key, capture_output=True, text=True)
wg_result = subprocess.run(["wg", "pubkey"], input=b64_key, capture_output=True, text=True)
if wg_result.returncode == 0:
wg_derived = wg_result.stdout.strip()
@ -286,8 +275,8 @@ def test_key_consistency():
raw_key_path, b64_key = generate_test_private_key()
try:
def derive_pubkey_from_same_key():
def derive_pubkey_from_same_key():
class MockModule:
def __init__(self, params):
self.params = params
@ -305,16 +294,18 @@ def test_key_consistency():
original_AnsibleModule = x25519_pubkey.AnsibleModule
try:
mock_module = MockModule({
'private_key_b64': b64_key, # SAME key each time
'private_key_path': None,
'public_key_path': None
})
mock_module = MockModule(
{
"private_key_b64": b64_key, # SAME key each time
"private_key_path": None,
"public_key_path": None,
}
)
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
run_module()
return mock_module.result['public_key']
return mock_module.result["public_key"]
finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule

View file

@ -14,13 +14,13 @@ from pathlib import Path
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateSyntaxError, meta
def find_jinja2_templates(root_dir: str = '.') -> list[Path]:
def find_jinja2_templates(root_dir: str = ".") -> list[Path]:
"""Find all Jinja2 template files in the project."""
templates = []
patterns = ['**/*.j2', '**/*.jinja2', '**/*.yml.j2', '**/*.conf.j2']
patterns = ["**/*.j2", "**/*.jinja2", "**/*.yml.j2", "**/*.conf.j2"]
# Skip these directories
skip_dirs = {'.git', '.venv', 'venv', '.env', 'configs', '__pycache__', '.cache'}
skip_dirs = {".git", ".venv", "venv", ".env", "configs", "__pycache__", ".cache"}
for pattern in patterns:
for path in Path(root_dir).glob(pattern):
@ -39,25 +39,25 @@ def check_inline_comments_in_expressions(template_content: str, template_path: P
errors = []
# Pattern to find Jinja2 expressions
jinja_pattern = re.compile(r'\{\{.*?\}\}|\{%.*?%\}', re.DOTALL)
jinja_pattern = re.compile(r"\{\{.*?\}\}|\{%.*?%\}", re.DOTALL)
for match in jinja_pattern.finditer(template_content):
expression = match.group()
lines = expression.split('\n')
lines = expression.split("\n")
for i, line in enumerate(lines):
# Check for # that's not in a string
# Simple heuristic: if # appears after non-whitespace and not in quotes
if '#' in line:
if "#" in line:
# Remove quoted strings to avoid false positives
cleaned = re.sub(r'"[^"]*"', '', line)
cleaned = re.sub(r"'[^']*'", '', cleaned)
cleaned = re.sub(r'"[^"]*"', "", line)
cleaned = re.sub(r"'[^']*'", "", cleaned)
if '#' in cleaned:
if "#" in cleaned:
# Check if it's likely a comment (has text after it)
hash_pos = cleaned.index('#')
if hash_pos > 0 and cleaned[hash_pos-1:hash_pos] != '\\':
line_num = template_content[:match.start()].count('\n') + i + 1
hash_pos = cleaned.index("#")
if hash_pos > 0 and cleaned[hash_pos - 1 : hash_pos] != "\\":
line_num = template_content[: match.start()].count("\n") + i + 1
errors.append(
f"{template_path}:{line_num}: Inline comment (#) found in Jinja2 expression. "
f"Move comments outside the expression."
@ -83,11 +83,24 @@ def check_undefined_variables(template_path: Path) -> list[str]:
# Common Ansible variables that are always available
ansible_builtins = {
'ansible_default_ipv4', 'ansible_default_ipv6', 'ansible_hostname',
'ansible_distribution', 'ansible_distribution_version', 'ansible_facts',
'inventory_hostname', 'hostvars', 'groups', 'group_names',
'play_hosts', 'ansible_version', 'ansible_user', 'ansible_host',
'item', 'ansible_loop', 'ansible_index', 'lookup'
"ansible_default_ipv4",
"ansible_default_ipv6",
"ansible_hostname",
"ansible_distribution",
"ansible_distribution_version",
"ansible_facts",
"inventory_hostname",
"hostvars",
"groups",
"group_names",
"play_hosts",
"ansible_version",
"ansible_user",
"ansible_host",
"item",
"ansible_loop",
"ansible_index",
"lookup",
}
# Filter out known Ansible variables
@ -95,9 +108,7 @@ def check_undefined_variables(template_path: Path) -> list[str]:
# Only report if there are truly unknown variables
if unknown_vars and len(unknown_vars) < 20: # Avoid noise from templates with many vars
errors.append(
f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}"
)
errors.append(f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}")
except Exception:
# Don't report parse errors here, they're handled elsewhere
@ -116,9 +127,9 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
# Skip full parsing for templates that use Ansible-specific features heavily
# We still check for inline comments but skip full template parsing
ansible_specific_templates = {
'dnscrypt-proxy.toml.j2', # Uses |bool filter
'mobileconfig.j2', # Uses |to_uuid filter and complex item structures
'vpn-dict.j2', # Uses |to_uuid filter
"dnscrypt-proxy.toml.j2", # Uses |bool filter
"mobileconfig.j2", # Uses |to_uuid filter and complex item structures
"vpn-dict.j2", # Uses |to_uuid filter
}
if template_path.name in ansible_specific_templates:
@ -139,18 +150,15 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
errors.extend(check_inline_comments_in_expressions(template_content, template_path))
# Try to parse the template
env = Environment(
loader=FileSystemLoader(template_path.parent),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_path.parent), undefined=StrictUndefined)
# Add mock Ansible filters to avoid syntax errors
env.filters['bool'] = lambda x: x
env.filters['to_uuid'] = lambda x: x
env.filters['b64encode'] = lambda x: x
env.filters['b64decode'] = lambda x: x
env.filters['regex_replace'] = lambda x, y, z: x
env.filters['default'] = lambda x, d: x if x else d
env.filters["bool"] = lambda x: x
env.filters["to_uuid"] = lambda x: x
env.filters["b64encode"] = lambda x: x
env.filters["b64decode"] = lambda x: x
env.filters["regex_replace"] = lambda x, y, z: x
env.filters["default"] = lambda x, d: x if x else d
# This will raise TemplateSyntaxError if there's a syntax problem
env.get_template(template_path.name)
@ -178,18 +186,20 @@ def check_common_antipatterns(template_path: Path) -> list[str]:
content = f.read()
# Check for missing spaces around filters
if re.search(r'\{\{[^}]+\|[^ ]', content):
if re.search(r"\{\{[^}]+\|[^ ]", content):
warnings.append(f"{template_path}: Missing space after filter pipe (|)")
# Check for deprecated 'when' in Jinja2 (should use if)
if re.search(r'\{%\s*when\s+', content):
if re.search(r"\{%\s*when\s+", content):
warnings.append(f"{template_path}: Use 'if' instead of 'when' in Jinja2 templates")
# Check for extremely long expressions (harder to debug)
for match in re.finditer(r'\{\{(.+?)\}\}', content, re.DOTALL):
for match in re.finditer(r"\{\{(.+?)\}\}", content, re.DOTALL):
if len(match.group(1)) > 200:
line_num = content[:match.start()].count('\n') + 1
warnings.append(f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up")
line_num = content[: match.start()].count("\n") + 1
warnings.append(
f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up"
)
except Exception:
pass # Ignore errors in anti-pattern checking