Apply Python linting and formatting

- Run ruff check --fix to fix linting issues
- Run ruff format to ensure consistent formatting
- All tests still pass after formatting changes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Dan Guido 2025-08-17 19:08:34 -04:00
parent 51847f3fbf
commit 15be88d28b
28 changed files with 1178 additions and 1199 deletions

View file

@ -9,11 +9,9 @@ import time
from ansible.module_utils.basic import AnsibleModule, env_fallback
from ansible.module_utils.digital_ocean import DigitalOceanHelper
ANSIBLE_METADATA = {'metadata_version': '1.1',
'status': ['preview'],
'supported_by': 'community'}
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
DOCUMENTATION = '''
DOCUMENTATION = """
---
module: digital_ocean_floating_ip
short_description: Manage DigitalOcean Floating IPs
@ -44,10 +42,10 @@ notes:
- Version 2 of DigitalOcean API is used.
requirements:
- "python >= 2.6"
'''
"""
EXAMPLES = '''
EXAMPLES = """
- name: "Create a Floating IP in region lon1"
digital_ocean_floating_ip:
state: present
@ -63,10 +61,10 @@ EXAMPLES = '''
state: absent
ip: "1.2.3.4"
'''
"""
RETURN = '''
RETURN = """
# Digital Ocean API info https://developers.digitalocean.com/documentation/v2/#floating-ips
data:
description: a DigitalOcean Floating IP resource
@ -106,11 +104,10 @@ data:
"region_slug": "nyc3"
}
}
'''
"""
class Response:
def __init__(self, resp, info):
self.body = None
if resp:
@ -132,36 +129,37 @@ class Response:
def status_code(self):
return self.info["status"]
def wait_action(module, rest, ip, action_id, timeout=10):
end_time = time.time() + 10
while time.time() < end_time:
response = rest.get(f'floating_ips/{ip}/actions/{action_id}')
response = rest.get(f"floating_ips/{ip}/actions/{action_id}")
# status_code = response.status_code # TODO: check status_code == 200?
status = response.json['action']['status']
if status == 'completed':
status = response.json["action"]["status"]
if status == "completed":
return True
elif status == 'errored':
module.fail_json(msg=f'Floating ip action error [ip: {ip}: action: {action_id}]', data=json)
elif status == "errored":
module.fail_json(msg=f"Floating ip action error [ip: {ip}: action: {action_id}]", data=json)
module.fail_json(msg=f'Floating ip action timeout [ip: {ip}: action: {action_id}]', data=json)
module.fail_json(msg=f"Floating ip action timeout [ip: {ip}: action: {action_id}]", data=json)
def core(module):
# api_token = module.params['oauth_token'] # unused for now
state = module.params['state']
ip = module.params['ip']
droplet_id = module.params['droplet_id']
state = module.params["state"]
ip = module.params["ip"]
droplet_id = module.params["droplet_id"]
rest = DigitalOceanHelper(module)
if state in ('present'):
if droplet_id is not None and module.params['ip'] is not None:
if state in ("present"):
if droplet_id is not None and module.params["ip"] is not None:
# Lets try to associate the ip to the specified droplet
associate_floating_ips(module, rest)
else:
create_floating_ips(module, rest)
elif state in ('absent'):
elif state in ("absent"):
response = rest.delete(f"floating_ips/{ip}")
status_code = response.status_code
json_data = response.json
@ -174,65 +172,68 @@ def core(module):
def get_floating_ip_details(module, rest):
ip = module.params['ip']
ip = module.params["ip"]
response = rest.get(f"floating_ips/{ip}")
status_code = response.status_code
json_data = response.json
if status_code == 200:
return json_data['floating_ip']
return json_data["floating_ip"]
else:
module.fail_json(msg="Error assigning floating ip [{}: {}]".format(
status_code, json_data["message"]), region=module.params['region'])
module.fail_json(
msg="Error assigning floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def assign_floating_id_to_droplet(module, rest):
ip = module.params['ip']
ip = module.params["ip"]
payload = {
"type": "assign",
"droplet_id": module.params['droplet_id'],
"droplet_id": module.params["droplet_id"],
}
response = rest.post(f"floating_ips/{ip}/actions", data=payload)
status_code = response.status_code
json_data = response.json
if status_code == 201:
wait_action(module, rest, ip, json_data['action']['id'])
wait_action(module, rest, ip, json_data["action"]["id"])
module.exit_json(changed=True, data=json_data)
else:
module.fail_json(msg="Error creating floating ip [{}: {}]".format(
status_code, json_data["message"]), region=module.params['region'])
module.fail_json(
msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def associate_floating_ips(module, rest):
floating_ip = get_floating_ip_details(module, rest)
droplet = floating_ip['droplet']
droplet = floating_ip["droplet"]
# TODO: If already assigned to a droplet verify if is one of the specified as valid
if droplet is not None and str(droplet['id']) in [module.params['droplet_id']]:
if droplet is not None and str(droplet["id"]) in [module.params["droplet_id"]]:
module.exit_json(changed=False)
else:
assign_floating_id_to_droplet(module, rest)
def create_floating_ips(module, rest):
payload = {
}
payload = {}
floating_ip_data = None
if module.params['region'] is not None:
payload["region"] = module.params['region']
if module.params["region"] is not None:
payload["region"] = module.params["region"]
if module.params['droplet_id'] is not None:
payload["droplet_id"] = module.params['droplet_id']
if module.params["droplet_id"] is not None:
payload["droplet_id"] = module.params["droplet_id"]
floating_ips = rest.get_paginated_data(base_url='floating_ips?', data_key_name='floating_ips')
floating_ips = rest.get_paginated_data(base_url="floating_ips?", data_key_name="floating_ips")
for floating_ip in floating_ips:
if floating_ip['droplet'] and floating_ip['droplet']['id'] == module.params['droplet_id']:
floating_ip_data = {'floating_ip': floating_ip}
if floating_ip["droplet"] and floating_ip["droplet"]["id"] == module.params["droplet_id"]:
floating_ip_data = {"floating_ip": floating_ip}
if floating_ip_data:
module.exit_json(changed=False, data=floating_ip_data)
@ -244,36 +245,34 @@ def create_floating_ips(module, rest):
if status_code == 202:
module.exit_json(changed=True, data=json_data)
else:
module.fail_json(msg="Error creating floating ip [{}: {}]".format(
status_code, json_data["message"]), region=module.params['region'])
module.fail_json(
msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
region=module.params["region"],
)
def main():
module = AnsibleModule(
argument_spec={
'state': {'choices': ['present', 'absent'], 'default': 'present'},
'ip': {'aliases': ['id'], 'required': False},
'region': {'required': False},
'droplet_id': {'required': False, 'type': 'int'},
'oauth_token': {
'no_log': True,
"state": {"choices": ["present", "absent"], "default": "present"},
"ip": {"aliases": ["id"], "required": False},
"region": {"required": False},
"droplet_id": {"required": False, "type": "int"},
"oauth_token": {
"no_log": True,
# Support environment variable for DigitalOcean OAuth Token
'fallback': (env_fallback, ['DO_API_TOKEN', 'DO_API_KEY', 'DO_OAUTH_TOKEN']),
'required': True,
"fallback": (env_fallback, ["DO_API_TOKEN", "DO_API_KEY", "DO_OAUTH_TOKEN"]),
"required": True,
},
'validate_certs': {'type': 'bool', 'default': True},
'timeout': {'type': 'int', 'default': 30},
"validate_certs": {"type": "bool", "default": True},
"timeout": {"type": "int", "default": 30},
},
required_if=[
('state', 'delete', ['ip'])
],
mutually_exclusive=[
['region', 'droplet_id']
],
required_if=[("state", "delete", ["ip"])],
mutually_exclusive=[["region", "droplet_id"]],
)
core(module)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,7 +1,6 @@
#!/usr/bin/python
import json
from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
@ -10,7 +9,7 @@ from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
# Documentation
################################################################################
ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported_by': 'community'}
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
################################################################################
# Main
@ -18,20 +17,24 @@ ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported
def main():
module = GcpModule(argument_spec={'filters': {'type': 'list', 'elements': 'str'}, 'scope': {'required': True, 'type': 'str'}})
module = GcpModule(
argument_spec={"filters": {"type": "list", "elements": "str"}, "scope": {"required": True, "type": "str"}}
)
if module._name == 'gcp_compute_image_facts':
module.deprecate("The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version='2.13')
if module._name == "gcp_compute_image_facts":
module.deprecate(
"The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version="2.13"
)
if not module.params['scopes']:
module.params['scopes'] = ['https://www.googleapis.com/auth/compute']
if not module.params["scopes"]:
module.params["scopes"] = ["https://www.googleapis.com/auth/compute"]
items = fetch_list(module, collection(module), query_options(module.params['filters']))
if items.get('items'):
items = items.get('items')
items = fetch_list(module, collection(module), query_options(module.params["filters"]))
if items.get("items"):
items = items.get("items")
else:
items = []
return_value = {'resources': items}
return_value = {"resources": items}
module.exit_json(**return_value)
@ -40,14 +43,14 @@ def collection(module):
def fetch_list(module, link, query):
auth = GcpSession(module, 'compute')
response = auth.get(link, params={'filter': query})
auth = GcpSession(module, "compute")
response = auth.get(link, params={"filter": query})
return return_if_object(module, response)
def query_options(filters):
if not filters:
return ''
return ""
if len(filters) == 1:
return filters[0]
@ -55,12 +58,12 @@ def query_options(filters):
queries = []
for f in filters:
# For multiple queries, all queries should have ()
if f[0] != '(' and f[-1] != ')':
queries.append("({})".format(''.join(f)))
if f[0] != "(" and f[-1] != ")":
queries.append("({})".format("".join(f)))
else:
queries.append(f)
return ' '.join(queries)
return " ".join(queries)
def return_if_object(module, response):
@ -75,11 +78,11 @@ def return_if_object(module, response):
try:
module.raise_for_status(response)
result = response.json()
except getattr(json.decoder, 'JSONDecodeError', ValueError) as inst:
except getattr(json.decoder, "JSONDecodeError", ValueError) as inst:
module.fail_json(msg=f"Invalid JSON response with error: {inst}")
if navigate_hash(result, ['error', 'errors']):
module.fail_json(msg=navigate_hash(result, ['error', 'errors']))
if navigate_hash(result, ["error", "errors"]):
module.fail_json(msg=navigate_hash(result, ["error", "errors"]))
return result

View file

@ -3,12 +3,9 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
ANSIBLE_METADATA = {'metadata_version': '1.1',
'status': ['preview'],
'supported_by': 'community'}
DOCUMENTATION = '''
DOCUMENTATION = """
---
module: lightsail_region_facts
short_description: Gather facts about AWS Lightsail regions.
@ -24,15 +21,15 @@ requirements:
extends_documentation_fragment:
- aws
- ec2
'''
"""
EXAMPLES = '''
EXAMPLES = """
# Gather facts about all regions
- lightsail_region_facts:
'''
"""
RETURN = '''
RETURN = """
regions:
returned: on success
description: >
@ -46,12 +43,13 @@ regions:
"displayName": "Virginia",
"name": "us-east-1"
}]"
'''
"""
import traceback
try:
import botocore
HAS_BOTOCORE = True
except ImportError:
HAS_BOTOCORE = False
@ -86,18 +84,19 @@ def main():
client = None
try:
client = boto3_conn(module, conn_type='client', resource='lightsail',
region=region, endpoint=ec2_url, **aws_connect_kwargs)
client = boto3_conn(
module, conn_type="client", resource="lightsail", region=region, endpoint=ec2_url, **aws_connect_kwargs
)
except (botocore.exceptions.ClientError, botocore.exceptions.ValidationError) as e:
module.fail_json(msg='Failed while connecting to the lightsail service: %s' % e, exception=traceback.format_exc())
module.fail_json(
msg="Failed while connecting to the lightsail service: %s" % e, exception=traceback.format_exc()
)
response = client.get_regions(
includeAvailabilityZones=False
)
response = client.get_regions(includeAvailabilityZones=False)
module.exit_json(changed=False, data=response)
except (botocore.exceptions.ClientError, Exception) as e:
module.fail_json(msg=str(e), exception=traceback.format_exc())
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -9,6 +9,7 @@ from ansible.module_utils.linode import get_user_agent
LINODE_IMP_ERR = None
try:
from linode_api4 import LinodeClient, StackScript
HAS_LINODE_DEPENDENCY = True
except ImportError:
LINODE_IMP_ERR = traceback.format_exc()
@ -21,57 +22,47 @@ def create_stackscript(module, client, **kwargs):
response = client.linode.stackscript_create(**kwargs)
return response._raw_json
except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def stackscript_available(module, client):
"""Try to retrieve a stackscript."""
try:
label = module.params['label']
desc = module.params['description']
label = module.params["label"]
desc = module.params["description"]
result = client.linode.stackscripts(StackScript.label == label,
StackScript.description == desc,
mine_only=True
)
result = client.linode.stackscripts(StackScript.label == label, StackScript.description == desc, mine_only=True)
return result[0]
except IndexError:
return None
except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def initialise_module():
"""Initialise the module parameter specification."""
return AnsibleModule(
argument_spec=dict(
label=dict(type='str', required=True),
state=dict(
type='str',
required=True,
choices=['present', 'absent']
),
label=dict(type="str", required=True),
state=dict(type="str", required=True, choices=["present", "absent"]),
access_token=dict(
type='str',
type="str",
required=True,
no_log=True,
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']),
fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
),
script=dict(type='str', required=True),
images=dict(type='list', required=True),
description=dict(type='str', required=False),
public=dict(type='bool', required=False, default=False),
script=dict(type="str", required=True),
images=dict(type="list", required=True),
description=dict(type="str", required=False),
public=dict(type="bool", required=False, default=False),
),
supports_check_mode=False
supports_check_mode=False,
)
def build_client(module):
"""Build a LinodeClient."""
return LinodeClient(
module.params['access_token'],
user_agent=get_user_agent('linode_v4_module')
)
return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
def main():
@ -79,30 +70,31 @@ def main():
module = initialise_module()
if not HAS_LINODE_DEPENDENCY:
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR)
module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
client = build_client(module)
stackscript = stackscript_available(module, client)
if module.params['state'] == 'present' and stackscript is not None:
if module.params["state"] == "present" and stackscript is not None:
module.exit_json(changed=False, stackscript=stackscript._raw_json)
elif module.params['state'] == 'present' and stackscript is None:
elif module.params["state"] == "present" and stackscript is None:
stackscript_json = create_stackscript(
module, client,
label=module.params['label'],
script=module.params['script'],
images=module.params['images'],
desc=module.params['description'],
public=module.params['public'],
module,
client,
label=module.params["label"],
script=module.params["script"],
images=module.params["images"],
desc=module.params["description"],
public=module.params["public"],
)
module.exit_json(changed=True, stackscript=stackscript_json)
elif module.params['state'] == 'absent' and stackscript is not None:
elif module.params["state"] == "absent" and stackscript is not None:
stackscript.delete()
module.exit_json(changed=True, stackscript=stackscript._raw_json)
elif module.params['state'] == 'absent' and stackscript is None:
elif module.params["state"] == "absent" and stackscript is None:
module.exit_json(changed=False, stackscript={})

View file

@ -13,6 +13,7 @@ from ansible.module_utils.linode import get_user_agent
LINODE_IMP_ERR = None
try:
from linode_api4 import Instance, LinodeClient
HAS_LINODE_DEPENDENCY = True
except ImportError:
LINODE_IMP_ERR = traceback.format_exc()
@ -21,82 +22,72 @@ except ImportError:
def create_linode(module, client, **kwargs):
"""Creates a Linode instance and handles return format."""
if kwargs['root_pass'] is None:
kwargs.pop('root_pass')
if kwargs["root_pass"] is None:
kwargs.pop("root_pass")
try:
response = client.linode.instance_create(**kwargs)
except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
try:
if isinstance(response, tuple):
instance, root_pass = response
instance_json = instance._raw_json
instance_json.update({'root_pass': root_pass})
instance_json.update({"root_pass": root_pass})
return instance_json
else:
return response._raw_json
except TypeError:
module.fail_json(msg='Unable to parse Linode instance creation'
' response. Please raise a bug against this'
' module on https://github.com/ansible/ansible/issues'
)
module.fail_json(
msg="Unable to parse Linode instance creation"
" response. Please raise a bug against this"
" module on https://github.com/ansible/ansible/issues"
)
def maybe_instance_from_label(module, client):
"""Try to retrieve an instance based on a label."""
try:
label = module.params['label']
label = module.params["label"]
result = client.linode.instances(Instance.label == label)
return result[0]
except IndexError:
return None
except Exception as exception:
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
def initialise_module():
"""Initialise the module parameter specification."""
return AnsibleModule(
argument_spec=dict(
label=dict(type='str', required=True),
state=dict(
type='str',
required=True,
choices=['present', 'absent']
),
label=dict(type="str", required=True),
state=dict(type="str", required=True, choices=["present", "absent"]),
access_token=dict(
type='str',
type="str",
required=True,
no_log=True,
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']),
fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
),
authorized_keys=dict(type='list', required=False),
group=dict(type='str', required=False),
image=dict(type='str', required=False),
region=dict(type='str', required=False),
root_pass=dict(type='str', required=False, no_log=True),
tags=dict(type='list', required=False),
type=dict(type='str', required=False),
stackscript_id=dict(type='int', required=False),
authorized_keys=dict(type="list", required=False),
group=dict(type="str", required=False),
image=dict(type="str", required=False),
region=dict(type="str", required=False),
root_pass=dict(type="str", required=False, no_log=True),
tags=dict(type="list", required=False),
type=dict(type="str", required=False),
stackscript_id=dict(type="int", required=False),
),
supports_check_mode=False,
required_one_of=(
['state', 'label'],
),
required_together=(
['region', 'image', 'type'],
)
required_one_of=(["state", "label"],),
required_together=(["region", "image", "type"],),
)
def build_client(module):
"""Build a LinodeClient."""
return LinodeClient(
module.params['access_token'],
user_agent=get_user_agent('linode_v4_module')
)
return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
def main():
@ -104,34 +95,35 @@ def main():
module = initialise_module()
if not HAS_LINODE_DEPENDENCY:
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR)
module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
client = build_client(module)
instance = maybe_instance_from_label(module, client)
if module.params['state'] == 'present' and instance is not None:
if module.params["state"] == "present" and instance is not None:
module.exit_json(changed=False, instance=instance._raw_json)
elif module.params['state'] == 'present' and instance is None:
elif module.params["state"] == "present" and instance is None:
instance_json = create_linode(
module, client,
authorized_keys=module.params['authorized_keys'],
group=module.params['group'],
image=module.params['image'],
label=module.params['label'],
region=module.params['region'],
root_pass=module.params['root_pass'],
tags=module.params['tags'],
ltype=module.params['type'],
stackscript_id=module.params['stackscript_id'],
module,
client,
authorized_keys=module.params["authorized_keys"],
group=module.params["group"],
image=module.params["image"],
label=module.params["label"],
region=module.params["region"],
root_pass=module.params["root_pass"],
tags=module.params["tags"],
ltype=module.params["type"],
stackscript_id=module.params["stackscript_id"],
)
module.exit_json(changed=True, instance=instance_json)
elif module.params['state'] == 'absent' and instance is not None:
elif module.params["state"] == "absent" and instance is not None:
instance.delete()
module.exit_json(changed=True, instance=instance._raw_json)
elif module.params['state'] == 'absent' and instance is None:
elif module.params["state"] == "absent" and instance is None:
module.exit_json(changed=False, instance={})

View file

@ -8,14 +8,9 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
ANSIBLE_METADATA = {
'metadata_version': '1.1',
'status': ['preview'],
'supported_by': 'community'
}
DOCUMENTATION = '''
DOCUMENTATION = """
---
module: scaleway_compute
short_description: Scaleway compute management module
@ -120,9 +115,9 @@ options:
- If no value provided, the default security group or current security group will be used
required: false
version_added: "2.8"
'''
"""
EXAMPLES = '''
EXAMPLES = """
- name: Create a server
scaleway_compute:
name: foobar
@ -156,10 +151,10 @@ EXAMPLES = '''
organization: 951df375-e094-4d26-97c1-ba548eeb9c42
region: ams1
commercial_type: VC1S
'''
"""
RETURN = '''
'''
RETURN = """
"""
import datetime
import time
@ -167,19 +162,9 @@ import time
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.scaleway import SCALEWAY_LOCATION, Scaleway, scaleway_argument_spec
SCALEWAY_SERVER_STATES = (
'stopped',
'stopping',
'starting',
'running',
'locked'
)
SCALEWAY_SERVER_STATES = ("stopped", "stopping", "starting", "running", "locked")
SCALEWAY_TRANSITIONS_STATES = (
"stopping",
"starting",
"pending"
)
SCALEWAY_TRANSITIONS_STATES = ("stopping", "starting", "pending")
def check_image_id(compute_api, image_id):
@ -188,9 +173,11 @@ def check_image_id(compute_api, image_id):
if response.ok and response.json:
image_ids = [image["id"] for image in response.json["images"]]
if image_id not in image_ids:
compute_api.module.fail_json(msg='Error in getting image %s on %s' % (image_id, compute_api.module.params.get('api_url')))
compute_api.module.fail_json(
msg="Error in getting image %s on %s" % (image_id, compute_api.module.params.get("api_url"))
)
else:
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get('api_url'))
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get("api_url"))
def fetch_state(compute_api, server):
@ -201,7 +188,7 @@ def fetch_state(compute_api, server):
return "absent"
if not response.ok:
msg = 'Error during state fetching: (%s) %s' % (response.status_code, response.json)
msg = "Error during state fetching: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
try:
@ -243,7 +230,7 @@ def public_ip_payload(compute_api, public_ip):
# We check that the IP we want to attach exists, if so its ID is returned
response = compute_api.get("ips")
if not response.ok:
msg = 'Error during public IP validation: (%s) %s' % (response.status_code, response.json)
msg = "Error during public IP validation: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
ip_list = []
@ -260,14 +247,15 @@ def public_ip_payload(compute_api, public_ip):
def create_server(compute_api, server):
compute_api.module.debug("Starting a create_server")
target_server = None
data = {"enable_ipv6": server["enable_ipv6"],
"tags": server["tags"],
"commercial_type": server["commercial_type"],
"image": server["image"],
"dynamic_ip_required": server["dynamic_ip_required"],
"name": server["name"],
"organization": server["organization"]
}
data = {
"enable_ipv6": server["enable_ipv6"],
"tags": server["tags"],
"commercial_type": server["commercial_type"],
"image": server["image"],
"dynamic_ip_required": server["dynamic_ip_required"],
"name": server["name"],
"organization": server["organization"],
}
if server["boot_type"]:
data["boot_type"] = server["boot_type"]
@ -278,7 +266,7 @@ def create_server(compute_api, server):
response = compute_api.post(path="servers", data=data)
if not response.ok:
msg = 'Error during server creation: (%s) %s' % (response.status_code, response.json)
msg = "Error during server creation: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
try:
@ -304,10 +292,9 @@ def start_server(compute_api, server):
def perform_action(compute_api, server, action):
response = compute_api.post(path="servers/%s/action" % server["id"],
data={"action": action})
response = compute_api.post(path="servers/%s/action" % server["id"], data={"action": action})
if not response.ok:
msg = 'Error during server %s: (%s) %s' % (action, response.status_code, response.json)
msg = "Error during server %s: (%s) %s" % (action, response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
wait_to_complete_state_transition(compute_api=compute_api, server=server)
@ -319,7 +306,7 @@ def remove_server(compute_api, server):
compute_api.module.debug("Starting remove server strategy")
response = compute_api.delete(path="servers/%s" % server["id"])
if not response.ok:
msg = 'Error during server deletion: (%s) %s' % (response.status_code, response.json)
msg = "Error during server deletion: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
wait_to_complete_state_transition(compute_api=compute_api, server=server)
@ -341,14 +328,17 @@ def present_strategy(compute_api, wished_server):
else:
target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
wished_server=wished_server):
if server_attributes_should_be_changed(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True
if compute_api.module.check_mode:
return changed, {"status": "Server %s attributes would be changed." % target_server["id"]}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
return changed, target_server
@ -375,7 +365,7 @@ def absent_strategy(compute_api, wished_server):
response = stop_server(compute_api=compute_api, server=target_server)
if not response.ok:
err_msg = f'Error while stopping a server before removing it [{response.status_code}: {response.json}]'
err_msg = f"Error while stopping a server before removing it [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=err_msg)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
@ -383,7 +373,7 @@ def absent_strategy(compute_api, wished_server):
response = remove_server(compute_api=compute_api, server=target_server)
if not response.ok:
err_msg = f'Error while removing server [{response.status_code}: {response.json}]'
err_msg = f"Error while removing server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=err_msg)
return changed, {"status": "Server %s deleted" % target_server["id"]}
@ -403,14 +393,17 @@ def running_strategy(compute_api, wished_server):
else:
target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
wished_server=wished_server):
if server_attributes_should_be_changed(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True
if compute_api.module.check_mode:
return changed, {"status": "Server %s attributes would be changed before running it." % target_server["id"]}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
current_state = fetch_state(compute_api=compute_api, server=target_server)
if current_state not in ("running", "starting"):
@ -422,7 +415,7 @@ def running_strategy(compute_api, wished_server):
response = start_server(compute_api=compute_api, server=target_server)
if not response.ok:
msg = f'Error while running server [{response.status_code}: {response.json}]'
msg = f"Error while running server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=msg)
return changed, target_server
@ -435,7 +428,6 @@ def stop_strategy(compute_api, wished_server):
changed = False
if not query_results:
if compute_api.module.check_mode:
return changed, {"status": "A server would be created before being stopped."}
@ -446,15 +438,19 @@ def stop_strategy(compute_api, wished_server):
compute_api.module.debug("stop_strategy: Servers are found.")
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
wished_server=wished_server):
if server_attributes_should_be_changed(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True
if compute_api.module.check_mode:
return changed, {
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]}
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]
}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
@ -472,7 +468,7 @@ def stop_strategy(compute_api, wished_server):
compute_api.module.debug(response.ok)
if not response.ok:
msg = f'Error while stopping server [{response.status_code}: {response.json}]'
msg = f"Error while stopping server [{response.status_code}: {response.json}]"
compute_api.module.fail_json(msg=msg)
return changed, target_server
@ -492,16 +488,19 @@ def restart_strategy(compute_api, wished_server):
else:
target_server = query_results[0]
if server_attributes_should_be_changed(compute_api=compute_api,
target_server=target_server,
wished_server=wished_server):
if server_attributes_should_be_changed(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
):
changed = True
if compute_api.module.check_mode:
return changed, {
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]}
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]
}
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
target_server = server_change_attributes(
compute_api=compute_api, target_server=target_server, wished_server=wished_server
)
changed = True
if compute_api.module.check_mode:
@ -513,14 +512,14 @@ def restart_strategy(compute_api, wished_server):
response = restart_server(compute_api=compute_api, server=target_server)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
if not response.ok:
msg = f'Error while restarting server that was running [{response.status_code}: {response.json}].'
msg = f"Error while restarting server that was running [{response.status_code}: {response.json}]."
compute_api.module.fail_json(msg=msg)
if fetch_state(compute_api=compute_api, server=target_server) in ("stopped",):
response = restart_server(compute_api=compute_api, server=target_server)
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
if not response.ok:
msg = f'Error while restarting server that was stopped [{response.status_code}: {response.json}].'
msg = f"Error while restarting server that was stopped [{response.status_code}: {response.json}]."
compute_api.module.fail_json(msg=msg)
return changed, target_server
@ -531,18 +530,17 @@ state_strategy = {
"restarted": restart_strategy,
"stopped": stop_strategy,
"running": running_strategy,
"absent": absent_strategy
"absent": absent_strategy,
}
def find(compute_api, wished_server, per_page=1):
compute_api.module.debug("Getting inside find")
# Only the name attribute is accepted in the Compute query API
response = compute_api.get("servers", params={"name": wished_server["name"],
"per_page": per_page})
response = compute_api.get("servers", params={"name": wished_server["name"], "per_page": per_page})
if not response.ok:
msg = 'Error during server search: (%s) %s' % (response.status_code, response.json)
msg = "Error during server search: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
search_results = response.json["servers"]
@ -563,16 +561,22 @@ def server_attributes_should_be_changed(compute_api, target_server, wished_serve
compute_api.module.debug("Checking if server attributes should be changed")
compute_api.module.debug("Current Server: %s" % target_server)
compute_api.module.debug("Wished Server: %s" % wished_server)
debug_dict = dict((x, (target_server[x], wished_server[x]))
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
if x in target_server and x in wished_server)
debug_dict = dict(
(x, (target_server[x], wished_server[x]))
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
if x in target_server and x in wished_server
)
compute_api.module.debug("Debug dict %s" % debug_dict)
try:
for key in PATCH_MUTABLE_SERVER_ATTRIBUTES:
if key in target_server and key in wished_server:
# When you are working with dict, only ID matter as we ask user to put only the resource ID in the playbook
if isinstance(target_server[key], dict) and wished_server[key] and "id" in target_server[key].keys(
) and target_server[key]["id"] != wished_server[key]:
if (
isinstance(target_server[key], dict)
and wished_server[key]
and "id" in target_server[key].keys()
and target_server[key]["id"] != wished_server[key]
):
return True
# Handling other structure compare simply the two objects content
elif not isinstance(target_server[key], dict) and target_server[key] != wished_server[key]:
@ -598,10 +602,9 @@ def server_change_attributes(compute_api, target_server, wished_server):
elif not isinstance(target_server[key], dict):
patch_payload[key] = wished_server[key]
response = compute_api.patch(path="servers/%s" % target_server["id"],
data=patch_payload)
response = compute_api.patch(path="servers/%s" % target_server["id"], data=patch_payload)
if not response.ok:
msg = 'Error during server attributes patching: (%s) %s' % (response.status_code, response.json)
msg = "Error during server attributes patching: (%s) %s" % (response.status_code, response.json)
compute_api.module.fail_json(msg=msg)
try:
@ -625,9 +628,9 @@ def core(module):
"boot_type": module.params["boot_type"],
"tags": module.params["tags"],
"organization": module.params["organization"],
"security_group": module.params["security_group"]
"security_group": module.params["security_group"],
}
module.params['api_url'] = SCALEWAY_LOCATION[region]["api_endpoint"]
module.params["api_url"] = SCALEWAY_LOCATION[region]["api_endpoint"]
compute_api = Scaleway(module=module)
@ -643,22 +646,24 @@ def core(module):
def main():
argument_spec = scaleway_argument_spec()
argument_spec.update(dict(
image=dict(required=True),
name=dict(),
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
commercial_type=dict(required=True),
enable_ipv6=dict(default=False, type="bool"),
boot_type=dict(choices=['bootscript', 'local']),
public_ip=dict(default="absent"),
state=dict(choices=state_strategy.keys(), default='present'),
tags=dict(type="list", default=[]),
organization=dict(required=True),
wait=dict(type="bool", default=False),
wait_timeout=dict(type="int", default=300),
wait_sleep_time=dict(type="int", default=3),
security_group=dict(),
))
argument_spec.update(
dict(
image=dict(required=True),
name=dict(),
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
commercial_type=dict(required=True),
enable_ipv6=dict(default=False, type="bool"),
boot_type=dict(choices=["bootscript", "local"]),
public_ip=dict(default="absent"),
state=dict(choices=state_strategy.keys(), default="present"),
tags=dict(type="list", default=[]),
organization=dict(required=True),
wait=dict(type="bool", default=False),
wait_timeout=dict(type="int", default=300),
wait_sleep_time=dict(type="int", default=3),
security_group=dict(),
)
)
module = AnsibleModule(
argument_spec=argument_spec,
supports_check_mode=True,
@ -667,5 +672,5 @@ def main():
core(module)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -32,32 +32,30 @@ Returns:
def run_module():
"""
Main execution function for the x25519_pubkey Ansible module.
Handles parameter validation, private key processing, public key derivation,
and optional file output with idempotent behavior.
"""
module_args = {
'private_key_b64': {'type': 'str', 'required': False},
'private_key_path': {'type': 'path', 'required': False},
'public_key_path': {'type': 'path', 'required': False},
"private_key_b64": {"type": "str", "required": False},
"private_key_path": {"type": "path", "required": False},
"public_key_path": {"type": "path", "required": False},
}
result = {
'changed': False,
'public_key': '',
"changed": False,
"public_key": "",
}
module = AnsibleModule(
argument_spec=module_args,
required_one_of=[['private_key_b64', 'private_key_path']],
supports_check_mode=True
argument_spec=module_args, required_one_of=[["private_key_b64", "private_key_path"]], supports_check_mode=True
)
priv_b64 = None
if module.params['private_key_path']:
if module.params["private_key_path"]:
try:
with open(module.params['private_key_path'], 'rb') as f:
with open(module.params["private_key_path"], "rb") as f:
data = f.read()
try:
# First attempt: assume file contains base64 text data
@ -71,12 +69,14 @@ def run_module():
# whitespace-like bytes (0x09, 0x0A, etc.) that must be preserved
# Stripping would corrupt the key and cause "got 31 bytes" errors
if len(data) != 32:
module.fail_json(msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes")
module.fail_json(
msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes"
)
priv_b64 = base64.b64encode(data).decode()
except OSError as e:
module.fail_json(msg=f"Failed to read private key file: {e}")
else:
priv_b64 = module.params['private_key_b64']
priv_b64 = module.params["private_key_b64"]
# Validate input parameters
if not priv_b64:
@ -93,15 +93,12 @@ def run_module():
try:
priv_key = x25519.X25519PrivateKey.from_private_bytes(priv_raw)
pub_key = priv_key.public_key()
pub_raw = pub_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw
)
pub_raw = pub_key.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
pub_b64 = base64.b64encode(pub_raw).decode()
result['public_key'] = pub_b64
result["public_key"] = pub_b64
if module.params['public_key_path']:
pub_path = module.params['public_key_path']
if module.params["public_key_path"]:
pub_path = module.params["public_key_path"]
existing = None
try:
@ -112,13 +109,13 @@ def run_module():
if existing != pub_b64:
try:
with open(pub_path, 'w') as f:
with open(pub_path, "w") as f:
f.write(pub_b64)
result['changed'] = True
result["changed"] = True
except OSError as e:
module.fail_json(msg=f"Failed to write public key file: {e}")
result['public_key_path'] = pub_path
result["public_key_path"] = pub_path
except Exception as e:
module.fail_json(msg=f"Failed to derive public key: {e}")
@ -131,5 +128,5 @@ def main():
run_module()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -3,6 +3,7 @@
Track test effectiveness by analyzing CI failures and correlating with issues/PRs
This helps identify which tests actually catch bugs vs just failing randomly
"""
import json
import subprocess
from collections import defaultdict
@ -12,7 +13,7 @@ from pathlib import Path
def get_github_api_data(endpoint):
"""Fetch data from GitHub API"""
cmd = ['gh', 'api', endpoint]
cmd = ["gh", "api", endpoint]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Error fetching {endpoint}: {result.stderr}")
@ -25,40 +26,38 @@ def analyze_workflow_runs(repo_owner, repo_name, days_back=30):
since = (datetime.now() - timedelta(days=days_back)).isoformat()
# Get workflow runs
runs = get_github_api_data(
f'/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure'
)
runs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure")
if not runs:
return {}
test_failures = defaultdict(list)
for run in runs.get('workflow_runs', []):
for run in runs.get("workflow_runs", []):
# Get jobs for this run
jobs = get_github_api_data(
f'/repos/{repo_owner}/{repo_name}/actions/runs/{run["id"]}/jobs'
)
jobs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs/{run['id']}/jobs")
if not jobs:
continue
for job in jobs.get('jobs', []):
if job['conclusion'] == 'failure':
for job in jobs.get("jobs", []):
if job["conclusion"] == "failure":
# Try to extract which test failed from logs
logs_url = job.get('logs_url')
logs_url = job.get("logs_url")
if logs_url:
# Parse logs to find test failures
test_name = extract_failed_test(job['name'], run['id'])
test_name = extract_failed_test(job["name"], run["id"])
if test_name:
test_failures[test_name].append({
'run_id': run['id'],
'run_number': run['run_number'],
'date': run['created_at'],
'branch': run['head_branch'],
'commit': run['head_sha'][:7],
'pr': extract_pr_number(run)
})
test_failures[test_name].append(
{
"run_id": run["id"],
"run_number": run["run_number"],
"date": run["created_at"],
"branch": run["head_branch"],
"commit": run["head_sha"][:7],
"pr": extract_pr_number(run),
}
)
return test_failures
@ -67,47 +66,44 @@ def extract_failed_test(job_name, run_id):
"""Extract test name from job - this is simplified"""
# Map job names to test categories
job_to_tests = {
'Basic sanity tests': 'test_basic_sanity',
'Ansible syntax check': 'ansible_syntax',
'Docker build test': 'docker_tests',
'Configuration generation test': 'config_generation',
'Ansible dry-run validation': 'ansible_dry_run'
"Basic sanity tests": "test_basic_sanity",
"Ansible syntax check": "ansible_syntax",
"Docker build test": "docker_tests",
"Configuration generation test": "config_generation",
"Ansible dry-run validation": "ansible_dry_run",
}
return job_to_tests.get(job_name, job_name)
def extract_pr_number(run):
"""Extract PR number from workflow run"""
for pr in run.get('pull_requests', []):
return pr['number']
for pr in run.get("pull_requests", []):
return pr["number"]
return None
def correlate_with_issues(repo_owner, repo_name, test_failures):
"""Correlate test failures with issues/PRs that fixed them"""
correlations = defaultdict(lambda: {'caught_bugs': 0, 'false_positives': 0})
correlations = defaultdict(lambda: {"caught_bugs": 0, "false_positives": 0})
for test_name, failures in test_failures.items():
for failure in failures:
if failure['pr']:
if failure["pr"]:
# Check if PR was merged (indicating it fixed a real issue)
pr = get_github_api_data(
f'/repos/{repo_owner}/{repo_name}/pulls/{failure["pr"]}'
)
pr = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/pulls/{failure['pr']}")
if pr and pr.get('merged'):
if pr and pr.get("merged"):
# Check PR title/body for bug indicators
title = pr.get('title', '').lower()
body = pr.get('body', '').lower()
title = pr.get("title", "").lower()
body = pr.get("body", "").lower()
bug_keywords = ['fix', 'bug', 'error', 'issue', 'broken', 'fail']
is_bug_fix = any(keyword in title or keyword in body
for keyword in bug_keywords)
bug_keywords = ["fix", "bug", "error", "issue", "broken", "fail"]
is_bug_fix = any(keyword in title or keyword in body for keyword in bug_keywords)
if is_bug_fix:
correlations[test_name]['caught_bugs'] += 1
correlations[test_name]["caught_bugs"] += 1
else:
correlations[test_name]['false_positives'] += 1
correlations[test_name]["false_positives"] += 1
return correlations
@ -133,8 +129,8 @@ def generate_effectiveness_report(test_failures, correlations):
scores = []
for test_name, failures in test_failures.items():
failure_count = len(failures)
caught = correlations[test_name]['caught_bugs']
false_pos = correlations[test_name]['false_positives']
caught = correlations[test_name]["caught_bugs"]
false_pos = correlations[test_name]["false_positives"]
# Calculate effectiveness (bugs caught / total failures)
if failure_count > 0:
@ -159,12 +155,12 @@ def generate_effectiveness_report(test_failures, correlations):
elif effectiveness > 0.8:
report.append(f"- ✅ `{test_name}` is highly effective ({effectiveness:.0%})")
return '\n'.join(report)
return "\n".join(report)
def save_metrics(test_failures, correlations):
"""Save metrics to JSON for historical tracking"""
metrics_file = Path('.metrics/test-effectiveness.json')
metrics_file = Path(".metrics/test-effectiveness.json")
metrics_file.parent.mkdir(exist_ok=True)
# Load existing metrics
@ -176,38 +172,34 @@ def save_metrics(test_failures, correlations):
# Add current metrics
current = {
'date': datetime.now().isoformat(),
'test_failures': {
test: len(failures) for test, failures in test_failures.items()
},
'effectiveness': {
"date": datetime.now().isoformat(),
"test_failures": {test: len(failures) for test, failures in test_failures.items()},
"effectiveness": {
test: {
'caught_bugs': data['caught_bugs'],
'false_positives': data['false_positives'],
'score': data['caught_bugs'] / (data['caught_bugs'] + data['false_positives'])
if (data['caught_bugs'] + data['false_positives']) > 0 else 0
"caught_bugs": data["caught_bugs"],
"false_positives": data["false_positives"],
"score": data["caught_bugs"] / (data["caught_bugs"] + data["false_positives"])
if (data["caught_bugs"] + data["false_positives"]) > 0
else 0,
}
for test, data in correlations.items()
}
},
}
historical.append(current)
# Keep last 12 months of data
cutoff = datetime.now() - timedelta(days=365)
historical = [
h for h in historical
if datetime.fromisoformat(h['date']) > cutoff
]
historical = [h for h in historical if datetime.fromisoformat(h["date"]) > cutoff]
with open(metrics_file, 'w') as f:
with open(metrics_file, "w") as f:
json.dump(historical, f, indent=2)
if __name__ == '__main__':
if __name__ == "__main__":
# Configure these for your repo
REPO_OWNER = 'trailofbits'
REPO_NAME = 'algo'
REPO_OWNER = "trailofbits"
REPO_NAME = "algo"
print("Analyzing test effectiveness...")
@ -223,9 +215,9 @@ if __name__ == '__main__':
print("\n" + report)
# Save report
report_file = Path('.metrics/test-effectiveness-report.md')
report_file = Path(".metrics/test-effectiveness-report.md")
report_file.parent.mkdir(exist_ok=True)
with open(report_file, 'w') as f:
with open(report_file, "w") as f:
f.write(report)
print(f"\nReport saved to: {report_file}")

View file

@ -1,4 +1,5 @@
"""Test fixtures for Algo unit tests"""
from pathlib import Path
import yaml
@ -6,7 +7,7 @@ import yaml
def load_test_variables():
"""Load test variables from YAML fixture"""
fixture_path = Path(__file__).parent / 'test_variables.yml'
fixture_path = Path(__file__).parent / "test_variables.yml"
with open(fixture_path) as f:
return yaml.safe_load(f)

View file

@ -2,49 +2,56 @@
"""
Wrapper for Ansible's service module that always succeeds for known services
"""
import json
import sys
# Parse module arguments
args = json.loads(sys.stdin.read())
module_args = args.get('ANSIBLE_MODULE_ARGS', {})
module_args = args.get("ANSIBLE_MODULE_ARGS", {})
service_name = module_args.get('name', '')
state = module_args.get('state', 'started')
service_name = module_args.get("name", "")
state = module_args.get("state", "started")
# Known services that should always succeed
known_services = [
'netfilter-persistent', 'iptables', 'wg-quick@wg0', 'strongswan-starter',
'ipsec', 'apparmor', 'unattended-upgrades', 'systemd-networkd',
'systemd-resolved', 'rsyslog', 'ipfw', 'cron'
"netfilter-persistent",
"iptables",
"wg-quick@wg0",
"strongswan-starter",
"ipsec",
"apparmor",
"unattended-upgrades",
"systemd-networkd",
"systemd-resolved",
"rsyslog",
"ipfw",
"cron",
]
# Check if it's a known service
service_found = False
for svc in known_services:
if service_name == svc or service_name.startswith(svc + '.'):
if service_name == svc or service_name.startswith(svc + "."):
service_found = True
break
if service_found:
# Return success
result = {
'changed': True if state in ['started', 'stopped', 'restarted', 'reloaded'] else False,
'name': service_name,
'state': state,
'status': {
'LoadState': 'loaded',
'ActiveState': 'active' if state != 'stopped' else 'inactive',
'SubState': 'running' if state != 'stopped' else 'dead'
}
"changed": True if state in ["started", "stopped", "restarted", "reloaded"] else False,
"name": service_name,
"state": state,
"status": {
"LoadState": "loaded",
"ActiveState": "active" if state != "stopped" else "inactive",
"SubState": "running" if state != "stopped" else "dead",
},
}
print(json.dumps(result))
sys.exit(0)
else:
# Service not found
error = {
'failed': True,
'msg': f'Could not find the requested service {service_name}: '
}
error = {"failed": True, "msg": f"Could not find the requested service {service_name}: "}
print(json.dumps(error))
sys.exit(1)

View file

@ -9,44 +9,44 @@ from ansible.module_utils.basic import AnsibleModule
def main():
module = AnsibleModule(
argument_spec={
'name': {'type': 'list', 'aliases': ['pkg', 'package']},
'state': {'type': 'str', 'default': 'present', 'choices': ['present', 'absent', 'latest', 'build-dep', 'fixed']},
'update_cache': {'type': 'bool', 'default': False},
'cache_valid_time': {'type': 'int', 'default': 0},
'install_recommends': {'type': 'bool'},
'force': {'type': 'bool', 'default': False},
'allow_unauthenticated': {'type': 'bool', 'default': False},
'allow_downgrade': {'type': 'bool', 'default': False},
'allow_change_held_packages': {'type': 'bool', 'default': False},
'dpkg_options': {'type': 'str', 'default': 'force-confdef,force-confold'},
'autoremove': {'type': 'bool', 'default': False},
'purge': {'type': 'bool', 'default': False},
'force_apt_get': {'type': 'bool', 'default': False},
"name": {"type": "list", "aliases": ["pkg", "package"]},
"state": {
"type": "str",
"default": "present",
"choices": ["present", "absent", "latest", "build-dep", "fixed"],
},
"update_cache": {"type": "bool", "default": False},
"cache_valid_time": {"type": "int", "default": 0},
"install_recommends": {"type": "bool"},
"force": {"type": "bool", "default": False},
"allow_unauthenticated": {"type": "bool", "default": False},
"allow_downgrade": {"type": "bool", "default": False},
"allow_change_held_packages": {"type": "bool", "default": False},
"dpkg_options": {"type": "str", "default": "force-confdef,force-confold"},
"autoremove": {"type": "bool", "default": False},
"purge": {"type": "bool", "default": False},
"force_apt_get": {"type": "bool", "default": False},
},
supports_check_mode=True
supports_check_mode=True,
)
name = module.params['name']
state = module.params['state']
update_cache = module.params['update_cache']
name = module.params["name"]
state = module.params["state"]
update_cache = module.params["update_cache"]
result = {
'changed': False,
'cache_updated': False,
'cache_update_time': 0
}
result = {"changed": False, "cache_updated": False, "cache_update_time": 0}
# Log the operation
with open('/var/log/mock-apt-module.log', 'a') as f:
with open("/var/log/mock-apt-module.log", "a") as f:
f.write(f"apt module called: name={name}, state={state}, update_cache={update_cache}\n")
# Handle cache update
if update_cache:
# In Docker, apt-get update was already run in entrypoint
# Just pretend it succeeded
result['cache_updated'] = True
result['cache_update_time'] = 1754231778 # Fixed timestamp
result['changed'] = True
result["cache_updated"] = True
result["cache_update_time"] = 1754231778 # Fixed timestamp
result["changed"] = True
# Handle package installation/removal
if name:
@ -56,40 +56,41 @@ def main():
installed_packages = []
for pkg in packages:
# Use dpkg to check if package is installed
check_cmd = ['dpkg', '-s', pkg]
check_cmd = ["dpkg", "-s", pkg]
rc = subprocess.run(check_cmd, capture_output=True)
if rc.returncode == 0:
installed_packages.append(pkg)
if state in ['present', 'latest']:
if state in ["present", "latest"]:
# Check if we need to install anything
missing_packages = [p for p in packages if p not in installed_packages]
if missing_packages:
# Log what we would install
with open('/var/log/mock-apt-module.log', 'a') as f:
with open("/var/log/mock-apt-module.log", "a") as f:
f.write(f"Would install packages: {missing_packages}\n")
# For our test purposes, these packages are pre-installed in Docker
# Just report success
result['changed'] = True
result['stdout'] = f"Mock: Packages {missing_packages} are already available"
result['stderr'] = ""
result["changed"] = True
result["stdout"] = f"Mock: Packages {missing_packages} are already available"
result["stderr"] = ""
else:
result['stdout'] = "All packages are already installed"
result["stdout"] = "All packages are already installed"
elif state == 'absent':
elif state == "absent":
# Check if we need to remove anything
present_packages = [p for p in packages if p in installed_packages]
if present_packages:
result['changed'] = True
result['stdout'] = f"Mock: Would remove packages {present_packages}"
result["changed"] = True
result["stdout"] = f"Mock: Would remove packages {present_packages}"
else:
result['stdout'] = "No packages to remove"
result["stdout"] = "No packages to remove"
# Always report success for our testing
module.exit_json(**result)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -9,79 +9,72 @@ from ansible.module_utils.basic import AnsibleModule
def main():
module = AnsibleModule(
argument_spec={
'_raw_params': {'type': 'str'},
'cmd': {'type': 'str'},
'creates': {'type': 'path'},
'removes': {'type': 'path'},
'chdir': {'type': 'path'},
'executable': {'type': 'path'},
'warn': {'type': 'bool', 'default': False},
'stdin': {'type': 'str'},
'stdin_add_newline': {'type': 'bool', 'default': True},
'strip_empty_ends': {'type': 'bool', 'default': True},
'_uses_shell': {'type': 'bool', 'default': False},
"_raw_params": {"type": "str"},
"cmd": {"type": "str"},
"creates": {"type": "path"},
"removes": {"type": "path"},
"chdir": {"type": "path"},
"executable": {"type": "path"},
"warn": {"type": "bool", "default": False},
"stdin": {"type": "str"},
"stdin_add_newline": {"type": "bool", "default": True},
"strip_empty_ends": {"type": "bool", "default": True},
"_uses_shell": {"type": "bool", "default": False},
},
supports_check_mode=True
supports_check_mode=True,
)
# Get the command
raw_params = module.params.get('_raw_params')
cmd = module.params.get('cmd') or raw_params
raw_params = module.params.get("_raw_params")
cmd = module.params.get("cmd") or raw_params
if not cmd:
module.fail_json(msg="no command given")
result = {
'changed': False,
'cmd': cmd,
'rc': 0,
'stdout': '',
'stderr': '',
'stdout_lines': [],
'stderr_lines': []
}
result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
# Log the operation
with open('/var/log/mock-command-module.log', 'a') as f:
with open("/var/log/mock-command-module.log", "a") as f:
f.write(f"command module called: cmd={cmd}\n")
# Handle specific commands
if 'apparmor_status' in cmd:
if "apparmor_status" in cmd:
# Pretend apparmor is not installed/active
result['rc'] = 127
result['stderr'] = "apparmor_status: command not found"
result['msg'] = "[Errno 2] No such file or directory: b'apparmor_status'"
module.fail_json(msg=result['msg'], **result)
elif 'netplan apply' in cmd:
result["rc"] = 127
result["stderr"] = "apparmor_status: command not found"
result["msg"] = "[Errno 2] No such file or directory: b'apparmor_status'"
module.fail_json(msg=result["msg"], **result)
elif "netplan apply" in cmd:
# Pretend netplan succeeded
result['stdout'] = "Mock: netplan configuration applied"
result['changed'] = True
elif 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd:
result["stdout"] = "Mock: netplan configuration applied"
result["changed"] = True
elif "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
# Routing cache flush
result['stdout'] = "1"
result['changed'] = True
result["stdout"] = "1"
result["changed"] = True
else:
# For other commands, try to run them
try:
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get('chdir'))
result['rc'] = proc.returncode
result['stdout'] = proc.stdout
result['stderr'] = proc.stderr
result['stdout_lines'] = proc.stdout.splitlines()
result['stderr_lines'] = proc.stderr.splitlines()
result['changed'] = True
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get("chdir"))
result["rc"] = proc.returncode
result["stdout"] = proc.stdout
result["stderr"] = proc.stderr
result["stdout_lines"] = proc.stdout.splitlines()
result["stderr_lines"] = proc.stderr.splitlines()
result["changed"] = True
except Exception as e:
result['rc'] = 1
result['stderr'] = str(e)
result['msg'] = str(e)
module.fail_json(msg=result['msg'], **result)
result["rc"] = 1
result["stderr"] = str(e)
result["msg"] = str(e)
module.fail_json(msg=result["msg"], **result)
if result['rc'] == 0:
if result["rc"] == 0:
module.exit_json(**result)
else:
if 'msg' not in result:
result['msg'] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result['msg'], **result)
if "msg" not in result:
result["msg"] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result["msg"], **result)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -9,73 +9,71 @@ from ansible.module_utils.basic import AnsibleModule
def main():
module = AnsibleModule(
argument_spec={
'_raw_params': {'type': 'str'},
'cmd': {'type': 'str'},
'creates': {'type': 'path'},
'removes': {'type': 'path'},
'chdir': {'type': 'path'},
'executable': {'type': 'path', 'default': '/bin/sh'},
'warn': {'type': 'bool', 'default': False},
'stdin': {'type': 'str'},
'stdin_add_newline': {'type': 'bool', 'default': True},
"_raw_params": {"type": "str"},
"cmd": {"type": "str"},
"creates": {"type": "path"},
"removes": {"type": "path"},
"chdir": {"type": "path"},
"executable": {"type": "path", "default": "/bin/sh"},
"warn": {"type": "bool", "default": False},
"stdin": {"type": "str"},
"stdin_add_newline": {"type": "bool", "default": True},
},
supports_check_mode=True
supports_check_mode=True,
)
# Get the command
raw_params = module.params.get('_raw_params')
cmd = module.params.get('cmd') or raw_params
raw_params = module.params.get("_raw_params")
cmd = module.params.get("cmd") or raw_params
if not cmd:
module.fail_json(msg="no command given")
result = {
'changed': False,
'cmd': cmd,
'rc': 0,
'stdout': '',
'stderr': '',
'stdout_lines': [],
'stderr_lines': []
}
result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
# Log the operation
with open('/var/log/mock-shell-module.log', 'a') as f:
with open("/var/log/mock-shell-module.log", "a") as f:
f.write(f"shell module called: cmd={cmd}\n")
# Handle specific commands
if 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd:
if "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
# Routing cache flush - just pretend it worked
result['stdout'] = ""
result['changed'] = True
elif 'ifconfig lo100' in cmd:
result["stdout"] = ""
result["changed"] = True
elif "ifconfig lo100" in cmd:
# BSD loopback commands - simulate success
result['stdout'] = "0"
result['changed'] = True
result["stdout"] = "0"
result["changed"] = True
else:
# For other commands, try to run them
try:
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True,
executable=module.params.get('executable'),
cwd=module.params.get('chdir'))
result['rc'] = proc.returncode
result['stdout'] = proc.stdout
result['stderr'] = proc.stderr
result['stdout_lines'] = proc.stdout.splitlines()
result['stderr_lines'] = proc.stderr.splitlines()
result['changed'] = True
proc = subprocess.run(
cmd,
shell=True,
capture_output=True,
text=True,
executable=module.params.get("executable"),
cwd=module.params.get("chdir"),
)
result["rc"] = proc.returncode
result["stdout"] = proc.stdout
result["stderr"] = proc.stderr
result["stdout_lines"] = proc.stdout.splitlines()
result["stderr_lines"] = proc.stderr.splitlines()
result["changed"] = True
except Exception as e:
result['rc'] = 1
result['stderr'] = str(e)
result['msg'] = str(e)
module.fail_json(msg=result['msg'], **result)
result["rc"] = 1
result["stderr"] = str(e)
result["msg"] = str(e)
module.fail_json(msg=result["msg"], **result)
if result['rc'] == 0:
if result["rc"] == 0:
module.exit_json(**result)
else:
if 'msg' not in result:
result['msg'] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result['msg'], **result)
if "msg" not in result:
result["msg"] = f"Command failed with return code {result['rc']}"
module.fail_json(msg=result["msg"], **result)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -24,6 +24,7 @@ import yaml
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
def create_expected_cloud_init():
"""
Create the expected cloud-init content that should be generated
@ -74,6 +75,7 @@ runcmd:
- systemctl restart sshd.service
"""
class TestCloudInitTemplate:
"""Test class for cloud-init template validation."""
@ -98,10 +100,7 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity()
required_sections = [
'package_update', 'package_upgrade', 'packages',
'users', 'write_files', 'runcmd'
]
required_sections = ["package_update", "package_upgrade", "packages", "users", "write_files", "runcmd"]
missing = [section for section in required_sections if section not in parsed]
assert not missing, f"Missing required sections: {missing}"
@ -114,35 +113,30 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity()
write_files = parsed.get('write_files', [])
write_files = parsed.get("write_files", [])
assert write_files, "write_files section should be present"
# Find sshd_config file
sshd_config = None
for file_entry in write_files:
if file_entry.get('path') == '/etc/ssh/sshd_config':
if file_entry.get("path") == "/etc/ssh/sshd_config":
sshd_config = file_entry
break
assert sshd_config, "sshd_config file should be in write_files"
content = sshd_config.get('content', '')
content = sshd_config.get("content", "")
assert content, "sshd_config should have content"
# Check required SSH configurations
required_configs = [
'Port 4160',
'AllowGroups algo',
'PermitRootLogin no',
'PasswordAuthentication no'
]
required_configs = ["Port 4160", "AllowGroups algo", "PermitRootLogin no", "PasswordAuthentication no"]
missing = [config for config in required_configs if config not in content]
assert not missing, f"Missing SSH configurations: {missing}"
# Verify proper formatting - first line should be Port directive
lines = content.strip().split('\n')
assert lines[0].strip() == 'Port 4160', f"First line should be 'Port 4160', got: {repr(lines[0])}"
lines = content.strip().split("\n")
assert lines[0].strip() == "Port 4160", f"First line should be 'Port 4160', got: {repr(lines[0])}"
print("✅ SSH configuration correct")
@ -152,26 +146,26 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity()
users = parsed.get('users', [])
users = parsed.get("users", [])
assert users, "users section should be present"
# Find algo user
algo_user = None
for user in users:
if isinstance(user, dict) and user.get('name') == 'algo':
if isinstance(user, dict) and user.get("name") == "algo":
algo_user = user
break
assert algo_user, "algo user should be defined"
# Check required user properties
required_props = ['sudo', 'groups', 'shell', 'ssh_authorized_keys']
required_props = ["sudo", "groups", "shell", "ssh_authorized_keys"]
missing = [prop for prop in required_props if prop not in algo_user]
assert not missing, f"algo user missing properties: {missing}"
# Verify sudo configuration
sudo_config = algo_user.get('sudo', '')
assert 'NOPASSWD:ALL' in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
sudo_config = algo_user.get("sudo", "")
assert "NOPASSWD:ALL" in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
print("✅ User creation correct")
@ -181,13 +175,13 @@ class TestCloudInitTemplate:
parsed = self.test_yaml_validity()
runcmd = parsed.get('runcmd', [])
runcmd = parsed.get("runcmd", [])
assert runcmd, "runcmd section should be present"
# Check for SSH restart command
ssh_restart_found = False
for cmd in runcmd:
if 'systemctl restart sshd' in str(cmd):
if "systemctl restart sshd" in str(cmd):
ssh_restart_found = True
break
@ -202,18 +196,18 @@ class TestCloudInitTemplate:
cloud_init_content = create_expected_cloud_init()
# Extract the sshd_config content lines
lines = cloud_init_content.split('\n')
lines = cloud_init_content.split("\n")
in_sshd_content = False
sshd_lines = []
for line in lines:
if 'content: |' in line:
if "content: |" in line:
in_sshd_content = True
continue
elif in_sshd_content:
if line.strip() == '' and len(sshd_lines) > 0:
if line.strip() == "" and len(sshd_lines) > 0:
break
if line.startswith('runcmd:'):
if line.startswith("runcmd:"):
break
sshd_lines.append(line)
@ -225,11 +219,13 @@ class TestCloudInitTemplate:
for line in non_empty_lines:
# Each line should start with exactly 6 spaces
assert line.startswith(' ') and not line.startswith(' '), \
assert line.startswith(" ") and not line.startswith(" "), (
f"Line should have exactly 6 spaces indentation: {repr(line)}"
)
print("✅ Indentation is consistent")
def run_tests():
"""Run all tests manually (for non-pytest usage)."""
print("🚀 Cloud-init template validation tests")
@ -258,6 +254,7 @@ def run_tests():
print(f"❌ Unexpected error: {e}")
return False
if __name__ == "__main__":
success = run_tests()
sys.exit(0 if success else 1)

View file

@ -30,49 +30,49 @@ packages:
rendered = self.packages_template.render({})
# Should only have sudo package
self.assertIn('- sudo', rendered)
self.assertNotIn('- git', rendered)
self.assertNotIn('- screen', rendered)
self.assertNotIn('- apparmor-utils', rendered)
self.assertIn("- sudo", rendered)
self.assertNotIn("- git", rendered)
self.assertNotIn("- screen", rendered)
self.assertNotIn("- apparmor-utils", rendered)
def test_preinstall_enabled(self):
"""Test that package pre-installation works when enabled."""
# Test with pre-installation enabled
rendered = self.packages_template.render({'performance_preinstall_packages': True})
rendered = self.packages_template.render({"performance_preinstall_packages": True})
# Should have sudo and all universal packages
self.assertIn('- sudo', rendered)
self.assertIn('- git', rendered)
self.assertIn('- screen', rendered)
self.assertIn('- apparmor-utils', rendered)
self.assertIn('- uuid-runtime', rendered)
self.assertIn('- coreutils', rendered)
self.assertIn('- iptables-persistent', rendered)
self.assertIn('- cgroup-tools', rendered)
self.assertIn("- sudo", rendered)
self.assertIn("- git", rendered)
self.assertIn("- screen", rendered)
self.assertIn("- apparmor-utils", rendered)
self.assertIn("- uuid-runtime", rendered)
self.assertIn("- coreutils", rendered)
self.assertIn("- iptables-persistent", rendered)
self.assertIn("- cgroup-tools", rendered)
def test_preinstall_disabled_explicitly(self):
"""Test that package pre-installation is disabled when set to false."""
# Test with pre-installation explicitly disabled
rendered = self.packages_template.render({'performance_preinstall_packages': False})
rendered = self.packages_template.render({"performance_preinstall_packages": False})
# Should only have sudo package
self.assertIn('- sudo', rendered)
self.assertNotIn('- git', rendered)
self.assertNotIn('- screen', rendered)
self.assertNotIn('- apparmor-utils', rendered)
self.assertIn("- sudo", rendered)
self.assertNotIn("- git", rendered)
self.assertNotIn("- screen", rendered)
self.assertNotIn("- apparmor-utils", rendered)
def test_package_count(self):
"""Test that the correct number of packages are included."""
# Default: should only have sudo (1 package)
rendered_default = self.packages_template.render({})
lines_default = [line.strip() for line in rendered_default.split('\n') if line.strip().startswith('- ')]
lines_default = [line.strip() for line in rendered_default.split("\n") if line.strip().startswith("- ")]
self.assertEqual(len(lines_default), 1)
# Enabled: should have sudo + 7 universal packages (8 total)
rendered_enabled = self.packages_template.render({'performance_preinstall_packages': True})
lines_enabled = [line.strip() for line in rendered_enabled.split('\n') if line.strip().startswith('- ')]
rendered_enabled = self.packages_template.render({"performance_preinstall_packages": True})
lines_enabled = [line.strip() for line in rendered_enabled.split("\n") if line.strip().startswith("- ")]
self.assertEqual(len(lines_enabled), 8)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View file

@ -2,6 +2,7 @@
"""
Basic sanity tests for Algo VPN that don't require deployment
"""
import os
import subprocess
import sys
@ -44,11 +45,7 @@ def test_config_file_valid():
def test_ansible_syntax():
"""Check that main playbook has valid syntax"""
result = subprocess.run(
["ansible-playbook", "main.yml", "--syntax-check"],
capture_output=True,
text=True
)
result = subprocess.run(["ansible-playbook", "main.yml", "--syntax-check"], capture_output=True, text=True)
assert result.returncode == 0, f"Ansible syntax check failed:\n{result.stderr}"
print("✓ Ansible playbook syntax is valid")
@ -60,11 +57,7 @@ def test_shellcheck():
for script in shell_scripts:
if os.path.exists(script):
result = subprocess.run(
["shellcheck", script],
capture_output=True,
text=True
)
result = subprocess.run(["shellcheck", script], capture_output=True, text=True)
assert result.returncode == 0, f"Shellcheck failed for {script}:\n{result.stdout}"
print(f"{script} passed shellcheck")
@ -87,7 +80,7 @@ def test_cloud_init_header_format():
assert os.path.exists(cloud_init_file), f"{cloud_init_file} not found"
with open(cloud_init_file) as f:
first_line = f.readline().rstrip('\n\r')
first_line = f.readline().rstrip("\n\r")
# The first line MUST be exactly "#cloud-config" (no space after #)
# This regression was introduced in PR #14775 and broke DigitalOcean deployments

View file

@ -4,31 +4,30 @@ Test cloud provider instance type configurations
Focused on validating that configured instance types are current/valid
Based on issues #14730 - Hetzner changed from cx11 to cx22
"""
import sys
def test_hetzner_server_types():
"""Test Hetzner server type configurations (issue #14730)"""
# Hetzner deprecated cx11 and cpx11 - smallest is now cx22
deprecated_types = ['cx11', 'cpx11']
current_types = ['cx22', 'cpx22', 'cx32', 'cpx32', 'cx42', 'cpx42']
deprecated_types = ["cx11", "cpx11"]
current_types = ["cx22", "cpx22", "cx32", "cpx32", "cx42", "cpx42"]
# Test that we're not using deprecated types in any configs
test_config = {
'cloud_providers': {
'hetzner': {
'size': 'cx22', # Should be cx22, not cx11
'image': 'ubuntu-22.04',
'location': 'hel1'
"cloud_providers": {
"hetzner": {
"size": "cx22", # Should be cx22, not cx11
"image": "ubuntu-22.04",
"location": "hel1",
}
}
}
hetzner = test_config['cloud_providers']['hetzner']
assert hetzner['size'] not in deprecated_types, \
f"Using deprecated Hetzner type: {hetzner['size']}"
assert hetzner['size'] in current_types, \
f"Unknown Hetzner type: {hetzner['size']}"
hetzner = test_config["cloud_providers"]["hetzner"]
assert hetzner["size"] not in deprecated_types, f"Using deprecated Hetzner type: {hetzner['size']}"
assert hetzner["size"] in current_types, f"Unknown Hetzner type: {hetzner['size']}"
print("✓ Hetzner server types test passed")
@ -36,10 +35,10 @@ def test_hetzner_server_types():
def test_digitalocean_instance_types():
"""Test DigitalOcean droplet size naming"""
# DigitalOcean uses format like s-1vcpu-1gb
valid_sizes = ['s-1vcpu-1gb', 's-2vcpu-2gb', 's-2vcpu-4gb', 's-4vcpu-8gb']
deprecated_sizes = ['512mb', '1gb', '2gb'] # Old naming scheme
valid_sizes = ["s-1vcpu-1gb", "s-2vcpu-2gb", "s-2vcpu-4gb", "s-4vcpu-8gb"]
deprecated_sizes = ["512mb", "1gb", "2gb"] # Old naming scheme
test_size = 's-2vcpu-2gb'
test_size = "s-2vcpu-2gb"
assert test_size in valid_sizes, f"Invalid DO size: {test_size}"
assert test_size not in deprecated_sizes, f"Using deprecated DO size: {test_size}"
@ -49,10 +48,10 @@ def test_digitalocean_instance_types():
def test_aws_instance_types():
"""Test AWS EC2 instance type naming"""
# Common valid instance types
valid_types = ['t2.micro', 't3.micro', 't3.small', 't3.medium', 'm5.large']
deprecated_types = ['t1.micro', 'm1.small'] # Very old types
valid_types = ["t2.micro", "t3.micro", "t3.small", "t3.medium", "m5.large"]
deprecated_types = ["t1.micro", "m1.small"] # Very old types
test_type = 't3.micro'
test_type = "t3.micro"
assert test_type in valid_types, f"Unknown EC2 type: {test_type}"
assert test_type not in deprecated_types, f"Using deprecated EC2 type: {test_type}"
@ -62,9 +61,10 @@ def test_aws_instance_types():
def test_vultr_instance_types():
"""Test Vultr instance type naming"""
# Vultr uses format like vc2-1c-1gb
test_type = 'vc2-1c-1gb'
assert any(test_type.startswith(prefix) for prefix in ['vc2-', 'vhf-', 'vhp-']), \
test_type = "vc2-1c-1gb"
assert any(test_type.startswith(prefix) for prefix in ["vc2-", "vhf-", "vhp-"]), (
f"Invalid Vultr type format: {test_type}"
)
print("✓ Vultr instance types test passed")

View file

@ -2,6 +2,7 @@
"""
Test configuration file validation without deployment
"""
import configparser
import os
import re
@ -28,14 +29,14 @@ Endpoint = 192.168.1.1:51820
config = configparser.ConfigParser()
config.read_string(sample_config)
assert 'Interface' in config, "Missing [Interface] section"
assert 'Peer' in config, "Missing [Peer] section"
assert "Interface" in config, "Missing [Interface] section"
assert "Peer" in config, "Missing [Peer] section"
# Validate required fields
assert config['Interface'].get('PrivateKey'), "Missing PrivateKey"
assert config['Interface'].get('Address'), "Missing Address"
assert config['Peer'].get('PublicKey'), "Missing PublicKey"
assert config['Peer'].get('AllowedIPs'), "Missing AllowedIPs"
assert config["Interface"].get("PrivateKey"), "Missing PrivateKey"
assert config["Interface"].get("Address"), "Missing Address"
assert config["Peer"].get("PublicKey"), "Missing PublicKey"
assert config["Peer"].get("AllowedIPs"), "Missing AllowedIPs"
print("✓ WireGuard config format validation passed")
@ -43,7 +44,7 @@ Endpoint = 192.168.1.1:51820
def test_base64_key_format():
"""Test that keys are in valid base64 format"""
# Base64 keys can have variable length, just check format
key_pattern = re.compile(r'^[A-Za-z0-9+/]+=*$')
key_pattern = re.compile(r"^[A-Za-z0-9+/]+=*$")
test_keys = [
"aGVsbG8gd29ybGQgdGhpcyBpcyBub3QgYSByZWFsIGtleQo=",
@ -58,8 +59,8 @@ def test_base64_key_format():
def test_ip_address_format():
"""Test IP address and CIDR notation validation"""
ip_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$')
endpoint_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$')
ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$")
endpoint_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$")
# Test CIDR notation
assert ip_pattern.match("10.19.49.2/32"), "Invalid CIDR notation"
@ -74,11 +75,7 @@ def test_ip_address_format():
def test_mobile_config_xml():
"""Test that mobile config files would be valid XML"""
# First check if xmllint is available
xmllint_check = subprocess.run(
['which', 'xmllint'],
capture_output=True,
text=True
)
xmllint_check = subprocess.run(["which", "xmllint"], capture_output=True, text=True)
if xmllint_check.returncode != 0:
print("⚠ Skipping XML validation test (xmllint not installed)")
@ -99,17 +96,13 @@ def test_mobile_config_xml():
</dict>
</plist>"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.mobileconfig', delete=False) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".mobileconfig", delete=False) as f:
f.write(sample_mobileconfig)
temp_file = f.name
try:
# Use xmllint to validate
result = subprocess.run(
['xmllint', '--noout', temp_file],
capture_output=True,
text=True
)
result = subprocess.run(["xmllint", "--noout", temp_file], capture_output=True, text=True)
assert result.returncode == 0, f"XML validation failed: {result.stderr}"
print("✓ Mobile config XML validation passed")

View file

@ -3,6 +3,7 @@
Simplified Docker-based localhost deployment tests
Verifies services can start and config files exist in expected locations
"""
import os
import subprocess
import sys
@ -11,7 +12,7 @@ import sys
def check_docker_available():
"""Check if Docker is available"""
try:
result = subprocess.run(['docker', '--version'], capture_output=True, text=True)
result = subprocess.run(["docker", "--version"], capture_output=True, text=True)
return result.returncode == 0
except FileNotFoundError:
return False
@ -31,8 +32,8 @@ AllowedIPs = 10.19.49.2/32,fd9d:bc11:4020::2/128
"""
# Just validate the format
required_sections = ['[Interface]', '[Peer]']
required_fields = ['PrivateKey', 'Address', 'PublicKey', 'AllowedIPs']
required_sections = ["[Interface]", "[Peer]"]
required_fields = ["PrivateKey", "Address", "PublicKey", "AllowedIPs"]
for section in required_sections:
if section not in config:
@ -68,15 +69,15 @@ conn ikev2-pubkey
"""
# Validate format
if 'config setup' not in config:
if "config setup" not in config:
print("✗ Missing 'config setup' section")
return False
if 'conn %default' not in config:
if "conn %default" not in config:
print("✗ Missing 'conn %default' section")
return False
if 'keyexchange=ikev2' not in config:
if "keyexchange=ikev2" not in config:
print("✗ Missing IKEv2 configuration")
return False
@ -87,19 +88,19 @@ conn ikev2-pubkey
def test_docker_algo_image():
"""Test that the Algo Docker image can be built"""
# Check if Dockerfile exists
if not os.path.exists('Dockerfile'):
if not os.path.exists("Dockerfile"):
print("✗ Dockerfile not found")
return False
# Read Dockerfile and validate basic structure
with open('Dockerfile') as f:
with open("Dockerfile") as f:
dockerfile_content = f.read()
required_elements = [
'FROM', # Base image
'RUN', # Build commands
'COPY', # Copy Algo files
'python' # Python dependency
"FROM", # Base image
"RUN", # Build commands
"COPY", # Copy Algo files
"python", # Python dependency
]
missing = []
@ -115,16 +116,14 @@ def test_docker_algo_image():
return True
def test_localhost_deployment_requirements():
"""Test that localhost deployment requirements are met"""
requirements = {
'Python 3.8+': sys.version_info >= (3, 8),
'Ansible installed': subprocess.run(['which', 'ansible'], capture_output=True).returncode == 0,
'Main playbook exists': os.path.exists('main.yml'),
'Project config exists': os.path.exists('pyproject.toml'),
'Config template exists': os.path.exists('config.cfg.example') or os.path.exists('config.cfg'),
"Python 3.8+": sys.version_info >= (3, 8),
"Ansible installed": subprocess.run(["which", "ansible"], capture_output=True).returncode == 0,
"Main playbook exists": os.path.exists("main.yml"),
"Project config exists": os.path.exists("pyproject.toml"),
"Config template exists": os.path.exists("config.cfg.example") or os.path.exists("config.cfg"),
}
all_met = True
@ -138,10 +137,6 @@ def test_localhost_deployment_requirements():
return all_met
if __name__ == "__main__":
print("Running Docker localhost deployment tests...")
print("=" * 50)

View file

@ -3,6 +3,7 @@
Test that generated configuration files have valid syntax
This validates WireGuard, StrongSwan, SSH, and other configs
"""
import re
import subprocess
import sys
@ -11,7 +12,7 @@ import sys
def check_command_available(cmd):
"""Check if a command is available on the system"""
try:
subprocess.run([cmd, '--version'], capture_output=True, check=False)
subprocess.run([cmd, "--version"], capture_output=True, check=False)
return True
except FileNotFoundError:
return False
@ -37,51 +38,50 @@ PersistentKeepalive = 25
errors = []
# Check for required sections
if '[Interface]' not in sample_config:
if "[Interface]" not in sample_config:
errors.append("Missing [Interface] section")
if '[Peer]' not in sample_config:
if "[Peer]" not in sample_config:
errors.append("Missing [Peer] section")
# Validate Interface section
interface_match = re.search(r'\[Interface\](.*?)\[Peer\]', sample_config, re.DOTALL)
interface_match = re.search(r"\[Interface\](.*?)\[Peer\]", sample_config, re.DOTALL)
if interface_match:
interface_section = interface_match.group(1)
# Check required fields
if not re.search(r'Address\s*=', interface_section):
if not re.search(r"Address\s*=", interface_section):
errors.append("Missing Address in Interface section")
if not re.search(r'PrivateKey\s*=', interface_section):
if not re.search(r"PrivateKey\s*=", interface_section):
errors.append("Missing PrivateKey in Interface section")
# Validate IP addresses
address_match = re.search(r'Address\s*=\s*([^\n]+)', interface_section)
address_match = re.search(r"Address\s*=\s*([^\n]+)", interface_section)
if address_match:
addresses = address_match.group(1).split(',')
addresses = address_match.group(1).split(",")
for addr in addresses:
addr = addr.strip()
# Basic IP validation
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', addr) and \
not re.match(r'^[0-9a-fA-F:]+/\d+$', addr):
if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", addr) and not re.match(r"^[0-9a-fA-F:]+/\d+$", addr):
errors.append(f"Invalid IP address format: {addr}")
# Validate Peer section
peer_match = re.search(r'\[Peer\](.*)', sample_config, re.DOTALL)
peer_match = re.search(r"\[Peer\](.*)", sample_config, re.DOTALL)
if peer_match:
peer_section = peer_match.group(1)
# Check required fields
if not re.search(r'PublicKey\s*=', peer_section):
if not re.search(r"PublicKey\s*=", peer_section):
errors.append("Missing PublicKey in Peer section")
if not re.search(r'AllowedIPs\s*=', peer_section):
if not re.search(r"AllowedIPs\s*=", peer_section):
errors.append("Missing AllowedIPs in Peer section")
if not re.search(r'Endpoint\s*=', peer_section):
if not re.search(r"Endpoint\s*=", peer_section):
errors.append("Missing Endpoint in Peer section")
# Validate endpoint format
endpoint_match = re.search(r'Endpoint\s*=\s*([^\n]+)', peer_section)
endpoint_match = re.search(r"Endpoint\s*=\s*([^\n]+)", peer_section)
if endpoint_match:
endpoint = endpoint_match.group(1).strip()
if not re.match(r'^[\d\.\:]+:\d+$', endpoint):
if not re.match(r"^[\d\.\:]+:\d+$", endpoint):
errors.append(f"Invalid Endpoint format: {endpoint}")
if errors:
@ -132,33 +132,32 @@ conn ikev2-pubkey
errors = []
# Check for required sections
if 'config setup' not in sample_config:
if "config setup" not in sample_config:
errors.append("Missing 'config setup' section")
if 'conn %default' not in sample_config:
if "conn %default" not in sample_config:
errors.append("Missing 'conn %default' section")
# Validate connection settings
conn_pattern = re.compile(r'conn\s+(\S+)')
conn_pattern = re.compile(r"conn\s+(\S+)")
connections = conn_pattern.findall(sample_config)
if len(connections) < 2: # Should have at least %default and one other
errors.append("Not enough connection definitions")
# Check for required parameters in connections
required_params = ['keyexchange', 'left', 'right']
required_params = ["keyexchange", "left", "right"]
for param in required_params:
if f'{param}=' not in sample_config:
if f"{param}=" not in sample_config:
errors.append(f"Missing required parameter: {param}")
# Validate IP subnet formats
subnet_pattern = re.compile(r'(left|right)subnet\s*=\s*([^\n]+)')
subnet_pattern = re.compile(r"(left|right)subnet\s*=\s*([^\n]+)")
for match in subnet_pattern.finditer(sample_config):
subnets = match.group(2).split(',')
subnets = match.group(2).split(",")
for subnet in subnets:
subnet = subnet.strip()
if subnet != '0.0.0.0/0' and subnet != '::/0':
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', subnet) and \
not re.match(r'^[0-9a-fA-F:]+/\d+$', subnet):
if subnet != "0.0.0.0/0" and subnet != "::/0":
if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", subnet) and not re.match(r"^[0-9a-fA-F:]+/\d+$", subnet):
errors.append(f"Invalid subnet format: {subnet}")
if errors:
@ -188,21 +187,21 @@ def test_ssh_config_syntax():
errors = []
# Parse SSH config format
lines = sample_config.strip().split('\n')
lines = sample_config.strip().split("\n")
current_host = None
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
if line.startswith('Host '):
if line.startswith("Host "):
current_host = line.split()[1]
elif current_host and ' ' in line:
elif current_host and " " in line:
key, value = line.split(None, 1)
# Validate common SSH options
if key == 'Port':
if key == "Port":
try:
port = int(value)
if not 1 <= port <= 65535:
@ -210,7 +209,7 @@ def test_ssh_config_syntax():
except ValueError:
errors.append(f"Port must be a number: {value}")
elif key == 'LocalForward':
elif key == "LocalForward":
# Format: LocalForward [bind_address:]port host:hostport
parts = value.split()
if len(parts) != 2:
@ -256,35 +255,35 @@ COMMIT
errors = []
# Check table definitions
tables = re.findall(r'\*(\w+)', sample_rules)
if 'filter' not in tables:
tables = re.findall(r"\*(\w+)", sample_rules)
if "filter" not in tables:
errors.append("Missing *filter table")
if 'nat' not in tables:
if "nat" not in tables:
errors.append("Missing *nat table")
# Check for COMMIT statements
commit_count = sample_rules.count('COMMIT')
commit_count = sample_rules.count("COMMIT")
if commit_count != len(tables):
errors.append(f"Number of COMMIT statements ({commit_count}) doesn't match tables ({len(tables)})")
# Validate chain policies
chain_pattern = re.compile(r'^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]', re.MULTILINE)
chain_pattern = re.compile(r"^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]", re.MULTILINE)
chains = chain_pattern.findall(sample_rules)
required_chains = [('INPUT', 'DROP'), ('FORWARD', 'DROP'), ('OUTPUT', 'ACCEPT')]
required_chains = [("INPUT", "DROP"), ("FORWARD", "DROP"), ("OUTPUT", "ACCEPT")]
for chain, _policy in required_chains:
if not any(c[0] == chain for c in chains):
errors.append(f"Missing required chain: {chain}")
# Validate rule syntax
rule_pattern = re.compile(r'^-[AI]\s+(\w+)', re.MULTILINE)
rule_pattern = re.compile(r"^-[AI]\s+(\w+)", re.MULTILINE)
rules = rule_pattern.findall(sample_rules)
if len(rules) < 5:
errors.append("Insufficient firewall rules")
# Check for essential security rules
if '-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT' not in sample_rules:
if "-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT" not in sample_rules:
errors.append("Missing stateful connection tracking rule")
if errors:
@ -320,27 +319,26 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
errors = []
# Parse config
for line in sample_config.strip().split('\n'):
for line in sample_config.strip().split("\n"):
line = line.strip()
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
# Most dnsmasq options are key=value or just key
if '=' in line:
key, value = line.split('=', 1)
if "=" in line:
key, value = line.split("=", 1)
# Validate specific options
if key == 'interface':
if not re.match(r'^[a-zA-Z0-9\-_]+$', value):
if key == "interface":
if not re.match(r"^[a-zA-Z0-9\-_]+$", value):
errors.append(f"Invalid interface name: {value}")
elif key == 'server':
elif key == "server":
# Basic IP validation
if not re.match(r'^\d+\.\d+\.\d+\.\d+$', value) and \
not re.match(r'^[0-9a-fA-F:]+$', value):
if not re.match(r"^\d+\.\d+\.\d+\.\d+$", value) and not re.match(r"^[0-9a-fA-F:]+$", value):
errors.append(f"Invalid DNS server IP: {value}")
elif key == 'cache-size':
elif key == "cache-size":
try:
size = int(value)
if size < 0:
@ -349,9 +347,9 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
errors.append(f"Cache size must be a number: {value}")
# Check for required options
required = ['interface', 'server']
required = ["interface", "server"]
for req in required:
if f'{req}=' not in sample_config:
if f"{req}=" not in sample_config:
errors.append(f"Missing required option: {req}")
if errors:

View file

@ -14,203 +14,203 @@ from jinja2 import Environment, FileSystemLoader
def load_template(template_name):
"""Load a Jinja2 template from the roles/common/templates directory."""
template_dir = Path(__file__).parent.parent.parent / 'roles' / 'common' / 'templates'
template_dir = Path(__file__).parent.parent.parent / "roles" / "common" / "templates"
env = Environment(loader=FileSystemLoader(str(template_dir)))
return env.get_template(template_name)
def test_wireguard_nat_rules_ipv4():
"""Test that WireGuard traffic gets proper NAT rules without policy matching."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
# Test with WireGuard enabled
result = template.render(
ipsec_enabled=False,
wireguard_enabled=True,
wireguard_network_ipv4='10.49.0.0/16',
wireguard_network_ipv4="10.49.0.0/16",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'},
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.49.0.1',
local_service_ip="10.49.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Verify NAT rule exists with output interface and without policy matching
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
# Verify no policy matching in WireGuard NAT rules
assert '-A POSTROUTING -s 10.49.0.0/16 -m policy' not in result
assert "-A POSTROUTING -s 10.49.0.0/16 -m policy" not in result
def test_ipsec_nat_rules_ipv4():
"""Test that IPsec traffic gets proper NAT rules without policy matching."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
# Test with IPsec enabled
result = template.render(
ipsec_enabled=True,
wireguard_enabled=False,
strongswan_network='10.48.0.0/16',
strongswan_network_ipv6='2001:db8::/48',
ansible_default_ipv4={'interface': 'eth0'},
strongswan_network="10.48.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.48.0.1',
local_service_ip="10.48.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Verify NAT rule exists with output interface and without policy matching
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
# Verify no policy matching in IPsec NAT rules (this was the bug)
assert '-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none' not in result
assert "-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none" not in result
def test_both_vpns_nat_rules_ipv4():
"""Test NAT rules when both VPN types are enabled."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=True,
wireguard_enabled=True,
strongswan_network='10.48.0.0/16',
wireguard_network_ipv4='10.49.0.0/16',
strongswan_network_ipv6='2001:db8::/48',
wireguard_network_ipv6='2001:db8:a160::/48',
strongswan_network="10.48.0.0/16",
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
wireguard_network_ipv6="2001:db8:a160::/48",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'},
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.49.0.1',
local_service_ip="10.49.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Both should have NAT rules with output interface
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
# Neither should have policy matching
assert '-m policy --pol none' not in result
assert "-m policy --pol none" not in result
def test_alternative_ingress_snat():
"""Test that alternative ingress IP uses SNAT instead of MASQUERADE."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=True,
wireguard_enabled=True,
strongswan_network='10.48.0.0/16',
wireguard_network_ipv4='10.49.0.0/16',
strongswan_network_ipv6='2001:db8::/48',
wireguard_network_ipv6='2001:db8:a160::/48',
strongswan_network="10.48.0.0/16",
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
wireguard_network_ipv6="2001:db8:a160::/48",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'},
snat_aipv4='192.168.1.100', # Alternative ingress IP
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4="192.168.1.100", # Alternative ingress IP
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.49.0.1',
local_service_ip="10.49.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Should use SNAT with specific IP and output interface instead of MASQUERADE
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100' in result
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100' in result
assert 'MASQUERADE' not in result
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
assert "MASQUERADE" not in result
def test_ipsec_forward_rule_has_policy_match():
"""Test that IPsec FORWARD rules still use policy matching (this is correct)."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=True,
wireguard_enabled=False,
strongswan_network='10.48.0.0/16',
strongswan_network_ipv6='2001:db8::/48',
ansible_default_ipv4={'interface': 'eth0'},
strongswan_network="10.48.0.0/16",
strongswan_network_ipv6="2001:db8::/48",
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.48.0.1',
local_service_ip="10.48.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# FORWARD rule should have policy match (this is correct and should stay)
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT' in result
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT" in result
def test_wireguard_forward_rule_no_policy_match():
"""Test that WireGuard FORWARD rules don't use policy matching."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
ipsec_enabled=False,
wireguard_enabled=True,
wireguard_network_ipv4='10.49.0.0/16',
wireguard_network_ipv4="10.49.0.0/16",
wireguard_port=51820,
wireguard_port_avoid=53,
wireguard_port_actual=51820,
ansible_default_ipv4={'interface': 'eth0'},
ansible_default_ipv4={"interface": "eth0"},
snat_aipv4=None,
BetweenClients_DROP=True,
block_smb=True,
block_netbios=True,
local_service_ip='10.49.0.1',
local_service_ip="10.49.0.1",
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# WireGuard FORWARD rule should NOT have any policy match
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT' in result
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy' not in result
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT" in result
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy" not in result
def test_output_interface_in_nat_rules():
"""Test that output interface is specified in NAT rules."""
template = load_template('rules.v4.j2')
template = load_template("rules.v4.j2")
result = template.render(
snat_aipv4=False,
wireguard_enabled=True,
ipsec_enabled=True,
wireguard_network_ipv4='10.49.0.0/16',
strongswan_network='10.48.0.0/16',
ansible_default_ipv4={'interface': 'eth0', 'address': '10.0.0.1'},
ansible_default_ipv6={'interface': 'eth0', 'address': 'fd9d:bc11:4020::1'},
wireguard_network_ipv4="10.49.0.0/16",
strongswan_network="10.48.0.0/16",
ansible_default_ipv4={"interface": "eth0", "address": "10.0.0.1"},
ansible_default_ipv6={"interface": "eth0", "address": "fd9d:bc11:4020::1"},
wireguard_port_actual=51820,
wireguard_port_avoid=53,
wireguard_port=51820,
ansible_ssh_port=22,
reduce_mtu=0
reduce_mtu=0,
)
# Check that output interface is specified for both VPNs
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
# Ensure we don't have rules without output interface
assert '-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE' not in result
assert '-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE' not in result
assert "-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE" not in result
assert "-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE" not in result
if __name__ == '__main__':
pytest.main([__file__, '-v'])
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -12,7 +12,7 @@ import unittest
from unittest.mock import MagicMock, patch
# Add the library directory to the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../library'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../library"))
class TestLightsailBoto3Fix(unittest.TestCase):
@ -22,15 +22,15 @@ class TestLightsailBoto3Fix(unittest.TestCase):
"""Set up test fixtures."""
# Mock the ansible module_utils since we're testing outside of Ansible
self.mock_modules = {
'ansible.module_utils.basic': MagicMock(),
'ansible.module_utils.ec2': MagicMock(),
'ansible.module_utils.aws.core': MagicMock(),
"ansible.module_utils.basic": MagicMock(),
"ansible.module_utils.ec2": MagicMock(),
"ansible.module_utils.aws.core": MagicMock(),
}
# Apply mocks
self.patches = []
for module_name, mock_module in self.mock_modules.items():
patcher = patch.dict('sys.modules', {module_name: mock_module})
patcher = patch.dict("sys.modules", {module_name: mock_module})
patcher.start()
self.patches.append(patcher)
@ -45,7 +45,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# Import the module
spec = importlib.util.spec_from_file_location(
"lightsail_region_facts",
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
)
module = importlib.util.module_from_spec(spec)
@ -54,7 +54,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# Verify the module loaded
self.assertIsNotNone(module)
self.assertTrue(hasattr(module, 'main'))
self.assertTrue(hasattr(module, "main"))
except Exception as e:
self.fail(f"Failed to import lightsail_region_facts: {e}")
@ -62,15 +62,13 @@ class TestLightsailBoto3Fix(unittest.TestCase):
def test_get_aws_connection_info_called_without_boto3(self):
"""Test that get_aws_connection_info is called without boto3 parameter."""
# Mock get_aws_connection_info to track calls
mock_get_aws_connection_info = MagicMock(
return_value=('us-west-2', None, {})
)
mock_get_aws_connection_info = MagicMock(return_value=("us-west-2", None, {}))
with patch('ansible.module_utils.ec2.get_aws_connection_info', mock_get_aws_connection_info):
with patch("ansible.module_utils.ec2.get_aws_connection_info", mock_get_aws_connection_info):
# Import the module
spec = importlib.util.spec_from_file_location(
"lightsail_region_facts",
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
)
module = importlib.util.module_from_spec(spec)
@ -79,7 +77,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
mock_ansible_module.params = {}
mock_ansible_module.check_mode = False
with patch('ansible.module_utils.basic.AnsibleModule', return_value=mock_ansible_module):
with patch("ansible.module_utils.basic.AnsibleModule", return_value=mock_ansible_module):
# Execute the module
try:
spec.loader.exec_module(module)
@ -100,28 +98,35 @@ class TestLightsailBoto3Fix(unittest.TestCase):
if call_args:
# Check positional arguments
if call_args[0]: # args
self.assertTrue(len(call_args[0]) <= 1,
"get_aws_connection_info should be called with at most 1 positional arg (module)")
self.assertTrue(
len(call_args[0]) <= 1,
"get_aws_connection_info should be called with at most 1 positional arg (module)",
)
# Check keyword arguments
if call_args[1]: # kwargs
self.assertNotIn('boto3', call_args[1],
"get_aws_connection_info should not be called with boto3 parameter")
self.assertNotIn(
"boto3", call_args[1], "get_aws_connection_info should not be called with boto3 parameter"
)
def test_no_boto3_parameter_in_source(self):
"""Verify that boto3 parameter is not present in the source code."""
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
with open(lightsail_path) as f:
content = f.read()
# Check that boto3=True is not in the file
self.assertNotIn('boto3=True', content,
"boto3=True parameter should not be present in lightsail_region_facts.py")
self.assertNotIn(
"boto3=True", content, "boto3=True parameter should not be present in lightsail_region_facts.py"
)
# Check that boto3 parameter is not used with get_aws_connection_info
self.assertNotIn('get_aws_connection_info(module, boto3', content,
"get_aws_connection_info should not be called with boto3 parameter")
self.assertNotIn(
"get_aws_connection_info(module, boto3",
content,
"get_aws_connection_info should not be called with boto3 parameter",
)
def test_regression_issue_14822(self):
"""
@ -132,26 +137,28 @@ class TestLightsailBoto3Fix(unittest.TestCase):
# The boto3 parameter was deprecated and removed in amazon.aws collection
# that comes with Ansible 11.x
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
with open(lightsail_path) as f:
lines = f.readlines()
# Find the line that calls get_aws_connection_info
for line_num, line in enumerate(lines, 1):
if 'get_aws_connection_info' in line and 'region' in line:
if "get_aws_connection_info" in line and "region" in line:
# This should be around line 85
# Verify it doesn't have boto3=True
self.assertNotIn('boto3', line,
f"Line {line_num} should not contain boto3 parameter")
self.assertNotIn("boto3", line, f"Line {line_num} should not contain boto3 parameter")
# Verify the correct format
self.assertIn('get_aws_connection_info(module)', line,
f"Line {line_num} should call get_aws_connection_info(module) without boto3")
self.assertIn(
"get_aws_connection_info(module)",
line,
f"Line {line_num} should call get_aws_connection_info(module) without boto3",
)
break
else:
self.fail("Could not find get_aws_connection_info call in lightsail_region_facts.py")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View file

@ -5,6 +5,7 @@ Hybrid approach: validates actual certificates when available, else tests templa
Based on issues #14755, #14718 - Apple device compatibility
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
"""
import glob
import os
import re
@ -22,7 +23,7 @@ def find_generated_certificates():
config_patterns = [
"configs/*/ipsec/.pki/cacert.pem",
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
"../../configs/*/ipsec/.pki/cacert.pem" # Alternative path
"../../configs/*/ipsec/.pki/cacert.pem", # Alternative path
]
for pattern in config_patterns:
@ -30,26 +31,23 @@ def find_generated_certificates():
if ca_certs:
base_path = os.path.dirname(ca_certs[0])
return {
'ca_cert': ca_certs[0],
'base_path': base_path,
'server_certs': glob.glob(f"{base_path}/certs/*.crt"),
'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12")
"ca_cert": ca_certs[0],
"base_path": base_path,
"server_certs": glob.glob(f"{base_path}/certs/*.crt"),
"p12_files": glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12"),
}
return None
def test_openssl_version_detection():
"""Test that we can detect OpenSSL version for compatibility checks"""
result = subprocess.run(
['openssl', 'version'],
capture_output=True,
text=True
)
result = subprocess.run(["openssl", "version"], capture_output=True, text=True)
assert result.returncode == 0, "Failed to get OpenSSL version"
# Parse version - e.g., "OpenSSL 3.0.2 15 Mar 2022"
version_match = re.search(r'OpenSSL\s+(\d+)\.(\d+)\.(\d+)', result.stdout)
version_match = re.search(r"OpenSSL\s+(\d+)\.(\d+)\.(\d+)", result.stdout)
assert version_match, f"Can't parse OpenSSL version: {result.stdout}"
major = int(version_match.group(1))
@ -62,7 +60,7 @@ def test_openssl_version_detection():
def validate_ca_certificate_real(cert_files):
"""Validate actual Ansible-generated CA certificate"""
# Read the actual CA certificate generated by Ansible
with open(cert_files['ca_cert'], 'rb') as f:
with open(cert_files["ca_cert"], "rb") as f:
cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data)
@ -89,30 +87,34 @@ def validate_ca_certificate_real(cert_files):
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
# Verify public domains are excluded
excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.DNSName)]
excluded_dns = [
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.DNSName)
]
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in public_domains:
assert domain in excluded_dns, f"CA should exclude public domain {domain}"
# Verify private IP ranges are excluded (Issue #75)
excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.IPAddress)]
excluded_ips = [
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.IPAddress)
]
assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
# Verify email domains are also excluded (Issue #153)
excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees
if isinstance(constraint, x509.RFC822Name)]
excluded_emails = [
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.RFC822Name)
]
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in email_domains:
assert domain in excluded_emails, f"CA should exclude email domain {domain}"
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
def validate_ca_certificate_config():
"""Validate CA certificate configuration in Ansible files (CI mode)"""
# Check that the Ansible task file has proper CA certificate configuration
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -122,15 +124,15 @@ def validate_ca_certificate_config():
# Verify key security configurations are present
security_checks = [
('name_constraints_permitted', 'Name constraints should be configured'),
('name_constraints_excluded', 'Excluded name constraints should be configured'),
('extended_key_usage', 'Extended Key Usage should be configured'),
('1.3.6.1.5.5.7.3.17', 'IPsec End Entity OID should be present'),
('serverAuth', 'Server authentication EKU should be present'),
('clientAuth', 'Client authentication EKU should be present'),
('basic_constraints', 'Basic constraints should be configured'),
('CA:TRUE', 'CA certificate should be marked as CA'),
('pathlen:0', 'Path length constraint should be set')
("name_constraints_permitted", "Name constraints should be configured"),
("name_constraints_excluded", "Excluded name constraints should be configured"),
("extended_key_usage", "Extended Key Usage should be configured"),
("1.3.6.1.5.5.7.3.17", "IPsec End Entity OID should be present"),
("serverAuth", "Server authentication EKU should be present"),
("clientAuth", "Client authentication EKU should be present"),
("basic_constraints", "Basic constraints should be configured"),
("CA:TRUE", "CA certificate should be marked as CA"),
("pathlen:0", "Path length constraint should be set"),
]
for check, message in security_checks:
@ -140,7 +142,9 @@ def validate_ca_certificate_config():
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in public_domains:
# Handle both double quotes and single quotes in YAML
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, f"Public domain {domain} should be excluded"
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, (
f"Public domain {domain} should be excluded"
)
# Verify private IP ranges are excluded
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
@ -151,13 +155,16 @@ def validate_ca_certificate_config():
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
for domain in email_domains:
# Handle both double quotes and single quotes in YAML
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, f"Email domain {domain} should be excluded"
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, (
f"Email domain {domain} should be excluded"
)
# Verify IPv6 constraints are present (Issue #153)
assert "IP:::/0" in content, "IPv6 all addresses should be excluded"
print("✓ CA certificate configuration has proper security constraints")
def test_ca_certificate():
"""Test CA certificate - uses real certs if available, else validates config (Issue #75, #153)"""
cert_files = find_generated_certificates()
@ -172,14 +179,18 @@ def validate_server_certificates_real(cert_files):
# Filter to only actual server certificates (not client certs)
# Server certificates contain IP addresses in the filename
import re
server_certs = [f for f in cert_files['server_certs']
if not f.endswith('/cacert.pem') and re.search(r'\d+\.\d+\.\d+\.\d+\.crt$', f)]
server_certs = [
f
for f in cert_files["server_certs"]
if not f.endswith("/cacert.pem") and re.search(r"\d+\.\d+\.\d+\.\d+\.crt$", f)
]
if not server_certs:
print("⚠ No server certificates found")
return
for server_cert_path in server_certs:
with open(server_cert_path, 'rb') as f:
with open(server_cert_path, "rb") as f:
cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data)
@ -193,7 +204,9 @@ def validate_server_certificates_real(cert_files):
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation"
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, (
"Server cert should NOT have clientAuth EKU for role separation"
)
# Check SAN extension exists (required for Apple devices)
try:
@ -204,9 +217,10 @@ def validate_server_certificates_real(cert_files):
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
def validate_server_certificates_config():
"""Validate server certificate configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -215,7 +229,7 @@ def validate_server_certificates_config():
content = f.read()
# Look for server certificate CSR section
server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL)
server_csr_section = re.search(r"Create CSRs for server certificate.*?register: server_csr", content, re.DOTALL)
if not server_csr_section:
print("⚠ Could not find server certificate CSR section")
return
@ -224,11 +238,11 @@ def validate_server_certificates_config():
# Check server certificate CSR configuration
server_checks = [
('subject_alt_name', 'Server certificates should have SAN extension'),
('serverAuth', 'Server certificates should have serverAuth EKU'),
('1.3.6.1.5.5.7.3.17', 'Server certificates should have IPsec End Entity EKU'),
('digitalSignature', 'Server certificates should have digital signature usage'),
('keyEncipherment', 'Server certificates should have key encipherment usage')
("subject_alt_name", "Server certificates should have SAN extension"),
("serverAuth", "Server certificates should have serverAuth EKU"),
("1.3.6.1.5.5.7.3.17", "Server certificates should have IPsec End Entity EKU"),
("digitalSignature", "Server certificates should have digital signature usage"),
("keyEncipherment", "Server certificates should have key encipherment usage"),
]
for check, message in server_checks:
@ -236,15 +250,20 @@ def validate_server_certificates_config():
# Security check: Server certificates should NOT have clientAuth (Issue #153)
# Look for clientAuth in extended_key_usage section, not in comments
eku_lines = [line for line in server_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'clientAuth' in line)]
has_client_auth = any('clientAuth' in line for line in eku_lines if line.strip().startswith('- '))
eku_lines = [
line
for line in server_section.split("\n")
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "clientAuth" in line)
]
has_client_auth = any("clientAuth" in line for line in eku_lines if line.strip().startswith("- "))
assert not has_client_auth, "Server certificates should NOT have clientAuth EKU for role separation"
# Verify SAN extension is configured for Apple compatibility
assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility"
assert "subjectAltName" in server_section, "Server certificates missing SAN configuration for Apple compatibility"
print("✓ Server certificate configuration has proper EKU and SAN settings")
def test_server_certificates():
"""Test server certificates - uses real certs if available, else validates config"""
cert_files = find_generated_certificates()
@ -258,18 +277,18 @@ def validate_client_certificates_real(cert_files):
"""Validate actual Ansible-generated client certificates"""
# Find client certificates (not CA cert, not server cert with IP/DNS name)
client_certs = []
for cert_path in cert_files['server_certs']:
if 'cacert.pem' in cert_path:
for cert_path in cert_files["server_certs"]:
if "cacert.pem" in cert_path:
continue
with open(cert_path, 'rb') as f:
with open(cert_path, "rb") as f:
cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data)
# Check if this looks like a client cert vs server cert
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
# Server certs typically have IP addresses or domain names as CN
if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4):
if not (cn.replace(".", "").isdigit() or "." in cn and len(cn.split(".")) == 4):
client_certs.append((cert_path, certificate))
if not client_certs:
@ -287,7 +306,9 @@ def validate_client_certificates_real(cert_files):
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation"
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, (
"Client cert must NOT have serverAuth EKU to prevent server impersonation"
)
# Check SAN extension for email
try:
@ -299,9 +320,10 @@ def validate_client_certificates_real(cert_files):
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
def validate_client_certificates_config():
"""Validate client certificate configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -310,7 +332,9 @@ def validate_client_certificates_config():
content = f.read()
# Look for client certificate CSR section
client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL)
client_csr_section = re.search(
r"Create CSRs for client certificates.*?register: client_csr_jobs", content, re.DOTALL
)
if not client_csr_section:
print("⚠ Could not find client certificate CSR section")
return
@ -319,11 +343,11 @@ def validate_client_certificates_config():
# Check client certificate configuration
client_checks = [
('clientAuth', 'Client certificates should have clientAuth EKU'),
('1.3.6.1.5.5.7.3.17', 'Client certificates should have IPsec End Entity EKU'),
('digitalSignature', 'Client certificates should have digital signature usage'),
('keyEncipherment', 'Client certificates should have key encipherment usage'),
('email:', 'Client certificates should have email SAN')
("clientAuth", "Client certificates should have clientAuth EKU"),
("1.3.6.1.5.5.7.3.17", "Client certificates should have IPsec End Entity EKU"),
("digitalSignature", "Client certificates should have digital signature usage"),
("keyEncipherment", "Client certificates should have key encipherment usage"),
("email:", "Client certificates should have email SAN"),
]
for check, message in client_checks:
@ -331,15 +355,22 @@ def validate_client_certificates_config():
# Security check: Client certificates should NOT have serverAuth (Issue #153)
# Look for serverAuth in extended_key_usage section, not in comments
eku_lines = [line for line in client_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'serverAuth' in line)]
has_server_auth = any('serverAuth' in line for line in eku_lines if line.strip().startswith('- '))
eku_lines = [
line
for line in client_section.split("\n")
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "serverAuth" in line)
]
has_server_auth = any("serverAuth" in line for line in eku_lines if line.strip().startswith("- "))
assert not has_server_auth, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
# Verify client certificates use unique email domains (Issue #153)
assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment"
assert "openssl_constraint_random_id" in client_section, (
"Client certificates should use unique email domain per deployment"
)
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
def test_client_certificates():
"""Test client certificates - uses real certs if available, else validates config (Issue #75, #153)"""
cert_files = find_generated_certificates()
@ -351,24 +382,33 @@ def test_client_certificates():
def validate_pkcs12_files_real(cert_files):
"""Validate actual Ansible-generated PKCS#12 files"""
if not cert_files.get('p12_files'):
if not cert_files.get("p12_files"):
print("⚠ No PKCS#12 files found")
return
major, minor = test_openssl_version_detection()
for p12_file in cert_files['p12_files']:
for p12_file in cert_files["p12_files"]:
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
# Test that PKCS#12 file can be read (validates format)
legacy_flag = ['-legacy'] if major >= 3 else []
legacy_flag = ["-legacy"] if major >= 3 else []
result = subprocess.run([
'openssl', 'pkcs12', '-info',
'-in', p12_file,
'-passin', 'pass:', # Try empty password first
'-noout'
] + legacy_flag, capture_output=True, text=True)
result = subprocess.run(
[
"openssl",
"pkcs12",
"-info",
"-in",
p12_file,
"-passin",
"pass:", # Try empty password first
"-noout",
]
+ legacy_flag,
capture_output=True,
text=True,
)
# PKCS#12 files should be readable (even if password-protected)
# We're just testing format validity, not trying to extract contents
@ -378,9 +418,10 @@ def validate_pkcs12_files_real(cert_files):
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
def validate_pkcs12_files_config():
"""Validate PKCS#12 file configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -390,13 +431,13 @@ def validate_pkcs12_files_config():
# Check PKCS#12 generation configuration
p12_checks = [
('openssl_pkcs12', 'PKCS#12 generation should be configured'),
('encryption_level', 'PKCS#12 encryption level should be configured'),
('compatibility2022', 'PKCS#12 should use Apple-compatible encryption'),
('friendly_name', 'PKCS#12 should have friendly names'),
('other_certificates', 'PKCS#12 should include CA certificate for full chain'),
('passphrase', 'PKCS#12 files should be password protected'),
('mode: "0600"', 'PKCS#12 files should have secure permissions')
("openssl_pkcs12", "PKCS#12 generation should be configured"),
("encryption_level", "PKCS#12 encryption level should be configured"),
("compatibility2022", "PKCS#12 should use Apple-compatible encryption"),
("friendly_name", "PKCS#12 should have friendly names"),
("other_certificates", "PKCS#12 should include CA certificate for full chain"),
("passphrase", "PKCS#12 files should be password protected"),
('mode: "0600"', "PKCS#12 files should have secure permissions"),
]
for check, message in p12_checks:
@ -404,6 +445,7 @@ def validate_pkcs12_files_config():
print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
def test_pkcs12_files():
"""Test PKCS#12 files - uses real files if available, else validates config (Issue #14755, #14718)"""
cert_files = find_generated_certificates()
@ -416,19 +458,19 @@ def test_pkcs12_files():
def validate_certificate_chain_real(cert_files):
"""Validate actual Ansible-generated certificate chain"""
# Load CA certificate
with open(cert_files['ca_cert'], 'rb') as f:
with open(cert_files["ca_cert"], "rb") as f:
ca_cert_data = f.read()
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
# Test that all other certificates are signed by the CA
other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']]
other_certs = [f for f in cert_files["server_certs"] if f != cert_files["ca_cert"]]
if not other_certs:
print("⚠ No client/server certificates found to validate")
return
for cert_path in other_certs:
with open(cert_path, 'rb') as f:
with open(cert_path, "rb") as f:
cert_data = f.read()
certificate = x509.load_pem_x509_certificate(cert_data)
@ -437,6 +479,7 @@ def validate_certificate_chain_real(cert_files):
# Verify certificate is currently valid (not expired)
from datetime import datetime
now = datetime.now(UTC)
assert certificate.not_valid_before_utc <= now, f"Certificate {cert_path} not yet valid"
assert certificate.not_valid_after_utc >= now, f"Certificate {cert_path} has expired"
@ -445,9 +488,10 @@ def validate_certificate_chain_real(cert_files):
print("✓ All real certificates properly signed by CA")
def validate_certificate_chain_config():
"""Validate certificate chain configuration in Ansible files (CI mode)"""
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
if not openssl_task_file:
print("⚠ Could not find openssl.yml task file")
return
@ -457,15 +501,18 @@ def validate_certificate_chain_config():
# Check certificate signing configuration
chain_checks = [
('provider: ownca', 'Certificates should be signed by own CA'),
('ownca_path', 'CA certificate path should be specified'),
('ownca_privatekey_path', 'CA private key path should be specified'),
('ownca_privatekey_passphrase', 'CA private key should be password protected'),
('certificate_validity_days: 3650', 'Certificate validity should be configurable (default 10 years)'),
('ownca_not_after: "+{{ certificate_validity_days }}d"', 'Certificates should use configurable validity period'),
('ownca_not_before: "-1d"', 'Certificates should have backdated start time'),
('curve: secp384r1', 'Should use strong elliptic curve cryptography'),
('type: ECC', 'Should use elliptic curve keys for better security')
("provider: ownca", "Certificates should be signed by own CA"),
("ownca_path", "CA certificate path should be specified"),
("ownca_privatekey_path", "CA private key path should be specified"),
("ownca_privatekey_passphrase", "CA private key should be password protected"),
("certificate_validity_days: 3650", "Certificate validity should be configurable (default 10 years)"),
(
'ownca_not_after: "+{{ certificate_validity_days }}d"',
"Certificates should use configurable validity period",
),
('ownca_not_before: "-1d"', "Certificates should have backdated start time"),
("curve: secp384r1", "Should use strong elliptic curve cryptography"),
("type: ECC", "Should use elliptic curve keys for better security"),
]
for check, message in chain_checks:
@ -473,6 +520,7 @@ def validate_certificate_chain_config():
print("✓ Certificate chain configuration properly set up for CA signing")
def test_certificate_chain():
"""Test certificate chain - uses real certs if available, else validates config"""
cert_files = find_generated_certificates()
@ -499,6 +547,7 @@ def find_ansible_file(relative_path):
return None
if __name__ == "__main__":
tests = [
test_openssl_version_detection,

View file

@ -3,6 +3,7 @@
Enhanced tests for StrongSwan templates.
Tests all strongswan role templates with various configurations.
"""
import os
import sys
import uuid
@ -21,7 +22,7 @@ def mock_to_uuid(value):
def mock_bool(value):
"""Mock the bool filter"""
return str(value).lower() in ('true', '1', 'yes', 'on')
return str(value).lower() in ("true", "1", "yes", "on")
def mock_version(version_string, comparison):
@ -33,67 +34,67 @@ def mock_version(version_string, comparison):
def mock_b64encode(value):
"""Mock base64 encoding"""
import base64
if isinstance(value, str):
value = value.encode('utf-8')
return base64.b64encode(value).decode('ascii')
value = value.encode("utf-8")
return base64.b64encode(value).decode("ascii")
def mock_b64decode(value):
"""Mock base64 decoding"""
import base64
return base64.b64decode(value).decode('utf-8')
return base64.b64decode(value).decode("utf-8")
def get_strongswan_test_variables(scenario='default'):
def get_strongswan_test_variables(scenario="default"):
"""Get test variables for StrongSwan templates with different scenarios."""
base_vars = load_test_variables()
# Add StrongSwan specific variables
strongswan_vars = {
'ipsec_config_path': '/etc/ipsec.d',
'ipsec_pki_path': '/etc/ipsec.d',
'strongswan_enabled': True,
'strongswan_network': '10.19.48.0/24',
'strongswan_network_ipv6': 'fd9d:bc11:4021::/64',
'strongswan_log_level': '2',
'openssl_constraint_random_id': 'test-' + str(uuid.uuid4()),
'subjectAltName': 'IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a',
'subjectAltName_type': 'IP',
'subjectAltName_client': 'IP:10.0.0.1',
'ansible_default_ipv6': {
'address': '2600:3c01::f03c:91ff:fedf:3b2a'
},
'openssl_version': '3.0.0',
'p12_export_password': 'test-password',
'ike_lifetime': '24h',
'ipsec_lifetime': '8h',
'ike_dpd': '30s',
'ipsec_dead_peer_detection': True,
'rekey_margin': '3m',
'rekeymargin': '3m',
'dpddelay': '35s',
'keyexchange': 'ikev2',
'ike_cipher': 'aes128gcm16-prfsha512-ecp256',
'esp_cipher': 'aes128gcm16-ecp256',
'leftsourceip': '10.19.48.1',
'leftsubnet': '0.0.0.0/0,::/0',
'rightsourceip': '10.19.48.2/24,fd9d:bc11:4021::2/64',
"ipsec_config_path": "/etc/ipsec.d",
"ipsec_pki_path": "/etc/ipsec.d",
"strongswan_enabled": True,
"strongswan_network": "10.19.48.0/24",
"strongswan_network_ipv6": "fd9d:bc11:4021::/64",
"strongswan_log_level": "2",
"openssl_constraint_random_id": "test-" + str(uuid.uuid4()),
"subjectAltName": "IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a",
"subjectAltName_type": "IP",
"subjectAltName_client": "IP:10.0.0.1",
"ansible_default_ipv6": {"address": "2600:3c01::f03c:91ff:fedf:3b2a"},
"openssl_version": "3.0.0",
"p12_export_password": "test-password",
"ike_lifetime": "24h",
"ipsec_lifetime": "8h",
"ike_dpd": "30s",
"ipsec_dead_peer_detection": True,
"rekey_margin": "3m",
"rekeymargin": "3m",
"dpddelay": "35s",
"keyexchange": "ikev2",
"ike_cipher": "aes128gcm16-prfsha512-ecp256",
"esp_cipher": "aes128gcm16-ecp256",
"leftsourceip": "10.19.48.1",
"leftsubnet": "0.0.0.0/0,::/0",
"rightsourceip": "10.19.48.2/24,fd9d:bc11:4021::2/64",
}
# Merge with base variables
test_vars = {**base_vars, **strongswan_vars}
# Apply scenario-specific overrides
if scenario == 'ipv4_only':
test_vars['ipv6_support'] = False
test_vars['subjectAltName'] = 'IP:10.0.0.1'
test_vars['ansible_default_ipv6'] = None
elif scenario == 'dns_hostname':
test_vars['IP_subject_alt_name'] = 'vpn.example.com'
test_vars['subjectAltName'] = 'DNS:vpn.example.com'
test_vars['subjectAltName_type'] = 'DNS'
elif scenario == 'openssl_legacy':
test_vars['openssl_version'] = '1.1.1'
if scenario == "ipv4_only":
test_vars["ipv6_support"] = False
test_vars["subjectAltName"] = "IP:10.0.0.1"
test_vars["ansible_default_ipv6"] = None
elif scenario == "dns_hostname":
test_vars["IP_subject_alt_name"] = "vpn.example.com"
test_vars["subjectAltName"] = "DNS:vpn.example.com"
test_vars["subjectAltName_type"] = "DNS"
elif scenario == "openssl_legacy":
test_vars["openssl_version"] = "1.1.1"
return test_vars
@ -101,16 +102,16 @@ def get_strongswan_test_variables(scenario='default'):
def test_strongswan_templates():
"""Test all StrongSwan templates with various configurations."""
templates = [
'roles/strongswan/templates/ipsec.conf.j2',
'roles/strongswan/templates/ipsec.secrets.j2',
'roles/strongswan/templates/strongswan.conf.j2',
'roles/strongswan/templates/charon.conf.j2',
'roles/strongswan/templates/client_ipsec.conf.j2',
'roles/strongswan/templates/client_ipsec.secrets.j2',
'roles/strongswan/templates/100-CustomLimitations.conf.j2',
"roles/strongswan/templates/ipsec.conf.j2",
"roles/strongswan/templates/ipsec.secrets.j2",
"roles/strongswan/templates/strongswan.conf.j2",
"roles/strongswan/templates/charon.conf.j2",
"roles/strongswan/templates/client_ipsec.conf.j2",
"roles/strongswan/templates/client_ipsec.secrets.j2",
"roles/strongswan/templates/100-CustomLimitations.conf.j2",
]
scenarios = ['default', 'ipv4_only', 'dns_hostname', 'openssl_legacy']
scenarios = ["default", "ipv4_only", "dns_hostname", "openssl_legacy"]
errors = []
tested = 0
@ -127,21 +128,18 @@ def test_strongswan_templates():
test_vars = get_strongswan_test_variables(scenario)
try:
env = Environment(
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
# Add mock filters
env.filters['to_uuid'] = mock_to_uuid
env.filters['bool'] = mock_bool
env.filters['b64encode'] = mock_b64encode
env.filters['b64decode'] = mock_b64decode
env.tests['version'] = mock_version
env.filters["to_uuid"] = mock_to_uuid
env.filters["bool"] = mock_bool
env.filters["b64encode"] = mock_b64encode
env.filters["b64decode"] = mock_b64decode
env.tests["version"] = mock_version
# For client templates, add item context
if 'client' in template_name:
test_vars['item'] = 'testuser'
if "client" in template_name:
test_vars["item"] = "testuser"
template = env.get_template(template_name)
output = template.render(**test_vars)
@ -150,16 +148,16 @@ def test_strongswan_templates():
assert len(output) > 0, f"Empty output from {template_path} ({scenario})"
# Specific validations based on template
if 'ipsec.conf' in template_name and 'client' not in template_name:
assert 'conn' in output, "Missing connection definition"
if scenario != 'ipv4_only' and test_vars.get('ipv6_support'):
assert '::/0' in output or 'fd9d:bc11' in output, "Missing IPv6 configuration"
if "ipsec.conf" in template_name and "client" not in template_name:
assert "conn" in output, "Missing connection definition"
if scenario != "ipv4_only" and test_vars.get("ipv6_support"):
assert "::/0" in output or "fd9d:bc11" in output, "Missing IPv6 configuration"
if 'ipsec.secrets' in template_name:
assert 'PSK' in output or 'ECDSA' in output, "Missing authentication method"
if "ipsec.secrets" in template_name:
assert "PSK" in output or "ECDSA" in output, "Missing authentication method"
if 'strongswan.conf' in template_name:
assert 'charon' in output, "Missing charon configuration"
if "strongswan.conf" in template_name:
assert "charon" in output, "Missing charon configuration"
print(f"{template_name} ({scenario})")
@ -182,7 +180,7 @@ def test_openssl_template_constraints():
# This tests the actual openssl.yml task file to ensure our fix works
import yaml
openssl_path = 'roles/strongswan/tasks/openssl.yml'
openssl_path = "roles/strongswan/tasks/openssl.yml"
if not os.path.exists(openssl_path):
print("⚠️ OpenSSL tasks file not found")
return True
@ -194,22 +192,23 @@ def test_openssl_template_constraints():
# Find the CA CSR task
ca_csr_task = None
for task in content:
if isinstance(task, dict) and task.get('name', '').startswith('Create certificate signing request'):
if isinstance(task, dict) and task.get("name", "").startswith("Create certificate signing request"):
ca_csr_task = task
break
if ca_csr_task:
# Check that name_constraints_permitted is properly formatted
csr_module = ca_csr_task.get('community.crypto.openssl_csr_pipe', {})
constraints = csr_module.get('name_constraints_permitted', '')
csr_module = ca_csr_task.get("community.crypto.openssl_csr_pipe", {})
constraints = csr_module.get("name_constraints_permitted", "")
# The constraints should be a Jinja2 template without inline comments
if '#' in str(constraints):
if "#" in str(constraints):
# Check if the # is within {{ }}
import re
jinja_blocks = re.findall(r'\{\{.*?\}\}', str(constraints), re.DOTALL)
jinja_blocks = re.findall(r"\{\{.*?\}\}", str(constraints), re.DOTALL)
for block in jinja_blocks:
if '#' in block:
if "#" in block:
print("❌ Found inline comment in Jinja2 expression")
return False
@ -223,7 +222,7 @@ def test_openssl_template_constraints():
def test_mobileconfig_template():
"""Test the mobileconfig template with various scenarios."""
template_path = 'roles/strongswan/templates/mobileconfig.j2'
template_path = "roles/strongswan/templates/mobileconfig.j2"
if not os.path.exists(template_path):
print("⚠️ Mobileconfig template not found")
@ -237,20 +236,20 @@ def test_mobileconfig_template():
test_cases = [
{
'name': 'iPhone with cellular on-demand',
'algo_ondemand_cellular': 'true',
'algo_ondemand_wifi': 'false',
"name": "iPhone with cellular on-demand",
"algo_ondemand_cellular": "true",
"algo_ondemand_wifi": "false",
},
{
'name': 'iPad with WiFi on-demand',
'algo_ondemand_cellular': 'false',
'algo_ondemand_wifi': 'true',
'algo_ondemand_wifi_exclude': 'MyHomeNetwork,OfficeWiFi',
"name": "iPad with WiFi on-demand",
"algo_ondemand_cellular": "false",
"algo_ondemand_wifi": "true",
"algo_ondemand_wifi_exclude": "MyHomeNetwork,OfficeWiFi",
},
{
'name': 'Mac without on-demand',
'algo_ondemand_cellular': 'false',
'algo_ondemand_wifi': 'false',
"name": "Mac without on-demand",
"algo_ondemand_cellular": "false",
"algo_ondemand_wifi": "false",
},
]
@ -258,43 +257,41 @@ def test_mobileconfig_template():
for test_case in test_cases:
test_vars = get_strongswan_test_variables()
test_vars.update(test_case)
# Mock Ansible task result format for item
class MockTaskResult:
def __init__(self, content):
self.stdout = content
test_vars['item'] = ('testuser', MockTaskResult('TU9DS19QS0NTMTJfQ09OVEVOVA==')) # Tuple with mock result
test_vars['PayloadContentCA_base64'] = 'TU9DS19DQV9DRVJUX0JBU0U2NA==' # Valid base64
test_vars['PayloadContentUser_base64'] = 'TU9DS19VU0VSX0NFUlRfQkFTRTY0' # Valid base64
test_vars['pkcs12_PayloadCertificateUUID'] = str(uuid.uuid4())
test_vars['PayloadContent'] = 'TU9DS19QS0NTMTJfQ09OVEVOVA==' # Valid base64 for PKCS12
test_vars['algo_server_name'] = 'test-algo-vpn'
test_vars['VPN_PayloadIdentifier'] = str(uuid.uuid4())
test_vars['CA_PayloadIdentifier'] = str(uuid.uuid4())
test_vars['PayloadContentCA'] = 'TU9DS19DQV9DRVJUX0NPTlRFTlQ=' # Valid base64
test_vars["item"] = ("testuser", MockTaskResult("TU9DS19QS0NTMTJfQ09OVEVOVA==")) # Tuple with mock result
test_vars["PayloadContentCA_base64"] = "TU9DS19DQV9DRVJUX0JBU0U2NA==" # Valid base64
test_vars["PayloadContentUser_base64"] = "TU9DS19VU0VSX0NFUlRfQkFTRTY0" # Valid base64
test_vars["pkcs12_PayloadCertificateUUID"] = str(uuid.uuid4())
test_vars["PayloadContent"] = "TU9DS19QS0NTMTJfQ09OVEVOVA==" # Valid base64 for PKCS12
test_vars["algo_server_name"] = "test-algo-vpn"
test_vars["VPN_PayloadIdentifier"] = str(uuid.uuid4())
test_vars["CA_PayloadIdentifier"] = str(uuid.uuid4())
test_vars["PayloadContentCA"] = "TU9DS19DQV9DRVJUX0NPTlRFTlQ=" # Valid base64
try:
env = Environment(
loader=FileSystemLoader('roles/strongswan/templates'),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader("roles/strongswan/templates"), undefined=StrictUndefined)
# Add mock filters
env.filters['to_uuid'] = mock_to_uuid
env.filters['b64encode'] = mock_b64encode
env.filters['b64decode'] = mock_b64decode
env.filters["to_uuid"] = mock_to_uuid
env.filters["b64encode"] = mock_b64encode
env.filters["b64decode"] = mock_b64decode
template = env.get_template('mobileconfig.j2')
template = env.get_template("mobileconfig.j2")
output = template.render(**test_vars)
# Validate output
assert '<?xml' in output, "Missing XML declaration"
assert '<plist' in output, "Missing plist element"
assert 'PayloadType' in output, "Missing PayloadType"
assert "<?xml" in output, "Missing XML declaration"
assert "<plist" in output, "Missing plist element"
assert "PayloadType" in output, "Missing PayloadType"
# Check on-demand configuration
if test_case.get('algo_ondemand_cellular') == 'true' or test_case.get('algo_ondemand_wifi') == 'true':
assert 'OnDemandEnabled' in output, f"Missing OnDemand config for {test_case['name']}"
if test_case.get("algo_ondemand_cellular") == "true" or test_case.get("algo_ondemand_wifi") == "true":
assert "OnDemandEnabled" in output, f"Missing OnDemand config for {test_case['name']}"
print(f" ✅ Mobileconfig: {test_case['name']}")

View file

@ -3,6 +3,7 @@
Test that Ansible templates render correctly
This catches undefined variables, syntax errors, and logic bugs
"""
import os
import sys
from pathlib import Path
@ -22,20 +23,20 @@ def mock_to_uuid(value):
def mock_bool(value):
"""Mock the bool filter"""
return str(value).lower() in ('true', '1', 'yes', 'on')
return str(value).lower() in ("true", "1", "yes", "on")
def mock_lookup(type, path):
"""Mock the lookup function"""
# Return fake data for file lookups
if type == 'file':
if 'private' in path:
return 'MOCK_PRIVATE_KEY_BASE64=='
elif 'public' in path:
return 'MOCK_PUBLIC_KEY_BASE64=='
elif 'preshared' in path:
return 'MOCK_PRESHARED_KEY_BASE64=='
return 'MOCK_LOOKUP_DATA'
if type == "file":
if "private" in path:
return "MOCK_PRIVATE_KEY_BASE64=="
elif "public" in path:
return "MOCK_PUBLIC_KEY_BASE64=="
elif "preshared" in path:
return "MOCK_PRESHARED_KEY_BASE64=="
return "MOCK_LOOKUP_DATA"
def get_test_variables():
@ -47,8 +48,8 @@ def get_test_variables():
def find_templates():
"""Find all Jinja2 template files in the repo"""
templates = []
for pattern in ['**/*.j2', '**/*.jinja2', '**/*.yml.j2']:
templates.extend(Path('.').glob(pattern))
for pattern in ["**/*.j2", "**/*.jinja2", "**/*.yml.j2"]:
templates.extend(Path(".").glob(pattern))
return templates
@ -57,10 +58,10 @@ def test_template_syntax():
templates = find_templates()
# Skip some paths that aren't real templates
skip_paths = ['.git/', 'venv/', '.venv/', '.env/', 'configs/']
skip_paths = [".git/", "venv/", ".venv/", ".env/", "configs/"]
# Skip templates that use Ansible-specific filters
skip_templates = ['vpn-dict.j2', 'mobileconfig.j2', 'dnscrypt-proxy.toml.j2']
skip_templates = ["vpn-dict.j2", "mobileconfig.j2", "dnscrypt-proxy.toml.j2"]
errors = []
skipped = 0
@ -76,10 +77,7 @@ def test_template_syntax():
try:
template_dir = template_path.parent
env = Environment(
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
# Just try to load the template - this checks syntax
env.get_template(template_path.name)
@ -103,13 +101,13 @@ def test_template_syntax():
def test_critical_templates():
"""Test that critical templates render with test data"""
critical_templates = [
'roles/wireguard/templates/client.conf.j2',
'roles/strongswan/templates/ipsec.conf.j2',
'roles/strongswan/templates/ipsec.secrets.j2',
'roles/dns/templates/adblock.sh.j2',
'roles/dns/templates/dnsmasq.conf.j2',
'roles/common/templates/rules.v4.j2',
'roles/common/templates/rules.v6.j2',
"roles/wireguard/templates/client.conf.j2",
"roles/strongswan/templates/ipsec.conf.j2",
"roles/strongswan/templates/ipsec.secrets.j2",
"roles/dns/templates/adblock.sh.j2",
"roles/dns/templates/dnsmasq.conf.j2",
"roles/common/templates/rules.v4.j2",
"roles/common/templates/rules.v6.j2",
]
test_vars = get_test_variables()
@ -123,21 +121,18 @@ def test_critical_templates():
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
env = Environment(
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
# Add mock functions
env.globals['lookup'] = mock_lookup
env.filters['to_uuid'] = mock_to_uuid
env.filters['bool'] = mock_bool
env.globals["lookup"] = mock_lookup
env.filters["to_uuid"] = mock_to_uuid
env.filters["bool"] = mock_bool
template = env.get_template(template_name)
# Add item context for templates that use loops
if 'client' in template_name:
test_vars['item'] = ('test-user', 'test-user')
if "client" in template_name:
test_vars["item"] = ("test-user", "test-user")
# Try to render
output = template.render(**test_vars)
@ -163,17 +158,17 @@ def test_variable_consistency():
"""Check that commonly used variables are defined consistently"""
# Variables that should be used consistently across templates
common_vars = [
'server_name',
'IP_subject_alt_name',
'wireguard_port',
'wireguard_network',
'dns_servers',
'users',
"server_name",
"IP_subject_alt_name",
"wireguard_port",
"wireguard_network",
"dns_servers",
"users",
]
# Check if main.yml defines these
if os.path.exists('main.yml'):
with open('main.yml') as f:
if os.path.exists("main.yml"):
with open("main.yml") as f:
content = f.read()
missing = []
@ -192,28 +187,19 @@ def test_wireguard_ipv6_endpoints():
"""Test that WireGuard client configs properly format IPv6 endpoints"""
test_cases = [
# IPv4 address - should not be bracketed
{
'IP_subject_alt_name': '192.168.1.100',
'expected_endpoint': 'Endpoint = 192.168.1.100:51820'
},
{"IP_subject_alt_name": "192.168.1.100", "expected_endpoint": "Endpoint = 192.168.1.100:51820"},
# IPv6 address - should be bracketed
{
'IP_subject_alt_name': '2600:3c01::f03c:91ff:fedf:3b2a',
'expected_endpoint': 'Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820'
"IP_subject_alt_name": "2600:3c01::f03c:91ff:fedf:3b2a",
"expected_endpoint": "Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820",
},
# Hostname - should not be bracketed
{
'IP_subject_alt_name': 'vpn.example.com',
'expected_endpoint': 'Endpoint = vpn.example.com:51820'
},
{"IP_subject_alt_name": "vpn.example.com", "expected_endpoint": "Endpoint = vpn.example.com:51820"},
# IPv6 with zone ID - should be bracketed
{
'IP_subject_alt_name': 'fe80::1%eth0',
'expected_endpoint': 'Endpoint = [fe80::1%eth0]:51820'
},
{"IP_subject_alt_name": "fe80::1%eth0", "expected_endpoint": "Endpoint = [fe80::1%eth0]:51820"},
]
template_path = 'roles/wireguard/templates/client.conf.j2'
template_path = "roles/wireguard/templates/client.conf.j2"
if not os.path.exists(template_path):
print(f"⚠ Skipping IPv6 endpoint test - {template_path} not found")
return
@ -225,24 +211,23 @@ def test_wireguard_ipv6_endpoints():
try:
# Set up test variables
test_vars = {**base_vars, **test_case}
test_vars['item'] = ('test-user', 'test-user')
test_vars["item"] = ("test-user", "test-user")
# Render template
env = Environment(
loader=FileSystemLoader('roles/wireguard/templates'),
undefined=StrictUndefined
)
env.globals['lookup'] = mock_lookup
env = Environment(loader=FileSystemLoader("roles/wireguard/templates"), undefined=StrictUndefined)
env.globals["lookup"] = mock_lookup
template = env.get_template('client.conf.j2')
template = env.get_template("client.conf.j2")
output = template.render(**test_vars)
# Check if the expected endpoint format is in the output
if test_case['expected_endpoint'] not in output:
errors.append(f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output")
if test_case["expected_endpoint"] not in output:
errors.append(
f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output"
)
# Print relevant part of output for debugging
for line in output.split('\n'):
if 'Endpoint' in line:
for line in output.split("\n"):
if "Endpoint" in line:
errors.append(f" Found: {line.strip()}")
except Exception as e:
@ -262,27 +247,27 @@ def test_template_conditionals():
test_cases = [
# WireGuard enabled, IPsec disabled
{
'wireguard_enabled': True,
'ipsec_enabled': False,
'dns_encryption': True,
'dns_adblocking': True,
'algo_ssh_tunneling': False,
"wireguard_enabled": True,
"ipsec_enabled": False,
"dns_encryption": True,
"dns_adblocking": True,
"algo_ssh_tunneling": False,
},
# IPsec enabled, WireGuard disabled
{
'wireguard_enabled': False,
'ipsec_enabled': True,
'dns_encryption': False,
'dns_adblocking': False,
'algo_ssh_tunneling': True,
"wireguard_enabled": False,
"ipsec_enabled": True,
"dns_encryption": False,
"dns_adblocking": False,
"algo_ssh_tunneling": True,
},
# Both enabled
{
'wireguard_enabled': True,
'ipsec_enabled': True,
'dns_encryption': True,
'dns_adblocking': True,
'algo_ssh_tunneling': True,
"wireguard_enabled": True,
"ipsec_enabled": True,
"dns_encryption": True,
"dns_adblocking": True,
"algo_ssh_tunneling": True,
},
]
@ -294,7 +279,7 @@ def test_template_conditionals():
# Test a few templates that have conditionals
conditional_templates = [
'roles/common/templates/rules.v4.j2',
"roles/common/templates/rules.v4.j2",
]
for template_path in conditional_templates:
@ -305,23 +290,19 @@ def test_template_conditionals():
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
env = Environment(
loader=FileSystemLoader(template_dir),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
# Add mock functions
env.globals['lookup'] = mock_lookup
env.filters['to_uuid'] = mock_to_uuid
env.filters['bool'] = mock_bool
env.globals["lookup"] = mock_lookup
env.filters["to_uuid"] = mock_to_uuid
env.filters["bool"] = mock_bool
template = env.get_template(template_name)
output = template.render(**test_vars)
# Verify conditionals work
if test_case.get('wireguard_enabled'):
assert str(test_vars['wireguard_port']) in output, \
f"WireGuard port missing when enabled (case {i})"
if test_case.get("wireguard_enabled"):
assert str(test_vars["wireguard_port"]) in output, f"WireGuard port missing when enabled (case {i})"
except Exception as e:
print(f"✗ Conditional test failed for {template_path} case {i}: {e}")

View file

@ -3,6 +3,7 @@
Test user management functionality without deployment
Based on issues #14745, #14746, #14738, #14726
"""
import os
import re
import sys
@ -23,15 +24,15 @@ users:
"""
config = yaml.safe_load(test_config)
users = config.get('users', [])
users = config.get("users", [])
assert len(users) == 5, f"Expected 5 users, got {len(users)}"
assert 'alice' in users, "Missing user 'alice'"
assert 'user-with-dash' in users, "Dash in username not handled"
assert 'user_with_underscore' in users, "Underscore in username not handled"
assert "alice" in users, "Missing user 'alice'"
assert "user-with-dash" in users, "Dash in username not handled"
assert "user_with_underscore" in users, "Underscore in username not handled"
# Test that usernames are valid
username_pattern = re.compile(r'^[a-zA-Z0-9_-]+$')
username_pattern = re.compile(r"^[a-zA-Z0-9_-]+$")
for user in users:
assert username_pattern.match(user), f"Invalid username format: {user}"
@ -42,35 +43,27 @@ def test_server_selection_format():
"""Test server selection string parsing (issue #14727)"""
# Test various server display formats
test_cases = [
{"display": "1. 192.168.1.100 (algo-server)", "expected_ip": "192.168.1.100", "expected_name": "algo-server"},
{"display": "2. 10.0.0.1 (production-vpn)", "expected_ip": "10.0.0.1", "expected_name": "production-vpn"},
{
'display': '1. 192.168.1.100 (algo-server)',
'expected_ip': '192.168.1.100',
'expected_name': 'algo-server'
"display": "3. vpn.example.com (example-server)",
"expected_ip": "vpn.example.com",
"expected_name": "example-server",
},
{
'display': '2. 10.0.0.1 (production-vpn)',
'expected_ip': '10.0.0.1',
'expected_name': 'production-vpn'
},
{
'display': '3. vpn.example.com (example-server)',
'expected_ip': 'vpn.example.com',
'expected_name': 'example-server'
}
]
# Pattern to extract IP and name from display string
pattern = re.compile(r'^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$')
pattern = re.compile(r"^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$")
for case in test_cases:
match = pattern.match(case['display'])
match = pattern.match(case["display"])
assert match, f"Failed to parse: {case['display']}"
ip_or_host = match.group(1)
name = match.group(2)
assert ip_or_host == case['expected_ip'], f"Wrong IP extracted: {ip_or_host}"
assert name == case['expected_name'], f"Wrong name extracted: {name}"
assert ip_or_host == case["expected_ip"], f"Wrong IP extracted: {ip_or_host}"
assert name == case["expected_name"], f"Wrong name extracted: {name}"
print("✓ Server selection format test passed")
@ -78,12 +71,12 @@ def test_server_selection_format():
def test_ssh_key_preservation():
"""Test that SSH keys aren't regenerated unnecessarily"""
with tempfile.TemporaryDirectory() as tmpdir:
ssh_key_path = os.path.join(tmpdir, 'test_key')
ssh_key_path = os.path.join(tmpdir, "test_key")
# Simulate existing SSH key
with open(ssh_key_path, 'w') as f:
with open(ssh_key_path, "w") as f:
f.write("EXISTING_SSH_KEY_CONTENT")
with open(f"{ssh_key_path}.pub", 'w') as f:
with open(f"{ssh_key_path}.pub", "w") as f:
f.write("ssh-rsa EXISTING_PUBLIC_KEY")
# Record original content
@ -105,11 +98,7 @@ def test_ssh_key_preservation():
def test_ca_password_handling():
"""Test CA password validation and handling"""
# Test password requirements
valid_passwords = [
"SecurePassword123!",
"Algo-VPN-2024",
"Complex#Pass@Word999"
]
valid_passwords = ["SecurePassword123!", "Algo-VPN-2024", "Complex#Pass@Word999"]
invalid_passwords = [
"", # Empty
@ -120,13 +109,13 @@ def test_ca_password_handling():
# Basic password validation
for pwd in valid_passwords:
assert len(pwd) >= 12, f"Password too short: {pwd}"
assert ' ' not in pwd, f"Password contains spaces: {pwd}"
assert " " not in pwd, f"Password contains spaces: {pwd}"
for pwd in invalid_passwords:
issues = []
if len(pwd) < 12:
issues.append("too short")
if ' ' in pwd:
if " " in pwd:
issues.append("contains spaces")
if not pwd:
issues.append("empty")
@ -137,8 +126,8 @@ def test_ca_password_handling():
def test_user_config_generation():
"""Test that user configs would be generated correctly"""
users = ['alice', 'bob', 'charlie']
server_name = 'test-server'
users = ["alice", "bob", "charlie"]
server_name = "test-server"
# Simulate config file structure
for user in users:
@ -168,7 +157,7 @@ users:
"""
config = yaml.safe_load(test_config)
users = config.get('users', [])
users = config.get("users", [])
# Check for duplicates
unique_users = list(set(users))
@ -182,7 +171,7 @@ users:
duplicates.append(user)
seen.add(user)
assert 'alice' in duplicates, "Duplicate 'alice' not detected"
assert "alice" in duplicates, "Duplicate 'alice' not detected"
print("✓ Duplicate user handling test passed")

View file

@ -3,6 +3,7 @@
Test WireGuard key generation - focused on x25519_pubkey module integration
Addresses test gap identified in tests/README.md line 63-67: WireGuard private/public key generation
"""
import base64
import os
import subprocess
@ -10,13 +11,13 @@ import sys
import tempfile
# Add library directory to path to import our custom module
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "library"))
def test_wireguard_tools_available():
"""Test that WireGuard tools are available for validation"""
try:
result = subprocess.run(['wg', '--version'], capture_output=True, text=True)
result = subprocess.run(["wg", "--version"], capture_output=True, text=True)
assert result.returncode == 0, "WireGuard tools not available"
print(f"✓ WireGuard tools available: {result.stdout.strip()}")
return True
@ -29,6 +30,7 @@ def test_x25519_module_import():
"""Test that our custom x25519_pubkey module can be imported and used"""
try:
import x25519_pubkey # noqa: F401
print("✓ x25519_pubkey module imports successfully")
return True
except ImportError as e:
@ -37,16 +39,17 @@ def test_x25519_module_import():
def generate_test_private_key():
"""Generate a test private key using the same method as Algo"""
with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file:
with tempfile.NamedTemporaryFile(suffix=".raw", delete=False) as temp_file:
raw_key_path = temp_file.name
try:
# Generate 32 random bytes for X25519 private key (same as community.crypto does)
import secrets
raw_data = secrets.token_bytes(32)
# Write raw key to file (like community.crypto openssl_privatekey with format: raw)
with open(raw_key_path, 'wb') as f:
with open(raw_key_path, "wb") as f:
f.write(raw_data)
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
@ -83,7 +86,7 @@ def test_x25519_pubkey_from_raw_file():
def exit_json(self, **kwargs):
self.result = kwargs
with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub:
with tempfile.NamedTemporaryFile(suffix=".pub", delete=False) as temp_pub:
public_key_path = temp_pub.name
try:
@ -95,11 +98,9 @@ def test_x25519_pubkey_from_raw_file():
try:
# Mock the module call
mock_module = MockModule({
'private_key_path': raw_key_path,
'public_key_path': public_key_path,
'private_key_b64': None
})
mock_module = MockModule(
{"private_key_path": raw_key_path, "public_key_path": public_key_path, "private_key_b64": None}
)
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
@ -107,8 +108,8 @@ def test_x25519_pubkey_from_raw_file():
run_module()
# Check the result
assert 'public_key' in mock_module.result
assert mock_module.result['changed']
assert "public_key" in mock_module.result
assert mock_module.result["changed"]
assert os.path.exists(public_key_path)
with open(public_key_path) as f:
@ -160,11 +161,7 @@ def test_x25519_pubkey_from_b64_string():
original_AnsibleModule = x25519_pubkey.AnsibleModule
try:
mock_module = MockModule({
'private_key_b64': b64_key,
'private_key_path': None,
'public_key_path': None
})
mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
@ -172,8 +169,8 @@ def test_x25519_pubkey_from_b64_string():
run_module()
# Check the result
assert 'public_key' in mock_module.result
derived_pubkey = mock_module.result['public_key']
assert "public_key" in mock_module.result
derived_pubkey = mock_module.result["public_key"]
# Validate base64 format
try:
@ -222,21 +219,17 @@ def test_wireguard_validation():
original_AnsibleModule = x25519_pubkey.AnsibleModule
try:
mock_module = MockModule({
'private_key_b64': b64_key,
'private_key_path': None,
'public_key_path': None
})
mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
run_module()
derived_pubkey = mock_module.result['public_key']
derived_pubkey = mock_module.result["public_key"]
finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config:
with tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False) as temp_config:
# Create a WireGuard config using our keys
wg_config = f"""[Interface]
PrivateKey = {b64_key}
@ -251,16 +244,12 @@ AllowedIPs = 10.19.49.2/32
try:
# Test that WireGuard can parse our config
result = subprocess.run([
'wg-quick', 'strip', config_path
], capture_output=True, text=True)
result = subprocess.run(["wg-quick", "strip", config_path], capture_output=True, text=True)
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
# Test key derivation with wg pubkey command
wg_result = subprocess.run([
'wg', 'pubkey'
], input=b64_key, capture_output=True, text=True)
wg_result = subprocess.run(["wg", "pubkey"], input=b64_key, capture_output=True, text=True)
if wg_result.returncode == 0:
wg_derived = wg_result.stdout.strip()
@ -286,8 +275,8 @@ def test_key_consistency():
raw_key_path, b64_key = generate_test_private_key()
try:
def derive_pubkey_from_same_key():
def derive_pubkey_from_same_key():
class MockModule:
def __init__(self, params):
self.params = params
@ -305,16 +294,18 @@ def test_key_consistency():
original_AnsibleModule = x25519_pubkey.AnsibleModule
try:
mock_module = MockModule({
'private_key_b64': b64_key, # SAME key each time
'private_key_path': None,
'public_key_path': None
})
mock_module = MockModule(
{
"private_key_b64": b64_key, # SAME key each time
"private_key_path": None,
"public_key_path": None,
}
)
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
run_module()
return mock_module.result['public_key']
return mock_module.result["public_key"]
finally:
x25519_pubkey.AnsibleModule = original_AnsibleModule

View file

@ -14,13 +14,13 @@ from pathlib import Path
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateSyntaxError, meta
def find_jinja2_templates(root_dir: str = '.') -> list[Path]:
def find_jinja2_templates(root_dir: str = ".") -> list[Path]:
"""Find all Jinja2 template files in the project."""
templates = []
patterns = ['**/*.j2', '**/*.jinja2', '**/*.yml.j2', '**/*.conf.j2']
patterns = ["**/*.j2", "**/*.jinja2", "**/*.yml.j2", "**/*.conf.j2"]
# Skip these directories
skip_dirs = {'.git', '.venv', 'venv', '.env', 'configs', '__pycache__', '.cache'}
skip_dirs = {".git", ".venv", "venv", ".env", "configs", "__pycache__", ".cache"}
for pattern in patterns:
for path in Path(root_dir).glob(pattern):
@ -39,25 +39,25 @@ def check_inline_comments_in_expressions(template_content: str, template_path: P
errors = []
# Pattern to find Jinja2 expressions
jinja_pattern = re.compile(r'\{\{.*?\}\}|\{%.*?%\}', re.DOTALL)
jinja_pattern = re.compile(r"\{\{.*?\}\}|\{%.*?%\}", re.DOTALL)
for match in jinja_pattern.finditer(template_content):
expression = match.group()
lines = expression.split('\n')
lines = expression.split("\n")
for i, line in enumerate(lines):
# Check for # that's not in a string
# Simple heuristic: if # appears after non-whitespace and not in quotes
if '#' in line:
if "#" in line:
# Remove quoted strings to avoid false positives
cleaned = re.sub(r'"[^"]*"', '', line)
cleaned = re.sub(r"'[^']*'", '', cleaned)
cleaned = re.sub(r'"[^"]*"', "", line)
cleaned = re.sub(r"'[^']*'", "", cleaned)
if '#' in cleaned:
if "#" in cleaned:
# Check if it's likely a comment (has text after it)
hash_pos = cleaned.index('#')
if hash_pos > 0 and cleaned[hash_pos-1:hash_pos] != '\\':
line_num = template_content[:match.start()].count('\n') + i + 1
hash_pos = cleaned.index("#")
if hash_pos > 0 and cleaned[hash_pos - 1 : hash_pos] != "\\":
line_num = template_content[: match.start()].count("\n") + i + 1
errors.append(
f"{template_path}:{line_num}: Inline comment (#) found in Jinja2 expression. "
f"Move comments outside the expression."
@ -83,11 +83,24 @@ def check_undefined_variables(template_path: Path) -> list[str]:
# Common Ansible variables that are always available
ansible_builtins = {
'ansible_default_ipv4', 'ansible_default_ipv6', 'ansible_hostname',
'ansible_distribution', 'ansible_distribution_version', 'ansible_facts',
'inventory_hostname', 'hostvars', 'groups', 'group_names',
'play_hosts', 'ansible_version', 'ansible_user', 'ansible_host',
'item', 'ansible_loop', 'ansible_index', 'lookup'
"ansible_default_ipv4",
"ansible_default_ipv6",
"ansible_hostname",
"ansible_distribution",
"ansible_distribution_version",
"ansible_facts",
"inventory_hostname",
"hostvars",
"groups",
"group_names",
"play_hosts",
"ansible_version",
"ansible_user",
"ansible_host",
"item",
"ansible_loop",
"ansible_index",
"lookup",
}
# Filter out known Ansible variables
@ -95,9 +108,7 @@ def check_undefined_variables(template_path: Path) -> list[str]:
# Only report if there are truly unknown variables
if unknown_vars and len(unknown_vars) < 20: # Avoid noise from templates with many vars
errors.append(
f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}"
)
errors.append(f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}")
except Exception:
# Don't report parse errors here, they're handled elsewhere
@ -116,9 +127,9 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
# Skip full parsing for templates that use Ansible-specific features heavily
# We still check for inline comments but skip full template parsing
ansible_specific_templates = {
'dnscrypt-proxy.toml.j2', # Uses |bool filter
'mobileconfig.j2', # Uses |to_uuid filter and complex item structures
'vpn-dict.j2', # Uses |to_uuid filter
"dnscrypt-proxy.toml.j2", # Uses |bool filter
"mobileconfig.j2", # Uses |to_uuid filter and complex item structures
"vpn-dict.j2", # Uses |to_uuid filter
}
if template_path.name in ansible_specific_templates:
@ -139,18 +150,15 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
errors.extend(check_inline_comments_in_expressions(template_content, template_path))
# Try to parse the template
env = Environment(
loader=FileSystemLoader(template_path.parent),
undefined=StrictUndefined
)
env = Environment(loader=FileSystemLoader(template_path.parent), undefined=StrictUndefined)
# Add mock Ansible filters to avoid syntax errors
env.filters['bool'] = lambda x: x
env.filters['to_uuid'] = lambda x: x
env.filters['b64encode'] = lambda x: x
env.filters['b64decode'] = lambda x: x
env.filters['regex_replace'] = lambda x, y, z: x
env.filters['default'] = lambda x, d: x if x else d
env.filters["bool"] = lambda x: x
env.filters["to_uuid"] = lambda x: x
env.filters["b64encode"] = lambda x: x
env.filters["b64decode"] = lambda x: x
env.filters["regex_replace"] = lambda x, y, z: x
env.filters["default"] = lambda x, d: x if x else d
# This will raise TemplateSyntaxError if there's a syntax problem
env.get_template(template_path.name)
@ -178,18 +186,20 @@ def check_common_antipatterns(template_path: Path) -> list[str]:
content = f.read()
# Check for missing spaces around filters
if re.search(r'\{\{[^}]+\|[^ ]', content):
if re.search(r"\{\{[^}]+\|[^ ]", content):
warnings.append(f"{template_path}: Missing space after filter pipe (|)")
# Check for deprecated 'when' in Jinja2 (should use if)
if re.search(r'\{%\s*when\s+', content):
if re.search(r"\{%\s*when\s+", content):
warnings.append(f"{template_path}: Use 'if' instead of 'when' in Jinja2 templates")
# Check for extremely long expressions (harder to debug)
for match in re.finditer(r'\{\{(.+?)\}\}', content, re.DOTALL):
for match in re.finditer(r"\{\{(.+?)\}\}", content, re.DOTALL):
if len(match.group(1)) > 200:
line_num = content[:match.start()].count('\n') + 1
warnings.append(f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up")
line_num = content[: match.start()].count("\n") + 1
warnings.append(
f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up"
)
except Exception:
pass # Ignore errors in anti-pattern checking