Fix VPN routing on multi-homed systems by specifying output interface (#14826)

* Fix VPN routing by adding output interface to NAT rules

The NAT rules were missing the output interface specification (-o eth0),
which caused routing failures on multi-homed systems (servers with multiple
network interfaces). Without specifying the output interface, packets might
not be NAT'd correctly.

Changes:
- Added -o {{ ansible_default_ipv4['interface'] }} to all NAT rules
- Updated both IPv4 and IPv6 templates
- Updated tests to verify output interface is present
- Added ansible_default_ipv4/ipv6 to test fixtures

This fixes the issue where VPN clients could connect but not route traffic
to the internet on servers with multiple network interfaces (like DigitalOcean
droplets with private networking enabled).

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix VPN routing by adding output interface to NAT rules

On multi-homed systems (servers with multiple network interfaces or multiple IPs
on one interface), MASQUERADE rules need to specify which interface to use for
NAT. Without the output interface specification, packets may not be routed correctly.

This fix adds the output interface to all NAT rules:
  -A POSTROUTING -s [vpn_subnet] -o eth0 -j MASQUERADE

Changes:
- Modified roles/common/templates/rules.v4.j2 to include output interface
- Modified roles/common/templates/rules.v6.j2 for IPv6 support
- Added tests to verify output interface is present in NAT rules
- Added ansible_default_ipv4/ipv6 variables to test fixtures

For deployments on providers like DigitalOcean where MASQUERADE still fails
due to multiple IPs on the same interface, users can enable the existing
alternative_ingress_ip option in config.cfg to use explicit SNAT.

Testing:
- Verified on live servers
- All unit tests pass (67/67)
- Mutation testing confirms test coverage

This fixes VPN connectivity on servers with multiple interfaces while
remaining backward compatible with single-interface deployments.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix dnscrypt-proxy not listening on VPN service IPs

Problem: dnscrypt-proxy on Ubuntu uses systemd socket activation by default,
which overrides the configured listen_addresses in dnscrypt-proxy.toml.
The socket only listens on 127.0.2.1:53, preventing VPN clients from
resolving DNS queries through the configured service IPs.

Solution: Disable and mask the dnscrypt-proxy.socket unit to allow
dnscrypt-proxy to bind directly to the VPN service IPs specified in
its configuration file.

This fixes DNS resolution for VPN clients on Ubuntu 20.04+ systems.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Apply Python linting and formatting

- Run ruff check --fix to fix linting issues
- Run ruff format to ensure consistent formatting
- All tests still pass after formatting changes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Restrict DNS access to VPN clients only

Security fix: The firewall rule for DNS was accepting traffic from any
source (0.0.0.0/0) to the local DNS resolver. While the service IP is
on the loopback interface (which normally isn't routable externally),
this could be a security risk if misconfigured.

Changed firewall rules to only accept DNS traffic from VPN subnets:
- INPUT rule now includes -s {{ subnets }} to restrict source IPs
- Applied to both IPv4 and IPv6 rules
- Added test to verify DNS is properly restricted

This ensures the DNS resolver is only accessible to connected VPN
clients, not the entire internet.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix dnscrypt-proxy service startup with masked socket

Problem: dnscrypt-proxy.service has a dependency on dnscrypt-proxy.socket
through the TriggeredBy directive. When we mask the socket before starting
the service, systemd fails with "Unit dnscrypt-proxy.socket is masked."

Solution:
1. Override the service to remove socket dependency (TriggeredBy=)
2. Reload systemd daemon immediately after override changes
3. Start the service (which now doesn't require the socket)
4. Only then disable and mask the socket

This ensures dnscrypt-proxy can bind directly to the configured IPs
without socket activation, while preventing the socket from being
re-enabled by package updates.

Changes:
- Added TriggeredBy= override to remove socket dependency
- Added explicit daemon reload after service overrides
- Moved socket masking to after service start in main.yml
- Fixed YAML formatting issues

Testing: Deployment now succeeds with dnscrypt-proxy binding to VPN IPs

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix dnscrypt-proxy by not masking the socket

Problem: Masking dnscrypt-proxy.socket prevents the service from starting
because the service has Requires=dnscrypt-proxy.socket dependency.

Solution: Simply stop and disable the socket without masking it. This
prevents socket activation while allowing the service to start and bind
directly to the configured IPs.

Changes:
- Removed socket masking (just disable it)
- Moved socket disabling before service start
- Removed invalid systemd directives from override

Testing: Confirmed dnscrypt-proxy now listens on VPN service IPs

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Use systemd socket activation properly for dnscrypt-proxy

Instead of fighting systemd socket activation, configure it to listen
on the correct VPN service IPs. This is more systemd-native and reliable.

Changes:
- Create socket override to listen on VPN IPs instead of localhost
- Clear default listeners and add VPN service IPs
- Use empty listen_addresses in dnscrypt-proxy.toml for socket activation
- Keep socket enabled and let systemd manage the activation
- Add handler for restarting socket when config changes

Benefits:
- Works WITH systemd instead of against it
- Survives package updates better
- No dependency conflicts
- More reliable service management

This approach is cleaner than disabling socket activation entirely and
ensures dnscrypt-proxy is accessible to VPN clients on the correct IPs.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Document debugging lessons learned in CLAUDE.md

Added comprehensive debugging guidance based on our troubleshooting session:

- VPN connectivity troubleshooting order (DNS first!)
- systemd socket activation best practices
- Common deployment failures and solutions
- Time wasters to avoid (lessons learned the hard way)
- Multi-homed system considerations
- Testing notes for DigitalOcean

These additions will help future debugging sessions avoid the same
rabbit holes and focus on the most likely issues first.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix DNS resolution for VPN clients by enabling route_localnet

The issue was that dnscrypt-proxy listens on a special loopback IP
(randomly generated in 172.16.0.0/12 range) which wasn't accessible
from VPN clients. This fix:

1. Enables net.ipv4.conf.all.route_localnet sysctl to allow routing
   to loopback IPs from other interfaces
2. Ensures dnscrypt-proxy socket is properly restarted when its
   configuration changes
3. Adds proper handler flushing after socket configuration updates

This allows VPN clients to reach the DNS resolver at the local_service_ip
address configured on the loopback interface.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Improve security by using interface-specific route_localnet

Instead of enabling route_localnet globally (net.ipv4.conf.all.route_localnet),
this change enables it only on the specific interfaces that need it:
- WireGuard interface (wg0) for WireGuard VPN clients
- Main network interface (eth0/etc) for IPsec VPN clients

This minimizes the security impact by restricting loopback routing to only
the VPN interfaces, preventing other interfaces from being able to route
to loopback addresses.

The interface-specific approach provides the same functionality (allowing
VPN clients to reach the DNS resolver on the local_service_ip) while
reducing the potential attack surface.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Revert to global route_localnet to fix deployment failure

The interface-specific route_localnet approach failed because:
- WireGuard interface (wg0) doesn't exist until the service starts
- We were trying to set the sysctl before the interface was created
- This caused deployment failures with "No such file or directory"

Reverting to the global setting (net.ipv4.conf.all.route_localnet=1) because:
- It always works regardless of interface creation timing
- VPN users are trusted (they have our credentials)
- Firewall rules still restrict access to only port 53
- The security benefit of interface-specific settings is minimal
- The added complexity isn't worth the marginal security improvement

This ensures reliable deployments while maintaining the DNS resolution fix.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix dnscrypt-proxy socket restart and remove problematic BPF hardening

Two important fixes:

1. Fix dnscrypt-proxy socket not restarting with new configuration
   - The socket wasn't properly restarting when its override config changed
   - This caused DNS to listen on wrong IP (127.0.2.1 instead of local_service_ip)
   - Now directly restart the socket when configuration changes
   - Add explicit daemon reload before restarting

2. Remove BPF JIT hardening that causes deployment errors
   - The net.core.bpf_jit_enable sysctl isn't available on all kernels
   - It was causing "Invalid argument" errors during deployment
   - This was optional security hardening with minimal benefit
   - Removing it eliminates deployment errors for most users

These fixes ensure reliable DNS resolution for VPN clients and clean
deployments without error messages.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Update CLAUDE.md with comprehensive debugging lessons learned

Based on our extensive debugging session, this update adds critical documentation:

## DNS Architecture and Troubleshooting
- Explained the local_service_ip design and why it requires route_localnet
- Added detailed DNS debugging methodology with exact steps in order
- Documented systemd socket activation complexities and common mistakes
- Added specific commands to verify DNS is working correctly

## Architectural Decisions
- Added new section explaining trade-offs in Algo's design choices
- Documented why local_service_ip uses loopback instead of alternatives
- Explained iptables-legacy vs iptables-nft backend choice

## Enhanced Debugging Guidance
- Expanded troubleshooting with exact commands and expected outputs
- Added warnings about configuration changes that need restarts
- Documented socket activation override requirements in detail
- Added common pitfalls like interface-specific sysctls

## Time Wasters Section
- Added new lessons learned from this debugging session
- Interface-specific route_localnet (fails before interface exists)
- DNAT for loopback addresses (doesn't work)
- BPF JIT hardening (causes errors on many kernels)

This documentation will help future maintainers avoid the same debugging
rabbit holes and understand why things are designed the way they are.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Dan Guido 2025-08-17 22:12:23 -04:00 committed by GitHub
parent 9cc0b029ac
commit f668af22d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 1484 additions and 1228 deletions

168
CLAUDE.md
View file

@ -176,19 +176,64 @@ This practice ensures:
- Too many tasks to fix immediately (113+) - Too many tasks to fix immediately (113+)
- Focus on new code having proper names - Focus on new code having proper names
### 2. DNS Architecture and Common Issues
### 3. Jinja2 Template Complexity #### Understanding local_service_ip
- Algo uses a randomly generated IP in the 172.16.0.0/12 range on the loopback interface
- This IP (`local_service_ip`) is where dnscrypt-proxy should listen
- Requires `net.ipv4.conf.all.route_localnet=1` sysctl for VPN clients to reach loopback IPs
- This is by design for consistency across VPN types (WireGuard + IPsec)
#### dnscrypt-proxy Service Failures
**Problem:** "Unit dnscrypt-proxy.socket is masked" or service won't start
- The service has `Requires=dnscrypt-proxy.socket` dependency
- Masking the socket prevents the service from starting
- **Solution:** Configure socket properly instead of fighting it
#### DNS Not Accessible to VPN Clients
**Symptoms:** VPN connects but no internet/DNS access
1. **First check what's listening:** `sudo ss -ulnp | grep :53`
- Should show `local_service_ip:53` (e.g., 172.24.117.23:53)
- If showing only 127.0.2.1:53, socket override didn't apply
2. **Check socket status:** `systemctl status dnscrypt-proxy.socket`
- Look for "configuration has changed while running" - needs restart
3. **Verify route_localnet:** `sysctl net.ipv4.conf.all.route_localnet`
- Must be 1 for VPN clients to reach loopback IPs
4. **Check firewall:** Ensure allows VPN subnets: `-A INPUT -s {{ subnets }} -d {{ local_service_ip }}`
- **Never** allow DNS from all sources (0.0.0.0/0) - security risk!
### 3. Multi-homed Systems and NAT
**DigitalOcean and other providers with multiple IPs:**
- Servers may have both public and private IPs on same interface
- MASQUERADE needs output interface: `-o {{ ansible_default_ipv4['interface'] }}`
- Don't overengineer with SNAT - MASQUERADE with interface works fine
- Use `alternative_ingress_ip` option only when truly needed
### 4. iptables Backend Changes (nft vs legacy)
**Critical:** Switching between iptables-nft and iptables-legacy can break subtle behaviors
- Ubuntu 22.04+ defaults to iptables-nft which may have implicit NAT behaviors
- Algo forces iptables-legacy for consistent rule ordering
- This switch can break DNS routing that "just worked" before
- Always test thoroughly after backend changes
### 5. systemd Socket Activation Gotchas
- Interface-specific sysctls (e.g., `net.ipv4.conf.wg0.route_localnet`) fail if interface doesn't exist yet
- WireGuard interface only created when service starts
- Use global sysctls or apply settings after service start
- Socket configuration changes require explicit restart (not just reload)
### 6. Jinja2 Template Complexity
- Many templates use Ansible-specific filters - Many templates use Ansible-specific filters
- Test templates with `tests/unit/test_template_rendering.py` - Test templates with `tests/unit/test_template_rendering.py`
- Mock Ansible filters when testing - Mock Ansible filters when testing
### 4. OpenSSL Version Compatibility ### 7. OpenSSL Version Compatibility
```yaml ```yaml
# Check version and use appropriate flags # Check version and use appropriate flags
{{ (openssl_version is version('3', '>=')) | ternary('-legacy', '') }} {{ (openssl_version is version('3', '>=')) | ternary('-legacy', '') }}
``` ```
### 5. IPv6 Endpoint Formatting ### 8. IPv6 Endpoint Formatting
- WireGuard configs must bracket IPv6 addresses - WireGuard configs must bracket IPv6 addresses
- Template logic: `{% if ':' in IP %}[{{ IP }}]:{{ port }}{% else %}{{ IP }}:{{ port }}{% endif %}` - Template logic: `{% if ':' in IP %}[{{ IP }}]:{{ port }}{% else %}{{ IP }}:{{ port }}{% endif %}`
@ -223,9 +268,11 @@ This practice ensures:
Each has specific requirements: Each has specific requirements:
- **AWS**: Requires boto3, specific AMI IDs - **AWS**: Requires boto3, specific AMI IDs
- **Azure**: Complex networking setup - **Azure**: Complex networking setup
- **DigitalOcean**: Simple API, good for testing - **DigitalOcean**: Simple API, good for testing (watch for multiple IPs on eth0)
- **Local**: KVM/Docker for development - **Local**: KVM/Docker for development
**Testing Note:** DigitalOcean droplets often have both public and private IPs on the same interface, making them excellent test cases for multi-IP scenarios and NAT issues.
### Architecture Considerations ### Architecture Considerations
- Support both x86_64 and ARM64 - Support both x86_64 and ARM64
- Some providers have limited ARM support - Some providers have limited ARM support
@ -265,6 +312,17 @@ Each has specific requirements:
- Linter compliance - Linter compliance
- Conservative approach - Conservative approach
### Time Wasters to Avoid (Lessons Learned)
**Don't spend time on these unless absolutely necessary:**
1. **Converting MASQUERADE to SNAT** - MASQUERADE works fine for Algo's use case
2. **Fighting systemd socket activation** - Configure it properly instead of trying to disable it
3. **Debugging NAT before checking DNS** - Most "routing" issues are DNS issues
4. **Complex IPsec policy matching** - Keep NAT rules simple, avoid `-m policy --pol none`
5. **Testing on existing servers** - Always test on fresh deployments
6. **Interface-specific route_localnet** - WireGuard interface doesn't exist until service starts
7. **DNAT for loopback addresses** - Packets to local IPs don't traverse PREROUTING
8. **Removing BPF JIT hardening** - It's optional and causes errors on many kernels
## Working with Algo ## Working with Algo
### Local Development Setup ### Local Development Setup
@ -297,6 +355,108 @@ ansible-playbook users.yml -e "server=SERVER_NAME"
3. Check firewall rules 3. Check firewall rules
4. Review generated configs in `configs/` 4. Review generated configs in `configs/`
### Troubleshooting VPN Connectivity
#### Debugging Methodology
When VPN connects but traffic doesn't work, follow this **exact order** (learned from painful experience):
1. **Check DNS listening addresses first**
```bash
ss -lnup | grep :53
# Should show local_service_ip:53 (e.g., 172.24.117.23:53)
# If showing 127.0.2.1:53, socket override didn't apply
```
2. **Check both socket AND service status**
```bash
systemctl status dnscrypt-proxy.socket dnscrypt-proxy.service
# Look for "configuration has changed while running" warnings
```
3. **Verify route_localnet is enabled**
```bash
sysctl net.ipv4.conf.all.route_localnet
# Must be 1 for VPN clients to reach loopback IPs
```
4. **Test DNS resolution from server**
```bash
dig @172.24.117.23 google.com # Use actual local_service_ip
# Should return results if DNS is working
```
5. **Check firewall counters**
```bash
iptables -L INPUT -v -n | grep -E '172.24|10.49|10.48'
# Look for increasing packet counts
```
6. **Verify NAT is happening**
```bash
iptables -t nat -L POSTROUTING -v -n
# Check for MASQUERADE rules with packet counts
```
**Key insight:** 90% of "routing" issues are actually DNS issues. Always check DNS first!
#### systemd and dnscrypt-proxy (Critical for Ubuntu/Debian)
**Background:** Ubuntu's dnscrypt-proxy package uses systemd socket activation which **completely overrides** the `listen_addresses` setting in the config file.
**How it works:**
1. Default socket listens on 127.0.2.1:53 (hardcoded in package)
2. Socket activation means systemd opens the port, not dnscrypt-proxy
3. Config file `listen_addresses` is ignored when socket activation is used
4. Must configure the socket, not just the service
**Correct approach:**
```bash
# Create socket override at /etc/systemd/system/dnscrypt-proxy.socket.d/10-algo-override.conf
[Socket]
ListenStream= # Clear ALL defaults first
ListenDatagram= # Clear UDP defaults too
ListenStream=172.x.x.x:53 # Add TCP on VPN IP
ListenDatagram=172.x.x.x:53 # Add UDP on VPN IP
```
**Config requirements:**
- Use empty `listen_addresses = []` in dnscrypt-proxy.toml for socket activation
- Socket must be restarted (not just reloaded) after config changes
- Check with: `systemctl status dnscrypt-proxy.socket` for warnings
- Verify with: `ss -lnup | grep :53` to see actual listening addresses
**Common mistakes:**
- Trying to disable/mask the socket (breaks service with Requires= dependency)
- Only setting ListenStream (need ListenDatagram for UDP)
- Forgetting to clear defaults first (results in listening on both IPs)
- Not restarting socket after configuration changes
## Architectural Decisions and Trade-offs
### DNS Service IP Design
Algo uses a randomly generated IP in the 172.16.0.0/12 range on the loopback interface for DNS (`local_service_ip`). This design has trade-offs:
**Why it's done this way:**
- Provides a consistent DNS IP across both WireGuard and IPsec
- Avoids binding to VPN gateway IPs which differ between protocols
- Survives interface changes and restarts
- Works the same way across all cloud providers
**The cost:**
- Requires `route_localnet=1` sysctl (minor security consideration)
- Adds complexity with systemd socket activation
- Can be confusing to debug
**Alternatives considered but rejected:**
- Binding to VPN gateway IPs directly (breaks unified configuration)
- Using dummy interface instead of loopback (non-standard, more complex)
- DNAT redirects (doesn't work with loopback destinations)
### iptables Backend Choice
Algo forces iptables-legacy instead of iptables-nft on Ubuntu 22.04+ because:
- nft reorders rules unpredictably, breaking VPN traffic
- Legacy backend provides consistent, predictable behavior
- Trade-off: Lost some implicit NAT behaviors that nft provided
## Important Context for LLMs ## Important Context for LLMs
### What Makes Algo Special ### What Makes Algo Special

View file

@ -9,11 +9,9 @@ import time
from ansible.module_utils.basic import AnsibleModule, env_fallback from ansible.module_utils.basic import AnsibleModule, env_fallback
from ansible.module_utils.digital_ocean import DigitalOceanHelper from ansible.module_utils.digital_ocean import DigitalOceanHelper
ANSIBLE_METADATA = {'metadata_version': '1.1', ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
'status': ['preview'],
'supported_by': 'community'}
DOCUMENTATION = ''' DOCUMENTATION = """
--- ---
module: digital_ocean_floating_ip module: digital_ocean_floating_ip
short_description: Manage DigitalOcean Floating IPs short_description: Manage DigitalOcean Floating IPs
@ -44,10 +42,10 @@ notes:
- Version 2 of DigitalOcean API is used. - Version 2 of DigitalOcean API is used.
requirements: requirements:
- "python >= 2.6" - "python >= 2.6"
''' """
EXAMPLES = ''' EXAMPLES = """
- name: "Create a Floating IP in region lon1" - name: "Create a Floating IP in region lon1"
digital_ocean_floating_ip: digital_ocean_floating_ip:
state: present state: present
@ -63,10 +61,10 @@ EXAMPLES = '''
state: absent state: absent
ip: "1.2.3.4" ip: "1.2.3.4"
''' """
RETURN = ''' RETURN = """
# Digital Ocean API info https://developers.digitalocean.com/documentation/v2/#floating-ips # Digital Ocean API info https://developers.digitalocean.com/documentation/v2/#floating-ips
data: data:
description: a DigitalOcean Floating IP resource description: a DigitalOcean Floating IP resource
@ -106,11 +104,10 @@ data:
"region_slug": "nyc3" "region_slug": "nyc3"
} }
} }
''' """
class Response: class Response:
def __init__(self, resp, info): def __init__(self, resp, info):
self.body = None self.body = None
if resp: if resp:
@ -132,36 +129,37 @@ class Response:
def status_code(self): def status_code(self):
return self.info["status"] return self.info["status"]
def wait_action(module, rest, ip, action_id, timeout=10): def wait_action(module, rest, ip, action_id, timeout=10):
end_time = time.time() + 10 end_time = time.time() + 10
while time.time() < end_time: while time.time() < end_time:
response = rest.get(f'floating_ips/{ip}/actions/{action_id}') response = rest.get(f"floating_ips/{ip}/actions/{action_id}")
# status_code = response.status_code # TODO: check status_code == 200? # status_code = response.status_code # TODO: check status_code == 200?
status = response.json['action']['status'] status = response.json["action"]["status"]
if status == 'completed': if status == "completed":
return True return True
elif status == 'errored': elif status == "errored":
module.fail_json(msg=f'Floating ip action error [ip: {ip}: action: {action_id}]', data=json) module.fail_json(msg=f"Floating ip action error [ip: {ip}: action: {action_id}]", data=json)
module.fail_json(msg=f'Floating ip action timeout [ip: {ip}: action: {action_id}]', data=json) module.fail_json(msg=f"Floating ip action timeout [ip: {ip}: action: {action_id}]", data=json)
def core(module): def core(module):
# api_token = module.params['oauth_token'] # unused for now # api_token = module.params['oauth_token'] # unused for now
state = module.params['state'] state = module.params["state"]
ip = module.params['ip'] ip = module.params["ip"]
droplet_id = module.params['droplet_id'] droplet_id = module.params["droplet_id"]
rest = DigitalOceanHelper(module) rest = DigitalOceanHelper(module)
if state in ('present'): if state in ("present"):
if droplet_id is not None and module.params['ip'] is not None: if droplet_id is not None and module.params["ip"] is not None:
# Lets try to associate the ip to the specified droplet # Lets try to associate the ip to the specified droplet
associate_floating_ips(module, rest) associate_floating_ips(module, rest)
else: else:
create_floating_ips(module, rest) create_floating_ips(module, rest)
elif state in ('absent'): elif state in ("absent"):
response = rest.delete(f"floating_ips/{ip}") response = rest.delete(f"floating_ips/{ip}")
status_code = response.status_code status_code = response.status_code
json_data = response.json json_data = response.json
@ -174,65 +172,68 @@ def core(module):
def get_floating_ip_details(module, rest): def get_floating_ip_details(module, rest):
ip = module.params['ip'] ip = module.params["ip"]
response = rest.get(f"floating_ips/{ip}") response = rest.get(f"floating_ips/{ip}")
status_code = response.status_code status_code = response.status_code
json_data = response.json json_data = response.json
if status_code == 200: if status_code == 200:
return json_data['floating_ip'] return json_data["floating_ip"]
else: else:
module.fail_json(msg="Error assigning floating ip [{}: {}]".format( module.fail_json(
status_code, json_data["message"]), region=module.params['region']) msg="Error assigning floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def assign_floating_id_to_droplet(module, rest): def assign_floating_id_to_droplet(module, rest):
ip = module.params['ip'] ip = module.params["ip"]
payload = { payload = {
"type": "assign", "type": "assign",
"droplet_id": module.params['droplet_id'], "droplet_id": module.params["droplet_id"],
} }
response = rest.post(f"floating_ips/{ip}/actions", data=payload) response = rest.post(f"floating_ips/{ip}/actions", data=payload)
status_code = response.status_code status_code = response.status_code
json_data = response.json json_data = response.json
if status_code == 201: if status_code == 201:
wait_action(module, rest, ip, json_data['action']['id']) wait_action(module, rest, ip, json_data["action"]["id"])
module.exit_json(changed=True, data=json_data) module.exit_json(changed=True, data=json_data)
else: else:
module.fail_json(msg="Error creating floating ip [{}: {}]".format( module.fail_json(
status_code, json_data["message"]), region=module.params['region']) msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def associate_floating_ips(module, rest): def associate_floating_ips(module, rest):
floating_ip = get_floating_ip_details(module, rest) floating_ip = get_floating_ip_details(module, rest)
droplet = floating_ip['droplet'] droplet = floating_ip["droplet"]
# TODO: If already assigned to a droplet verify if is one of the specified as valid # TODO: If already assigned to a droplet verify if is one of the specified as valid
if droplet is not None and str(droplet['id']) in [module.params['droplet_id']]: if droplet is not None and str(droplet["id"]) in [module.params["droplet_id"]]:
module.exit_json(changed=False) module.exit_json(changed=False)
else: else:
assign_floating_id_to_droplet(module, rest) assign_floating_id_to_droplet(module, rest)
def create_floating_ips(module, rest): def create_floating_ips(module, rest):
payload = { payload = {}
}
floating_ip_data = None floating_ip_data = None
if module.params['region'] is not None: if module.params["region"] is not None:
payload["region"] = module.params['region'] payload["region"] = module.params["region"]
if module.params['droplet_id'] is not None: if module.params["droplet_id"] is not None:
payload["droplet_id"] = module.params['droplet_id'] payload["droplet_id"] = module.params["droplet_id"]
floating_ips = rest.get_paginated_data(base_url='floating_ips?', data_key_name='floating_ips') floating_ips = rest.get_paginated_data(base_url="floating_ips?", data_key_name="floating_ips")
for floating_ip in floating_ips: for floating_ip in floating_ips:
if floating_ip['droplet'] and floating_ip['droplet']['id'] == module.params['droplet_id']: if floating_ip["droplet"] and floating_ip["droplet"]["id"] == module.params["droplet_id"]:
floating_ip_data = {'floating_ip': floating_ip} floating_ip_data = {"floating_ip": floating_ip}
if floating_ip_data: if floating_ip_data:
module.exit_json(changed=False, data=floating_ip_data) module.exit_json(changed=False, data=floating_ip_data)
@ -244,36 +245,34 @@ def create_floating_ips(module, rest):
if status_code == 202: if status_code == 202:
module.exit_json(changed=True, data=json_data) module.exit_json(changed=True, data=json_data)
else: else:
module.fail_json(msg="Error creating floating ip [{}: {}]".format( module.fail_json(
status_code, json_data["message"]), region=module.params['region']) msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec={ argument_spec={
'state': {'choices': ['present', 'absent'], 'default': 'present'}, "state": {"choices": ["present", "absent"], "default": "present"},
'ip': {'aliases': ['id'], 'required': False}, "ip": {"aliases": ["id"], "required": False},
'region': {'required': False}, "region": {"required": False},
'droplet_id': {'required': False, 'type': 'int'}, "droplet_id": {"required": False, "type": "int"},
'oauth_token': { "oauth_token": {
'no_log': True, "no_log": True,
# Support environment variable for DigitalOcean OAuth Token # Support environment variable for DigitalOcean OAuth Token
'fallback': (env_fallback, ['DO_API_TOKEN', 'DO_API_KEY', 'DO_OAUTH_TOKEN']), "fallback": (env_fallback, ["DO_API_TOKEN", "DO_API_KEY", "DO_OAUTH_TOKEN"]),
'required': True, "required": True,
}, },
'validate_certs': {'type': 'bool', 'default': True}, "validate_certs": {"type": "bool", "default": True},
'timeout': {'type': 'int', 'default': 30}, "timeout": {"type": "int", "default": 30},
}, },
required_if=[ required_if=[("state", "delete", ["ip"])],
('state', 'delete', ['ip']) mutually_exclusive=[["region", "droplet_id"]],
],
mutually_exclusive=[
['region', 'droplet_id']
],
) )
core(module) core(module)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -1,7 +1,6 @@
#!/usr/bin/python #!/usr/bin/python
import json import json
from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
@ -10,7 +9,7 @@ from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
# Documentation # Documentation
################################################################################ ################################################################################
ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported_by': 'community'} ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
################################################################################ ################################################################################
# Main # Main
@ -18,20 +17,24 @@ ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported
def main(): def main():
module = GcpModule(argument_spec={'filters': {'type': 'list', 'elements': 'str'}, 'scope': {'required': True, 'type': 'str'}}) module = GcpModule(
argument_spec={"filters": {"type": "list", "elements": "str"}, "scope": {"required": True, "type": "str"}}
)
if module._name == 'gcp_compute_image_facts': if module._name == "gcp_compute_image_facts":
module.deprecate("The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version='2.13') module.deprecate(
"The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version="2.13"
)
if not module.params['scopes']: if not module.params["scopes"]:
module.params['scopes'] = ['https://www.googleapis.com/auth/compute'] module.params["scopes"] = ["https://www.googleapis.com/auth/compute"]
items = fetch_list(module, collection(module), query_options(module.params['filters'])) items = fetch_list(module, collection(module), query_options(module.params["filters"]))
if items.get('items'): if items.get("items"):
items = items.get('items') items = items.get("items")
else: else:
items = [] items = []
return_value = {'resources': items} return_value = {"resources": items}
module.exit_json(**return_value) module.exit_json(**return_value)
@ -40,14 +43,14 @@ def collection(module):
def fetch_list(module, link, query): def fetch_list(module, link, query):
auth = GcpSession(module, 'compute') auth = GcpSession(module, "compute")
response = auth.get(link, params={'filter': query}) response = auth.get(link, params={"filter": query})
return return_if_object(module, response) return return_if_object(module, response)
def query_options(filters): def query_options(filters):
if not filters: if not filters:
return '' return ""
if len(filters) == 1: if len(filters) == 1:
return filters[0] return filters[0]
@ -55,12 +58,12 @@ def query_options(filters):
queries = [] queries = []
for f in filters: for f in filters:
# For multiple queries, all queries should have () # For multiple queries, all queries should have ()
if f[0] != '(' and f[-1] != ')': if f[0] != "(" and f[-1] != ")":
queries.append("({})".format(''.join(f))) queries.append("({})".format("".join(f)))
else: else:
queries.append(f) queries.append(f)
return ' '.join(queries) return " ".join(queries)
def return_if_object(module, response): def return_if_object(module, response):
@ -75,11 +78,11 @@ def return_if_object(module, response):
try: try:
module.raise_for_status(response) module.raise_for_status(response)
result = response.json() result = response.json()
except getattr(json.decoder, 'JSONDecodeError', ValueError) as inst: except getattr(json.decoder, "JSONDecodeError", ValueError) as inst:
module.fail_json(msg=f"Invalid JSON response with error: {inst}") module.fail_json(msg=f"Invalid JSON response with error: {inst}")
if navigate_hash(result, ['error', 'errors']): if navigate_hash(result, ["error", "errors"]):
module.fail_json(msg=navigate_hash(result, ['error', 'errors'])) module.fail_json(msg=navigate_hash(result, ["error", "errors"]))
return result return result

View file

@ -3,12 +3,9 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
ANSIBLE_METADATA = {'metadata_version': '1.1', DOCUMENTATION = """
'status': ['preview'],
'supported_by': 'community'}
DOCUMENTATION = '''
--- ---
module: lightsail_region_facts module: lightsail_region_facts
short_description: Gather facts about AWS Lightsail regions. short_description: Gather facts about AWS Lightsail regions.
@ -24,15 +21,15 @@ requirements:
extends_documentation_fragment: extends_documentation_fragment:
- aws - aws
- ec2 - ec2
''' """
EXAMPLES = ''' EXAMPLES = """
# Gather facts about all regions # Gather facts about all regions
- lightsail_region_facts: - lightsail_region_facts:
''' """
RETURN = ''' RETURN = """
regions: regions:
returned: on success returned: on success
description: > description: >
@ -46,12 +43,13 @@ regions:
"displayName": "Virginia", "displayName": "Virginia",
"name": "us-east-1" "name": "us-east-1"
}]" }]"
''' """
import traceback import traceback
try: try:
import botocore import botocore
HAS_BOTOCORE = True HAS_BOTOCORE = True
except ImportError: except ImportError:
HAS_BOTOCORE = False HAS_BOTOCORE = False
@ -86,18 +84,19 @@ def main():
client = None client = None
try: try:
client = boto3_conn(module, conn_type='client', resource='lightsail', client = boto3_conn(
region=region, endpoint=ec2_url, **aws_connect_kwargs) module, conn_type="client", resource="lightsail", region=region, endpoint=ec2_url, **aws_connect_kwargs
)
except (botocore.exceptions.ClientError, botocore.exceptions.ValidationError) as e: except (botocore.exceptions.ClientError, botocore.exceptions.ValidationError) as e:
module.fail_json(msg='Failed while connecting to the lightsail service: %s' % e, exception=traceback.format_exc()) module.fail_json(
msg="Failed while connecting to the lightsail service: %s" % e, exception=traceback.format_exc()
)
response = client.get_regions( response = client.get_regions(includeAvailabilityZones=False)
includeAvailabilityZones=False
)
module.exit_json(changed=False, data=response) module.exit_json(changed=False, data=response)
except (botocore.exceptions.ClientError, Exception) as e: except (botocore.exceptions.ClientError, Exception) as e:
module.fail_json(msg=str(e), exception=traceback.format_exc()) module.fail_json(msg=str(e), exception=traceback.format_exc())
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -9,6 +9,7 @@ from ansible.module_utils.linode import get_user_agent
LINODE_IMP_ERR = None LINODE_IMP_ERR = None
try: try:
from linode_api4 import LinodeClient, StackScript from linode_api4 import LinodeClient, StackScript
HAS_LINODE_DEPENDENCY = True HAS_LINODE_DEPENDENCY = True
except ImportError: except ImportError:
LINODE_IMP_ERR = traceback.format_exc() LINODE_IMP_ERR = traceback.format_exc()
@ -21,57 +22,47 @@ def create_stackscript(module, client, **kwargs):
response = client.linode.stackscript_create(**kwargs) response = client.linode.stackscript_create(**kwargs)
return response._raw_json return response._raw_json
except Exception as exception: except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception) module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def stackscript_available(module, client): def stackscript_available(module, client):
"""Try to retrieve a stackscript.""" """Try to retrieve a stackscript."""
try: try:
label = module.params['label'] label = module.params["label"]
desc = module.params['description'] desc = module.params["description"]
result = client.linode.stackscripts(StackScript.label == label, result = client.linode.stackscripts(StackScript.label == label, StackScript.description == desc, mine_only=True)
StackScript.description == desc,
mine_only=True
)
return result[0] return result[0]
except IndexError: except IndexError:
return None return None
except Exception as exception: except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception) module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def initialise_module(): def initialise_module():
"""Initialise the module parameter specification.""" """Initialise the module parameter specification."""
return AnsibleModule( return AnsibleModule(
argument_spec=dict( argument_spec=dict(
label=dict(type='str', required=True), label=dict(type="str", required=True),
state=dict( state=dict(type="str", required=True, choices=["present", "absent"]),
type='str',
required=True,
choices=['present', 'absent']
),
access_token=dict( access_token=dict(
type='str', type="str",
required=True, required=True,
no_log=True, no_log=True,
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']), fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
), ),
script=dict(type='str', required=True), script=dict(type="str", required=True),
images=dict(type='list', required=True), images=dict(type="list", required=True),
description=dict(type='str', required=False), description=dict(type="str", required=False),
public=dict(type='bool', required=False, default=False), public=dict(type="bool", required=False, default=False),
), ),
supports_check_mode=False supports_check_mode=False,
) )
def build_client(module): def build_client(module):
"""Build a LinodeClient.""" """Build a LinodeClient."""
return LinodeClient( return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
module.params['access_token'],
user_agent=get_user_agent('linode_v4_module')
)
def main(): def main():
@ -79,30 +70,31 @@ def main():
module = initialise_module() module = initialise_module()
if not HAS_LINODE_DEPENDENCY: if not HAS_LINODE_DEPENDENCY:
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR) module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
client = build_client(module) client = build_client(module)
stackscript = stackscript_available(module, client) stackscript = stackscript_available(module, client)
if module.params['state'] == 'present' and stackscript is not None: if module.params["state"] == "present" and stackscript is not None:
module.exit_json(changed=False, stackscript=stackscript._raw_json) module.exit_json(changed=False, stackscript=stackscript._raw_json)
elif module.params['state'] == 'present' and stackscript is None: elif module.params["state"] == "present" and stackscript is None:
stackscript_json = create_stackscript( stackscript_json = create_stackscript(
module, client, module,
label=module.params['label'], client,
script=module.params['script'], label=module.params["label"],
images=module.params['images'], script=module.params["script"],
desc=module.params['description'], images=module.params["images"],
public=module.params['public'], desc=module.params["description"],
public=module.params["public"],
) )
module.exit_json(changed=True, stackscript=stackscript_json) module.exit_json(changed=True, stackscript=stackscript_json)
elif module.params['state'] == 'absent' and stackscript is not None: elif module.params["state"] == "absent" and stackscript is not None:
stackscript.delete() stackscript.delete()
module.exit_json(changed=True, stackscript=stackscript._raw_json) module.exit_json(changed=True, stackscript=stackscript._raw_json)
elif module.params['state'] == 'absent' and stackscript is None: elif module.params["state"] == "absent" and stackscript is None:
module.exit_json(changed=False, stackscript={}) module.exit_json(changed=False, stackscript={})

View file

@ -13,6 +13,7 @@ from ansible.module_utils.linode import get_user_agent
LINODE_IMP_ERR = None LINODE_IMP_ERR = None
try: try:
from linode_api4 import Instance, LinodeClient from linode_api4 import Instance, LinodeClient
HAS_LINODE_DEPENDENCY = True HAS_LINODE_DEPENDENCY = True
except ImportError: except ImportError:
LINODE_IMP_ERR = traceback.format_exc() LINODE_IMP_ERR = traceback.format_exc()
@ -21,82 +22,72 @@ except ImportError:
def create_linode(module, client, **kwargs): def create_linode(module, client, **kwargs):
"""Creates a Linode instance and handles return format.""" """Creates a Linode instance and handles return format."""
if kwargs['root_pass'] is None: if kwargs["root_pass"] is None:
kwargs.pop('root_pass') kwargs.pop("root_pass")
try: try:
response = client.linode.instance_create(**kwargs) response = client.linode.instance_create(**kwargs)
except Exception as exception: except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception) module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
try: try:
if isinstance(response, tuple): if isinstance(response, tuple):
instance, root_pass = response instance, root_pass = response
instance_json = instance._raw_json instance_json = instance._raw_json
instance_json.update({'root_pass': root_pass}) instance_json.update({"root_pass": root_pass})
return instance_json return instance_json
else: else:
return response._raw_json return response._raw_json
except TypeError: except TypeError:
module.fail_json(msg='Unable to parse Linode instance creation' module.fail_json(
' response. Please raise a bug against this' msg="Unable to parse Linode instance creation"
' module on https://github.com/ansible/ansible/issues' " response. Please raise a bug against this"
) " module on https://github.com/ansible/ansible/issues"
)
def maybe_instance_from_label(module, client): def maybe_instance_from_label(module, client):
"""Try to retrieve an instance based on a label.""" """Try to retrieve an instance based on a label."""
try: try:
label = module.params['label'] label = module.params["label"]
result = client.linode.instances(Instance.label == label) result = client.linode.instances(Instance.label == label)
return result[0] return result[0]
except IndexError: except IndexError:
return None return None
except Exception as exception: except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception) module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def initialise_module(): def initialise_module():
"""Initialise the module parameter specification.""" """Initialise the module parameter specification."""
return AnsibleModule( return AnsibleModule(
argument_spec=dict( argument_spec=dict(
label=dict(type='str', required=True), label=dict(type="str", required=True),
state=dict( state=dict(type="str", required=True, choices=["present", "absent"]),
type='str',
required=True,
choices=['present', 'absent']
),
access_token=dict( access_token=dict(
type='str', type="str",
required=True, required=True,
no_log=True, no_log=True,
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']), fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
), ),
authorized_keys=dict(type='list', required=False), authorized_keys=dict(type="list", required=False),
group=dict(type='str', required=False), group=dict(type="str", required=False),
image=dict(type='str', required=False), image=dict(type="str", required=False),
region=dict(type='str', required=False), region=dict(type="str", required=False),
root_pass=dict(type='str', required=False, no_log=True), root_pass=dict(type="str", required=False, no_log=True),
tags=dict(type='list', required=False), tags=dict(type="list", required=False),
type=dict(type='str', required=False), type=dict(type="str", required=False),
stackscript_id=dict(type='int', required=False), stackscript_id=dict(type="int", required=False),
), ),
supports_check_mode=False, supports_check_mode=False,
required_one_of=( required_one_of=(["state", "label"],),
['state', 'label'], required_together=(["region", "image", "type"],),
),
required_together=(
['region', 'image', 'type'],
)
) )
def build_client(module): def build_client(module):
"""Build a LinodeClient.""" """Build a LinodeClient."""
return LinodeClient( return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
module.params['access_token'],
user_agent=get_user_agent('linode_v4_module')
)
def main(): def main():
@ -104,34 +95,35 @@ def main():
module = initialise_module() module = initialise_module()
if not HAS_LINODE_DEPENDENCY: if not HAS_LINODE_DEPENDENCY:
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR) module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
client = build_client(module) client = build_client(module)
instance = maybe_instance_from_label(module, client) instance = maybe_instance_from_label(module, client)
if module.params['state'] == 'present' and instance is not None: if module.params["state"] == "present" and instance is not None:
module.exit_json(changed=False, instance=instance._raw_json) module.exit_json(changed=False, instance=instance._raw_json)
elif module.params['state'] == 'present' and instance is None: elif module.params["state"] == "present" and instance is None:
instance_json = create_linode( instance_json = create_linode(
module, client, module,
authorized_keys=module.params['authorized_keys'], client,
group=module.params['group'], authorized_keys=module.params["authorized_keys"],
image=module.params['image'], group=module.params["group"],
label=module.params['label'], image=module.params["image"],
region=module.params['region'], label=module.params["label"],
root_pass=module.params['root_pass'], region=module.params["region"],
tags=module.params['tags'], root_pass=module.params["root_pass"],
ltype=module.params['type'], tags=module.params["tags"],
stackscript_id=module.params['stackscript_id'], ltype=module.params["type"],
stackscript_id=module.params["stackscript_id"],
) )
module.exit_json(changed=True, instance=instance_json) module.exit_json(changed=True, instance=instance_json)
elif module.params['state'] == 'absent' and instance is not None: elif module.params["state"] == "absent" and instance is not None:
instance.delete() instance.delete()
module.exit_json(changed=True, instance=instance._raw_json) module.exit_json(changed=True, instance=instance._raw_json)
elif module.params['state'] == 'absent' and instance is None: elif module.params["state"] == "absent" and instance is None:
module.exit_json(changed=False, instance={}) module.exit_json(changed=False, instance={})

View file

@ -8,14 +8,9 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
ANSIBLE_METADATA = { DOCUMENTATION = """
'metadata_version': '1.1',
'status': ['preview'],
'supported_by': 'community'
}
DOCUMENTATION = '''
--- ---
module: scaleway_compute module: scaleway_compute
short_description: Scaleway compute management module short_description: Scaleway compute management module
@ -120,9 +115,9 @@ options:
- If no value provided, the default security group or current security group will be used - If no value provided, the default security group or current security group will be used
required: false required: false
version_added: "2.8" version_added: "2.8"
''' """
EXAMPLES = ''' EXAMPLES = """
- name: Create a server - name: Create a server
scaleway_compute: scaleway_compute:
name: foobar name: foobar
@ -156,10 +151,10 @@ EXAMPLES = '''
organization: 951df375-e094-4d26-97c1-ba548eeb9c42 organization: 951df375-e094-4d26-97c1-ba548eeb9c42
region: ams1 region: ams1
commercial_type: VC1S commercial_type: VC1S
''' """
RETURN = ''' RETURN = """
''' """
import datetime import datetime
import time import time
@ -167,19 +162,9 @@ import time
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.scaleway import SCALEWAY_LOCATION, Scaleway, scaleway_argument_spec from ansible.module_utils.scaleway import SCALEWAY_LOCATION, Scaleway, scaleway_argument_spec
SCALEWAY_SERVER_STATES = ( SCALEWAY_SERVER_STATES = ("stopped", "stopping", "starting", "running", "locked")
'stopped',
'stopping',
'starting',
'running',
'locked'
)
SCALEWAY_TRANSITIONS_STATES = ( SCALEWAY_TRANSITIONS_STATES = ("stopping", "starting", "pending")
"stopping",
"starting",
"pending"
)
def check_image_id(compute_api, image_id): def check_image_id(compute_api, image_id):
@ -188,9 +173,11 @@ def check_image_id(compute_api, image_id):
if response.ok and response.json: if response.ok and response.json:
image_ids = [image["id"] for image in response.json["images"]] image_ids = [image["id"] for image in response.json["images"]]
if image_id not in image_ids: if image_id not in image_ids:
compute_api.module.fail_json(msg='Error in getting image %s on %s' % (image_id, compute_api.module.params.get('api_url'))) compute_api.module.fail_json(
msg="Error in getting image %s on %s" % (image_id, compute_api.module.params.get("api_url"))
)
else: else:
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get('api_url')) compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get("api_url"))
def fetch_state(compute_api, server): def fetch_state(compute_api, server):
@ -201,7 +188,7 @@ def fetch_state(compute_api, server):
return "absent" return "absent"
if not response.ok: if not response.ok:
msg = 'Error during state fetching: (%s) %s' % (response.status_code, response.json) msg = "Error during state fetching: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
try: try:
@ -243,7 +230,7 @@ def public_ip_payload(compute_api, public_ip):
# We check that the IP we want to attach exists, if so its ID is returned # We check that the IP we want to attach exists, if so its ID is returned
response = compute_api.get("ips") response = compute_api.get("ips")
if not response.ok: if not response.ok:
msg = 'Error during public IP validation: (%s) %s' % (response.status_code, response.json) msg = "Error during public IP validation: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
ip_list = [] ip_list = []
@ -260,14 +247,15 @@ def public_ip_payload(compute_api, public_ip):
def create_server(compute_api, server): def create_server(compute_api, server):
compute_api.module.debug("Starting a create_server") compute_api.module.debug("Starting a create_server")
target_server = None target_server = None
data = {"enable_ipv6": server["enable_ipv6"], data = {
"tags": server["tags"], "enable_ipv6": server["enable_ipv6"],
"commercial_type": server["commercial_type"], "tags": server["tags"],
"image": server["image"], "commercial_type": server["commercial_type"],
"dynamic_ip_required": server["dynamic_ip_required"], "image": server["image"],
"name": server["name"], "dynamic_ip_required": server["dynamic_ip_required"],
"organization": server["organization"] "name": server["name"],
} "organization": server["organization"],
}
if server["boot_type"]: if server["boot_type"]:
data["boot_type"] = server["boot_type"] data["boot_type"] = server["boot_type"]
@ -278,7 +266,7 @@ def create_server(compute_api, server):
response = compute_api.post(path="servers", data=data) response = compute_api.post(path="servers", data=data)
if not response.ok: if not response.ok:
msg = 'Error during server creation: (%s) %s' % (response.status_code, response.json) msg = "Error during server creation: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
try: try:
@ -304,10 +292,9 @@ def start_server(compute_api, server):
def perform_action(compute_api, server, action): def perform_action(compute_api, server, action):
response = compute_api.post(path="servers/%s/action" % server["id"], response = compute_api.post(path="servers/%s/action" % server["id"], data={"action": action})
data={"action": action})
if not response.ok: if not response.ok:
msg = 'Error during server %s: (%s) %s' % (action, response.status_code, response.json) msg = "Error during server %s: (%s) %s" % (action, response.status_code, response.json)
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
wait_to_complete_state_transition(compute_api=compute_api, server=server) wait_to_complete_state_transition(compute_api=compute_api, server=server)
@ -319,7 +306,7 @@ def remove_server(compute_api, server):
compute_api.module.debug("Starting remove server strategy") compute_api.module.debug("Starting remove server strategy")
response = compute_api.delete(path="servers/%s" % server["id"]) response = compute_api.delete(path="servers/%s" % server["id"])
if not response.ok: if not response.ok:
msg = 'Error during server deletion: (%s) %s' % (response.status_code, response.json) msg = "Error during server deletion: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
wait_to_complete_state_transition(compute_api=compute_api, server=server) wait_to_complete_state_transition(compute_api=compute_api, server=server)
@ -341,14 +328,17 @@ def present_strategy(compute_api, wished_server):
else: else:
target_server = query_results[0] target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server, if server_attributes_should_be_changed(
wished_server=wished_server): compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True changed = True
if compute_api.module.check_mode: if compute_api.module.check_mode:
return changed, {"status": "Server %s attributes would be changed." % target_server["id"]} return changed, {"status": "Server %s attributes would be changed." % target_server["id"]}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server) target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
return changed, target_server return changed, target_server
@ -375,7 +365,7 @@ def absent_strategy(compute_api, wished_server):
response = stop_server(compute_api=compute_api, server=target_server) response = stop_server(compute_api=compute_api, server=target_server)
if not response.ok: if not response.ok:
err_msg = f'Error while stopping a server before removing it [{response.status_code}: {response.json}]' err_msg = f"Error while stopping a server before removing it [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=err_msg) compute_api.module.fail_json(msg=err_msg)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server) wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
@ -383,7 +373,7 @@ def absent_strategy(compute_api, wished_server):
response = remove_server(compute_api=compute_api, server=target_server) response = remove_server(compute_api=compute_api, server=target_server)
if not response.ok: if not response.ok:
err_msg = f'Error while removing server [{response.status_code}: {response.json}]' err_msg = f"Error while removing server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=err_msg) compute_api.module.fail_json(msg=err_msg)
return changed, {"status": "Server %s deleted" % target_server["id"]} return changed, {"status": "Server %s deleted" % target_server["id"]}
@ -403,14 +393,17 @@ def running_strategy(compute_api, wished_server):
else: else:
target_server = query_results[0] target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server, if server_attributes_should_be_changed(
wished_server=wished_server): compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True changed = True
if compute_api.module.check_mode: if compute_api.module.check_mode:
return changed, {"status": "Server %s attributes would be changed before running it." % target_server["id"]} return changed, {"status": "Server %s attributes would be changed before running it." % target_server["id"]}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server) target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
current_state = fetch_state(compute_api=compute_api, server=target_server) current_state = fetch_state(compute_api=compute_api, server=target_server)
if current_state not in ("running", "starting"): if current_state not in ("running", "starting"):
@ -422,7 +415,7 @@ def running_strategy(compute_api, wished_server):
response = start_server(compute_api=compute_api, server=target_server) response = start_server(compute_api=compute_api, server=target_server)
if not response.ok: if not response.ok:
msg = f'Error while running server [{response.status_code}: {response.json}]' msg = f"Error while running server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
return changed, target_server return changed, target_server
@ -435,7 +428,6 @@ def stop_strategy(compute_api, wished_server):
changed = False changed = False
if not query_results: if not query_results:
if compute_api.module.check_mode: if compute_api.module.check_mode:
return changed, {"status": "A server would be created before being stopped."} return changed, {"status": "A server would be created before being stopped."}
@ -446,15 +438,19 @@ def stop_strategy(compute_api, wished_server):
compute_api.module.debug("stop_strategy: Servers are found.") compute_api.module.debug("stop_strategy: Servers are found.")
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server, if server_attributes_should_be_changed(
wished_server=wished_server): compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True changed = True
if compute_api.module.check_mode: if compute_api.module.check_mode:
return changed, { return changed, {
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]} "status": "Server %s attributes would be changed before stopping it." % target_server["id"]
}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server) target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server) wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
@ -472,7 +468,7 @@ def stop_strategy(compute_api, wished_server):
compute_api.module.debug(response.ok) compute_api.module.debug(response.ok)
if not response.ok: if not response.ok:
msg = f'Error while stopping server [{response.status_code}: {response.json}]' msg = f"Error while stopping server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
return changed, target_server return changed, target_server
@ -492,16 +488,19 @@ def restart_strategy(compute_api, wished_server):
else: else:
target_server = query_results[0] target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api, if server_attributes_should_be_changed(
target_server=target_server, compute_api=compute_api, target_server=target_server, wished_server=wished_server
wished_server=wished_server): ):
changed = True changed = True
if compute_api.module.check_mode: if compute_api.module.check_mode:
return changed, { return changed, {
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]} "status": "Server %s attributes would be changed before rebooting it." % target_server["id"]
}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server) target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
changed = True changed = True
if compute_api.module.check_mode: if compute_api.module.check_mode:
@ -513,14 +512,14 @@ def restart_strategy(compute_api, wished_server):
response = restart_server(compute_api=compute_api, server=target_server) response = restart_server(compute_api=compute_api, server=target_server)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server) wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
if not response.ok: if not response.ok:
msg = f'Error while restarting server that was running [{response.status_code}: {response.json}].' msg = f"Error while restarting server that was running [{response.status_code}: {response.json}]."
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
if fetch_state(compute_api=compute_api, server=target_server) in ("stopped",): if fetch_state(compute_api=compute_api, server=target_server) in ("stopped",):
response = restart_server(compute_api=compute_api, server=target_server) response = restart_server(compute_api=compute_api, server=target_server)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server) wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
if not response.ok: if not response.ok:
msg = f'Error while restarting server that was stopped [{response.status_code}: {response.json}].' msg = f"Error while restarting server that was stopped [{response.status_code}: {response.json}]."
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
return changed, target_server return changed, target_server
@ -531,18 +530,17 @@ state_strategy = {
"restarted": restart_strategy, "restarted": restart_strategy,
"stopped": stop_strategy, "stopped": stop_strategy,
"running": running_strategy, "running": running_strategy,
"absent": absent_strategy "absent": absent_strategy,
} }
def find(compute_api, wished_server, per_page=1): def find(compute_api, wished_server, per_page=1):
compute_api.module.debug("Getting inside find") compute_api.module.debug("Getting inside find")
# Only the name attribute is accepted in the Compute query API # Only the name attribute is accepted in the Compute query API
response = compute_api.get("servers", params={"name": wished_server["name"], response = compute_api.get("servers", params={"name": wished_server["name"], "per_page": per_page})
"per_page": per_page})
if not response.ok: if not response.ok:
msg = 'Error during server search: (%s) %s' % (response.status_code, response.json) msg = "Error during server search: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
search_results = response.json["servers"] search_results = response.json["servers"]
@ -563,16 +561,22 @@ def server_attributes_should_be_changed(compute_api, target_server, wished_serve
compute_api.module.debug("Checking if server attributes should be changed") compute_api.module.debug("Checking if server attributes should be changed")
compute_api.module.debug("Current Server: %s" % target_server) compute_api.module.debug("Current Server: %s" % target_server)
compute_api.module.debug("Wished Server: %s" % wished_server) compute_api.module.debug("Wished Server: %s" % wished_server)
debug_dict = dict((x, (target_server[x], wished_server[x])) debug_dict = dict(
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES (x, (target_server[x], wished_server[x]))
if x in target_server and x in wished_server) for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
if x in target_server and x in wished_server
)
compute_api.module.debug("Debug dict %s" % debug_dict) compute_api.module.debug("Debug dict %s" % debug_dict)
try: try:
for key in PATCH_MUTABLE_SERVER_ATTRIBUTES: for key in PATCH_MUTABLE_SERVER_ATTRIBUTES:
if key in target_server and key in wished_server: if key in target_server and key in wished_server:
# When you are working with dict, only ID matter as we ask user to put only the resource ID in the playbook # When you are working with dict, only ID matter as we ask user to put only the resource ID in the playbook
if isinstance(target_server[key], dict) and wished_server[key] and "id" in target_server[key].keys( if (
) and target_server[key]["id"] != wished_server[key]: isinstance(target_server[key], dict)
and wished_server[key]
and "id" in target_server[key].keys()
and target_server[key]["id"] != wished_server[key]
):
return True return True
# Handling other structure compare simply the two objects content # Handling other structure compare simply the two objects content
elif not isinstance(target_server[key], dict) and target_server[key] != wished_server[key]: elif not isinstance(target_server[key], dict) and target_server[key] != wished_server[key]:
@ -598,10 +602,9 @@ def server_change_attributes(compute_api, target_server, wished_server):
elif not isinstance(target_server[key], dict): elif not isinstance(target_server[key], dict):
patch_payload[key] = wished_server[key] patch_payload[key] = wished_server[key]
response = compute_api.patch(path="servers/%s" % target_server["id"], response = compute_api.patch(path="servers/%s" % target_server["id"], data=patch_payload)
data=patch_payload)
if not response.ok: if not response.ok:
msg = 'Error during server attributes patching: (%s) %s' % (response.status_code, response.json) msg = "Error during server attributes patching: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg) compute_api.module.fail_json(msg=msg)
try: try:
@ -625,9 +628,9 @@ def core(module):
"boot_type": module.params["boot_type"], "boot_type": module.params["boot_type"],
"tags": module.params["tags"], "tags": module.params["tags"],
"organization": module.params["organization"], "organization": module.params["organization"],
"security_group": module.params["security_group"] "security_group": module.params["security_group"],
} }
module.params['api_url'] = SCALEWAY_LOCATION[region]["api_endpoint"] module.params["api_url"] = SCALEWAY_LOCATION[region]["api_endpoint"]
compute_api = Scaleway(module=module) compute_api = Scaleway(module=module)
@ -643,22 +646,24 @@ def core(module):
def main(): def main():
argument_spec = scaleway_argument_spec() argument_spec = scaleway_argument_spec()
argument_spec.update(dict( argument_spec.update(
image=dict(required=True), dict(
name=dict(), image=dict(required=True),
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()), name=dict(),
commercial_type=dict(required=True), region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
enable_ipv6=dict(default=False, type="bool"), commercial_type=dict(required=True),
boot_type=dict(choices=['bootscript', 'local']), enable_ipv6=dict(default=False, type="bool"),
public_ip=dict(default="absent"), boot_type=dict(choices=["bootscript", "local"]),
state=dict(choices=state_strategy.keys(), default='present'), public_ip=dict(default="absent"),
tags=dict(type="list", default=[]), state=dict(choices=state_strategy.keys(), default="present"),
organization=dict(required=True), tags=dict(type="list", default=[]),
wait=dict(type="bool", default=False), organization=dict(required=True),
wait_timeout=dict(type="int", default=300), wait=dict(type="bool", default=False),
wait_sleep_time=dict(type="int", default=3), wait_timeout=dict(type="int", default=300),
security_group=dict(), wait_sleep_time=dict(type="int", default=3),
)) security_group=dict(),
)
)
module = AnsibleModule( module = AnsibleModule(
argument_spec=argument_spec, argument_spec=argument_spec,
supports_check_mode=True, supports_check_mode=True,
@ -667,5 +672,5 @@ def main():
core(module) core(module)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -32,32 +32,30 @@ Returns:
def run_module(): def run_module():
""" """
Main execution function for the x25519_pubkey Ansible module. Main execution function for the x25519_pubkey Ansible module.
Handles parameter validation, private key processing, public key derivation, Handles parameter validation, private key processing, public key derivation,
and optional file output with idempotent behavior. and optional file output with idempotent behavior.
""" """
module_args = { module_args = {
'private_key_b64': {'type': 'str', 'required': False}, "private_key_b64": {"type": "str", "required": False},
'private_key_path': {'type': 'path', 'required': False}, "private_key_path": {"type": "path", "required": False},
'public_key_path': {'type': 'path', 'required': False}, "public_key_path": {"type": "path", "required": False},
} }
result = { result = {
'changed': False, "changed": False,
'public_key': '', "public_key": "",
} }
module = AnsibleModule( module = AnsibleModule(
argument_spec=module_args, argument_spec=module_args, required_one_of=[["private_key_b64", "private_key_path"]], supports_check_mode=True
required_one_of=[['private_key_b64', 'private_key_path']],
supports_check_mode=True
) )
priv_b64 = None priv_b64 = None
if module.params['private_key_path']: if module.params["private_key_path"]:
try: try:
with open(module.params['private_key_path'], 'rb') as f: with open(module.params["private_key_path"], "rb") as f:
data = f.read() data = f.read()
try: try:
# First attempt: assume file contains base64 text data # First attempt: assume file contains base64 text data
@ -71,12 +69,14 @@ def run_module():
# whitespace-like bytes (0x09, 0x0A, etc.) that must be preserved # whitespace-like bytes (0x09, 0x0A, etc.) that must be preserved
# Stripping would corrupt the key and cause "got 31 bytes" errors # Stripping would corrupt the key and cause "got 31 bytes" errors
if len(data) != 32: if len(data) != 32:
module.fail_json(msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes") module.fail_json(
msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes"
)
priv_b64 = base64.b64encode(data).decode() priv_b64 = base64.b64encode(data).decode()
except OSError as e: except OSError as e:
module.fail_json(msg=f"Failed to read private key file: {e}") module.fail_json(msg=f"Failed to read private key file: {e}")
else: else:
priv_b64 = module.params['private_key_b64'] priv_b64 = module.params["private_key_b64"]
# Validate input parameters # Validate input parameters
if not priv_b64: if not priv_b64:
@ -93,15 +93,12 @@ def run_module():
try: try:
priv_key = x25519.X25519PrivateKey.from_private_bytes(priv_raw) priv_key = x25519.X25519PrivateKey.from_private_bytes(priv_raw)
pub_key = priv_key.public_key() pub_key = priv_key.public_key()
pub_raw = pub_key.public_bytes( pub_raw = pub_key.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw
)
pub_b64 = base64.b64encode(pub_raw).decode() pub_b64 = base64.b64encode(pub_raw).decode()
result['public_key'] = pub_b64 result["public_key"] = pub_b64
if module.params['public_key_path']: if module.params["public_key_path"]:
pub_path = module.params['public_key_path'] pub_path = module.params["public_key_path"]
existing = None existing = None
try: try:
@ -112,13 +109,13 @@ def run_module():
if existing != pub_b64: if existing != pub_b64:
try: try:
with open(pub_path, 'w') as f: with open(pub_path, "w") as f:
f.write(pub_b64) f.write(pub_b64)
result['changed'] = True result["changed"] = True
except OSError as e: except OSError as e:
module.fail_json(msg=f"Failed to write public key file: {e}") module.fail_json(msg=f"Failed to write public key file: {e}")
result['public_key_path'] = pub_path result["public_key_path"] = pub_path
except Exception as e: except Exception as e:
module.fail_json(msg=f"Failed to derive public key: {e}") module.fail_json(msg=f"Failed to derive public key: {e}")
@ -131,5 +128,5 @@ def main():
run_module() run_module()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -136,6 +136,8 @@
value: 1 value: 1
- item: "{{ 'net.ipv6.conf.all.forwarding' if ipv6_support else none }}" - item: "{{ 'net.ipv6.conf.all.forwarding' if ipv6_support else none }}"
value: 1 value: 1
- item: net.ipv4.conf.all.route_localnet
value: 1
- name: Install packages (batch optimization) - name: Install packages (batch optimization)
include_tasks: packages.yml include_tasks: packages.yml

View file

@ -38,11 +38,11 @@ COMMIT
# Allow traffic from the VPN network to the outside world, and replies # Allow traffic from the VPN network to the outside world, and replies
{% if ipsec_enabled %} {% if ipsec_enabled %}
# For IPsec traffic - NAT the decrypted packets from the VPN subnet # For IPsec traffic - NAT the decrypted packets from the VPN subnet
-A POSTROUTING -s {{ strongswan_network }} {{ '-j SNAT --to ' + snat_aipv4 if snat_aipv4 else '-j MASQUERADE' }} -A POSTROUTING -s {{ strongswan_network }} -o {{ ansible_default_ipv4['interface'] }} {{ '-j SNAT --to ' + snat_aipv4 if snat_aipv4 else '-j MASQUERADE' }}
{% endif %} {% endif %}
{% if wireguard_enabled %} {% if wireguard_enabled %}
# For WireGuard traffic - NAT packets from the VPN subnet # For WireGuard traffic - NAT packets from the VPN subnet
-A POSTROUTING -s {{ wireguard_network_ipv4 }} {{ '-j SNAT --to ' + snat_aipv4 if snat_aipv4 else '-j MASQUERADE' }} -A POSTROUTING -s {{ wireguard_network_ipv4 }} -o {{ ansible_default_ipv4['interface'] }} {{ '-j SNAT --to ' + snat_aipv4 if snat_aipv4 else '-j MASQUERADE' }}
{% endif %} {% endif %}
@ -85,8 +85,8 @@ COMMIT
# DUMMY interfaces are the proper way to install IPs without assigning them any # DUMMY interfaces are the proper way to install IPs without assigning them any
# particular virtual (tun,tap,...) or physical (ethernet) interface. # particular virtual (tun,tap,...) or physical (ethernet) interface.
# Accept DNS traffic to the local DNS resolver # Accept DNS traffic to the local DNS resolver from VPN clients only
-A INPUT -d {{ local_service_ip }} -p udp --dport 53 -j ACCEPT -A INPUT -s {{ subnets | join(',') }} -d {{ local_service_ip }} -p udp --dport 53 -j ACCEPT
# Drop traffic between VPN clients # Drop traffic between VPN clients
-A FORWARD -s {{ subnets | join(',') }} -d {{ subnets | join(',') }} -j {{ "DROP" if BetweenClients_DROP else "ACCEPT" }} -A FORWARD -s {{ subnets | join(',') }} -d {{ subnets | join(',') }} -j {{ "DROP" if BetweenClients_DROP else "ACCEPT" }}

View file

@ -37,11 +37,11 @@ COMMIT
# Allow traffic from the VPN network to the outside world, and replies # Allow traffic from the VPN network to the outside world, and replies
{% if ipsec_enabled %} {% if ipsec_enabled %}
# For IPsec traffic - NAT the decrypted packets from the VPN subnet # For IPsec traffic - NAT the decrypted packets from the VPN subnet
-A POSTROUTING -s {{ strongswan_network_ipv6 }} {{ '-j SNAT --to ' + ipv6_egress_ip | ansible.utils.ipaddr('address') if alternative_ingress_ip else '-j MASQUERADE' }} -A POSTROUTING -s {{ strongswan_network_ipv6 }} -o {{ ansible_default_ipv6['interface'] }} {{ '-j SNAT --to ' + ipv6_egress_ip | ansible.utils.ipaddr('address') if alternative_ingress_ip else '-j MASQUERADE' }}
{% endif %} {% endif %}
{% if wireguard_enabled %} {% if wireguard_enabled %}
# For WireGuard traffic - NAT packets from the VPN subnet # For WireGuard traffic - NAT packets from the VPN subnet
-A POSTROUTING -s {{ wireguard_network_ipv6 }} {{ '-j SNAT --to ' + ipv6_egress_ip | ansible.utils.ipaddr('address') if alternative_ingress_ip else '-j MASQUERADE' }} -A POSTROUTING -s {{ wireguard_network_ipv6 }} -o {{ ansible_default_ipv6['interface'] }} {{ '-j SNAT --to ' + ipv6_egress_ip | ansible.utils.ipaddr('address') if alternative_ingress_ip else '-j MASQUERADE' }}
{% endif %} {% endif %}
COMMIT COMMIT
@ -95,8 +95,8 @@ COMMIT
# DUMMY interfaces are the proper way to install IPs without assigning them any # DUMMY interfaces are the proper way to install IPs without assigning them any
# particular virtual (tun,tap,...) or physical (ethernet) interface. # particular virtual (tun,tap,...) or physical (ethernet) interface.
# Accept DNS traffic to the local DNS resolver # Accept DNS traffic to the local DNS resolver from VPN clients only
-A INPUT -d {{ local_service_ipv6 }}/128 -p udp --dport 53 -j ACCEPT -A INPUT -s {{ subnets | join(',') }} -d {{ local_service_ipv6 }}/128 -p udp --dport 53 -j ACCEPT
# Drop traffic between VPN clients # Drop traffic between VPN clients
-A FORWARD -s {{ subnets | join(',') }} -d {{ subnets | join(',') }} -j {{ "DROP" if BetweenClients_DROP else "ACCEPT" }} -A FORWARD -s {{ subnets | join(',') }} -d {{ subnets | join(',') }} -j {{ "DROP" if BetweenClients_DROP else "ACCEPT" }}

View file

@ -3,9 +3,16 @@
systemd: systemd:
daemon_reload: true daemon_reload: true
- name: restart dnscrypt-proxy.socket
systemd:
name: dnscrypt-proxy.socket
state: restarted
daemon_reload: true
when: ansible_distribution == 'Ubuntu' or ansible_distribution == 'Debian'
- name: restart dnscrypt-proxy - name: restart dnscrypt-proxy
systemd: systemd:
name: dnscrypt-proxy name: dnscrypt-proxy
state: restarted state: restarted
daemon_reload: true daemon_reload: true
when: ansible_distribution == 'Ubuntu' when: ansible_distribution == 'Ubuntu' or ansible_distribution == 'Debian'

View file

@ -3,7 +3,6 @@
include_tasks: ubuntu.yml include_tasks: ubuntu.yml
when: ansible_distribution == 'Debian' or ansible_distribution == 'Ubuntu' when: ansible_distribution == 'Debian' or ansible_distribution == 'Ubuntu'
- name: dnscrypt-proxy ip-blacklist configured - name: dnscrypt-proxy ip-blacklist configured
template: template:
src: ip-blacklist.txt.j2 src: ip-blacklist.txt.j2
@ -26,6 +25,14 @@
- meta: flush_handlers - meta: flush_handlers
- name: Ubuntu | Ensure dnscrypt-proxy socket is enabled and started
systemd:
name: dnscrypt-proxy.socket
enabled: true
state: started
daemon_reload: true
when: ansible_distribution == 'Debian' or ansible_distribution == 'Ubuntu'
- name: dnscrypt-proxy enabled and started - name: dnscrypt-proxy enabled and started
service: service:
name: dnscrypt-proxy name: dnscrypt-proxy

View file

@ -50,6 +50,49 @@
owner: root owner: root
group: root group: root
- name: Ubuntu | Ensure socket override directory exists
file:
path: /etc/systemd/system/dnscrypt-proxy.socket.d/
state: directory
mode: '0755'
owner: root
group: root
- name: Ubuntu | Configure dnscrypt-proxy socket to listen on VPN IPs
copy:
dest: /etc/systemd/system/dnscrypt-proxy.socket.d/10-algo-override.conf
content: |
[Socket]
# Clear default listeners
ListenStream=
ListenDatagram=
# Add VPN service IPs
ListenStream={{ local_service_ip }}:53
ListenDatagram={{ local_service_ip }}:53
{% if ipv6_support %}
ListenStream=[{{ local_service_ipv6 }}]:53
ListenDatagram=[{{ local_service_ipv6 }}]:53
{% endif %}
NoDelay=true
DeferAcceptSec=1
mode: '0644'
register: socket_override
notify:
- daemon-reload
- restart dnscrypt-proxy.socket
- restart dnscrypt-proxy
- name: Ubuntu | Reload systemd daemon after socket configuration
systemd:
daemon_reload: true
when: socket_override.changed
- name: Ubuntu | Restart dnscrypt-proxy socket to apply configuration
systemd:
name: dnscrypt-proxy.socket
state: restarted
when: socket_override.changed
- name: Ubuntu | Add custom requirements to successfully start the unit - name: Ubuntu | Add custom requirements to successfully start the unit
copy: copy:
dest: /etc/systemd/system/dnscrypt-proxy.service.d/99-algo.conf dest: /etc/systemd/system/dnscrypt-proxy.service.d/99-algo.conf
@ -61,8 +104,12 @@
[Service] [Service]
AmbientCapabilities=CAP_NET_BIND_SERVICE AmbientCapabilities=CAP_NET_BIND_SERVICE
notify: register: dnscrypt_override
- restart dnscrypt-proxy
- name: Ubuntu | Reload systemd daemon if override changed
systemd:
daemon_reload: true
when: dnscrypt_override.changed
- name: Ubuntu | Apply systemd security hardening for dnscrypt-proxy - name: Ubuntu | Apply systemd security hardening for dnscrypt-proxy
copy: copy:
@ -95,6 +142,9 @@
owner: root owner: root
group: root group: root
mode: '0644' mode: '0644'
notify: register: dnscrypt_hardening
- daemon-reload
- restart dnscrypt-proxy - name: Ubuntu | Reload systemd daemon if hardening changed
systemd:
daemon_reload: true
when: dnscrypt_hardening.changed

View file

@ -37,10 +37,16 @@
## List of local addresses and ports to listen to. Can be IPv4 and/or IPv6. ## List of local addresses and ports to listen to. Can be IPv4 and/or IPv6.
## Note: When using systemd socket activation, choose an empty set (i.e. [] ). ## Note: When using systemd socket activation, choose an empty set (i.e. [] ).
{% if ansible_distribution == 'Ubuntu' or ansible_distribution == 'Debian' %}
# Using systemd socket activation on Ubuntu/Debian
listen_addresses = []
{% else %}
# Direct binding on non-systemd systems
listen_addresses = [ listen_addresses = [
'{{ local_service_ip }}:53'{% if ipv6_support %}, '{{ local_service_ip }}:53'{% if ipv6_support %},
'[{{ local_service_ipv6 }}]:53'{% endif %} '[{{ local_service_ipv6 }}]:53'{% endif %}
] ]
{% endif %}
## Maximum number of simultaneous client connections to accept ## Maximum number of simultaneous client connections to accept

View file

@ -12,15 +12,6 @@
- { name: 'kernel.dmesg_restrict', value: '1' } - { name: 'kernel.dmesg_restrict', value: '1' }
when: privacy_advanced.reduce_kernel_verbosity | bool when: privacy_advanced.reduce_kernel_verbosity | bool
- name: Disable BPF JIT if available (optional security hardening)
sysctl:
name: net.core.bpf_jit_enable
value: '0'
state: present
reload: yes
when: privacy_advanced.reduce_kernel_verbosity | bool
ignore_errors: yes
- name: Configure kernel parameters for privacy - name: Configure kernel parameters for privacy
lineinfile: lineinfile:
path: /etc/sysctl.d/99-privacy.conf path: /etc/sysctl.d/99-privacy.conf
@ -31,18 +22,8 @@
- "# Privacy enhancements - reduce kernel logging" - "# Privacy enhancements - reduce kernel logging"
- "kernel.printk = 3 4 1 3" - "kernel.printk = 3 4 1 3"
- "kernel.dmesg_restrict = 1" - "kernel.dmesg_restrict = 1"
- "# Note: net.core.bpf_jit_enable may not be available on all kernels"
when: privacy_advanced.reduce_kernel_verbosity | bool when: privacy_advanced.reduce_kernel_verbosity | bool
- name: Add BPF JIT disable to sysctl config if kernel supports it
lineinfile:
path: /etc/sysctl.d/99-privacy.conf
line: "net.core.bpf_jit_enable = 0 # Disable BPF JIT to reduce attack surface"
create: yes
mode: '0644'
when: privacy_advanced.reduce_kernel_verbosity | bool
ignore_errors: yes
- name: Configure journal settings for privacy - name: Configure journal settings for privacy
lineinfile: lineinfile:
path: /etc/systemd/journald.conf path: /etc/systemd/journald.conf

View file

@ -3,6 +3,7 @@
Track test effectiveness by analyzing CI failures and correlating with issues/PRs Track test effectiveness by analyzing CI failures and correlating with issues/PRs
This helps identify which tests actually catch bugs vs just failing randomly This helps identify which tests actually catch bugs vs just failing randomly
""" """
import json import json
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
@ -12,7 +13,7 @@ from pathlib import Path
def get_github_api_data(endpoint): def get_github_api_data(endpoint):
"""Fetch data from GitHub API""" """Fetch data from GitHub API"""
cmd = ['gh', 'api', endpoint] cmd = ["gh", "api", endpoint]
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0: if result.returncode != 0:
print(f"Error fetching {endpoint}: {result.stderr}") print(f"Error fetching {endpoint}: {result.stderr}")
@ -25,40 +26,38 @@ def analyze_workflow_runs(repo_owner, repo_name, days_back=30):
since = (datetime.now() - timedelta(days=days_back)).isoformat() since = (datetime.now() - timedelta(days=days_back)).isoformat()
# Get workflow runs # Get workflow runs
runs = get_github_api_data( runs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure")
f'/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure'
)
if not runs: if not runs:
return {} return {}
test_failures = defaultdict(list) test_failures = defaultdict(list)
for run in runs.get('workflow_runs', []): for run in runs.get("workflow_runs", []):
# Get jobs for this run # Get jobs for this run
jobs = get_github_api_data( jobs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs/{run['id']}/jobs")
f'/repos/{repo_owner}/{repo_name}/actions/runs/{run["id"]}/jobs'
)
if not jobs: if not jobs:
continue continue
for job in jobs.get('jobs', []): for job in jobs.get("jobs", []):
if job['conclusion'] == 'failure': if job["conclusion"] == "failure":
# Try to extract which test failed from logs # Try to extract which test failed from logs
logs_url = job.get('logs_url') logs_url = job.get("logs_url")
if logs_url: if logs_url:
# Parse logs to find test failures # Parse logs to find test failures
test_name = extract_failed_test(job['name'], run['id']) test_name = extract_failed_test(job["name"], run["id"])
if test_name: if test_name:
test_failures[test_name].append({ test_failures[test_name].append(
'run_id': run['id'], {
'run_number': run['run_number'], "run_id": run["id"],
'date': run['created_at'], "run_number": run["run_number"],
'branch': run['head_branch'], "date": run["created_at"],
'commit': run['head_sha'][:7], "branch": run["head_branch"],
'pr': extract_pr_number(run) "commit": run["head_sha"][:7],
}) "pr": extract_pr_number(run),
}
)
return test_failures return test_failures
@ -67,47 +66,44 @@ def extract_failed_test(job_name, run_id):
"""Extract test name from job - this is simplified""" """Extract test name from job - this is simplified"""
# Map job names to test categories # Map job names to test categories
job_to_tests = { job_to_tests = {
'Basic sanity tests': 'test_basic_sanity', "Basic sanity tests": "test_basic_sanity",
'Ansible syntax check': 'ansible_syntax', "Ansible syntax check": "ansible_syntax",
'Docker build test': 'docker_tests', "Docker build test": "docker_tests",
'Configuration generation test': 'config_generation', "Configuration generation test": "config_generation",
'Ansible dry-run validation': 'ansible_dry_run' "Ansible dry-run validation": "ansible_dry_run",
} }
return job_to_tests.get(job_name, job_name) return job_to_tests.get(job_name, job_name)
def extract_pr_number(run): def extract_pr_number(run):
"""Extract PR number from workflow run""" """Extract PR number from workflow run"""
for pr in run.get('pull_requests', []): for pr in run.get("pull_requests", []):
return pr['number'] return pr["number"]
return None return None
def correlate_with_issues(repo_owner, repo_name, test_failures): def correlate_with_issues(repo_owner, repo_name, test_failures):
"""Correlate test failures with issues/PRs that fixed them""" """Correlate test failures with issues/PRs that fixed them"""
correlations = defaultdict(lambda: {'caught_bugs': 0, 'false_positives': 0}) correlations = defaultdict(lambda: {"caught_bugs": 0, "false_positives": 0})
for test_name, failures in test_failures.items(): for test_name, failures in test_failures.items():
for failure in failures: for failure in failures:
if failure['pr']: if failure["pr"]:
# Check if PR was merged (indicating it fixed a real issue) # Check if PR was merged (indicating it fixed a real issue)
pr = get_github_api_data( pr = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/pulls/{failure['pr']}")
f'/repos/{repo_owner}/{repo_name}/pulls/{failure["pr"]}'
)
if pr and pr.get('merged'): if pr and pr.get("merged"):
# Check PR title/body for bug indicators # Check PR title/body for bug indicators
title = pr.get('title', '').lower() title = pr.get("title", "").lower()
body = pr.get('body', '').lower() body = pr.get("body", "").lower()
bug_keywords = ['fix', 'bug', 'error', 'issue', 'broken', 'fail'] bug_keywords = ["fix", "bug", "error", "issue", "broken", "fail"]
is_bug_fix = any(keyword in title or keyword in body is_bug_fix = any(keyword in title or keyword in body for keyword in bug_keywords)
for keyword in bug_keywords)
if is_bug_fix: if is_bug_fix:
correlations[test_name]['caught_bugs'] += 1 correlations[test_name]["caught_bugs"] += 1
else: else:
correlations[test_name]['false_positives'] += 1 correlations[test_name]["false_positives"] += 1
return correlations return correlations
@ -133,8 +129,8 @@ def generate_effectiveness_report(test_failures, correlations):
scores = [] scores = []
for test_name, failures in test_failures.items(): for test_name, failures in test_failures.items():
failure_count = len(failures) failure_count = len(failures)
caught = correlations[test_name]['caught_bugs'] caught = correlations[test_name]["caught_bugs"]
false_pos = correlations[test_name]['false_positives'] false_pos = correlations[test_name]["false_positives"]
# Calculate effectiveness (bugs caught / total failures) # Calculate effectiveness (bugs caught / total failures)
if failure_count > 0: if failure_count > 0:
@ -159,12 +155,12 @@ def generate_effectiveness_report(test_failures, correlations):
elif effectiveness > 0.8: elif effectiveness > 0.8:
report.append(f"- ✅ `{test_name}` is highly effective ({effectiveness:.0%})") report.append(f"- ✅ `{test_name}` is highly effective ({effectiveness:.0%})")
return '\n'.join(report) return "\n".join(report)
def save_metrics(test_failures, correlations): def save_metrics(test_failures, correlations):
"""Save metrics to JSON for historical tracking""" """Save metrics to JSON for historical tracking"""
metrics_file = Path('.metrics/test-effectiveness.json') metrics_file = Path(".metrics/test-effectiveness.json")
metrics_file.parent.mkdir(exist_ok=True) metrics_file.parent.mkdir(exist_ok=True)
# Load existing metrics # Load existing metrics
@ -176,38 +172,34 @@ def save_metrics(test_failures, correlations):
# Add current metrics # Add current metrics
current = { current = {
'date': datetime.now().isoformat(), "date": datetime.now().isoformat(),
'test_failures': { "test_failures": {test: len(failures) for test, failures in test_failures.items()},
test: len(failures) for test, failures in test_failures.items() "effectiveness": {
},
'effectiveness': {
test: { test: {
'caught_bugs': data['caught_bugs'], "caught_bugs": data["caught_bugs"],
'false_positives': data['false_positives'], "false_positives": data["false_positives"],
'score': data['caught_bugs'] / (data['caught_bugs'] + data['false_positives']) "score": data["caught_bugs"] / (data["caught_bugs"] + data["false_positives"])
if (data['caught_bugs'] + data['false_positives']) > 0 else 0 if (data["caught_bugs"] + data["false_positives"]) > 0
else 0,
} }
for test, data in correlations.items() for test, data in correlations.items()
} },
} }
historical.append(current) historical.append(current)
# Keep last 12 months of data # Keep last 12 months of data
cutoff = datetime.now() - timedelta(days=365) cutoff = datetime.now() - timedelta(days=365)
historical = [ historical = [h for h in historical if datetime.fromisoformat(h["date"]) > cutoff]
h for h in historical
if datetime.fromisoformat(h['date']) > cutoff
]
with open(metrics_file, 'w') as f: with open(metrics_file, "w") as f:
json.dump(historical, f, indent=2) json.dump(historical, f, indent=2)
if __name__ == '__main__': if __name__ == "__main__":
# Configure these for your repo # Configure these for your repo
REPO_OWNER = 'trailofbits' REPO_OWNER = "trailofbits"
REPO_NAME = 'algo' REPO_NAME = "algo"
print("Analyzing test effectiveness...") print("Analyzing test effectiveness...")
@ -223,9 +215,9 @@ if __name__ == '__main__':
print("\n" + report) print("\n" + report)
# Save report # Save report
report_file = Path('.metrics/test-effectiveness-report.md') report_file = Path(".metrics/test-effectiveness-report.md")
report_file.parent.mkdir(exist_ok=True) report_file.parent.mkdir(exist_ok=True)
with open(report_file, 'w') as f: with open(report_file, "w") as f:
f.write(report) f.write(report)
print(f"\nReport saved to: {report_file}") print(f"\nReport saved to: {report_file}")

View file

@ -1,4 +1,5 @@
"""Test fixtures for Algo unit tests""" """Test fixtures for Algo unit tests"""
from pathlib import Path from pathlib import Path
import yaml import yaml
@ -6,7 +7,7 @@ import yaml
def load_test_variables(): def load_test_variables():
"""Load test variables from YAML fixture""" """Load test variables from YAML fixture"""
fixture_path = Path(__file__).parent / 'test_variables.yml' fixture_path = Path(__file__).parent / "test_variables.yml"
with open(fixture_path) as f: with open(fixture_path) as f:
return yaml.safe_load(f) return yaml.safe_load(f)

View file

@ -72,6 +72,12 @@ CA_password: test-ca-pass
# System # System
ansible_ssh_port: 4160 ansible_ssh_port: 4160
ansible_python_interpreter: /usr/bin/python3 ansible_python_interpreter: /usr/bin/python3
ansible_default_ipv4:
interface: eth0
address: 10.0.0.1
ansible_default_ipv6:
interface: eth0
address: 'fd9d:bc11:4020::1'
BetweenClients_DROP: 'Y' BetweenClients_DROP: 'Y'
ssh_tunnels_config_path: /etc/ssh/ssh_tunnels ssh_tunnels_config_path: /etc/ssh/ssh_tunnels
config_prefix: /etc/algo config_prefix: /etc/algo

View file

@ -2,49 +2,56 @@
""" """
Wrapper for Ansible's service module that always succeeds for known services Wrapper for Ansible's service module that always succeeds for known services
""" """
import json import json
import sys import sys
# Parse module arguments # Parse module arguments
args = json.loads(sys.stdin.read()) args = json.loads(sys.stdin.read())
module_args = args.get('ANSIBLE_MODULE_ARGS', {}) module_args = args.get("ANSIBLE_MODULE_ARGS", {})
service_name = module_args.get('name', '') service_name = module_args.get("name", "")
state = module_args.get('state', 'started') state = module_args.get("state", "started")
# Known services that should always succeed # Known services that should always succeed
known_services = [ known_services = [
'netfilter-persistent', 'iptables', 'wg-quick@wg0', 'strongswan-starter', "netfilter-persistent",
'ipsec', 'apparmor', 'unattended-upgrades', 'systemd-networkd', "iptables",
'systemd-resolved', 'rsyslog', 'ipfw', 'cron' "wg-quick@wg0",
"strongswan-starter",
"ipsec",
"apparmor",
"unattended-upgrades",
"systemd-networkd",
"systemd-resolved",
"rsyslog",
"ipfw",
"cron",
] ]
# Check if it's a known service # Check if it's a known service
service_found = False service_found = False
for svc in known_services: for svc in known_services:
if service_name == svc or service_name.startswith(svc + '.'): if service_name == svc or service_name.startswith(svc + "."):
service_found = True service_found = True
break break
if service_found: if service_found:
# Return success # Return success
result = { result = {
'changed': True if state in ['started', 'stopped', 'restarted', 'reloaded'] else False, "changed": True if state in ["started", "stopped", "restarted", "reloaded"] else False,
'name': service_name, "name": service_name,
'state': state, "state": state,
'status': { "status": {
'LoadState': 'loaded', "LoadState": "loaded",
'ActiveState': 'active' if state != 'stopped' else 'inactive', "ActiveState": "active" if state != "stopped" else "inactive",
'SubState': 'running' if state != 'stopped' else 'dead' "SubState": "running" if state != "stopped" else "dead",
} },
} }
print(json.dumps(result)) print(json.dumps(result))
sys.exit(0) sys.exit(0)
else: else:
# Service not found # Service not found
error = { error = {"failed": True, "msg": f"Could not find the requested service {service_name}: "}
'failed': True,
'msg': f'Could not find the requested service {service_name}: '
}
print(json.dumps(error)) print(json.dumps(error))
sys.exit(1) sys.exit(1)

View file

@ -9,44 +9,44 @@ from ansible.module_utils.basic import AnsibleModule
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec={ argument_spec={
'name': {'type': 'list', 'aliases': ['pkg', 'package']}, "name": {"type": "list", "aliases": ["pkg", "package"]},
'state': {'type': 'str', 'default': 'present', 'choices': ['present', 'absent', 'latest', 'build-dep', 'fixed']}, "state": {
'update_cache': {'type': 'bool', 'default': False}, "type": "str",
'cache_valid_time': {'type': 'int', 'default': 0}, "default": "present",
'install_recommends': {'type': 'bool'}, "choices": ["present", "absent", "latest", "build-dep", "fixed"],
'force': {'type': 'bool', 'default': False}, },
'allow_unauthenticated': {'type': 'bool', 'default': False}, "update_cache": {"type": "bool", "default": False},
'allow_downgrade': {'type': 'bool', 'default': False}, "cache_valid_time": {"type": "int", "default": 0},
'allow_change_held_packages': {'type': 'bool', 'default': False}, "install_recommends": {"type": "bool"},
'dpkg_options': {'type': 'str', 'default': 'force-confdef,force-confold'}, "force": {"type": "bool", "default": False},
'autoremove': {'type': 'bool', 'default': False}, "allow_unauthenticated": {"type": "bool", "default": False},
'purge': {'type': 'bool', 'default': False}, "allow_downgrade": {"type": "bool", "default": False},
'force_apt_get': {'type': 'bool', 'default': False}, "allow_change_held_packages": {"type": "bool", "default": False},
"dpkg_options": {"type": "str", "default": "force-confdef,force-confold"},
"autoremove": {"type": "bool", "default": False},
"purge": {"type": "bool", "default": False},
"force_apt_get": {"type": "bool", "default": False},
}, },
supports_check_mode=True supports_check_mode=True,
) )
name = module.params['name'] name = module.params["name"]
state = module.params['state'] state = module.params["state"]
update_cache = module.params['update_cache'] update_cache = module.params["update_cache"]
result = { result = {"changed": False, "cache_updated": False, "cache_update_time": 0}
'changed': False,
'cache_updated': False,
'cache_update_time': 0
}
# Log the operation # Log the operation
with open('/var/log/mock-apt-module.log', 'a') as f: with open("/var/log/mock-apt-module.log", "a") as f:
f.write(f"apt module called: name={name}, state={state}, update_cache={update_cache}\n") f.write(f"apt module called: name={name}, state={state}, update_cache={update_cache}\n")
# Handle cache update # Handle cache update
if update_cache: if update_cache:
# In Docker, apt-get update was already run in entrypoint # In Docker, apt-get update was already run in entrypoint
# Just pretend it succeeded # Just pretend it succeeded
result['cache_updated'] = True result["cache_updated"] = True
result['cache_update_time'] = 1754231778 # Fixed timestamp result["cache_update_time"] = 1754231778 # Fixed timestamp
result['changed'] = True result["changed"] = True
# Handle package installation/removal # Handle package installation/removal
if name: if name:
@ -56,40 +56,41 @@ def main():
installed_packages = [] installed_packages = []
for pkg in packages: for pkg in packages:
# Use dpkg to check if package is installed # Use dpkg to check if package is installed
check_cmd = ['dpkg', '-s', pkg] check_cmd = ["dpkg", "-s", pkg]
rc = subprocess.run(check_cmd, capture_output=True) rc = subprocess.run(check_cmd, capture_output=True)
if rc.returncode == 0: if rc.returncode == 0:
installed_packages.append(pkg) installed_packages.append(pkg)
if state in ['present', 'latest']: if state in ["present", "latest"]:
# Check if we need to install anything # Check if we need to install anything
missing_packages = [p for p in packages if p not in installed_packages] missing_packages = [p for p in packages if p not in installed_packages]
if missing_packages: if missing_packages:
# Log what we would install # Log what we would install
with open('/var/log/mock-apt-module.log', 'a') as f: with open("/var/log/mock-apt-module.log", "a") as f:
f.write(f"Would install packages: {missing_packages}\n") f.write(f"Would install packages: {missing_packages}\n")
# For our test purposes, these packages are pre-installed in Docker # For our test purposes, these packages are pre-installed in Docker
# Just report success # Just report success
result['changed'] = True result["changed"] = True
result['stdout'] = f"Mock: Packages {missing_packages} are already available" result["stdout"] = f"Mock: Packages {missing_packages} are already available"
result['stderr'] = "" result["stderr"] = ""
else: else:
result['stdout'] = "All packages are already installed" result["stdout"] = "All packages are already installed"
elif state == 'absent': elif state == "absent":
# Check if we need to remove anything # Check if we need to remove anything
present_packages = [p for p in packages if p in installed_packages] present_packages = [p for p in packages if p in installed_packages]
if present_packages: if present_packages:
result['changed'] = True result["changed"] = True
result['stdout'] = f"Mock: Would remove packages {present_packages}" result["stdout"] = f"Mock: Would remove packages {present_packages}"
else: else:
result['stdout'] = "No packages to remove" result["stdout"] = "No packages to remove"
# Always report success for our testing # Always report success for our testing
module.exit_json(**result) module.exit_json(**result)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View file

@ -9,79 +9,72 @@ from ansible.module_utils.basic import AnsibleModule
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec={ argument_spec={
'_raw_params': {'type': 'str'}, "_raw_params": {"type": "str"},
'cmd': {'type': 'str'}, "cmd": {"type": "str"},
'creates': {'type': 'path'}, "creates": {"type": "path"},
'removes': {'type': 'path'}, "removes": {"type": "path"},
'chdir': {'type': 'path'}, "chdir": {"type": "path"},
'executable': {'type': 'path'}, "executable": {"type": "path"},
'warn': {'type': 'bool', 'default': False}, "warn": {"type": "bool", "default": False},
'stdin': {'type': 'str'}, "stdin": {"type": "str"},
'stdin_add_newline': {'type': 'bool', 'default': True}, "stdin_add_newline": {"type": "bool", "default": True},
'strip_empty_ends': {'type': 'bool', 'default': True}, "strip_empty_ends": {"type": "bool", "default": True},
'_uses_shell': {'type': 'bool', 'default': False}, "_uses_shell": {"type": "bool", "default": False},
}, },
supports_check_mode=True supports_check_mode=True,
) )
# Get the command # Get the command
raw_params = module.params.get('_raw_params') raw_params = module.params.get("_raw_params")
cmd = module.params.get('cmd') or raw_params cmd = module.params.get("cmd") or raw_params
if not cmd: if not cmd:
module.fail_json(msg="no command given") module.fail_json(msg="no command given")
result = { result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
'changed': False,
'cmd': cmd,
'rc': 0,
'stdout': '',
'stderr': '',
'stdout_lines': [],
'stderr_lines': []
}
# Log the operation # Log the operation
with open('/var/log/mock-command-module.log', 'a') as f: with open("/var/log/mock-command-module.log", "a") as f:
f.write(f"command module called: cmd={cmd}\n") f.write(f"command module called: cmd={cmd}\n")
# Handle specific commands # Handle specific commands
if 'apparmor_status' in cmd: if "apparmor_status" in cmd:
# Pretend apparmor is not installed/active # Pretend apparmor is not installed/active
result['rc'] = 127 result["rc"] = 127
result['stderr'] = "apparmor_status: command not found" result["stderr"] = "apparmor_status: command not found"
result['msg'] = "[Errno 2] No such file or directory: b'apparmor_status'" result["msg"] = "[Errno 2] No such file or directory: b'apparmor_status'"
module.fail_json(msg=result['msg'], **result) module.fail_json(msg=result["msg"], **result)
elif 'netplan apply' in cmd: elif "netplan apply" in cmd:
# Pretend netplan succeeded # Pretend netplan succeeded
result['stdout'] = "Mock: netplan configuration applied" result["stdout"] = "Mock: netplan configuration applied"
result['changed'] = True result["changed"] = True
elif 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd: elif "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
# Routing cache flush # Routing cache flush
result['stdout'] = "1" result["stdout"] = "1"
result['changed'] = True result["changed"] = True
else: else:
# For other commands, try to run them # For other commands, try to run them
try: try:
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get('chdir')) proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get("chdir"))
result['rc'] = proc.returncode result["rc"] = proc.returncode
result['stdout'] = proc.stdout result["stdout"] = proc.stdout
result['stderr'] = proc.stderr result["stderr"] = proc.stderr
result['stdout_lines'] = proc.stdout.splitlines() result["stdout_lines"] = proc.stdout.splitlines()
result['stderr_lines'] = proc.stderr.splitlines() result["stderr_lines"] = proc.stderr.splitlines()
result['changed'] = True result["changed"] = True
except Exception as e: except Exception as e:
result['rc'] = 1 result["rc"] = 1
result['stderr'] = str(e) result["stderr"] = str(e)
result['msg'] = str(e) result["msg"] = str(e)
module.fail_json(msg=result['msg'], **result) module.fail_json(msg=result["msg"], **result)
if result['rc'] == 0: if result["rc"] == 0:
module.exit_json(**result) module.exit_json(**result)
else: else:
if 'msg' not in result: if "msg" not in result:
result['msg'] = f"Command failed with return code {result['rc']}" result["msg"] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result['msg'], **result) module.fail_json(msg=result["msg"], **result)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View file

@ -9,73 +9,71 @@ from ansible.module_utils.basic import AnsibleModule
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec={ argument_spec={
'_raw_params': {'type': 'str'}, "_raw_params": {"type": "str"},
'cmd': {'type': 'str'}, "cmd": {"type": "str"},
'creates': {'type': 'path'}, "creates": {"type": "path"},
'removes': {'type': 'path'}, "removes": {"type": "path"},
'chdir': {'type': 'path'}, "chdir": {"type": "path"},
'executable': {'type': 'path', 'default': '/bin/sh'}, "executable": {"type": "path", "default": "/bin/sh"},
'warn': {'type': 'bool', 'default': False}, "warn": {"type": "bool", "default": False},
'stdin': {'type': 'str'}, "stdin": {"type": "str"},
'stdin_add_newline': {'type': 'bool', 'default': True}, "stdin_add_newline": {"type": "bool", "default": True},
}, },
supports_check_mode=True supports_check_mode=True,
) )
# Get the command # Get the command
raw_params = module.params.get('_raw_params') raw_params = module.params.get("_raw_params")
cmd = module.params.get('cmd') or raw_params cmd = module.params.get("cmd") or raw_params
if not cmd: if not cmd:
module.fail_json(msg="no command given") module.fail_json(msg="no command given")
result = { result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
'changed': False,
'cmd': cmd,
'rc': 0,
'stdout': '',
'stderr': '',
'stdout_lines': [],
'stderr_lines': []
}
# Log the operation # Log the operation
with open('/var/log/mock-shell-module.log', 'a') as f: with open("/var/log/mock-shell-module.log", "a") as f:
f.write(f"shell module called: cmd={cmd}\n") f.write(f"shell module called: cmd={cmd}\n")
# Handle specific commands # Handle specific commands
if 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd: if "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
# Routing cache flush - just pretend it worked # Routing cache flush - just pretend it worked
result['stdout'] = "" result["stdout"] = ""
result['changed'] = True result["changed"] = True
elif 'ifconfig lo100' in cmd: elif "ifconfig lo100" in cmd:
# BSD loopback commands - simulate success # BSD loopback commands - simulate success
result['stdout'] = "0" result["stdout"] = "0"
result['changed'] = True result["changed"] = True
else: else:
# For other commands, try to run them # For other commands, try to run them
try: try:
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, proc = subprocess.run(
executable=module.params.get('executable'), cmd,
cwd=module.params.get('chdir')) shell=True,
result['rc'] = proc.returncode capture_output=True,
result['stdout'] = proc.stdout text=True,
result['stderr'] = proc.stderr executable=module.params.get("executable"),
result['stdout_lines'] = proc.stdout.splitlines() cwd=module.params.get("chdir"),
result['stderr_lines'] = proc.stderr.splitlines() )
result['changed'] = True result["rc"] = proc.returncode
result["stdout"] = proc.stdout
result["stderr"] = proc.stderr
result["stdout_lines"] = proc.stdout.splitlines()
result["stderr_lines"] = proc.stderr.splitlines()
result["changed"] = True
except Exception as e: except Exception as e:
result['rc'] = 1 result["rc"] = 1
result['stderr'] = str(e) result["stderr"] = str(e)
result['msg'] = str(e) result["msg"] = str(e)
module.fail_json(msg=result['msg'], **result) module.fail_json(msg=result["msg"], **result)
if result['rc'] == 0: if result["rc"] == 0:
module.exit_json(**result) module.exit_json(**result)
else: else:
if 'msg' not in result: if "msg" not in result:
result['msg'] = f"Command failed with return code {result['rc']}" result["msg"] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result['msg'], **result) module.fail_json(msg=result["msg"], **result)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View file

@ -24,6 +24,7 @@ import yaml
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT)) sys.path.insert(0, str(PROJECT_ROOT))
def create_expected_cloud_init(): def create_expected_cloud_init():
""" """
Create the expected cloud-init content that should be generated Create the expected cloud-init content that should be generated
@ -74,6 +75,7 @@ runcmd:
- systemctl restart sshd.service - systemctl restart sshd.service
""" """
class TestCloudInitTemplate: class TestCloudInitTemplate:
"""Test class for cloud-init template validation.""" """Test class for cloud-init template validation."""
@ -98,10 +100,7 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity() parsed = self.test_yaml_validity()
required_sections = [ required_sections = ["package_update", "package_upgrade", "packages", "users", "write_files", "runcmd"]
'package_update', 'package_upgrade', 'packages',
'users', 'write_files', 'runcmd'
]
missing = [section for section in required_sections if section not in parsed] missing = [section for section in required_sections if section not in parsed]
assert not missing, f"Missing required sections: {missing}" assert not missing, f"Missing required sections: {missing}"
@ -114,35 +113,30 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity() parsed = self.test_yaml_validity()
write_files = parsed.get('write_files', []) write_files = parsed.get("write_files", [])
assert write_files, "write_files section should be present" assert write_files, "write_files section should be present"
# Find sshd_config file # Find sshd_config file
sshd_config = None sshd_config = None
for file_entry in write_files: for file_entry in write_files:
if file_entry.get('path') == '/etc/ssh/sshd_config': if file_entry.get("path") == "/etc/ssh/sshd_config":
sshd_config = file_entry sshd_config = file_entry
break break
assert sshd_config, "sshd_config file should be in write_files" assert sshd_config, "sshd_config file should be in write_files"
content = sshd_config.get('content', '') content = sshd_config.get("content", "")
assert content, "sshd_config should have content" assert content, "sshd_config should have content"
# Check required SSH configurations # Check required SSH configurations
required_configs = [ required_configs = ["Port 4160", "AllowGroups algo", "PermitRootLogin no", "PasswordAuthentication no"]
'Port 4160',
'AllowGroups algo',
'PermitRootLogin no',
'PasswordAuthentication no'
]
missing = [config for config in required_configs if config not in content] missing = [config for config in required_configs if config not in content]
assert not missing, f"Missing SSH configurations: {missing}" assert not missing, f"Missing SSH configurations: {missing}"
# Verify proper formatting - first line should be Port directive # Verify proper formatting - first line should be Port directive
lines = content.strip().split('\n') lines = content.strip().split("\n")
assert lines[0].strip() == 'Port 4160', f"First line should be 'Port 4160', got: {repr(lines[0])}" assert lines[0].strip() == "Port 4160", f"First line should be 'Port 4160', got: {repr(lines[0])}"
print("✅ SSH configuration correct") print("✅ SSH configuration correct")
@ -152,26 +146,26 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity() parsed = self.test_yaml_validity()
users = parsed.get('users', []) users = parsed.get("users", [])
assert users, "users section should be present" assert users, "users section should be present"
# Find algo user # Find algo user
algo_user = None algo_user = None
for user in users: for user in users:
if isinstance(user, dict) and user.get('name') == 'algo': if isinstance(user, dict) and user.get("name") == "algo":
algo_user = user algo_user = user
break break
assert algo_user, "algo user should be defined" assert algo_user, "algo user should be defined"
# Check required user properties # Check required user properties
required_props = ['sudo', 'groups', 'shell', 'ssh_authorized_keys'] required_props = ["sudo", "groups", "shell", "ssh_authorized_keys"]
missing = [prop for prop in required_props if prop not in algo_user] missing = [prop for prop in required_props if prop not in algo_user]
assert not missing, f"algo user missing properties: {missing}" assert not missing, f"algo user missing properties: {missing}"
# Verify sudo configuration # Verify sudo configuration
sudo_config = algo_user.get('sudo', '') sudo_config = algo_user.get("sudo", "")
assert 'NOPASSWD:ALL' in sudo_config, f"sudo config should allow passwordless access: {sudo_config}" assert "NOPASSWD:ALL" in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
print("✅ User creation correct") print("✅ User creation correct")
@ -181,13 +175,13 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity() parsed = self.test_yaml_validity()
runcmd = parsed.get('runcmd', []) runcmd = parsed.get("runcmd", [])
assert runcmd, "runcmd section should be present" assert runcmd, "runcmd section should be present"
# Check for SSH restart command # Check for SSH restart command
ssh_restart_found = False ssh_restart_found = False
for cmd in runcmd: for cmd in runcmd:
if 'systemctl restart sshd' in str(cmd): if "systemctl restart sshd" in str(cmd):
ssh_restart_found = True ssh_restart_found = True
break break
@ -202,18 +196,18 @@ class TestCloudInitTemplate:
cloud_init_content = create_expected_cloud_init() cloud_init_content = create_expected_cloud_init()
# Extract the sshd_config content lines # Extract the sshd_config content lines
lines = cloud_init_content.split('\n') lines = cloud_init_content.split("\n")
in_sshd_content = False in_sshd_content = False
sshd_lines = [] sshd_lines = []
for line in lines: for line in lines:
if 'content: |' in line: if "content: |" in line:
in_sshd_content = True in_sshd_content = True
continue continue
elif in_sshd_content: elif in_sshd_content:
if line.strip() == '' and len(sshd_lines) > 0: if line.strip() == "" and len(sshd_lines) > 0:
break break
if line.startswith('runcmd:'): if line.startswith("runcmd:"):
break break
sshd_lines.append(line) sshd_lines.append(line)
@ -225,11 +219,13 @@ class TestCloudInitTemplate:
for line in non_empty_lines: for line in non_empty_lines:
# Each line should start with exactly 6 spaces # Each line should start with exactly 6 spaces
assert line.startswith(' ') and not line.startswith(' '), \ assert line.startswith(" ") and not line.startswith(" "), (
f"Line should have exactly 6 spaces indentation: {repr(line)}" f"Line should have exactly 6 spaces indentation: {repr(line)}"
)
print("✅ Indentation is consistent") print("✅ Indentation is consistent")
def run_tests(): def run_tests():
"""Run all tests manually (for non-pytest usage).""" """Run all tests manually (for non-pytest usage)."""
print("🚀 Cloud-init template validation tests") print("🚀 Cloud-init template validation tests")
@ -258,6 +254,7 @@ def run_tests():
print(f"❌ Unexpected error: {e}") print(f"❌ Unexpected error: {e}")
return False return False
if __name__ == "__main__": if __name__ == "__main__":
success = run_tests() success = run_tests()
sys.exit(0 if success else 1) sys.exit(0 if success else 1)

View file

@ -30,49 +30,49 @@ packages:
rendered = self.packages_template.render({}) rendered = self.packages_template.render({})
# Should only have sudo package # Should only have sudo package
self.assertIn('- sudo', rendered) self.assertIn("- sudo", rendered)
self.assertNotIn('- git', rendered) self.assertNotIn("- git", rendered)
self.assertNotIn('- screen', rendered) self.assertNotIn("- screen", rendered)
self.assertNotIn('- apparmor-utils', rendered) self.assertNotIn("- apparmor-utils", rendered)
def test_preinstall_enabled(self): def test_preinstall_enabled(self):
"""Test that package pre-installation works when enabled.""" """Test that package pre-installation works when enabled."""
# Test with pre-installation enabled # Test with pre-installation enabled
rendered = self.packages_template.render({'performance_preinstall_packages': True}) rendered = self.packages_template.render({"performance_preinstall_packages": True})
# Should have sudo and all universal packages # Should have sudo and all universal packages
self.assertIn('- sudo', rendered) self.assertIn("- sudo", rendered)
self.assertIn('- git', rendered) self.assertIn("- git", rendered)
self.assertIn('- screen', rendered) self.assertIn("- screen", rendered)
self.assertIn('- apparmor-utils', rendered) self.assertIn("- apparmor-utils", rendered)
self.assertIn('- uuid-runtime', rendered) self.assertIn("- uuid-runtime", rendered)
self.assertIn('- coreutils', rendered) self.assertIn("- coreutils", rendered)
self.assertIn('- iptables-persistent', rendered) self.assertIn("- iptables-persistent", rendered)
self.assertIn('- cgroup-tools', rendered) self.assertIn("- cgroup-tools", rendered)
def test_preinstall_disabled_explicitly(self): def test_preinstall_disabled_explicitly(self):
"""Test that package pre-installation is disabled when set to false.""" """Test that package pre-installation is disabled when set to false."""
# Test with pre-installation explicitly disabled # Test with pre-installation explicitly disabled
rendered = self.packages_template.render({'performance_preinstall_packages': False}) rendered = self.packages_template.render({"performance_preinstall_packages": False})
# Should only have sudo package # Should only have sudo package
self.assertIn('- sudo', rendered) self.assertIn("- sudo", rendered)
self.assertNotIn('- git', rendered) self.assertNotIn("- git", rendered)
self.assertNotIn('- screen', rendered) self.assertNotIn("- screen", rendered)
self.assertNotIn('- apparmor-utils', rendered) self.assertNotIn("- apparmor-utils", rendered)
def test_package_count(self): def test_package_count(self):
"""Test that the correct number of packages are included.""" """Test that the correct number of packages are included."""
# Default: should only have sudo (1 package) # Default: should only have sudo (1 package)
rendered_default = self.packages_template.render({}) rendered_default = self.packages_template.render({})
lines_default = [line.strip() for line in rendered_default.split('\n') if line.strip().startswith('- ')] lines_default = [line.strip() for line in rendered_default.split("\n") if line.strip().startswith("- ")]
self.assertEqual(len(lines_default), 1) self.assertEqual(len(lines_default), 1)
# Enabled: should have sudo + 7 universal packages (8 total) # Enabled: should have sudo + 7 universal packages (8 total)
rendered_enabled = self.packages_template.render({'performance_preinstall_packages': True}) rendered_enabled = self.packages_template.render({"performance_preinstall_packages": True})
lines_enabled = [line.strip() for line in rendered_enabled.split('\n') if line.strip().startswith('- ')] lines_enabled = [line.strip() for line in rendered_enabled.split("\n") if line.strip().startswith("- ")]
self.assertEqual(len(lines_enabled), 8) self.assertEqual(len(lines_enabled), 8)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -2,6 +2,7 @@
""" """
Basic sanity tests for Algo VPN that don't require deployment Basic sanity tests for Algo VPN that don't require deployment
""" """
import os import os
import subprocess import subprocess
import sys import sys
@ -44,11 +45,7 @@ def test_config_file_valid():
def test_ansible_syntax(): def test_ansible_syntax():
"""Check that main playbook has valid syntax""" """Check that main playbook has valid syntax"""
result = subprocess.run( result = subprocess.run(["ansible-playbook", "main.yml", "--syntax-check"], capture_output=True, text=True)
["ansible-playbook", "main.yml", "--syntax-check"],
capture_output=True,
text=True
)
assert result.returncode == 0, f"Ansible syntax check failed:\n{result.stderr}" assert result.returncode == 0, f"Ansible syntax check failed:\n{result.stderr}"
print("✓ Ansible playbook syntax is valid") print("✓ Ansible playbook syntax is valid")
@ -60,11 +57,7 @@ def test_shellcheck():
for script in shell_scripts: for script in shell_scripts:
if os.path.exists(script): if os.path.exists(script):
result = subprocess.run( result = subprocess.run(["shellcheck", script], capture_output=True, text=True)
["shellcheck", script],
capture_output=True,
text=True
)
assert result.returncode == 0, f"Shellcheck failed for {script}:\n{result.stdout}" assert result.returncode == 0, f"Shellcheck failed for {script}:\n{result.stdout}"
print(f"{script} passed shellcheck") print(f"{script} passed shellcheck")
@ -87,7 +80,7 @@ def test_cloud_init_header_format():
assert os.path.exists(cloud_init_file), f"{cloud_init_file} not found" assert os.path.exists(cloud_init_file), f"{cloud_init_file} not found"
with open(cloud_init_file) as f: with open(cloud_init_file) as f:
first_line = f.readline().rstrip('\n\r') first_line = f.readline().rstrip("\n\r")
# The first line MUST be exactly "#cloud-config" (no space after #) # The first line MUST be exactly "#cloud-config" (no space after #)
# This regression was introduced in PR #14775 and broke DigitalOcean deployments # This regression was introduced in PR #14775 and broke DigitalOcean deployments

View file

@ -4,31 +4,30 @@ Test cloud provider instance type configurations
Focused on validating that configured instance types are current/valid Focused on validating that configured instance types are current/valid
Based on issues #14730 - Hetzner changed from cx11 to cx22 Based on issues #14730 - Hetzner changed from cx11 to cx22
""" """
import sys import sys
def test_hetzner_server_types(): def test_hetzner_server_types():
"""Test Hetzner server type configurations (issue #14730)""" """Test Hetzner server type configurations (issue #14730)"""
# Hetzner deprecated cx11 and cpx11 - smallest is now cx22 # Hetzner deprecated cx11 and cpx11 - smallest is now cx22
deprecated_types = ['cx11', 'cpx11'] deprecated_types = ["cx11", "cpx11"]
current_types = ['cx22', 'cpx22', 'cx32', 'cpx32', 'cx42', 'cpx42'] current_types = ["cx22", "cpx22", "cx32", "cpx32", "cx42", "cpx42"]
# Test that we're not using deprecated types in any configs # Test that we're not using deprecated types in any configs
test_config = { test_config = {
'cloud_providers': { "cloud_providers": {
'hetzner': { "hetzner": {
'size': 'cx22', # Should be cx22, not cx11 "size": "cx22", # Should be cx22, not cx11
'image': 'ubuntu-22.04', "image": "ubuntu-22.04",
'location': 'hel1' "location": "hel1",
} }
} }
} }
hetzner = test_config['cloud_providers']['hetzner'] hetzner = test_config["cloud_providers"]["hetzner"]
assert hetzner['size'] not in deprecated_types, \ assert hetzner["size"] not in deprecated_types, f"Using deprecated Hetzner type: {hetzner['size']}"
f"Using deprecated Hetzner type: {hetzner['size']}" assert hetzner["size"] in current_types, f"Unknown Hetzner type: {hetzner['size']}"
assert hetzner['size'] in current_types, \
f"Unknown Hetzner type: {hetzner['size']}"
print("✓ Hetzner server types test passed") print("✓ Hetzner server types test passed")
@ -36,10 +35,10 @@ def test_hetzner_server_types():
def test_digitalocean_instance_types(): def test_digitalocean_instance_types():
"""Test DigitalOcean droplet size naming""" """Test DigitalOcean droplet size naming"""
# DigitalOcean uses format like s-1vcpu-1gb # DigitalOcean uses format like s-1vcpu-1gb
valid_sizes = ['s-1vcpu-1gb', 's-2vcpu-2gb', 's-2vcpu-4gb', 's-4vcpu-8gb'] valid_sizes = ["s-1vcpu-1gb", "s-2vcpu-2gb", "s-2vcpu-4gb", "s-4vcpu-8gb"]
deprecated_sizes = ['512mb', '1gb', '2gb'] # Old naming scheme deprecated_sizes = ["512mb", "1gb", "2gb"] # Old naming scheme
test_size = 's-2vcpu-2gb' test_size = "s-2vcpu-2gb"
assert test_size in valid_sizes, f"Invalid DO size: {test_size}" assert test_size in valid_sizes, f"Invalid DO size: {test_size}"
assert test_size not in deprecated_sizes, f"Using deprecated DO size: {test_size}" assert test_size not in deprecated_sizes, f"Using deprecated DO size: {test_size}"
@ -49,10 +48,10 @@ def test_digitalocean_instance_types():
def test_aws_instance_types(): def test_aws_instance_types():
"""Test AWS EC2 instance type naming""" """Test AWS EC2 instance type naming"""
# Common valid instance types # Common valid instance types
valid_types = ['t2.micro', 't3.micro', 't3.small', 't3.medium', 'm5.large'] valid_types = ["t2.micro", "t3.micro", "t3.small", "t3.medium", "m5.large"]
deprecated_types = ['t1.micro', 'm1.small'] # Very old types deprecated_types = ["t1.micro", "m1.small"] # Very old types
test_type = 't3.micro' test_type = "t3.micro"
assert test_type in valid_types, f"Unknown EC2 type: {test_type}" assert test_type in valid_types, f"Unknown EC2 type: {test_type}"
assert test_type not in deprecated_types, f"Using deprecated EC2 type: {test_type}" assert test_type not in deprecated_types, f"Using deprecated EC2 type: {test_type}"
@ -62,9 +61,10 @@ def test_aws_instance_types():
def test_vultr_instance_types(): def test_vultr_instance_types():
"""Test Vultr instance type naming""" """Test Vultr instance type naming"""
# Vultr uses format like vc2-1c-1gb # Vultr uses format like vc2-1c-1gb
test_type = 'vc2-1c-1gb' test_type = "vc2-1c-1gb"
assert any(test_type.startswith(prefix) for prefix in ['vc2-', 'vhf-', 'vhp-']), \ assert any(test_type.startswith(prefix) for prefix in ["vc2-", "vhf-", "vhp-"]), (
f"Invalid Vultr type format: {test_type}" f"Invalid Vultr type format: {test_type}"
)
print("✓ Vultr instance types test passed") print("✓ Vultr instance types test passed")

View file

@ -2,6 +2,7 @@
""" """
Test configuration file validation without deployment Test configuration file validation without deployment
""" """
import configparser import configparser
import os import os
import re import re
@ -28,14 +29,14 @@ Endpoint = 192.168.1.1:51820
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read_string(sample_config) config.read_string(sample_config)
assert 'Interface' in config, "Missing [Interface] section" assert "Interface" in config, "Missing [Interface] section"
assert 'Peer' in config, "Missing [Peer] section" assert "Peer" in config, "Missing [Peer] section"
# Validate required fields # Validate required fields
assert config['Interface'].get('PrivateKey'), "Missing PrivateKey" assert config["Interface"].get("PrivateKey"), "Missing PrivateKey"
assert config['Interface'].get('Address'), "Missing Address" assert config["Interface"].get("Address"), "Missing Address"
assert config['Peer'].get('PublicKey'), "Missing PublicKey" assert config["Peer"].get("PublicKey"), "Missing PublicKey"
assert config['Peer'].get('AllowedIPs'), "Missing AllowedIPs" assert config["Peer"].get("AllowedIPs"), "Missing AllowedIPs"
print("✓ WireGuard config format validation passed") print("✓ WireGuard config format validation passed")
@ -43,7 +44,7 @@ Endpoint = 192.168.1.1:51820
def test_base64_key_format(): def test_base64_key_format():
"""Test that keys are in valid base64 format""" """Test that keys are in valid base64 format"""
# Base64 keys can have variable length, just check format # Base64 keys can have variable length, just check format
key_pattern = re.compile(r'^[A-Za-z0-9+/]+=*$') key_pattern = re.compile(r"^[A-Za-z0-9+/]+=*$")
test_keys = [ test_keys = [
"aGVsbG8gd29ybGQgdGhpcyBpcyBub3QgYSByZWFsIGtleQo=", "aGVsbG8gd29ybGQgdGhpcyBpcyBub3QgYSByZWFsIGtleQo=",
@ -58,8 +59,8 @@ def test_base64_key_format():
def test_ip_address_format(): def test_ip_address_format():
"""Test IP address and CIDR notation validation""" """Test IP address and CIDR notation validation"""
ip_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$') ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$")
endpoint_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$') endpoint_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$")
# Test CIDR notation # Test CIDR notation
assert ip_pattern.match("10.19.49.2/32"), "Invalid CIDR notation" assert ip_pattern.match("10.19.49.2/32"), "Invalid CIDR notation"
@ -74,11 +75,7 @@ def test_ip_address_format():
def test_mobile_config_xml(): def test_mobile_config_xml():
"""Test that mobile config files would be valid XML""" """Test that mobile config files would be valid XML"""
# First check if xmllint is available # First check if xmllint is available
xmllint_check = subprocess.run( xmllint_check = subprocess.run(["which", "xmllint"], capture_output=True, text=True)
['which', 'xmllint'],
capture_output=True,
text=True
)
if xmllint_check.returncode != 0: if xmllint_check.returncode != 0:
print("⚠ Skipping XML validation test (xmllint not installed)") print("⚠ Skipping XML validation test (xmllint not installed)")
@ -99,17 +96,13 @@ def test_mobile_config_xml():
</dict> </dict>
</plist>""" </plist>"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.mobileconfig', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".mobileconfig", delete=False) as f:
f.write(sample_mobileconfig) f.write(sample_mobileconfig)
temp_file = f.name temp_file = f.name
try: try:
# Use xmllint to validate # Use xmllint to validate
result = subprocess.run( result = subprocess.run(["xmllint", "--noout", temp_file], capture_output=True, text=True)
['xmllint', '--noout', temp_file],
capture_output=True,
text=True
)
assert result.returncode == 0, f"XML validation failed: {result.stderr}" assert result.returncode == 0, f"XML validation failed: {result.stderr}"
print("✓ Mobile config XML validation passed") print("✓ Mobile config XML validation passed")

View file

@ -3,6 +3,7 @@
Simplified Docker-based localhost deployment tests Simplified Docker-based localhost deployment tests
Verifies services can start and config files exist in expected locations Verifies services can start and config files exist in expected locations
""" """
import os import os
import subprocess import subprocess
import sys import sys
@ -11,7 +12,7 @@ import sys
def check_docker_available(): def check_docker_available():
"""Check if Docker is available""" """Check if Docker is available"""
try: try:
result = subprocess.run(['docker', '--version'], capture_output=True, text=True) result = subprocess.run(["docker", "--version"], capture_output=True, text=True)
return result.returncode == 0 return result.returncode == 0
except FileNotFoundError: except FileNotFoundError:
return False return False
@ -31,8 +32,8 @@ AllowedIPs = 10.19.49.2/32,fd9d:bc11:4020::2/128
""" """
# Just validate the format # Just validate the format
required_sections = ['[Interface]', '[Peer]'] required_sections = ["[Interface]", "[Peer]"]
required_fields = ['PrivateKey', 'Address', 'PublicKey', 'AllowedIPs'] required_fields = ["PrivateKey", "Address", "PublicKey", "AllowedIPs"]
for section in required_sections: for section in required_sections:
if section not in config: if section not in config:
@ -68,15 +69,15 @@ conn ikev2-pubkey
""" """
# Validate format # Validate format
if 'config setup' not in config: if "config setup" not in config:
print("✗ Missing 'config setup' section") print("✗ Missing 'config setup' section")
return False return False
if 'conn %default' not in config: if "conn %default" not in config:
print("✗ Missing 'conn %default' section") print("✗ Missing 'conn %default' section")
return False return False
if 'keyexchange=ikev2' not in config: if "keyexchange=ikev2" not in config:
print("✗ Missing IKEv2 configuration") print("✗ Missing IKEv2 configuration")
return False return False
@ -87,19 +88,19 @@ conn ikev2-pubkey
def test_docker_algo_image(): def test_docker_algo_image():
"""Test that the Algo Docker image can be built""" """Test that the Algo Docker image can be built"""
# Check if Dockerfile exists # Check if Dockerfile exists
if not os.path.exists('Dockerfile'): if not os.path.exists("Dockerfile"):
print("✗ Dockerfile not found") print("✗ Dockerfile not found")
return False return False
# Read Dockerfile and validate basic structure # Read Dockerfile and validate basic structure
with open('Dockerfile') as f: with open("Dockerfile") as f:
dockerfile_content = f.read() dockerfile_content = f.read()
required_elements = [ required_elements = [
'FROM', # Base image "FROM", # Base image
'RUN', # Build commands "RUN", # Build commands
'COPY', # Copy Algo files "COPY", # Copy Algo files
'python' # Python dependency "python", # Python dependency
] ]
missing = [] missing = []
@ -115,16 +116,14 @@ def test_docker_algo_image():
return True return True
def test_localhost_deployment_requirements(): def test_localhost_deployment_requirements():
"""Test that localhost deployment requirements are met""" """Test that localhost deployment requirements are met"""
requirements = { requirements = {
'Python 3.8+': sys.version_info >= (3, 8), "Python 3.8+": sys.version_info >= (3, 8),
'Ansible installed': subprocess.run(['which', 'ansible'], capture_output=True).returncode == 0, "Ansible installed": subprocess.run(["which", "ansible"], capture_output=True).returncode == 0,
'Main playbook exists': os.path.exists('main.yml'), "Main playbook exists": os.path.exists("main.yml"),
'Project config exists': os.path.exists('pyproject.toml'), "Project config exists": os.path.exists("pyproject.toml"),
'Config template exists': os.path.exists('config.cfg.example') or os.path.exists('config.cfg'), "Config template exists": os.path.exists("config.cfg.example") or os.path.exists("config.cfg"),
} }
all_met = True all_met = True
@ -138,10 +137,6 @@ def test_localhost_deployment_requirements():
return all_met return all_met
if __name__ == "__main__": if __name__ == "__main__":
print("Running Docker localhost deployment tests...") print("Running Docker localhost deployment tests...")
print("=" * 50) print("=" * 50)

View file

@ -3,6 +3,7 @@
Test that generated configuration files have valid syntax Test that generated configuration files have valid syntax
This validates WireGuard, StrongSwan, SSH, and other configs This validates WireGuard, StrongSwan, SSH, and other configs
""" """
import re import re
import subprocess import subprocess
import sys import sys
@ -11,7 +12,7 @@ import sys
def check_command_available(cmd): def check_command_available(cmd):
"""Check if a command is available on the system""" """Check if a command is available on the system"""
try: try:
subprocess.run([cmd, '--version'], capture_output=True, check=False) subprocess.run([cmd, "--version"], capture_output=True, check=False)
return True return True
except FileNotFoundError: except FileNotFoundError:
return False return False
@ -37,51 +38,50 @@ PersistentKeepalive = 25
errors = [] errors = []
# Check for required sections # Check for required sections
if '[Interface]' not in sample_config: if "[Interface]" not in sample_config:
errors.append("Missing [Interface] section") errors.append("Missing [Interface] section")
if '[Peer]' not in sample_config: if "[Peer]" not in sample_config:
errors.append("Missing [Peer] section") errors.append("Missing [Peer] section")
# Validate Interface section # Validate Interface section
interface_match = re.search(r'\[Interface\](.*?)\[Peer\]', sample_config, re.DOTALL) interface_match = re.search(r"\[Interface\](.*?)\[Peer\]", sample_config, re.DOTALL)
if interface_match: if interface_match:
interface_section = interface_match.group(1) interface_section = interface_match.group(1)
# Check required fields # Check required fields
if not re.search(r'Address\s*=', interface_section): if not re.search(r"Address\s*=", interface_section):
errors.append("Missing Address in Interface section") errors.append("Missing Address in Interface section")
if not re.search(r'PrivateKey\s*=', interface_section): if not re.search(r"PrivateKey\s*=", interface_section):
errors.append("Missing PrivateKey in Interface section") errors.append("Missing PrivateKey in Interface section")
# Validate IP addresses # Validate IP addresses
address_match = re.search(r'Address\s*=\s*([^\n]+)', interface_section) address_match = re.search(r"Address\s*=\s*([^\n]+)", interface_section)
if address_match: if address_match:
addresses = address_match.group(1).split(',') addresses = address_match.group(1).split(",")
for addr in addresses: for addr in addresses:
addr = addr.strip() addr = addr.strip()
# Basic IP validation # Basic IP validation
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', addr) and \ if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", addr) and not re.match(r"^[0-9a-fA-F:]+/\d+$", addr):
not re.match(r'^[0-9a-fA-F:]+/\d+$', addr):
errors.append(f"Invalid IP address format: {addr}") errors.append(f"Invalid IP address format: {addr}")
# Validate Peer section # Validate Peer section
peer_match = re.search(r'\[Peer\](.*)', sample_config, re.DOTALL) peer_match = re.search(r"\[Peer\](.*)", sample_config, re.DOTALL)
if peer_match: if peer_match:
peer_section = peer_match.group(1) peer_section = peer_match.group(1)
# Check required fields # Check required fields
if not re.search(r'PublicKey\s*=', peer_section): if not re.search(r"PublicKey\s*=", peer_section):
errors.append("Missing PublicKey in Peer section") errors.append("Missing PublicKey in Peer section")
if not re.search(r'AllowedIPs\s*=', peer_section): if not re.search(r"AllowedIPs\s*=", peer_section):
errors.append("Missing AllowedIPs in Peer section") errors.append("Missing AllowedIPs in Peer section")
if not re.search(r'Endpoint\s*=', peer_section): if not re.search(r"Endpoint\s*=", peer_section):
errors.append("Missing Endpoint in Peer section") errors.append("Missing Endpoint in Peer section")
# Validate endpoint format # Validate endpoint format
endpoint_match = re.search(r'Endpoint\s*=\s*([^\n]+)', peer_section) endpoint_match = re.search(r"Endpoint\s*=\s*([^\n]+)", peer_section)
if endpoint_match: if endpoint_match:
endpoint = endpoint_match.group(1).strip() endpoint = endpoint_match.group(1).strip()
if not re.match(r'^[\d\.\:]+:\d+$', endpoint): if not re.match(r"^[\d\.\:]+:\d+$", endpoint):
errors.append(f"Invalid Endpoint format: {endpoint}") errors.append(f"Invalid Endpoint format: {endpoint}")
if errors: if errors:
@ -132,33 +132,32 @@ conn ikev2-pubkey
errors = [] errors = []
# Check for required sections # Check for required sections
if 'config setup' not in sample_config: if "config setup" not in sample_config:
errors.append("Missing 'config setup' section") errors.append("Missing 'config setup' section")
if 'conn %default' not in sample_config: if "conn %default" not in sample_config:
errors.append("Missing 'conn %default' section") errors.append("Missing 'conn %default' section")
# Validate connection settings # Validate connection settings
conn_pattern = re.compile(r'conn\s+(\S+)') conn_pattern = re.compile(r"conn\s+(\S+)")
connections = conn_pattern.findall(sample_config) connections = conn_pattern.findall(sample_config)
if len(connections) < 2: # Should have at least %default and one other if len(connections) < 2: # Should have at least %default and one other
errors.append("Not enough connection definitions") errors.append("Not enough connection definitions")
# Check for required parameters in connections # Check for required parameters in connections
required_params = ['keyexchange', 'left', 'right'] required_params = ["keyexchange", "left", "right"]
for param in required_params: for param in required_params:
if f'{param}=' not in sample_config: if f"{param}=" not in sample_config:
errors.append(f"Missing required parameter: {param}") errors.append(f"Missing required parameter: {param}")
# Validate IP subnet formats # Validate IP subnet formats
subnet_pattern = re.compile(r'(left|right)subnet\s*=\s*([^\n]+)') subnet_pattern = re.compile(r"(left|right)subnet\s*=\s*([^\n]+)")
for match in subnet_pattern.finditer(sample_config): for match in subnet_pattern.finditer(sample_config):
subnets = match.group(2).split(',') subnets = match.group(2).split(",")
for subnet in subnets: for subnet in subnets:
subnet = subnet.strip() subnet = subnet.strip()
if subnet != '0.0.0.0/0' and subnet != '::/0': if subnet != "0.0.0.0/0" and subnet != "::/0":
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', subnet) and \ if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", subnet) and not re.match(r"^[0-9a-fA-F:]+/\d+$", subnet):
not re.match(r'^[0-9a-fA-F:]+/\d+$', subnet):
errors.append(f"Invalid subnet format: {subnet}") errors.append(f"Invalid subnet format: {subnet}")
if errors: if errors:
@ -188,21 +187,21 @@ def test_ssh_config_syntax():
errors = [] errors = []
# Parse SSH config format # Parse SSH config format
lines = sample_config.strip().split('\n') lines = sample_config.strip().split("\n")
current_host = None current_host = None
for line in lines: for line in lines:
line = line.strip() line = line.strip()
if not line or line.startswith('#'): if not line or line.startswith("#"):
continue continue
if line.startswith('Host '): if line.startswith("Host "):
current_host = line.split()[1] current_host = line.split()[1]
elif current_host and ' ' in line: elif current_host and " " in line:
key, value = line.split(None, 1) key, value = line.split(None, 1)
# Validate common SSH options # Validate common SSH options
if key == 'Port': if key == "Port":
try: try:
port = int(value) port = int(value)
if not 1 <= port <= 65535: if not 1 <= port <= 65535:
@ -210,7 +209,7 @@ def test_ssh_config_syntax():
except ValueError: except ValueError:
errors.append(f"Port must be a number: {value}") errors.append(f"Port must be a number: {value}")
elif key == 'LocalForward': elif key == "LocalForward":
# Format: LocalForward [bind_address:]port host:hostport # Format: LocalForward [bind_address:]port host:hostport
parts = value.split() parts = value.split()
if len(parts) != 2: if len(parts) != 2:
@ -256,35 +255,35 @@ COMMIT
errors = [] errors = []
# Check table definitions # Check table definitions
tables = re.findall(r'\*(\w+)', sample_rules) tables = re.findall(r"\*(\w+)", sample_rules)
if 'filter' not in tables: if "filter" not in tables:
errors.append("Missing *filter table") errors.append("Missing *filter table")
if 'nat' not in tables: if "nat" not in tables:
errors.append("Missing *nat table") errors.append("Missing *nat table")
# Check for COMMIT statements # Check for COMMIT statements
commit_count = sample_rules.count('COMMIT') commit_count = sample_rules.count("COMMIT")
if commit_count != len(tables): if commit_count != len(tables):
errors.append(f"Number of COMMIT statements ({commit_count}) doesn't match tables ({len(tables)})") errors.append(f"Number of COMMIT statements ({commit_count}) doesn't match tables ({len(tables)})")
# Validate chain policies # Validate chain policies
chain_pattern = re.compile(r'^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]', re.MULTILINE) chain_pattern = re.compile(r"^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]", re.MULTILINE)
chains = chain_pattern.findall(sample_rules) chains = chain_pattern.findall(sample_rules)
required_chains = [('INPUT', 'DROP'), ('FORWARD', 'DROP'), ('OUTPUT', 'ACCEPT')] required_chains = [("INPUT", "DROP"), ("FORWARD", "DROP"), ("OUTPUT", "ACCEPT")]
for chain, _policy in required_chains: for chain, _policy in required_chains:
if not any(c[0] == chain for c in chains): if not any(c[0] == chain for c in chains):
errors.append(f"Missing required chain: {chain}") errors.append(f"Missing required chain: {chain}")
# Validate rule syntax # Validate rule syntax
rule_pattern = re.compile(r'^-[AI]\s+(\w+)', re.MULTILINE) rule_pattern = re.compile(r"^-[AI]\s+(\w+)", re.MULTILINE)
rules = rule_pattern.findall(sample_rules) rules = rule_pattern.findall(sample_rules)
if len(rules) < 5: if len(rules) < 5:
errors.append("Insufficient firewall rules") errors.append("Insufficient firewall rules")
# Check for essential security rules # Check for essential security rules
if '-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT' not in sample_rules: if "-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT" not in sample_rules:
errors.append("Missing stateful connection tracking rule") errors.append("Missing stateful connection tracking rule")
if errors: if errors:
@ -320,27 +319,26 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
errors = [] errors = []
# Parse config # Parse config
for line in sample_config.strip().split('\n'): for line in sample_config.strip().split("\n"):
line = line.strip() line = line.strip()
if not line or line.startswith('#'): if not line or line.startswith("#"):
continue continue
# Most dnsmasq options are key=value or just key # Most dnsmasq options are key=value or just key
if '=' in line: if "=" in line:
key, value = line.split('=', 1) key, value = line.split("=", 1)
# Validate specific options # Validate specific options
if key == 'interface': if key == "interface":
if not re.match(r'^[a-zA-Z0-9\-_]+$', value): if not re.match(r"^[a-zA-Z0-9\-_]+$", value):
errors.append(f"Invalid interface name: {value}") errors.append(f"Invalid interface name: {value}")
elif key == 'server': elif key == "server":
# Basic IP validation # Basic IP validation
if not re.match(r'^\d+\.\d+\.\d+\.\d+$', value) and \ if not re.match(r"^\d+\.\d+\.\d+\.\d+$", value) and not re.match(r"^[0-9a-fA-F:]+$", value):
not re.match(r'^[0-9a-fA-F:]+$', value):
errors.append(f"Invalid DNS server IP: {value}") errors.append(f"Invalid DNS server IP: {value}")
elif key == 'cache-size': elif key == "cache-size":
try: try:
size = int(value) size = int(value)
if size < 0: if size < 0:
@ -349,9 +347,9 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
errors.append(f"Cache size must be a number: {value}") errors.append(f"Cache size must be a number: {value}")
# Check for required options # Check for required options
required = ['interface', 'server'] required = ["interface", "server"]
for req in required: for req in required:
if f'{req}=' not in sample_config: if f"{req}=" not in sample_config:
errors.append(f"Missing required option: {req}") errors.append(f"Missing required option: {req}")
if errors: if errors:

View file

@ -14,175 +14,233 @@ from jinja2 import Environment, FileSystemLoader
def load_template(template_name): def load_template(template_name):
"""Load a Jinja2 template from the roles/common/templates directory.""" """Load a Jinja2 template from the roles/common/templates directory."""
template_dir = Path(__file__).parent.parent.parent / 'roles' / 'common' / 'templates' template_dir = Path(__file__).parent.parent.parent / "roles" / "common" / "templates"
env = Environment(loader=FileSystemLoader(str(template_dir))) env = Environment(loader=FileSystemLoader(str(template_dir)))
return env.get_template(template_name) return env.get_template(template_name)
def test_wireguard_nat_rules_ipv4(): def test_wireguard_nat_rules_ipv4():
"""Test that WireGuard traffic gets proper NAT rules without policy matching.""" """Test that WireGuard traffic gets proper NAT rules without policy matching."""
template = load_template('rules.v4.j2') template = load_template("rules.v4.j2")
# Test with WireGuard enabled # Test with WireGuard enabled
result = template.render( result = template.render(
ipsec_enabled=False, ipsec_enabled=False,
wireguard_enabled=True, wireguard_enabled=True,
wireguard_network_ipv4='10.49.0.0/16', wireguard_network_ipv4="10.49.0.0/16",
wireguard_port=51820, wireguard_port=51820,
wireguard_port_avoid=53, wireguard_port_avoid=53,
wireguard_port_actual=51820, wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'}, ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None, snat_aipv4=None,
BetweenClients_DROP=True, BetweenClients_DROP=True,
block_smb=True, block_smb=True,
block_netbios=True, block_netbios=True,
local_service_ip='10.49.0.1', local_service_ip="10.49.0.1",
ansible_ssh_port=22, ansible_ssh_port=22,
reduce_mtu=0 reduce_mtu=0,
) )
# Verify NAT rule exists without policy matching # Verify NAT rule exists with output interface and without policy matching
assert '-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE' in result assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
# Verify no policy matching in WireGuard NAT rules # Verify no policy matching in WireGuard NAT rules
assert '-A POSTROUTING -s 10.49.0.0/16 -m policy' not in result assert "-A POSTROUTING -s 10.49.0.0/16 -m policy" not in result
def test_ipsec_nat_rules_ipv4(): def test_ipsec_nat_rules_ipv4():
"""Test that IPsec traffic gets proper NAT rules without policy matching.""" """Test that IPsec traffic gets proper NAT rules without policy matching."""
template = load_template('rules.v4.j2') template = load_template("rules.v4.j2")
# Test with IPsec enabled # Test with IPsec enabled
result = template.render( result = template.render(
ipsec_enabled=True, ipsec_enabled=True,
wireguard_enabled=False, wireguard_enabled=False,
strongswan_network='10.48.0.0/16', strongswan_network="10.48.0.0/16",
strongswan_network_ipv6='2001:db8::/48', strongswan_network_ipv6="2001:db8::/48",
ansible_default_ipv4={'interface': 'eth0'}, ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None, snat_aipv4=None,
BetweenClients_DROP=True, BetweenClients_DROP=True,
block_smb=True, block_smb=True,
block_netbios=True, block_netbios=True,
local_service_ip='10.48.0.1', local_service_ip="10.48.0.1",
ansible_ssh_port=22, ansible_ssh_port=22,
reduce_mtu=0 reduce_mtu=0,
) )
# Verify NAT rule exists without policy matching # Verify NAT rule exists with output interface and without policy matching
assert '-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE' in result assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
# Verify no policy matching in IPsec NAT rules (this was the bug) # Verify no policy matching in IPsec NAT rules (this was the bug)
assert '-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none' not in result assert "-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none" not in result
def test_both_vpns_nat_rules_ipv4(): def test_both_vpns_nat_rules_ipv4():
"""Test NAT rules when both VPN types are enabled.""" """Test NAT rules when both VPN types are enabled."""
template = load_template('rules.v4.j2') template = load_template("rules.v4.j2")
result = template.render( result = template.render(
ipsec_enabled=True, ipsec_enabled=True,
wireguard_enabled=True, wireguard_enabled=True,
strongswan_network='10.48.0.0/16', strongswan_network="10.48.0.0/16",
wireguard_network_ipv4='10.49.0.0/16', wireguard_network_ipv4="10.49.0.0/16",
strongswan_network_ipv6='2001:db8::/48', strongswan_network_ipv6="2001:db8::/48",
wireguard_network_ipv6='2001:db8:a160::/48', wireguard_network_ipv6="2001:db8:a160::/48",
wireguard_port=51820, wireguard_port=51820,
wireguard_port_avoid=53, wireguard_port_avoid=53,
wireguard_port_actual=51820, wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'}, ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None, snat_aipv4=None,
BetweenClients_DROP=True, BetweenClients_DROP=True,
block_smb=True, block_smb=True,
block_netbios=True, block_netbios=True,
local_service_ip='10.49.0.1', local_service_ip="10.49.0.1",
ansible_ssh_port=22, ansible_ssh_port=22,
reduce_mtu=0 reduce_mtu=0,
) )
# Both should have NAT rules # Both should have NAT rules with output interface
assert '-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE' in result assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
assert '-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE' in result assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
# Neither should have policy matching # Neither should have policy matching
assert '-m policy --pol none' not in result assert "-m policy --pol none" not in result
def test_alternative_ingress_snat(): def test_alternative_ingress_snat():
"""Test that alternative ingress IP uses SNAT instead of MASQUERADE.""" """Test that alternative ingress IP uses SNAT instead of MASQUERADE."""
template = load_template('rules.v4.j2') template = load_template("rules.v4.j2")
result = template.render( result = template.render(
ipsec_enabled=True, ipsec_enabled=True,
wireguard_enabled=True, wireguard_enabled=True,
strongswan_network='10.48.0.0/16', strongswan_network="10.48.0.0/16",
wireguard_network_ipv4='10.49.0.0/16', wireguard_network_ipv4="10.49.0.0/16",
strongswan_network_ipv6='2001:db8::/48', strongswan_network_ipv6="2001:db8::/48",
wireguard_network_ipv6='2001:db8:a160::/48', wireguard_network_ipv6="2001:db8:a160::/48",
wireguard_port=51820, wireguard_port=51820,
wireguard_port_avoid=53, wireguard_port_avoid=53,
wireguard_port_actual=51820, wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'}, ansible_default_ipv4={"interface": "eth0"},
snat_aipv4='192.168.1.100', # Alternative ingress IP snat_aipv4="192.168.1.100", # Alternative ingress IP
BetweenClients_DROP=True, BetweenClients_DROP=True,
block_smb=True, block_smb=True,
block_netbios=True, block_netbios=True,
local_service_ip='10.49.0.1', local_service_ip="10.49.0.1",
ansible_ssh_port=22, ansible_ssh_port=22,
reduce_mtu=0 reduce_mtu=0,
) )
# Should use SNAT with specific IP instead of MASQUERADE # Should use SNAT with specific IP and output interface instead of MASQUERADE
assert '-A POSTROUTING -s 10.48.0.0/16 -j SNAT --to 192.168.1.100' in result assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
assert '-A POSTROUTING -s 10.49.0.0/16 -j SNAT --to 192.168.1.100' in result assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
assert 'MASQUERADE' not in result assert "MASQUERADE" not in result
def test_ipsec_forward_rule_has_policy_match(): def test_ipsec_forward_rule_has_policy_match():
"""Test that IPsec FORWARD rules still use policy matching (this is correct).""" """Test that IPsec FORWARD rules still use policy matching (this is correct)."""
template = load_template('rules.v4.j2') template = load_template("rules.v4.j2")
result = template.render( result = template.render(
ipsec_enabled=True, ipsec_enabled=True,
wireguard_enabled=False, wireguard_enabled=False,
strongswan_network='10.48.0.0/16', strongswan_network="10.48.0.0/16",
strongswan_network_ipv6='2001:db8::/48', strongswan_network_ipv6="2001:db8::/48",
ansible_default_ipv4={'interface': 'eth0'}, ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None, snat_aipv4=None,
BetweenClients_DROP=True, BetweenClients_DROP=True,
block_smb=True, block_smb=True,
block_netbios=True, block_netbios=True,
local_service_ip='10.48.0.1', local_service_ip="10.48.0.1",
ansible_ssh_port=22, ansible_ssh_port=22,
reduce_mtu=0 reduce_mtu=0,
) )
# FORWARD rule should have policy match (this is correct and should stay) # FORWARD rule should have policy match (this is correct and should stay)
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT' in result assert "-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT" in result
def test_wireguard_forward_rule_no_policy_match(): def test_wireguard_forward_rule_no_policy_match():
"""Test that WireGuard FORWARD rules don't use policy matching.""" """Test that WireGuard FORWARD rules don't use policy matching."""
template = load_template('rules.v4.j2') template = load_template("rules.v4.j2")
result = template.render( result = template.render(
ipsec_enabled=False, ipsec_enabled=False,
wireguard_enabled=True, wireguard_enabled=True,
wireguard_network_ipv4='10.49.0.0/16', wireguard_network_ipv4="10.49.0.0/16",
wireguard_port=51820, wireguard_port=51820,
wireguard_port_avoid=53, wireguard_port_avoid=53,
wireguard_port_actual=51820, wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'}, ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None, snat_aipv4=None,
BetweenClients_DROP=True, BetweenClients_DROP=True,
block_smb=True, block_smb=True,
block_netbios=True, block_netbios=True,
local_service_ip='10.49.0.1', local_service_ip="10.49.0.1",
ansible_ssh_port=22, ansible_ssh_port=22,
reduce_mtu=0 reduce_mtu=0,
) )
# WireGuard FORWARD rule should NOT have any policy match # WireGuard FORWARD rule should NOT have any policy match
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT' in result assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT" in result
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy' not in result assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy" not in result
if __name__ == '__main__': def test_output_interface_in_nat_rules():
pytest.main([__file__, '-v']) """Test that output interface is specified in NAT rules."""
template = load_template("rules.v4.j2")
result = template.render(
snat_aipv4=False,
wireguard_enabled=True,
ipsec_enabled=True,
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network="10.48.0.0/16",
ansible_default_ipv4={"interface": "eth0", "address": "10.0.0.1"},
ansible_default_ipv6={"interface": "eth0", "address": "fd9d:bc11:4020::1"},
wireguard_port_actual=51820,
wireguard_port_avoid=53,
wireguard_port=51820,
ansible_ssh_port=22,
reduce_mtu=0,
)
# Check that output interface is specified for both VPNs
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
# Ensure we don't have rules without output interface
assert "-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE" not in result
assert "-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE" not in result
def test_dns_firewall_restricted_to_vpn():
"""Test that DNS access is restricted to VPN clients only."""
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=True,
wireguard_enabled=True,
strongswan_network="10.48.0.0/16",
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
wireguard_network_ipv6="2001:db8:a160::/48",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip="172.23.198.242",
ansible_ssh_port=22,
reduce_mtu=0,
)
# DNS should only be accessible from VPN subnets
assert "-A INPUT -s 10.48.0.0/16,10.49.0.0/16 -d 172.23.198.242 -p udp --dport 53 -j ACCEPT" in result
# Should NOT have unrestricted DNS access
assert "-A INPUT -d 172.23.198.242 -p udp --dport 53 -j ACCEPT" not in result
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -12,7 +12,7 @@ import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
# Add the library directory to the path # Add the library directory to the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../library')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../library"))
class TestLightsailBoto3Fix(unittest.TestCase): class TestLightsailBoto3Fix(unittest.TestCase):
@ -22,15 +22,15 @@ class TestLightsailBoto3Fix(unittest.TestCase):
"""Set up test fixtures.""" """Set up test fixtures."""
# Mock the ansible module_utils since we're testing outside of Ansible # Mock the ansible module_utils since we're testing outside of Ansible
self.mock_modules = { self.mock_modules = {
'ansible.module_utils.basic': MagicMock(), "ansible.module_utils.basic": MagicMock(),
'ansible.module_utils.ec2': MagicMock(), "ansible.module_utils.ec2": MagicMock(),
'ansible.module_utils.aws.core': MagicMock(), "ansible.module_utils.aws.core": MagicMock(),
} }
# Apply mocks # Apply mocks
self.patches = [] self.patches = []
for module_name, mock_module in self.mock_modules.items(): for module_name, mock_module in self.mock_modules.items():
patcher = patch.dict('sys.modules', {module_name: mock_module}) patcher = patch.dict("sys.modules", {module_name: mock_module})
patcher.start() patcher.start()
self.patches.append(patcher) self.patches.append(patcher)
@ -45,7 +45,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# Import the module # Import the module
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"lightsail_region_facts", "lightsail_region_facts",
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py') os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
) )
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
@ -54,7 +54,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# Verify the module loaded # Verify the module loaded
self.assertIsNotNone(module) self.assertIsNotNone(module)
self.assertTrue(hasattr(module, 'main')) self.assertTrue(hasattr(module, "main"))
except Exception as e: except Exception as e:
self.fail(f"Failed to import lightsail_region_facts: {e}") self.fail(f"Failed to import lightsail_region_facts: {e}")
@ -62,15 +62,13 @@ class TestLightsailBoto3Fix(unittest.TestCase):
def test_get_aws_connection_info_called_without_boto3(self): def test_get_aws_connection_info_called_without_boto3(self):
"""Test that get_aws_connection_info is called without boto3 parameter.""" """Test that get_aws_connection_info is called without boto3 parameter."""
# Mock get_aws_connection_info to track calls # Mock get_aws_connection_info to track calls
mock_get_aws_connection_info = MagicMock( mock_get_aws_connection_info = MagicMock(return_value=("us-west-2", None, {}))
return_value=('us-west-2', None, {})
)
with patch('ansible.module_utils.ec2.get_aws_connection_info', mock_get_aws_connection_info): with patch("ansible.module_utils.ec2.get_aws_connection_info", mock_get_aws_connection_info):
# Import the module # Import the module
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"lightsail_region_facts", "lightsail_region_facts",
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py') os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
) )
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
@ -79,7 +77,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
mock_ansible_module.params = {} mock_ansible_module.params = {}
mock_ansible_module.check_mode = False mock_ansible_module.check_mode = False
with patch('ansible.module_utils.basic.AnsibleModule', return_value=mock_ansible_module): with patch("ansible.module_utils.basic.AnsibleModule", return_value=mock_ansible_module):
# Execute the module # Execute the module
try: try:
spec.loader.exec_module(module) spec.loader.exec_module(module)
@ -100,28 +98,35 @@ class TestLightsailBoto3Fix(unittest.TestCase):
if call_args: if call_args:
# Check positional arguments # Check positional arguments
if call_args[0]: # args if call_args[0]: # args
self.assertTrue(len(call_args[0]) <= 1, self.assertTrue(
"get_aws_connection_info should be called with at most 1 positional arg (module)") len(call_args[0]) <= 1,
"get_aws_connection_info should be called with at most 1 positional arg (module)",
)
# Check keyword arguments # Check keyword arguments
if call_args[1]: # kwargs if call_args[1]: # kwargs
self.assertNotIn('boto3', call_args[1], self.assertNotIn(
"get_aws_connection_info should not be called with boto3 parameter") "boto3", call_args[1], "get_aws_connection_info should not be called with boto3 parameter"
)
def test_no_boto3_parameter_in_source(self): def test_no_boto3_parameter_in_source(self):
"""Verify that boto3 parameter is not present in the source code.""" """Verify that boto3 parameter is not present in the source code."""
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py') lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
with open(lightsail_path) as f: with open(lightsail_path) as f:
content = f.read() content = f.read()
# Check that boto3=True is not in the file # Check that boto3=True is not in the file
self.assertNotIn('boto3=True', content, self.assertNotIn(
"boto3=True parameter should not be present in lightsail_region_facts.py") "boto3=True", content, "boto3=True parameter should not be present in lightsail_region_facts.py"
)
# Check that boto3 parameter is not used with get_aws_connection_info # Check that boto3 parameter is not used with get_aws_connection_info
self.assertNotIn('get_aws_connection_info(module, boto3', content, self.assertNotIn(
"get_aws_connection_info should not be called with boto3 parameter") "get_aws_connection_info(module, boto3",
content,
"get_aws_connection_info should not be called with boto3 parameter",
)
def test_regression_issue_14822(self): def test_regression_issue_14822(self):
""" """
@ -132,26 +137,28 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# The boto3 parameter was deprecated and removed in amazon.aws collection # The boto3 parameter was deprecated and removed in amazon.aws collection
# that comes with Ansible 11.x # that comes with Ansible 11.x
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py') lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
with open(lightsail_path) as f: with open(lightsail_path) as f:
lines = f.readlines() lines = f.readlines()
# Find the line that calls get_aws_connection_info # Find the line that calls get_aws_connection_info
for line_num, line in enumerate(lines, 1): for line_num, line in enumerate(lines, 1):
if 'get_aws_connection_info' in line and 'region' in line: if "get_aws_connection_info" in line and "region" in line:
# This should be around line 85 # This should be around line 85
# Verify it doesn't have boto3=True # Verify it doesn't have boto3=True
self.assertNotIn('boto3', line, self.assertNotIn("boto3", line, f"Line {line_num} should not contain boto3 parameter")
f"Line {line_num} should not contain boto3 parameter")
# Verify the correct format # Verify the correct format
self.assertIn('get_aws_connection_info(module)', line, self.assertIn(
f"Line {line_num} should call get_aws_connection_info(module) without boto3") "get_aws_connection_info(module)",
line,
f"Line {line_num} should call get_aws_connection_info(module) without boto3",
)
break break
else: else:
self.fail("Could not find get_aws_connection_info call in lightsail_region_facts.py") self.fail("Could not find get_aws_connection_info call in lightsail_region_facts.py")
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -5,6 +5,7 @@ Hybrid approach: validates actual certificates when available, else tests templa
Based on issues #14755, #14718 - Apple device compatibility Based on issues #14755, #14718 - Apple device compatibility
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions) Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
""" """
import glob import glob
import os import os
import re import re
@ -22,7 +23,7 @@ def find_generated_certificates():
config_patterns = [ config_patterns = [
"configs/*/ipsec/.pki/cacert.pem", "configs/*/ipsec/.pki/cacert.pem",
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory "../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
"../../configs/*/ipsec/.pki/cacert.pem" # Alternative path "../../configs/*/ipsec/.pki/cacert.pem", # Alternative path
] ]
for pattern in config_patterns: for pattern in config_patterns:
@ -30,26 +31,23 @@ def find_generated_certificates():
if ca_certs: if ca_certs:
base_path = os.path.dirname(ca_certs[0]) base_path = os.path.dirname(ca_certs[0])
return { return {
'ca_cert': ca_certs[0], "ca_cert": ca_certs[0],
'base_path': base_path, "base_path": base_path,
'server_certs': glob.glob(f"{base_path}/certs/*.crt"), "server_certs": glob.glob(f"{base_path}/certs/*.crt"),
'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12") "p12_files": glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12"),
} }
return None return None
def test_openssl_version_detection(): def test_openssl_version_detection():
"""Test that we can detect OpenSSL version for compatibility checks""" """Test that we can detect OpenSSL version for compatibility checks"""
result = subprocess.run( result = subprocess.run(["openssl", "version"], capture_output=True, text=True)
['openssl', 'version'],
capture_output=True,
text=True
)
assert result.returncode == 0, "Failed to get OpenSSL version" assert result.returncode == 0, "Failed to get OpenSSL version"
# Parse version - e.g., "OpenSSL 3.0.2 15 Mar 2022" # Parse version - e.g., "OpenSSL 3.0.2 15 Mar 2022"
version_match = re.search(r'OpenSSL\s+(\d+)\.(\d+)\.(\d+)', result.stdout) version_match = re.search(r"OpenSSL\s+(\d+)\.(\d+)\.(\d+)", result.stdout)
assert version_match, f"Can't parse OpenSSL version: {result.stdout}" assert version_match, f"Can't parse OpenSSL version: {result.stdout}"
major = int(version_match.group(1)) major = int(version_match.group(1))
@ -62,7 +60,7 @@ def test_openssl_version_detection():
def validate_ca_certificate_real(cert_files): def validate_ca_certificate_real(cert_files):
"""Validate actual Ansible-generated CA certificate""" """Validate actual Ansible-generated CA certificate"""
# Read the actual CA certificate generated by Ansible # Read the actual CA certificate generated by Ansible
with open(cert_files['ca_cert'], 'rb') as f: with open(cert_files["ca_cert"], "rb") as f:
cert_data = f.read() cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data) certificate = x509.load_pem_x509_certificate(cert_data)
@ -89,30 +87,34 @@ def validate_ca_certificate_real(cert_files):
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints" assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
# Verify public domains are excluded # Verify public domains are excluded
excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees excluded_dns = [
if isinstance(constraint, x509.DNSName)] constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.DNSName)
]
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in public_domains: for domain in public_domains:
assert domain in excluded_dns, f"CA should exclude public domain {domain}" assert domain in excluded_dns, f"CA should exclude public domain {domain}"
# Verify private IP ranges are excluded (Issue #75) # Verify private IP ranges are excluded (Issue #75)
excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees excluded_ips = [
if isinstance(constraint, x509.IPAddress)] constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.IPAddress)
]
assert len(excluded_ips) > 0, "CA should exclude private IP ranges" assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
# Verify email domains are also excluded (Issue #153) # Verify email domains are also excluded (Issue #153)
excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees excluded_emails = [
if isinstance(constraint, x509.RFC822Name)] constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.RFC822Name)
]
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in email_domains: for domain in email_domains:
assert domain in excluded_emails, f"CA should exclude email domain {domain}" assert domain in excluded_emails, f"CA should exclude email domain {domain}"
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}") print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
def validate_ca_certificate_config(): def validate_ca_certificate_config():
"""Validate CA certificate configuration in Ansible files (CI mode)""" """Validate CA certificate configuration in Ansible files (CI mode)"""
# Check that the Ansible task file has proper CA certificate configuration # Check that the Ansible task file has proper CA certificate configuration
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
@ -122,15 +124,15 @@ def validate_ca_certificate_config():
# Verify key security configurations are present # Verify key security configurations are present
security_checks = [ security_checks = [
('name_constraints_permitted', 'Name constraints should be configured'), ("name_constraints_permitted", "Name constraints should be configured"),
('name_constraints_excluded', 'Excluded name constraints should be configured'), ("name_constraints_excluded", "Excluded name constraints should be configured"),
('extended_key_usage', 'Extended Key Usage should be configured'), ("extended_key_usage", "Extended Key Usage should be configured"),
('1.3.6.1.5.5.7.3.17', 'IPsec End Entity OID should be present'), ("1.3.6.1.5.5.7.3.17", "IPsec End Entity OID should be present"),
('serverAuth', 'Server authentication EKU should be present'), ("serverAuth", "Server authentication EKU should be present"),
('clientAuth', 'Client authentication EKU should be present'), ("clientAuth", "Client authentication EKU should be present"),
('basic_constraints', 'Basic constraints should be configured'), ("basic_constraints", "Basic constraints should be configured"),
('CA:TRUE', 'CA certificate should be marked as CA'), ("CA:TRUE", "CA certificate should be marked as CA"),
('pathlen:0', 'Path length constraint should be set') ("pathlen:0", "Path length constraint should be set"),
] ]
for check, message in security_checks: for check, message in security_checks:
@ -140,7 +142,9 @@ def validate_ca_certificate_config():
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in public_domains: for domain in public_domains:
# Handle both double quotes and single quotes in YAML # Handle both double quotes and single quotes in YAML
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, f"Public domain {domain} should be excluded" assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, (
f"Public domain {domain} should be excluded"
)
# Verify private IP ranges are excluded # Verify private IP ranges are excluded
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"] private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
@ -151,13 +155,16 @@ def validate_ca_certificate_config():
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in email_domains: for domain in email_domains:
# Handle both double quotes and single quotes in YAML # Handle both double quotes and single quotes in YAML
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, f"Email domain {domain} should be excluded" assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, (
f"Email domain {domain} should be excluded"
)
# Verify IPv6 constraints are present (Issue #153) # Verify IPv6 constraints are present (Issue #153)
assert "IP:::/0" in content, "IPv6 all addresses should be excluded" assert "IP:::/0" in content, "IPv6 all addresses should be excluded"
print("✓ CA certificate configuration has proper security constraints") print("✓ CA certificate configuration has proper security constraints")
def test_ca_certificate(): def test_ca_certificate():
"""Test CA certificate - uses real certs if available, else validates config (Issue #75, #153)""" """Test CA certificate - uses real certs if available, else validates config (Issue #75, #153)"""
cert_files = find_generated_certificates() cert_files = find_generated_certificates()
@ -172,14 +179,18 @@ def validate_server_certificates_real(cert_files):
# Filter to only actual server certificates (not client certs) # Filter to only actual server certificates (not client certs)
# Server certificates contain IP addresses in the filename # Server certificates contain IP addresses in the filename
import re import re
server_certs = [f for f in cert_files['server_certs']
if not f.endswith('/cacert.pem') and re.search(r'\d+\.\d+\.\d+\.\d+\.crt$', f)] server_certs = [
f
for f in cert_files["server_certs"]
if not f.endswith("/cacert.pem") and re.search(r"\d+\.\d+\.\d+\.\d+\.crt$", f)
]
if not server_certs: if not server_certs:
print("⚠ No server certificates found") print("⚠ No server certificates found")
return return
for server_cert_path in server_certs: for server_cert_path in server_certs:
with open(server_cert_path, 'rb') as f: with open(server_cert_path, "rb") as f:
cert_data = f.read() cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data) certificate = x509.load_pem_x509_certificate(cert_data)
@ -193,7 +204,9 @@ def validate_server_certificates_real(cert_files):
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU" assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153) # Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation" assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, (
"Server cert should NOT have clientAuth EKU for role separation"
)
# Check SAN extension exists (required for Apple devices) # Check SAN extension exists (required for Apple devices)
try: try:
@ -204,9 +217,10 @@ def validate_server_certificates_real(cert_files):
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}") print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
def validate_server_certificates_config(): def validate_server_certificates_config():
"""Validate server certificate configuration in Ansible files (CI mode)""" """Validate server certificate configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
@ -215,7 +229,7 @@ def validate_server_certificates_config():
content = f.read() content = f.read()
# Look for server certificate CSR section # Look for server certificate CSR section
server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL) server_csr_section = re.search(r"Create CSRs for server certificate.*?register: server_csr", content, re.DOTALL)
if not server_csr_section: if not server_csr_section:
print("⚠ Could not find server certificate CSR section") print("⚠ Could not find server certificate CSR section")
return return
@ -224,11 +238,11 @@ def validate_server_certificates_config():
# Check server certificate CSR configuration # Check server certificate CSR configuration
server_checks = [ server_checks = [
('subject_alt_name', 'Server certificates should have SAN extension'), ("subject_alt_name", "Server certificates should have SAN extension"),
('serverAuth', 'Server certificates should have serverAuth EKU'), ("serverAuth", "Server certificates should have serverAuth EKU"),
('1.3.6.1.5.5.7.3.17', 'Server certificates should have IPsec End Entity EKU'), ("1.3.6.1.5.5.7.3.17", "Server certificates should have IPsec End Entity EKU"),
('digitalSignature', 'Server certificates should have digital signature usage'), ("digitalSignature", "Server certificates should have digital signature usage"),
('keyEncipherment', 'Server certificates should have key encipherment usage') ("keyEncipherment", "Server certificates should have key encipherment usage"),
] ]
for check, message in server_checks: for check, message in server_checks:
@ -236,15 +250,20 @@ def validate_server_certificates_config():
# Security check: Server certificates should NOT have clientAuth (Issue #153) # Security check: Server certificates should NOT have clientAuth (Issue #153)
# Look for clientAuth in extended_key_usage section, not in comments # Look for clientAuth in extended_key_usage section, not in comments
eku_lines = [line for line in server_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'clientAuth' in line)] eku_lines = [
has_client_auth = any('clientAuth' in line for line in eku_lines if line.strip().startswith('- ')) line
for line in server_section.split("\n")
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "clientAuth" in line)
]
has_client_auth = any("clientAuth" in line for line in eku_lines if line.strip().startswith("- "))
assert not has_client_auth, "Server certificates should NOT have clientAuth EKU for role separation" assert not has_client_auth, "Server certificates should NOT have clientAuth EKU for role separation"
# Verify SAN extension is configured for Apple compatibility # Verify SAN extension is configured for Apple compatibility
assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility" assert "subjectAltName" in server_section, "Server certificates missing SAN configuration for Apple compatibility"
print("✓ Server certificate configuration has proper EKU and SAN settings") print("✓ Server certificate configuration has proper EKU and SAN settings")
def test_server_certificates(): def test_server_certificates():
"""Test server certificates - uses real certs if available, else validates config""" """Test server certificates - uses real certs if available, else validates config"""
cert_files = find_generated_certificates() cert_files = find_generated_certificates()
@ -258,18 +277,18 @@ def validate_client_certificates_real(cert_files):
"""Validate actual Ansible-generated client certificates""" """Validate actual Ansible-generated client certificates"""
# Find client certificates (not CA cert, not server cert with IP/DNS name) # Find client certificates (not CA cert, not server cert with IP/DNS name)
client_certs = [] client_certs = []
for cert_path in cert_files['server_certs']: for cert_path in cert_files["server_certs"]:
if 'cacert.pem' in cert_path: if "cacert.pem" in cert_path:
continue continue
with open(cert_path, 'rb') as f: with open(cert_path, "rb") as f:
cert_data = f.read() cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data) certificate = x509.load_pem_x509_certificate(cert_data)
# Check if this looks like a client cert vs server cert # Check if this looks like a client cert vs server cert
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
# Server certs typically have IP addresses or domain names as CN # Server certs typically have IP addresses or domain names as CN
if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4): if not (cn.replace(".", "").isdigit() or "." in cn and len(cn.split(".")) == 4):
client_certs.append((cert_path, certificate)) client_certs.append((cert_path, certificate))
if not client_certs: if not client_certs:
@ -287,7 +306,9 @@ def validate_client_certificates_real(cert_files):
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153) # Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation" assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, (
"Client cert must NOT have serverAuth EKU to prevent server impersonation"
)
# Check SAN extension for email # Check SAN extension for email
try: try:
@ -299,9 +320,10 @@ def validate_client_certificates_real(cert_files):
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}") print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
def validate_client_certificates_config(): def validate_client_certificates_config():
"""Validate client certificate configuration in Ansible files (CI mode)""" """Validate client certificate configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
@ -310,7 +332,9 @@ def validate_client_certificates_config():
content = f.read() content = f.read()
# Look for client certificate CSR section # Look for client certificate CSR section
client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL) client_csr_section = re.search(
r"Create CSRs for client certificates.*?register: client_csr_jobs", content, re.DOTALL
)
if not client_csr_section: if not client_csr_section:
print("⚠ Could not find client certificate CSR section") print("⚠ Could not find client certificate CSR section")
return return
@ -319,11 +343,11 @@ def validate_client_certificates_config():
# Check client certificate configuration # Check client certificate configuration
client_checks = [ client_checks = [
('clientAuth', 'Client certificates should have clientAuth EKU'), ("clientAuth", "Client certificates should have clientAuth EKU"),
('1.3.6.1.5.5.7.3.17', 'Client certificates should have IPsec End Entity EKU'), ("1.3.6.1.5.5.7.3.17", "Client certificates should have IPsec End Entity EKU"),
('digitalSignature', 'Client certificates should have digital signature usage'), ("digitalSignature", "Client certificates should have digital signature usage"),
('keyEncipherment', 'Client certificates should have key encipherment usage'), ("keyEncipherment", "Client certificates should have key encipherment usage"),
('email:', 'Client certificates should have email SAN') ("email:", "Client certificates should have email SAN"),
] ]
for check, message in client_checks: for check, message in client_checks:
@ -331,15 +355,22 @@ def validate_client_certificates_config():
# Security check: Client certificates should NOT have serverAuth (Issue #153) # Security check: Client certificates should NOT have serverAuth (Issue #153)
# Look for serverAuth in extended_key_usage section, not in comments # Look for serverAuth in extended_key_usage section, not in comments
eku_lines = [line for line in client_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'serverAuth' in line)] eku_lines = [
has_server_auth = any('serverAuth' in line for line in eku_lines if line.strip().startswith('- ')) line
for line in client_section.split("\n")
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "serverAuth" in line)
]
has_server_auth = any("serverAuth" in line for line in eku_lines if line.strip().startswith("- "))
assert not has_server_auth, "Client certificates must NOT have serverAuth EKU to prevent server impersonation" assert not has_server_auth, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
# Verify client certificates use unique email domains (Issue #153) # Verify client certificates use unique email domains (Issue #153)
assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment" assert "openssl_constraint_random_id" in client_section, (
"Client certificates should use unique email domain per deployment"
)
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)") print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
def test_client_certificates(): def test_client_certificates():
"""Test client certificates - uses real certs if available, else validates config (Issue #75, #153)""" """Test client certificates - uses real certs if available, else validates config (Issue #75, #153)"""
cert_files = find_generated_certificates() cert_files = find_generated_certificates()
@ -351,24 +382,33 @@ def test_client_certificates():
def validate_pkcs12_files_real(cert_files): def validate_pkcs12_files_real(cert_files):
"""Validate actual Ansible-generated PKCS#12 files""" """Validate actual Ansible-generated PKCS#12 files"""
if not cert_files.get('p12_files'): if not cert_files.get("p12_files"):
print("⚠ No PKCS#12 files found") print("⚠ No PKCS#12 files found")
return return
major, minor = test_openssl_version_detection() major, minor = test_openssl_version_detection()
for p12_file in cert_files['p12_files']: for p12_file in cert_files["p12_files"]:
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}" assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
# Test that PKCS#12 file can be read (validates format) # Test that PKCS#12 file can be read (validates format)
legacy_flag = ['-legacy'] if major >= 3 else [] legacy_flag = ["-legacy"] if major >= 3 else []
result = subprocess.run([ result = subprocess.run(
'openssl', 'pkcs12', '-info', [
'-in', p12_file, "openssl",
'-passin', 'pass:', # Try empty password first "pkcs12",
'-noout' "-info",
] + legacy_flag, capture_output=True, text=True) "-in",
p12_file,
"-passin",
"pass:", # Try empty password first
"-noout",
]
+ legacy_flag,
capture_output=True,
text=True,
)
# PKCS#12 files should be readable (even if password-protected) # PKCS#12 files should be readable (even if password-protected)
# We're just testing format validity, not trying to extract contents # We're just testing format validity, not trying to extract contents
@ -378,9 +418,10 @@ def validate_pkcs12_files_real(cert_files):
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}") print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
def validate_pkcs12_files_config(): def validate_pkcs12_files_config():
"""Validate PKCS#12 file configuration in Ansible files (CI mode)""" """Validate PKCS#12 file configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
@ -390,13 +431,13 @@ def validate_pkcs12_files_config():
# Check PKCS#12 generation configuration # Check PKCS#12 generation configuration
p12_checks = [ p12_checks = [
('openssl_pkcs12', 'PKCS#12 generation should be configured'), ("openssl_pkcs12", "PKCS#12 generation should be configured"),
('encryption_level', 'PKCS#12 encryption level should be configured'), ("encryption_level", "PKCS#12 encryption level should be configured"),
('compatibility2022', 'PKCS#12 should use Apple-compatible encryption'), ("compatibility2022", "PKCS#12 should use Apple-compatible encryption"),
('friendly_name', 'PKCS#12 should have friendly names'), ("friendly_name", "PKCS#12 should have friendly names"),
('other_certificates', 'PKCS#12 should include CA certificate for full chain'), ("other_certificates", "PKCS#12 should include CA certificate for full chain"),
('passphrase', 'PKCS#12 files should be password protected'), ("passphrase", "PKCS#12 files should be password protected"),
('mode: "0600"', 'PKCS#12 files should have secure permissions') ('mode: "0600"', "PKCS#12 files should have secure permissions"),
] ]
for check, message in p12_checks: for check, message in p12_checks:
@ -404,6 +445,7 @@ def validate_pkcs12_files_config():
print("✓ PKCS#12 configuration has proper Apple device compatibility settings") print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
def test_pkcs12_files(): def test_pkcs12_files():
"""Test PKCS#12 files - uses real files if available, else validates config (Issue #14755, #14718)""" """Test PKCS#12 files - uses real files if available, else validates config (Issue #14755, #14718)"""
cert_files = find_generated_certificates() cert_files = find_generated_certificates()
@ -416,19 +458,19 @@ def test_pkcs12_files():
def validate_certificate_chain_real(cert_files): def validate_certificate_chain_real(cert_files):
"""Validate actual Ansible-generated certificate chain""" """Validate actual Ansible-generated certificate chain"""
# Load CA certificate # Load CA certificate
with open(cert_files['ca_cert'], 'rb') as f: with open(cert_files["ca_cert"], "rb") as f:
ca_cert_data = f.read() ca_cert_data = f.read()
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data) ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
# Test that all other certificates are signed by the CA # Test that all other certificates are signed by the CA
other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']] other_certs = [f for f in cert_files["server_certs"] if f != cert_files["ca_cert"]]
if not other_certs: if not other_certs:
print("⚠ No client/server certificates found to validate") print("⚠ No client/server certificates found to validate")
return return
for cert_path in other_certs: for cert_path in other_certs:
with open(cert_path, 'rb') as f: with open(cert_path, "rb") as f:
cert_data = f.read() cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data) certificate = x509.load_pem_x509_certificate(cert_data)
@ -437,6 +479,7 @@ def validate_certificate_chain_real(cert_files):
# Verify certificate is currently valid (not expired) # Verify certificate is currently valid (not expired)
from datetime import datetime from datetime import datetime
now = datetime.now(UTC) now = datetime.now(UTC)
assert certificate.not_valid_before_utc <= now, f"Certificate {cert_path} not yet valid" assert certificate.not_valid_before_utc <= now, f"Certificate {cert_path} not yet valid"
assert certificate.not_valid_after_utc >= now, f"Certificate {cert_path} has expired" assert certificate.not_valid_after_utc >= now, f"Certificate {cert_path} has expired"
@ -445,9 +488,10 @@ def validate_certificate_chain_real(cert_files):
print("✓ All real certificates properly signed by CA") print("✓ All real certificates properly signed by CA")
def validate_certificate_chain_config(): def validate_certificate_chain_config():
"""Validate certificate chain configuration in Ansible files (CI mode)""" """Validate certificate chain configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file: if not openssl_task_file:
print("⚠ Could not find openssl.yml task file") print("⚠ Could not find openssl.yml task file")
return return
@ -457,15 +501,18 @@ def validate_certificate_chain_config():
# Check certificate signing configuration # Check certificate signing configuration
chain_checks = [ chain_checks = [
('provider: ownca', 'Certificates should be signed by own CA'), ("provider: ownca", "Certificates should be signed by own CA"),
('ownca_path', 'CA certificate path should be specified'), ("ownca_path", "CA certificate path should be specified"),
('ownca_privatekey_path', 'CA private key path should be specified'), ("ownca_privatekey_path", "CA private key path should be specified"),
('ownca_privatekey_passphrase', 'CA private key should be password protected'), ("ownca_privatekey_passphrase", "CA private key should be password protected"),
('certificate_validity_days: 3650', 'Certificate validity should be configurable (default 10 years)'), ("certificate_validity_days: 3650", "Certificate validity should be configurable (default 10 years)"),
('ownca_not_after: "+{{ certificate_validity_days }}d"', 'Certificates should use configurable validity period'), (
('ownca_not_before: "-1d"', 'Certificates should have backdated start time'), 'ownca_not_after: "+{{ certificate_validity_days }}d"',
('curve: secp384r1', 'Should use strong elliptic curve cryptography'), "Certificates should use configurable validity period",
('type: ECC', 'Should use elliptic curve keys for better security') ),
('ownca_not_before: "-1d"', "Certificates should have backdated start time"),
("curve: secp384r1", "Should use strong elliptic curve cryptography"),
("type: ECC", "Should use elliptic curve keys for better security"),
] ]
for check, message in chain_checks: for check, message in chain_checks:
@ -473,6 +520,7 @@ def validate_certificate_chain_config():
print("✓ Certificate chain configuration properly set up for CA signing") print("✓ Certificate chain configuration properly set up for CA signing")
def test_certificate_chain(): def test_certificate_chain():
"""Test certificate chain - uses real certs if available, else validates config""" """Test certificate chain - uses real certs if available, else validates config"""
cert_files = find_generated_certificates() cert_files = find_generated_certificates()
@ -499,6 +547,7 @@ def find_ansible_file(relative_path):
return None return None
if __name__ == "__main__": if __name__ == "__main__":
tests = [ tests = [
test_openssl_version_detection, test_openssl_version_detection,

View file

@ -3,6 +3,7 @@
Enhanced tests for StrongSwan templates. Enhanced tests for StrongSwan templates.
Tests all strongswan role templates with various configurations. Tests all strongswan role templates with various configurations.
""" """
import os import os
import sys import sys
import uuid import uuid
@ -21,7 +22,7 @@ def mock_to_uuid(value):
def mock_bool(value): def mock_bool(value):
"""Mock the bool filter""" """Mock the bool filter"""
return str(value).lower() in ('true', '1', 'yes', 'on') return str(value).lower() in ("true", "1", "yes", "on")
def mock_version(version_string, comparison): def mock_version(version_string, comparison):
@ -33,67 +34,67 @@ def mock_version(version_string, comparison):
def mock_b64encode(value): def mock_b64encode(value):
"""Mock base64 encoding""" """Mock base64 encoding"""
import base64 import base64
if isinstance(value, str): if isinstance(value, str):
value = value.encode('utf-8') value = value.encode("utf-8")
return base64.b64encode(value).decode('ascii') return base64.b64encode(value).decode("ascii")
def mock_b64decode(value): def mock_b64decode(value):
"""Mock base64 decoding""" """Mock base64 decoding"""
import base64 import base64
return base64.b64decode(value).decode('utf-8')
return base64.b64decode(value).decode("utf-8")
def get_strongswan_test_variables(scenario='default'): def get_strongswan_test_variables(scenario="default"):
"""Get test variables for StrongSwan templates with different scenarios.""" """Get test variables for StrongSwan templates with different scenarios."""
base_vars = load_test_variables() base_vars = load_test_variables()
# Add StrongSwan specific variables # Add StrongSwan specific variables
strongswan_vars = { strongswan_vars = {
'ipsec_config_path': '/etc/ipsec.d', "ipsec_config_path": "/etc/ipsec.d",
'ipsec_pki_path': '/etc/ipsec.d', "ipsec_pki_path": "/etc/ipsec.d",
'strongswan_enabled': True, "strongswan_enabled": True,
'strongswan_network': '10.19.48.0/24', "strongswan_network": "10.19.48.0/24",
'strongswan_network_ipv6': 'fd9d:bc11:4021::/64', "strongswan_network_ipv6": "fd9d:bc11:4021::/64",
'strongswan_log_level': '2', "strongswan_log_level": "2",
'openssl_constraint_random_id': 'test-' + str(uuid.uuid4()), "openssl_constraint_random_id": "test-" + str(uuid.uuid4()),
'subjectAltName': 'IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a', "subjectAltName": "IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a",
'subjectAltName_type': 'IP', "subjectAltName_type": "IP",
'subjectAltName_client': 'IP:10.0.0.1', "subjectAltName_client": "IP:10.0.0.1",
'ansible_default_ipv6': { "ansible_default_ipv6": {"address": "2600:3c01::f03c:91ff:fedf:3b2a"},
'address': '2600:3c01::f03c:91ff:fedf:3b2a' "openssl_version": "3.0.0",
}, "p12_export_password": "test-password",
'openssl_version': '3.0.0', "ike_lifetime": "24h",
'p12_export_password': 'test-password', "ipsec_lifetime": "8h",
'ike_lifetime': '24h', "ike_dpd": "30s",
'ipsec_lifetime': '8h', "ipsec_dead_peer_detection": True,
'ike_dpd': '30s', "rekey_margin": "3m",
'ipsec_dead_peer_detection': True, "rekeymargin": "3m",
'rekey_margin': '3m', "dpddelay": "35s",
'rekeymargin': '3m', "keyexchange": "ikev2",
'dpddelay': '35s', "ike_cipher": "aes128gcm16-prfsha512-ecp256",
'keyexchange': 'ikev2', "esp_cipher": "aes128gcm16-ecp256",
'ike_cipher': 'aes128gcm16-prfsha512-ecp256', "leftsourceip": "10.19.48.1",
'esp_cipher': 'aes128gcm16-ecp256', "leftsubnet": "0.0.0.0/0,::/0",
'leftsourceip': '10.19.48.1', "rightsourceip": "10.19.48.2/24,fd9d:bc11:4021::2/64",
'leftsubnet': '0.0.0.0/0,::/0',
'rightsourceip': '10.19.48.2/24,fd9d:bc11:4021::2/64',
} }
# Merge with base variables # Merge with base variables
test_vars = {**base_vars, **strongswan_vars} test_vars = {**base_vars, **strongswan_vars}
# Apply scenario-specific overrides # Apply scenario-specific overrides
if scenario == 'ipv4_only': if scenario == "ipv4_only":
test_vars['ipv6_support'] = False test_vars["ipv6_support"] = False
test_vars['subjectAltName'] = 'IP:10.0.0.1' test_vars["subjectAltName"] = "IP:10.0.0.1"
test_vars['ansible_default_ipv6'] = None test_vars["ansible_default_ipv6"] = None
elif scenario == 'dns_hostname': elif scenario == "dns_hostname":
test_vars['IP_subject_alt_name'] = 'vpn.example.com' test_vars["IP_subject_alt_name"] = "vpn.example.com"
test_vars['subjectAltName'] = 'DNS:vpn.example.com' test_vars["subjectAltName"] = "DNS:vpn.example.com"
test_vars['subjectAltName_type'] = 'DNS' test_vars["subjectAltName_type"] = "DNS"
elif scenario == 'openssl_legacy': elif scenario == "openssl_legacy":
test_vars['openssl_version'] = '1.1.1' test_vars["openssl_version"] = "1.1.1"
return test_vars return test_vars
@ -101,16 +102,16 @@ def get_strongswan_test_variables(scenario='default'):
def test_strongswan_templates(): def test_strongswan_templates():
"""Test all StrongSwan templates with various configurations.""" """Test all StrongSwan templates with various configurations."""
templates = [ templates = [
'roles/strongswan/templates/ipsec.conf.j2', "roles/strongswan/templates/ipsec.conf.j2",
'roles/strongswan/templates/ipsec.secrets.j2', "roles/strongswan/templates/ipsec.secrets.j2",
'roles/strongswan/templates/strongswan.conf.j2', "roles/strongswan/templates/strongswan.conf.j2",
'roles/strongswan/templates/charon.conf.j2', "roles/strongswan/templates/charon.conf.j2",
'roles/strongswan/templates/client_ipsec.conf.j2', "roles/strongswan/templates/client_ipsec.conf.j2",
'roles/strongswan/templates/client_ipsec.secrets.j2', "roles/strongswan/templates/client_ipsec.secrets.j2",
'roles/strongswan/templates/100-CustomLimitations.conf.j2', "roles/strongswan/templates/100-CustomLimitations.conf.j2",
] ]
scenarios = ['default', 'ipv4_only', 'dns_hostname', 'openssl_legacy'] scenarios = ["default", "ipv4_only", "dns_hostname", "openssl_legacy"]
errors = [] errors = []
tested = 0 tested = 0
@ -127,21 +128,18 @@ def test_strongswan_templates():
test_vars = get_strongswan_test_variables(scenario) test_vars = get_strongswan_test_variables(scenario)
try: try:
env = Environment( env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
# Add mock filters # Add mock filters
env.filters['to_uuid'] = mock_to_uuid env.filters["to_uuid"] = mock_to_uuid
env.filters['bool'] = mock_bool env.filters["bool"] = mock_bool
env.filters['b64encode'] = mock_b64encode env.filters["b64encode"] = mock_b64encode
env.filters['b64decode'] = mock_b64decode env.filters["b64decode"] = mock_b64decode
env.tests['version'] = mock_version env.tests["version"] = mock_version
# For client templates, add item context # For client templates, add item context
if 'client' in template_name: if "client" in template_name:
test_vars['item'] = 'testuser' test_vars["item"] = "testuser"
template = env.get_template(template_name) template = env.get_template(template_name)
output = template.render(**test_vars) output = template.render(**test_vars)
@ -150,16 +148,16 @@ def test_strongswan_templates():
assert len(output) > 0, f"Empty output from {template_path} ({scenario})" assert len(output) > 0, f"Empty output from {template_path} ({scenario})"
# Specific validations based on template # Specific validations based on template
if 'ipsec.conf' in template_name and 'client' not in template_name: if "ipsec.conf" in template_name and "client" not in template_name:
assert 'conn' in output, "Missing connection definition" assert "conn" in output, "Missing connection definition"
if scenario != 'ipv4_only' and test_vars.get('ipv6_support'): if scenario != "ipv4_only" and test_vars.get("ipv6_support"):
assert '::/0' in output or 'fd9d:bc11' in output, "Missing IPv6 configuration" assert "::/0" in output or "fd9d:bc11" in output, "Missing IPv6 configuration"
if 'ipsec.secrets' in template_name: if "ipsec.secrets" in template_name:
assert 'PSK' in output or 'ECDSA' in output, "Missing authentication method" assert "PSK" in output or "ECDSA" in output, "Missing authentication method"
if 'strongswan.conf' in template_name: if "strongswan.conf" in template_name:
assert 'charon' in output, "Missing charon configuration" assert "charon" in output, "Missing charon configuration"
print(f"{template_name} ({scenario})") print(f"{template_name} ({scenario})")
@ -182,7 +180,7 @@ def test_openssl_template_constraints():
# This tests the actual openssl.yml task file to ensure our fix works # This tests the actual openssl.yml task file to ensure our fix works
import yaml import yaml
openssl_path = 'roles/strongswan/tasks/openssl.yml' openssl_path = "roles/strongswan/tasks/openssl.yml"
if not os.path.exists(openssl_path): if not os.path.exists(openssl_path):
print("⚠️ OpenSSL tasks file not found") print("⚠️ OpenSSL tasks file not found")
return True return True
@ -194,22 +192,23 @@ def test_openssl_template_constraints():
# Find the CA CSR task # Find the CA CSR task
ca_csr_task = None ca_csr_task = None
for task in content: for task in content:
if isinstance(task, dict) and task.get('name', '').startswith('Create certificate signing request'): if isinstance(task, dict) and task.get("name", "").startswith("Create certificate signing request"):
ca_csr_task = task ca_csr_task = task
break break
if ca_csr_task: if ca_csr_task:
# Check that name_constraints_permitted is properly formatted # Check that name_constraints_permitted is properly formatted
csr_module = ca_csr_task.get('community.crypto.openssl_csr_pipe', {}) csr_module = ca_csr_task.get("community.crypto.openssl_csr_pipe", {})
constraints = csr_module.get('name_constraints_permitted', '') constraints = csr_module.get("name_constraints_permitted", "")
# The constraints should be a Jinja2 template without inline comments # The constraints should be a Jinja2 template without inline comments
if '#' in str(constraints): if "#" in str(constraints):
# Check if the # is within {{ }} # Check if the # is within {{ }}
import re import re
jinja_blocks = re.findall(r'\{\{.*?\}\}', str(constraints), re.DOTALL)
jinja_blocks = re.findall(r"\{\{.*?\}\}", str(constraints), re.DOTALL)
for block in jinja_blocks: for block in jinja_blocks:
if '#' in block: if "#" in block:
print("❌ Found inline comment in Jinja2 expression") print("❌ Found inline comment in Jinja2 expression")
return False return False
@ -223,7 +222,7 @@ def test_openssl_template_constraints():
def test_mobileconfig_template(): def test_mobileconfig_template():
"""Test the mobileconfig template with various scenarios.""" """Test the mobileconfig template with various scenarios."""
template_path = 'roles/strongswan/templates/mobileconfig.j2' template_path = "roles/strongswan/templates/mobileconfig.j2"
if not os.path.exists(template_path): if not os.path.exists(template_path):
print("⚠️ Mobileconfig template not found") print("⚠️ Mobileconfig template not found")
@ -237,20 +236,20 @@ def test_mobileconfig_template():
test_cases = [ test_cases = [
{ {
'name': 'iPhone with cellular on-demand', "name": "iPhone with cellular on-demand",
'algo_ondemand_cellular': 'true', "algo_ondemand_cellular": "true",
'algo_ondemand_wifi': 'false', "algo_ondemand_wifi": "false",
}, },
{ {
'name': 'iPad with WiFi on-demand', "name": "iPad with WiFi on-demand",
'algo_ondemand_cellular': 'false', "algo_ondemand_cellular": "false",
'algo_ondemand_wifi': 'true', "algo_ondemand_wifi": "true",
'algo_ondemand_wifi_exclude': 'MyHomeNetwork,OfficeWiFi', "algo_ondemand_wifi_exclude": "MyHomeNetwork,OfficeWiFi",
}, },
{ {
'name': 'Mac without on-demand', "name": "Mac without on-demand",
'algo_ondemand_cellular': 'false', "algo_ondemand_cellular": "false",
'algo_ondemand_wifi': 'false', "algo_ondemand_wifi": "false",
}, },
] ]
@ -258,43 +257,41 @@ def test_mobileconfig_template():
for test_case in test_cases: for test_case in test_cases:
test_vars = get_strongswan_test_variables() test_vars = get_strongswan_test_variables()
test_vars.update(test_case) test_vars.update(test_case)
# Mock Ansible task result format for item # Mock Ansible task result format for item
class MockTaskResult: class MockTaskResult:
def __init__(self, content): def __init__(self, content):
self.stdout = content self.stdout = content
test_vars['item'] = ('testuser', MockTaskResult('TU9DS19QS0NTMTJfQ09OVEVOVA==')) # Tuple with mock result test_vars["item"] = ("testuser", MockTaskResult("TU9DS19QS0NTMTJfQ09OVEVOVA==")) # Tuple with mock result
test_vars['PayloadContentCA_base64'] = 'TU9DS19DQV9DRVJUX0JBU0U2NA==' # Valid base64 test_vars["PayloadContentCA_base64"] = "TU9DS19DQV9DRVJUX0JBU0U2NA==" # Valid base64
test_vars['PayloadContentUser_base64'] = 'TU9DS19VU0VSX0NFUlRfQkFTRTY0' # Valid base64 test_vars["PayloadContentUser_base64"] = "TU9DS19VU0VSX0NFUlRfQkFTRTY0" # Valid base64
test_vars['pkcs12_PayloadCertificateUUID'] = str(uuid.uuid4()) test_vars["pkcs12_PayloadCertificateUUID"] = str(uuid.uuid4())
test_vars['PayloadContent'] = 'TU9DS19QS0NTMTJfQ09OVEVOVA==' # Valid base64 for PKCS12 test_vars["PayloadContent"] = "TU9DS19QS0NTMTJfQ09OVEVOVA==" # Valid base64 for PKCS12
test_vars['algo_server_name'] = 'test-algo-vpn' test_vars["algo_server_name"] = "test-algo-vpn"
test_vars['VPN_PayloadIdentifier'] = str(uuid.uuid4()) test_vars["VPN_PayloadIdentifier"] = str(uuid.uuid4())
test_vars['CA_PayloadIdentifier'] = str(uuid.uuid4()) test_vars["CA_PayloadIdentifier"] = str(uuid.uuid4())
test_vars['PayloadContentCA'] = 'TU9DS19DQV9DRVJUX0NPTlRFTlQ=' # Valid base64 test_vars["PayloadContentCA"] = "TU9DS19DQV9DRVJUX0NPTlRFTlQ=" # Valid base64
try: try:
env = Environment( env = Environment(loader=FileSystemLoader("roles/strongswan/templates"), undefined=StrictUndefined)
loader=FileSystemLoader('roles/strongswan/templates'),
undefined=StrictUndefined
)
# Add mock filters # Add mock filters
env.filters['to_uuid'] = mock_to_uuid env.filters["to_uuid"] = mock_to_uuid
env.filters['b64encode'] = mock_b64encode env.filters["b64encode"] = mock_b64encode
env.filters['b64decode'] = mock_b64decode env.filters["b64decode"] = mock_b64decode
template = env.get_template('mobileconfig.j2') template = env.get_template("mobileconfig.j2")
output = template.render(**test_vars) output = template.render(**test_vars)
# Validate output # Validate output
assert '<?xml' in output, "Missing XML declaration" assert "<?xml" in output, "Missing XML declaration"
assert '<plist' in output, "Missing plist element" assert "<plist" in output, "Missing plist element"
assert 'PayloadType' in output, "Missing PayloadType" assert "PayloadType" in output, "Missing PayloadType"
# Check on-demand configuration # Check on-demand configuration
if test_case.get('algo_ondemand_cellular') == 'true' or test_case.get('algo_ondemand_wifi') == 'true': if test_case.get("algo_ondemand_cellular") == "true" or test_case.get("algo_ondemand_wifi") == "true":
assert 'OnDemandEnabled' in output, f"Missing OnDemand config for {test_case['name']}" assert "OnDemandEnabled" in output, f"Missing OnDemand config for {test_case['name']}"
print(f" ✅ Mobileconfig: {test_case['name']}") print(f" ✅ Mobileconfig: {test_case['name']}")

View file

@ -3,6 +3,7 @@
Test that Ansible templates render correctly Test that Ansible templates render correctly
This catches undefined variables, syntax errors, and logic bugs This catches undefined variables, syntax errors, and logic bugs
""" """
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
@ -22,20 +23,20 @@ def mock_to_uuid(value):
def mock_bool(value): def mock_bool(value):
"""Mock the bool filter""" """Mock the bool filter"""
return str(value).lower() in ('true', '1', 'yes', 'on') return str(value).lower() in ("true", "1", "yes", "on")
def mock_lookup(type, path): def mock_lookup(type, path):
"""Mock the lookup function""" """Mock the lookup function"""
# Return fake data for file lookups # Return fake data for file lookups
if type == 'file': if type == "file":
if 'private' in path: if "private" in path:
return 'MOCK_PRIVATE_KEY_BASE64==' return "MOCK_PRIVATE_KEY_BASE64=="
elif 'public' in path: elif "public" in path:
return 'MOCK_PUBLIC_KEY_BASE64==' return "MOCK_PUBLIC_KEY_BASE64=="
elif 'preshared' in path: elif "preshared" in path:
return 'MOCK_PRESHARED_KEY_BASE64==' return "MOCK_PRESHARED_KEY_BASE64=="
return 'MOCK_LOOKUP_DATA' return "MOCK_LOOKUP_DATA"
def get_test_variables(): def get_test_variables():
@ -47,8 +48,8 @@ def get_test_variables():
def find_templates(): def find_templates():
"""Find all Jinja2 template files in the repo""" """Find all Jinja2 template files in the repo"""
templates = [] templates = []
for pattern in ['**/*.j2', '**/*.jinja2', '**/*.yml.j2']: for pattern in ["**/*.j2", "**/*.jinja2", "**/*.yml.j2"]:
templates.extend(Path('.').glob(pattern)) templates.extend(Path(".").glob(pattern))
return templates return templates
@ -57,10 +58,10 @@ def test_template_syntax():
templates = find_templates() templates = find_templates()
# Skip some paths that aren't real templates # Skip some paths that aren't real templates
skip_paths = ['.git/', 'venv/', '.venv/', '.env/', 'configs/'] skip_paths = [".git/", "venv/", ".venv/", ".env/", "configs/"]
# Skip templates that use Ansible-specific filters # Skip templates that use Ansible-specific filters
skip_templates = ['vpn-dict.j2', 'mobileconfig.j2', 'dnscrypt-proxy.toml.j2'] skip_templates = ["vpn-dict.j2", "mobileconfig.j2", "dnscrypt-proxy.toml.j2"]
errors = [] errors = []
skipped = 0 skipped = 0
@ -76,10 +77,7 @@ def test_template_syntax():
try: try:
template_dir = template_path.parent template_dir = template_path.parent
env = Environment( env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
# Just try to load the template - this checks syntax # Just try to load the template - this checks syntax
env.get_template(template_path.name) env.get_template(template_path.name)
@ -103,13 +101,13 @@ def test_template_syntax():
def test_critical_templates(): def test_critical_templates():
"""Test that critical templates render with test data""" """Test that critical templates render with test data"""
critical_templates = [ critical_templates = [
'roles/wireguard/templates/client.conf.j2', "roles/wireguard/templates/client.conf.j2",
'roles/strongswan/templates/ipsec.conf.j2', "roles/strongswan/templates/ipsec.conf.j2",
'roles/strongswan/templates/ipsec.secrets.j2', "roles/strongswan/templates/ipsec.secrets.j2",
'roles/dns/templates/adblock.sh.j2', "roles/dns/templates/adblock.sh.j2",
'roles/dns/templates/dnsmasq.conf.j2', "roles/dns/templates/dnsmasq.conf.j2",
'roles/common/templates/rules.v4.j2', "roles/common/templates/rules.v4.j2",
'roles/common/templates/rules.v6.j2', "roles/common/templates/rules.v6.j2",
] ]
test_vars = get_test_variables() test_vars = get_test_variables()
@ -123,21 +121,18 @@ def test_critical_templates():
template_dir = os.path.dirname(template_path) template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path) template_name = os.path.basename(template_path)
env = Environment( env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
# Add mock functions # Add mock functions
env.globals['lookup'] = mock_lookup env.globals["lookup"] = mock_lookup
env.filters['to_uuid'] = mock_to_uuid env.filters["to_uuid"] = mock_to_uuid
env.filters['bool'] = mock_bool env.filters["bool"] = mock_bool
template = env.get_template(template_name) template = env.get_template(template_name)
# Add item context for templates that use loops # Add item context for templates that use loops
if 'client' in template_name: if "client" in template_name:
test_vars['item'] = ('test-user', 'test-user') test_vars["item"] = ("test-user", "test-user")
# Try to render # Try to render
output = template.render(**test_vars) output = template.render(**test_vars)
@ -163,17 +158,17 @@ def test_variable_consistency():
"""Check that commonly used variables are defined consistently""" """Check that commonly used variables are defined consistently"""
# Variables that should be used consistently across templates # Variables that should be used consistently across templates
common_vars = [ common_vars = [
'server_name', "server_name",
'IP_subject_alt_name', "IP_subject_alt_name",
'wireguard_port', "wireguard_port",
'wireguard_network', "wireguard_network",
'dns_servers', "dns_servers",
'users', "users",
] ]
# Check if main.yml defines these # Check if main.yml defines these
if os.path.exists('main.yml'): if os.path.exists("main.yml"):
with open('main.yml') as f: with open("main.yml") as f:
content = f.read() content = f.read()
missing = [] missing = []
@ -192,28 +187,19 @@ def test_wireguard_ipv6_endpoints():
"""Test that WireGuard client configs properly format IPv6 endpoints""" """Test that WireGuard client configs properly format IPv6 endpoints"""
test_cases = [ test_cases = [
# IPv4 address - should not be bracketed # IPv4 address - should not be bracketed
{ {"IP_subject_alt_name": "192.168.1.100", "expected_endpoint": "Endpoint = 192.168.1.100:51820"},
'IP_subject_alt_name': '192.168.1.100',
'expected_endpoint': 'Endpoint = 192.168.1.100:51820'
},
# IPv6 address - should be bracketed # IPv6 address - should be bracketed
{ {
'IP_subject_alt_name': '2600:3c01::f03c:91ff:fedf:3b2a', "IP_subject_alt_name": "2600:3c01::f03c:91ff:fedf:3b2a",
'expected_endpoint': 'Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820' "expected_endpoint": "Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820",
}, },
# Hostname - should not be bracketed # Hostname - should not be bracketed
{ {"IP_subject_alt_name": "vpn.example.com", "expected_endpoint": "Endpoint = vpn.example.com:51820"},
'IP_subject_alt_name': 'vpn.example.com',
'expected_endpoint': 'Endpoint = vpn.example.com:51820'
},
# IPv6 with zone ID - should be bracketed # IPv6 with zone ID - should be bracketed
{ {"IP_subject_alt_name": "fe80::1%eth0", "expected_endpoint": "Endpoint = [fe80::1%eth0]:51820"},
'IP_subject_alt_name': 'fe80::1%eth0',
'expected_endpoint': 'Endpoint = [fe80::1%eth0]:51820'
},
] ]
template_path = 'roles/wireguard/templates/client.conf.j2' template_path = "roles/wireguard/templates/client.conf.j2"
if not os.path.exists(template_path): if not os.path.exists(template_path):
print(f"⚠ Skipping IPv6 endpoint test - {template_path} not found") print(f"⚠ Skipping IPv6 endpoint test - {template_path} not found")
return return
@ -225,24 +211,23 @@ def test_wireguard_ipv6_endpoints():
try: try:
# Set up test variables # Set up test variables
test_vars = {**base_vars, **test_case} test_vars = {**base_vars, **test_case}
test_vars['item'] = ('test-user', 'test-user') test_vars["item"] = ("test-user", "test-user")
# Render template # Render template
env = Environment( env = Environment(loader=FileSystemLoader("roles/wireguard/templates"), undefined=StrictUndefined)
loader=FileSystemLoader('roles/wireguard/templates'), env.globals["lookup"] = mock_lookup
undefined=StrictUndefined
)
env.globals['lookup'] = mock_lookup
template = env.get_template('client.conf.j2') template = env.get_template("client.conf.j2")
output = template.render(**test_vars) output = template.render(**test_vars)
# Check if the expected endpoint format is in the output # Check if the expected endpoint format is in the output
if test_case['expected_endpoint'] not in output: if test_case["expected_endpoint"] not in output:
errors.append(f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output") errors.append(
f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output"
)
# Print relevant part of output for debugging # Print relevant part of output for debugging
for line in output.split('\n'): for line in output.split("\n"):
if 'Endpoint' in line: if "Endpoint" in line:
errors.append(f" Found: {line.strip()}") errors.append(f" Found: {line.strip()}")
except Exception as e: except Exception as e:
@ -262,27 +247,27 @@ def test_template_conditionals():
test_cases = [ test_cases = [
# WireGuard enabled, IPsec disabled # WireGuard enabled, IPsec disabled
{ {
'wireguard_enabled': True, "wireguard_enabled": True,
'ipsec_enabled': False, "ipsec_enabled": False,
'dns_encryption': True, "dns_encryption": True,
'dns_adblocking': True, "dns_adblocking": True,
'algo_ssh_tunneling': False, "algo_ssh_tunneling": False,
}, },
# IPsec enabled, WireGuard disabled # IPsec enabled, WireGuard disabled
{ {
'wireguard_enabled': False, "wireguard_enabled": False,
'ipsec_enabled': True, "ipsec_enabled": True,
'dns_encryption': False, "dns_encryption": False,
'dns_adblocking': False, "dns_adblocking": False,
'algo_ssh_tunneling': True, "algo_ssh_tunneling": True,
}, },
# Both enabled # Both enabled
{ {
'wireguard_enabled': True, "wireguard_enabled": True,
'ipsec_enabled': True, "ipsec_enabled": True,
'dns_encryption': True, "dns_encryption": True,
'dns_adblocking': True, "dns_adblocking": True,
'algo_ssh_tunneling': True, "algo_ssh_tunneling": True,
}, },
] ]
@ -294,7 +279,7 @@ def test_template_conditionals():
# Test a few templates that have conditionals # Test a few templates that have conditionals
conditional_templates = [ conditional_templates = [
'roles/common/templates/rules.v4.j2', "roles/common/templates/rules.v4.j2",
] ]
for template_path in conditional_templates: for template_path in conditional_templates:
@ -305,23 +290,19 @@ def test_template_conditionals():
template_dir = os.path.dirname(template_path) template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path) template_name = os.path.basename(template_path)
env = Environment( env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
# Add mock functions # Add mock functions
env.globals['lookup'] = mock_lookup env.globals["lookup"] = mock_lookup
env.filters['to_uuid'] = mock_to_uuid env.filters["to_uuid"] = mock_to_uuid
env.filters['bool'] = mock_bool env.filters["bool"] = mock_bool
template = env.get_template(template_name) template = env.get_template(template_name)
output = template.render(**test_vars) output = template.render(**test_vars)
# Verify conditionals work # Verify conditionals work
if test_case.get('wireguard_enabled'): if test_case.get("wireguard_enabled"):
assert str(test_vars['wireguard_port']) in output, \ assert str(test_vars["wireguard_port"]) in output, f"WireGuard port missing when enabled (case {i})"
f"WireGuard port missing when enabled (case {i})"
except Exception as e: except Exception as e:
print(f"✗ Conditional test failed for {template_path} case {i}: {e}") print(f"✗ Conditional test failed for {template_path} case {i}: {e}")

View file

@ -3,6 +3,7 @@
Test user management functionality without deployment Test user management functionality without deployment
Based on issues #14745, #14746, #14738, #14726 Based on issues #14745, #14746, #14738, #14726
""" """
import os import os
import re import re
import sys import sys
@ -23,15 +24,15 @@ users:
""" """
config = yaml.safe_load(test_config) config = yaml.safe_load(test_config)
users = config.get('users', []) users = config.get("users", [])
assert len(users) == 5, f"Expected 5 users, got {len(users)}" assert len(users) == 5, f"Expected 5 users, got {len(users)}"
assert 'alice' in users, "Missing user 'alice'" assert "alice" in users, "Missing user 'alice'"
assert 'user-with-dash' in users, "Dash in username not handled" assert "user-with-dash" in users, "Dash in username not handled"
assert 'user_with_underscore' in users, "Underscore in username not handled" assert "user_with_underscore" in users, "Underscore in username not handled"
# Test that usernames are valid # Test that usernames are valid
username_pattern = re.compile(r'^[a-zA-Z0-9_-]+$') username_pattern = re.compile(r"^[a-zA-Z0-9_-]+$")
for user in users: for user in users:
assert username_pattern.match(user), f"Invalid username format: {user}" assert username_pattern.match(user), f"Invalid username format: {user}"
@ -42,35 +43,27 @@ def test_server_selection_format():
"""Test server selection string parsing (issue #14727)""" """Test server selection string parsing (issue #14727)"""
# Test various server display formats # Test various server display formats
test_cases = [ test_cases = [
{"display": "1. 192.168.1.100 (algo-server)", "expected_ip": "192.168.1.100", "expected_name": "algo-server"},
{"display": "2. 10.0.0.1 (production-vpn)", "expected_ip": "10.0.0.1", "expected_name": "production-vpn"},
{ {
'display': '1. 192.168.1.100 (algo-server)', "display": "3. vpn.example.com (example-server)",
'expected_ip': '192.168.1.100', "expected_ip": "vpn.example.com",
'expected_name': 'algo-server' "expected_name": "example-server",
}, },
{
'display': '2. 10.0.0.1 (production-vpn)',
'expected_ip': '10.0.0.1',
'expected_name': 'production-vpn'
},
{
'display': '3. vpn.example.com (example-server)',
'expected_ip': 'vpn.example.com',
'expected_name': 'example-server'
}
] ]
# Pattern to extract IP and name from display string # Pattern to extract IP and name from display string
pattern = re.compile(r'^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$') pattern = re.compile(r"^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$")
for case in test_cases: for case in test_cases:
match = pattern.match(case['display']) match = pattern.match(case["display"])
assert match, f"Failed to parse: {case['display']}" assert match, f"Failed to parse: {case['display']}"
ip_or_host = match.group(1) ip_or_host = match.group(1)
name = match.group(2) name = match.group(2)
assert ip_or_host == case['expected_ip'], f"Wrong IP extracted: {ip_or_host}" assert ip_or_host == case["expected_ip"], f"Wrong IP extracted: {ip_or_host}"
assert name == case['expected_name'], f"Wrong name extracted: {name}" assert name == case["expected_name"], f"Wrong name extracted: {name}"
print("✓ Server selection format test passed") print("✓ Server selection format test passed")
@ -78,12 +71,12 @@ def test_server_selection_format():
def test_ssh_key_preservation(): def test_ssh_key_preservation():
"""Test that SSH keys aren't regenerated unnecessarily""" """Test that SSH keys aren't regenerated unnecessarily"""
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
ssh_key_path = os.path.join(tmpdir, 'test_key') ssh_key_path = os.path.join(tmpdir, "test_key")
# Simulate existing SSH key # Simulate existing SSH key
with open(ssh_key_path, 'w') as f: with open(ssh_key_path, "w") as f:
f.write("EXISTING_SSH_KEY_CONTENT") f.write("EXISTING_SSH_KEY_CONTENT")
with open(f"{ssh_key_path}.pub", 'w') as f: with open(f"{ssh_key_path}.pub", "w") as f:
f.write("ssh-rsa EXISTING_PUBLIC_KEY") f.write("ssh-rsa EXISTING_PUBLIC_KEY")
# Record original content # Record original content
@ -105,11 +98,7 @@ def test_ssh_key_preservation():
def test_ca_password_handling(): def test_ca_password_handling():
"""Test CA password validation and handling""" """Test CA password validation and handling"""
# Test password requirements # Test password requirements
valid_passwords = [ valid_passwords = ["SecurePassword123!", "Algo-VPN-2024", "Complex#Pass@Word999"]
"SecurePassword123!",
"Algo-VPN-2024",
"Complex#Pass@Word999"
]
invalid_passwords = [ invalid_passwords = [
"", # Empty "", # Empty
@ -120,13 +109,13 @@ def test_ca_password_handling():
# Basic password validation # Basic password validation
for pwd in valid_passwords: for pwd in valid_passwords:
assert len(pwd) >= 12, f"Password too short: {pwd}" assert len(pwd) >= 12, f"Password too short: {pwd}"
assert ' ' not in pwd, f"Password contains spaces: {pwd}" assert " " not in pwd, f"Password contains spaces: {pwd}"
for pwd in invalid_passwords: for pwd in invalid_passwords:
issues = [] issues = []
if len(pwd) < 12: if len(pwd) < 12:
issues.append("too short") issues.append("too short")
if ' ' in pwd: if " " in pwd:
issues.append("contains spaces") issues.append("contains spaces")
if not pwd: if not pwd:
issues.append("empty") issues.append("empty")
@ -137,8 +126,8 @@ def test_ca_password_handling():
def test_user_config_generation(): def test_user_config_generation():
"""Test that user configs would be generated correctly""" """Test that user configs would be generated correctly"""
users = ['alice', 'bob', 'charlie'] users = ["alice", "bob", "charlie"]
server_name = 'test-server' server_name = "test-server"
# Simulate config file structure # Simulate config file structure
for user in users: for user in users:
@ -168,7 +157,7 @@ users:
""" """
config = yaml.safe_load(test_config) config = yaml.safe_load(test_config)
users = config.get('users', []) users = config.get("users", [])
# Check for duplicates # Check for duplicates
unique_users = list(set(users)) unique_users = list(set(users))
@ -182,7 +171,7 @@ users:
duplicates.append(user) duplicates.append(user)
seen.add(user) seen.add(user)
assert 'alice' in duplicates, "Duplicate 'alice' not detected" assert "alice" in duplicates, "Duplicate 'alice' not detected"
print("✓ Duplicate user handling test passed") print("✓ Duplicate user handling test passed")

View file

@ -3,6 +3,7 @@
Test WireGuard key generation - focused on x25519_pubkey module integration Test WireGuard key generation - focused on x25519_pubkey module integration
Addresses test gap identified in tests/README.md line 63-67: WireGuard private/public key generation Addresses test gap identified in tests/README.md line 63-67: WireGuard private/public key generation
""" """
import base64 import base64
import os import os
import subprocess import subprocess
@ -10,13 +11,13 @@ import sys
import tempfile import tempfile
# Add library directory to path to import our custom module # Add library directory to path to import our custom module
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "library"))
def test_wireguard_tools_available(): def test_wireguard_tools_available():
"""Test that WireGuard tools are available for validation""" """Test that WireGuard tools are available for validation"""
try: try:
result = subprocess.run(['wg', '--version'], capture_output=True, text=True) result = subprocess.run(["wg", "--version"], capture_output=True, text=True)
assert result.returncode == 0, "WireGuard tools not available" assert result.returncode == 0, "WireGuard tools not available"
print(f"✓ WireGuard tools available: {result.stdout.strip()}") print(f"✓ WireGuard tools available: {result.stdout.strip()}")
return True return True
@ -29,6 +30,7 @@ def test_x25519_module_import():
"""Test that our custom x25519_pubkey module can be imported and used""" """Test that our custom x25519_pubkey module can be imported and used"""
try: try:
import x25519_pubkey # noqa: F401 import x25519_pubkey # noqa: F401
print("✓ x25519_pubkey module imports successfully") print("✓ x25519_pubkey module imports successfully")
return True return True
except ImportError as e: except ImportError as e:
@ -37,16 +39,17 @@ def test_x25519_module_import():
def generate_test_private_key(): def generate_test_private_key():
"""Generate a test private key using the same method as Algo""" """Generate a test private key using the same method as Algo"""
with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".raw", delete=False) as temp_file:
raw_key_path = temp_file.name raw_key_path = temp_file.name
try: try:
# Generate 32 random bytes for X25519 private key (same as community.crypto does) # Generate 32 random bytes for X25519 private key (same as community.crypto does)
import secrets import secrets
raw_data = secrets.token_bytes(32) raw_data = secrets.token_bytes(32)
# Write raw key to file (like community.crypto openssl_privatekey with format: raw) # Write raw key to file (like community.crypto openssl_privatekey with format: raw)
with open(raw_key_path, 'wb') as f: with open(raw_key_path, "wb") as f:
f.write(raw_data) f.write(raw_data)
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}" assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
@ -83,7 +86,7 @@ def test_x25519_pubkey_from_raw_file():
def exit_json(self, **kwargs): def exit_json(self, **kwargs):
self.result = kwargs self.result = kwargs
with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub: with tempfile.NamedTemporaryFile(suffix=".pub", delete=False) as temp_pub:
public_key_path = temp_pub.name public_key_path = temp_pub.name
try: try:
@ -95,11 +98,9 @@ def test_x25519_pubkey_from_raw_file():
try: try:
# Mock the module call # Mock the module call
mock_module = MockModule({ mock_module = MockModule(
'private_key_path': raw_key_path, {"private_key_path": raw_key_path, "public_key_path": public_key_path, "private_key_b64": None}
'public_key_path': public_key_path, )
'private_key_b64': None
})
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
@ -107,8 +108,8 @@ def test_x25519_pubkey_from_raw_file():
run_module() run_module()
# Check the result # Check the result
assert 'public_key' in mock_module.result assert "public_key" in mock_module.result
assert mock_module.result['changed'] assert mock_module.result["changed"]
assert os.path.exists(public_key_path) assert os.path.exists(public_key_path)
with open(public_key_path) as f: with open(public_key_path) as f:
@ -160,11 +161,7 @@ def test_x25519_pubkey_from_b64_string():
original_AnsibleModule = x25519_pubkey.AnsibleModule original_AnsibleModule = x25519_pubkey.AnsibleModule
try: try:
mock_module = MockModule({ mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
'private_key_b64': b64_key,
'private_key_path': None,
'public_key_path': None
})
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
@ -172,8 +169,8 @@ def test_x25519_pubkey_from_b64_string():
run_module() run_module()
# Check the result # Check the result
assert 'public_key' in mock_module.result assert "public_key" in mock_module.result
derived_pubkey = mock_module.result['public_key'] derived_pubkey = mock_module.result["public_key"]
# Validate base64 format # Validate base64 format
try: try:
@ -222,21 +219,17 @@ def test_wireguard_validation():
original_AnsibleModule = x25519_pubkey.AnsibleModule original_AnsibleModule = x25519_pubkey.AnsibleModule
try: try:
mock_module = MockModule({ mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
'private_key_b64': b64_key,
'private_key_path': None,
'public_key_path': None
})
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
run_module() run_module()
derived_pubkey = mock_module.result['public_key'] derived_pubkey = mock_module.result["public_key"]
finally: finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule x25519_pubkey.AnsibleModule = original_AnsibleModule
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config: with tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False) as temp_config:
# Create a WireGuard config using our keys # Create a WireGuard config using our keys
wg_config = f"""[Interface] wg_config = f"""[Interface]
PrivateKey = {b64_key} PrivateKey = {b64_key}
@ -251,16 +244,12 @@ AllowedIPs = 10.19.49.2/32
try: try:
# Test that WireGuard can parse our config # Test that WireGuard can parse our config
result = subprocess.run([ result = subprocess.run(["wg-quick", "strip", config_path], capture_output=True, text=True)
'wg-quick', 'strip', config_path
], capture_output=True, text=True)
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}" assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
# Test key derivation with wg pubkey command # Test key derivation with wg pubkey command
wg_result = subprocess.run([ wg_result = subprocess.run(["wg", "pubkey"], input=b64_key, capture_output=True, text=True)
'wg', 'pubkey'
], input=b64_key, capture_output=True, text=True)
if wg_result.returncode == 0: if wg_result.returncode == 0:
wg_derived = wg_result.stdout.strip() wg_derived = wg_result.stdout.strip()
@ -286,8 +275,8 @@ def test_key_consistency():
raw_key_path, b64_key = generate_test_private_key() raw_key_path, b64_key = generate_test_private_key()
try: try:
def derive_pubkey_from_same_key():
def derive_pubkey_from_same_key():
class MockModule: class MockModule:
def __init__(self, params): def __init__(self, params):
self.params = params self.params = params
@ -305,16 +294,18 @@ def test_key_consistency():
original_AnsibleModule = x25519_pubkey.AnsibleModule original_AnsibleModule = x25519_pubkey.AnsibleModule
try: try:
mock_module = MockModule({ mock_module = MockModule(
'private_key_b64': b64_key, # SAME key each time {
'private_key_path': None, "private_key_b64": b64_key, # SAME key each time
'public_key_path': None "private_key_path": None,
}) "public_key_path": None,
}
)
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
run_module() run_module()
return mock_module.result['public_key'] return mock_module.result["public_key"]
finally: finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule x25519_pubkey.AnsibleModule = original_AnsibleModule

View file

@ -14,13 +14,13 @@ from pathlib import Path
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateSyntaxError, meta from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateSyntaxError, meta
def find_jinja2_templates(root_dir: str = '.') -> list[Path]: def find_jinja2_templates(root_dir: str = ".") -> list[Path]:
"""Find all Jinja2 template files in the project.""" """Find all Jinja2 template files in the project."""
templates = [] templates = []
patterns = ['**/*.j2', '**/*.jinja2', '**/*.yml.j2', '**/*.conf.j2'] patterns = ["**/*.j2", "**/*.jinja2", "**/*.yml.j2", "**/*.conf.j2"]
# Skip these directories # Skip these directories
skip_dirs = {'.git', '.venv', 'venv', '.env', 'configs', '__pycache__', '.cache'} skip_dirs = {".git", ".venv", "venv", ".env", "configs", "__pycache__", ".cache"}
for pattern in patterns: for pattern in patterns:
for path in Path(root_dir).glob(pattern): for path in Path(root_dir).glob(pattern):
@ -39,25 +39,25 @@ def check_inline_comments_in_expressions(template_content: str, template_path: P
errors = [] errors = []
# Pattern to find Jinja2 expressions # Pattern to find Jinja2 expressions
jinja_pattern = re.compile(r'\{\{.*?\}\}|\{%.*?%\}', re.DOTALL) jinja_pattern = re.compile(r"\{\{.*?\}\}|\{%.*?%\}", re.DOTALL)
for match in jinja_pattern.finditer(template_content): for match in jinja_pattern.finditer(template_content):
expression = match.group() expression = match.group()
lines = expression.split('\n') lines = expression.split("\n")
for i, line in enumerate(lines): for i, line in enumerate(lines):
# Check for # that's not in a string # Check for # that's not in a string
# Simple heuristic: if # appears after non-whitespace and not in quotes # Simple heuristic: if # appears after non-whitespace and not in quotes
if '#' in line: if "#" in line:
# Remove quoted strings to avoid false positives # Remove quoted strings to avoid false positives
cleaned = re.sub(r'"[^"]*"', '', line) cleaned = re.sub(r'"[^"]*"', "", line)
cleaned = re.sub(r"'[^']*'", '', cleaned) cleaned = re.sub(r"'[^']*'", "", cleaned)
if '#' in cleaned: if "#" in cleaned:
# Check if it's likely a comment (has text after it) # Check if it's likely a comment (has text after it)
hash_pos = cleaned.index('#') hash_pos = cleaned.index("#")
if hash_pos > 0 and cleaned[hash_pos-1:hash_pos] != '\\': if hash_pos > 0 and cleaned[hash_pos - 1 : hash_pos] != "\\":
line_num = template_content[:match.start()].count('\n') + i + 1 line_num = template_content[: match.start()].count("\n") + i + 1
errors.append( errors.append(
f"{template_path}:{line_num}: Inline comment (#) found in Jinja2 expression. " f"{template_path}:{line_num}: Inline comment (#) found in Jinja2 expression. "
f"Move comments outside the expression." f"Move comments outside the expression."
@ -83,11 +83,24 @@ def check_undefined_variables(template_path: Path) -> list[str]:
# Common Ansible variables that are always available # Common Ansible variables that are always available
ansible_builtins = { ansible_builtins = {
'ansible_default_ipv4', 'ansible_default_ipv6', 'ansible_hostname', "ansible_default_ipv4",
'ansible_distribution', 'ansible_distribution_version', 'ansible_facts', "ansible_default_ipv6",
'inventory_hostname', 'hostvars', 'groups', 'group_names', "ansible_hostname",
'play_hosts', 'ansible_version', 'ansible_user', 'ansible_host', "ansible_distribution",
'item', 'ansible_loop', 'ansible_index', 'lookup' "ansible_distribution_version",
"ansible_facts",
"inventory_hostname",
"hostvars",
"groups",
"group_names",
"play_hosts",
"ansible_version",
"ansible_user",
"ansible_host",
"item",
"ansible_loop",
"ansible_index",
"lookup",
} }
# Filter out known Ansible variables # Filter out known Ansible variables
@ -95,9 +108,7 @@ def check_undefined_variables(template_path: Path) -> list[str]:
# Only report if there are truly unknown variables # Only report if there are truly unknown variables
if unknown_vars and len(unknown_vars) < 20: # Avoid noise from templates with many vars if unknown_vars and len(unknown_vars) < 20: # Avoid noise from templates with many vars
errors.append( errors.append(f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}")
f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}"
)
except Exception: except Exception:
# Don't report parse errors here, they're handled elsewhere # Don't report parse errors here, they're handled elsewhere
@ -116,9 +127,9 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
# Skip full parsing for templates that use Ansible-specific features heavily # Skip full parsing for templates that use Ansible-specific features heavily
# We still check for inline comments but skip full template parsing # We still check for inline comments but skip full template parsing
ansible_specific_templates = { ansible_specific_templates = {
'dnscrypt-proxy.toml.j2', # Uses |bool filter "dnscrypt-proxy.toml.j2", # Uses |bool filter
'mobileconfig.j2', # Uses |to_uuid filter and complex item structures "mobileconfig.j2", # Uses |to_uuid filter and complex item structures
'vpn-dict.j2', # Uses |to_uuid filter "vpn-dict.j2", # Uses |to_uuid filter
} }
if template_path.name in ansible_specific_templates: if template_path.name in ansible_specific_templates:
@ -139,18 +150,15 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
errors.extend(check_inline_comments_in_expressions(template_content, template_path)) errors.extend(check_inline_comments_in_expressions(template_content, template_path))
# Try to parse the template # Try to parse the template
env = Environment( env = Environment(loader=FileSystemLoader(template_path.parent), undefined=StrictUndefined)
loader=FileSystemLoader(template_path.parent),
undefined=StrictUndefined
)
# Add mock Ansible filters to avoid syntax errors # Add mock Ansible filters to avoid syntax errors
env.filters['bool'] = lambda x: x env.filters["bool"] = lambda x: x
env.filters['to_uuid'] = lambda x: x env.filters["to_uuid"] = lambda x: x
env.filters['b64encode'] = lambda x: x env.filters["b64encode"] = lambda x: x
env.filters['b64decode'] = lambda x: x env.filters["b64decode"] = lambda x: x
env.filters['regex_replace'] = lambda x, y, z: x env.filters["regex_replace"] = lambda x, y, z: x
env.filters['default'] = lambda x, d: x if x else d env.filters["default"] = lambda x, d: x if x else d
# This will raise TemplateSyntaxError if there's a syntax problem # This will raise TemplateSyntaxError if there's a syntax problem
env.get_template(template_path.name) env.get_template(template_path.name)
@ -178,18 +186,20 @@ def check_common_antipatterns(template_path: Path) -> list[str]:
content = f.read() content = f.read()
# Check for missing spaces around filters # Check for missing spaces around filters
if re.search(r'\{\{[^}]+\|[^ ]', content): if re.search(r"\{\{[^}]+\|[^ ]", content):
warnings.append(f"{template_path}: Missing space after filter pipe (|)") warnings.append(f"{template_path}: Missing space after filter pipe (|)")
# Check for deprecated 'when' in Jinja2 (should use if) # Check for deprecated 'when' in Jinja2 (should use if)
if re.search(r'\{%\s*when\s+', content): if re.search(r"\{%\s*when\s+", content):
warnings.append(f"{template_path}: Use 'if' instead of 'when' in Jinja2 templates") warnings.append(f"{template_path}: Use 'if' instead of 'when' in Jinja2 templates")
# Check for extremely long expressions (harder to debug) # Check for extremely long expressions (harder to debug)
for match in re.finditer(r'\{\{(.+?)\}\}', content, re.DOTALL): for match in re.finditer(r"\{\{(.+?)\}\}", content, re.DOTALL):
if len(match.group(1)) > 200: if len(match.group(1)) > 200:
line_num = content[:match.start()].count('\n') + 1 line_num = content[: match.start()].count("\n") + 1
warnings.append(f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up") warnings.append(
f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up"
)
except Exception: except Exception:
pass # Ignore errors in anti-pattern checking pass # Ignore errors in anti-pattern checking