From 15be88d28ba3511f07acf3cda6620285852227fd Mon Sep 17 00:00:00 2001 From: Dan Guido Date: Sun, 17 Aug 2025 19:08:34 -0400 Subject: [PATCH] Apply Python linting and formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Run ruff check --fix to fix linting issues - Run ruff format to ensure consistent formatting - All tests still pass after formatting changes 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- library/digital_ocean_floating_ip.py | 121 +++++----- library/gcp_compute_location_info.py | 43 ++-- library/lightsail_region_facts.py | 33 ++- library/linode_stackscript_v4.py | 64 +++-- library/linode_v4.py | 94 ++++---- library/scaleway_compute.py | 193 +++++++-------- library/x25519_pubkey.py | 45 ++-- scripts/track-test-effectiveness.py | 120 +++++---- tests/fixtures/__init__.py | 3 +- tests/integration/ansible-service-wrapper.py | 45 ++-- tests/integration/mock_modules/apt.py | 77 +++--- tests/integration/mock_modules/command.py | 95 ++++---- tests/integration/mock_modules/shell.py | 90 ++++--- tests/test_cloud_init_template.py | 51 ++-- tests/test_package_preinstall.py | 44 ++-- tests/unit/test_basic_sanity.py | 15 +- tests/unit/test_cloud_provider_configs.py | 40 +-- tests/unit/test_config_validation.py | 33 +-- .../unit/test_docker_localhost_deployment.py | 41 ++-- tests/unit/test_generated_configs.py | 102 ++++---- tests/unit/test_iptables_rules.py | 136 +++++------ tests/unit/test_lightsail_boto3_fix.py | 65 ++--- tests/unit/test_openssl_compatibility.py | 227 +++++++++++------- tests/unit/test_strongswan_templates.py | 217 +++++++++-------- tests/unit/test_template_rendering.py | 167 ++++++------- tests/unit/test_user_management.py | 61 ++--- tests/unit/test_wireguard_key_generation.py | 69 +++--- tests/validate_jinja2_templates.py | 86 ++++--- 28 files changed, 1178 insertions(+), 1199 deletions(-) diff --git a/library/digital_ocean_floating_ip.py b/library/digital_ocean_floating_ip.py index ece51680..19cf54cb 100644 --- a/library/digital_ocean_floating_ip.py +++ b/library/digital_ocean_floating_ip.py @@ -9,11 +9,9 @@ import time from ansible.module_utils.basic import AnsibleModule, env_fallback from ansible.module_utils.digital_ocean import DigitalOceanHelper -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"} -DOCUMENTATION = ''' +DOCUMENTATION = """ --- module: digital_ocean_floating_ip short_description: Manage DigitalOcean Floating IPs @@ -44,10 +42,10 @@ notes: - Version 2 of DigitalOcean API is used. requirements: - "python >= 2.6" -''' +""" -EXAMPLES = ''' +EXAMPLES = """ - name: "Create a Floating IP in region lon1" digital_ocean_floating_ip: state: present @@ -63,10 +61,10 @@ EXAMPLES = ''' state: absent ip: "1.2.3.4" -''' +""" -RETURN = ''' +RETURN = """ # Digital Ocean API info https://developers.digitalocean.com/documentation/v2/#floating-ips data: description: a DigitalOcean Floating IP resource @@ -106,11 +104,10 @@ data: "region_slug": "nyc3" } } -''' +""" class Response: - def __init__(self, resp, info): self.body = None if resp: @@ -132,36 +129,37 @@ class Response: def status_code(self): return self.info["status"] + def wait_action(module, rest, ip, action_id, timeout=10): end_time = time.time() + 10 while time.time() < end_time: - response = rest.get(f'floating_ips/{ip}/actions/{action_id}') + response = rest.get(f"floating_ips/{ip}/actions/{action_id}") # status_code = response.status_code # TODO: check status_code == 200? - status = response.json['action']['status'] - if status == 'completed': + status = response.json["action"]["status"] + if status == "completed": return True - elif status == 'errored': - module.fail_json(msg=f'Floating ip action error [ip: {ip}: action: {action_id}]', data=json) + elif status == "errored": + module.fail_json(msg=f"Floating ip action error [ip: {ip}: action: {action_id}]", data=json) - module.fail_json(msg=f'Floating ip action timeout [ip: {ip}: action: {action_id}]', data=json) + module.fail_json(msg=f"Floating ip action timeout [ip: {ip}: action: {action_id}]", data=json) def core(module): # api_token = module.params['oauth_token'] # unused for now - state = module.params['state'] - ip = module.params['ip'] - droplet_id = module.params['droplet_id'] + state = module.params["state"] + ip = module.params["ip"] + droplet_id = module.params["droplet_id"] rest = DigitalOceanHelper(module) - if state in ('present'): - if droplet_id is not None and module.params['ip'] is not None: + if state in ("present"): + if droplet_id is not None and module.params["ip"] is not None: # Lets try to associate the ip to the specified droplet associate_floating_ips(module, rest) else: create_floating_ips(module, rest) - elif state in ('absent'): + elif state in ("absent"): response = rest.delete(f"floating_ips/{ip}") status_code = response.status_code json_data = response.json @@ -174,65 +172,68 @@ def core(module): def get_floating_ip_details(module, rest): - ip = module.params['ip'] + ip = module.params["ip"] response = rest.get(f"floating_ips/{ip}") status_code = response.status_code json_data = response.json if status_code == 200: - return json_data['floating_ip'] + return json_data["floating_ip"] else: - module.fail_json(msg="Error assigning floating ip [{}: {}]".format( - status_code, json_data["message"]), region=module.params['region']) + module.fail_json( + msg="Error assigning floating ip [{}: {}]".format(status_code, json_data["message"]), + region=module.params["region"], + ) def assign_floating_id_to_droplet(module, rest): - ip = module.params['ip'] + ip = module.params["ip"] payload = { "type": "assign", - "droplet_id": module.params['droplet_id'], + "droplet_id": module.params["droplet_id"], } response = rest.post(f"floating_ips/{ip}/actions", data=payload) status_code = response.status_code json_data = response.json if status_code == 201: - wait_action(module, rest, ip, json_data['action']['id']) + wait_action(module, rest, ip, json_data["action"]["id"]) module.exit_json(changed=True, data=json_data) else: - module.fail_json(msg="Error creating floating ip [{}: {}]".format( - status_code, json_data["message"]), region=module.params['region']) + module.fail_json( + msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]), + region=module.params["region"], + ) def associate_floating_ips(module, rest): floating_ip = get_floating_ip_details(module, rest) - droplet = floating_ip['droplet'] + droplet = floating_ip["droplet"] # TODO: If already assigned to a droplet verify if is one of the specified as valid - if droplet is not None and str(droplet['id']) in [module.params['droplet_id']]: + if droplet is not None and str(droplet["id"]) in [module.params["droplet_id"]]: module.exit_json(changed=False) else: assign_floating_id_to_droplet(module, rest) def create_floating_ips(module, rest): - payload = { - } + payload = {} floating_ip_data = None - if module.params['region'] is not None: - payload["region"] = module.params['region'] + if module.params["region"] is not None: + payload["region"] = module.params["region"] - if module.params['droplet_id'] is not None: - payload["droplet_id"] = module.params['droplet_id'] + if module.params["droplet_id"] is not None: + payload["droplet_id"] = module.params["droplet_id"] - floating_ips = rest.get_paginated_data(base_url='floating_ips?', data_key_name='floating_ips') + floating_ips = rest.get_paginated_data(base_url="floating_ips?", data_key_name="floating_ips") for floating_ip in floating_ips: - if floating_ip['droplet'] and floating_ip['droplet']['id'] == module.params['droplet_id']: - floating_ip_data = {'floating_ip': floating_ip} + if floating_ip["droplet"] and floating_ip["droplet"]["id"] == module.params["droplet_id"]: + floating_ip_data = {"floating_ip": floating_ip} if floating_ip_data: module.exit_json(changed=False, data=floating_ip_data) @@ -244,36 +245,34 @@ def create_floating_ips(module, rest): if status_code == 202: module.exit_json(changed=True, data=json_data) else: - module.fail_json(msg="Error creating floating ip [{}: {}]".format( - status_code, json_data["message"]), region=module.params['region']) + module.fail_json( + msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]), + region=module.params["region"], + ) def main(): module = AnsibleModule( argument_spec={ - 'state': {'choices': ['present', 'absent'], 'default': 'present'}, - 'ip': {'aliases': ['id'], 'required': False}, - 'region': {'required': False}, - 'droplet_id': {'required': False, 'type': 'int'}, - 'oauth_token': { - 'no_log': True, + "state": {"choices": ["present", "absent"], "default": "present"}, + "ip": {"aliases": ["id"], "required": False}, + "region": {"required": False}, + "droplet_id": {"required": False, "type": "int"}, + "oauth_token": { + "no_log": True, # Support environment variable for DigitalOcean OAuth Token - 'fallback': (env_fallback, ['DO_API_TOKEN', 'DO_API_KEY', 'DO_OAUTH_TOKEN']), - 'required': True, + "fallback": (env_fallback, ["DO_API_TOKEN", "DO_API_KEY", "DO_OAUTH_TOKEN"]), + "required": True, }, - 'validate_certs': {'type': 'bool', 'default': True}, - 'timeout': {'type': 'int', 'default': 30}, + "validate_certs": {"type": "bool", "default": True}, + "timeout": {"type": "int", "default": 30}, }, - required_if=[ - ('state', 'delete', ['ip']) - ], - mutually_exclusive=[ - ['region', 'droplet_id'] - ], + required_if=[("state", "delete", ["ip"])], + mutually_exclusive=[["region", "droplet_id"]], ) core(module) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/library/gcp_compute_location_info.py b/library/gcp_compute_location_info.py index 39118ef6..f208747a 100644 --- a/library/gcp_compute_location_info.py +++ b/library/gcp_compute_location_info.py @@ -1,7 +1,6 @@ #!/usr/bin/python - import json from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash @@ -10,7 +9,7 @@ from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash # Documentation ################################################################################ -ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported_by': 'community'} +ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"} ################################################################################ # Main @@ -18,20 +17,24 @@ ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported def main(): - module = GcpModule(argument_spec={'filters': {'type': 'list', 'elements': 'str'}, 'scope': {'required': True, 'type': 'str'}}) + module = GcpModule( + argument_spec={"filters": {"type": "list", "elements": "str"}, "scope": {"required": True, "type": "str"}} + ) - if module._name == 'gcp_compute_image_facts': - module.deprecate("The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version='2.13') + if module._name == "gcp_compute_image_facts": + module.deprecate( + "The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version="2.13" + ) - if not module.params['scopes']: - module.params['scopes'] = ['https://www.googleapis.com/auth/compute'] + if not module.params["scopes"]: + module.params["scopes"] = ["https://www.googleapis.com/auth/compute"] - items = fetch_list(module, collection(module), query_options(module.params['filters'])) - if items.get('items'): - items = items.get('items') + items = fetch_list(module, collection(module), query_options(module.params["filters"])) + if items.get("items"): + items = items.get("items") else: items = [] - return_value = {'resources': items} + return_value = {"resources": items} module.exit_json(**return_value) @@ -40,14 +43,14 @@ def collection(module): def fetch_list(module, link, query): - auth = GcpSession(module, 'compute') - response = auth.get(link, params={'filter': query}) + auth = GcpSession(module, "compute") + response = auth.get(link, params={"filter": query}) return return_if_object(module, response) def query_options(filters): if not filters: - return '' + return "" if len(filters) == 1: return filters[0] @@ -55,12 +58,12 @@ def query_options(filters): queries = [] for f in filters: # For multiple queries, all queries should have () - if f[0] != '(' and f[-1] != ')': - queries.append("({})".format(''.join(f))) + if f[0] != "(" and f[-1] != ")": + queries.append("({})".format("".join(f))) else: queries.append(f) - return ' '.join(queries) + return " ".join(queries) def return_if_object(module, response): @@ -75,11 +78,11 @@ def return_if_object(module, response): try: module.raise_for_status(response) result = response.json() - except getattr(json.decoder, 'JSONDecodeError', ValueError) as inst: + except getattr(json.decoder, "JSONDecodeError", ValueError) as inst: module.fail_json(msg=f"Invalid JSON response with error: {inst}") - if navigate_hash(result, ['error', 'errors']): - module.fail_json(msg=navigate_hash(result, ['error', 'errors'])) + if navigate_hash(result, ["error", "errors"]): + module.fail_json(msg=navigate_hash(result, ["error", "errors"])) return result diff --git a/library/lightsail_region_facts.py b/library/lightsail_region_facts.py index c61a4006..5fe7cd99 100644 --- a/library/lightsail_region_facts.py +++ b/library/lightsail_region_facts.py @@ -3,12 +3,9 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"} -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} - -DOCUMENTATION = ''' +DOCUMENTATION = """ --- module: lightsail_region_facts short_description: Gather facts about AWS Lightsail regions. @@ -24,15 +21,15 @@ requirements: extends_documentation_fragment: - aws - ec2 -''' +""" -EXAMPLES = ''' +EXAMPLES = """ # Gather facts about all regions - lightsail_region_facts: -''' +""" -RETURN = ''' +RETURN = """ regions: returned: on success description: > @@ -46,12 +43,13 @@ regions: "displayName": "Virginia", "name": "us-east-1" }]" -''' +""" import traceback try: import botocore + HAS_BOTOCORE = True except ImportError: HAS_BOTOCORE = False @@ -86,18 +84,19 @@ def main(): client = None try: - client = boto3_conn(module, conn_type='client', resource='lightsail', - region=region, endpoint=ec2_url, **aws_connect_kwargs) + client = boto3_conn( + module, conn_type="client", resource="lightsail", region=region, endpoint=ec2_url, **aws_connect_kwargs + ) except (botocore.exceptions.ClientError, botocore.exceptions.ValidationError) as e: - module.fail_json(msg='Failed while connecting to the lightsail service: %s' % e, exception=traceback.format_exc()) + module.fail_json( + msg="Failed while connecting to the lightsail service: %s" % e, exception=traceback.format_exc() + ) - response = client.get_regions( - includeAvailabilityZones=False - ) + response = client.get_regions(includeAvailabilityZones=False) module.exit_json(changed=False, data=response) except (botocore.exceptions.ClientError, Exception) as e: module.fail_json(msg=str(e), exception=traceback.format_exc()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/library/linode_stackscript_v4.py b/library/linode_stackscript_v4.py index 1d29ac5d..1298fc9c 100644 --- a/library/linode_stackscript_v4.py +++ b/library/linode_stackscript_v4.py @@ -9,6 +9,7 @@ from ansible.module_utils.linode import get_user_agent LINODE_IMP_ERR = None try: from linode_api4 import LinodeClient, StackScript + HAS_LINODE_DEPENDENCY = True except ImportError: LINODE_IMP_ERR = traceback.format_exc() @@ -21,57 +22,47 @@ def create_stackscript(module, client, **kwargs): response = client.linode.stackscript_create(**kwargs) return response._raw_json except Exception as exception: - module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception) + module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception) def stackscript_available(module, client): """Try to retrieve a stackscript.""" try: - label = module.params['label'] - desc = module.params['description'] + label = module.params["label"] + desc = module.params["description"] - result = client.linode.stackscripts(StackScript.label == label, - StackScript.description == desc, - mine_only=True - ) + result = client.linode.stackscripts(StackScript.label == label, StackScript.description == desc, mine_only=True) return result[0] except IndexError: return None except Exception as exception: - module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception) + module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception) def initialise_module(): """Initialise the module parameter specification.""" return AnsibleModule( argument_spec=dict( - label=dict(type='str', required=True), - state=dict( - type='str', - required=True, - choices=['present', 'absent'] - ), + label=dict(type="str", required=True), + state=dict(type="str", required=True, choices=["present", "absent"]), access_token=dict( - type='str', + type="str", required=True, no_log=True, - fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']), + fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]), ), - script=dict(type='str', required=True), - images=dict(type='list', required=True), - description=dict(type='str', required=False), - public=dict(type='bool', required=False, default=False), + script=dict(type="str", required=True), + images=dict(type="list", required=True), + description=dict(type="str", required=False), + public=dict(type="bool", required=False, default=False), ), - supports_check_mode=False + supports_check_mode=False, ) def build_client(module): """Build a LinodeClient.""" - return LinodeClient( - module.params['access_token'], - user_agent=get_user_agent('linode_v4_module') - ) + return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module")) def main(): @@ -79,30 +70,31 @@ def main(): module = initialise_module() if not HAS_LINODE_DEPENDENCY: - module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR) + module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR) client = build_client(module) stackscript = stackscript_available(module, client) - if module.params['state'] == 'present' and stackscript is not None: + if module.params["state"] == "present" and stackscript is not None: module.exit_json(changed=False, stackscript=stackscript._raw_json) - elif module.params['state'] == 'present' and stackscript is None: + elif module.params["state"] == "present" and stackscript is None: stackscript_json = create_stackscript( - module, client, - label=module.params['label'], - script=module.params['script'], - images=module.params['images'], - desc=module.params['description'], - public=module.params['public'], + module, + client, + label=module.params["label"], + script=module.params["script"], + images=module.params["images"], + desc=module.params["description"], + public=module.params["public"], ) module.exit_json(changed=True, stackscript=stackscript_json) - elif module.params['state'] == 'absent' and stackscript is not None: + elif module.params["state"] == "absent" and stackscript is not None: stackscript.delete() module.exit_json(changed=True, stackscript=stackscript._raw_json) - elif module.params['state'] == 'absent' and stackscript is None: + elif module.params["state"] == "absent" and stackscript is None: module.exit_json(changed=False, stackscript={}) diff --git a/library/linode_v4.py b/library/linode_v4.py index b097ff84..cf93602a 100644 --- a/library/linode_v4.py +++ b/library/linode_v4.py @@ -13,6 +13,7 @@ from ansible.module_utils.linode import get_user_agent LINODE_IMP_ERR = None try: from linode_api4 import Instance, LinodeClient + HAS_LINODE_DEPENDENCY = True except ImportError: LINODE_IMP_ERR = traceback.format_exc() @@ -21,82 +22,72 @@ except ImportError: def create_linode(module, client, **kwargs): """Creates a Linode instance and handles return format.""" - if kwargs['root_pass'] is None: - kwargs.pop('root_pass') + if kwargs["root_pass"] is None: + kwargs.pop("root_pass") try: response = client.linode.instance_create(**kwargs) except Exception as exception: - module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception) + module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception) try: if isinstance(response, tuple): instance, root_pass = response instance_json = instance._raw_json - instance_json.update({'root_pass': root_pass}) + instance_json.update({"root_pass": root_pass}) return instance_json else: return response._raw_json except TypeError: - module.fail_json(msg='Unable to parse Linode instance creation' - ' response. Please raise a bug against this' - ' module on https://github.com/ansible/ansible/issues' - ) + module.fail_json( + msg="Unable to parse Linode instance creation" + " response. Please raise a bug against this" + " module on https://github.com/ansible/ansible/issues" + ) def maybe_instance_from_label(module, client): """Try to retrieve an instance based on a label.""" try: - label = module.params['label'] + label = module.params["label"] result = client.linode.instances(Instance.label == label) return result[0] except IndexError: return None except Exception as exception: - module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception) + module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception) def initialise_module(): """Initialise the module parameter specification.""" return AnsibleModule( argument_spec=dict( - label=dict(type='str', required=True), - state=dict( - type='str', - required=True, - choices=['present', 'absent'] - ), + label=dict(type="str", required=True), + state=dict(type="str", required=True, choices=["present", "absent"]), access_token=dict( - type='str', + type="str", required=True, no_log=True, - fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']), + fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]), ), - authorized_keys=dict(type='list', required=False), - group=dict(type='str', required=False), - image=dict(type='str', required=False), - region=dict(type='str', required=False), - root_pass=dict(type='str', required=False, no_log=True), - tags=dict(type='list', required=False), - type=dict(type='str', required=False), - stackscript_id=dict(type='int', required=False), + authorized_keys=dict(type="list", required=False), + group=dict(type="str", required=False), + image=dict(type="str", required=False), + region=dict(type="str", required=False), + root_pass=dict(type="str", required=False, no_log=True), + tags=dict(type="list", required=False), + type=dict(type="str", required=False), + stackscript_id=dict(type="int", required=False), ), supports_check_mode=False, - required_one_of=( - ['state', 'label'], - ), - required_together=( - ['region', 'image', 'type'], - ) + required_one_of=(["state", "label"],), + required_together=(["region", "image", "type"],), ) def build_client(module): """Build a LinodeClient.""" - return LinodeClient( - module.params['access_token'], - user_agent=get_user_agent('linode_v4_module') - ) + return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module")) def main(): @@ -104,34 +95,35 @@ def main(): module = initialise_module() if not HAS_LINODE_DEPENDENCY: - module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR) + module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR) client = build_client(module) instance = maybe_instance_from_label(module, client) - if module.params['state'] == 'present' and instance is not None: + if module.params["state"] == "present" and instance is not None: module.exit_json(changed=False, instance=instance._raw_json) - elif module.params['state'] == 'present' and instance is None: + elif module.params["state"] == "present" and instance is None: instance_json = create_linode( - module, client, - authorized_keys=module.params['authorized_keys'], - group=module.params['group'], - image=module.params['image'], - label=module.params['label'], - region=module.params['region'], - root_pass=module.params['root_pass'], - tags=module.params['tags'], - ltype=module.params['type'], - stackscript_id=module.params['stackscript_id'], + module, + client, + authorized_keys=module.params["authorized_keys"], + group=module.params["group"], + image=module.params["image"], + label=module.params["label"], + region=module.params["region"], + root_pass=module.params["root_pass"], + tags=module.params["tags"], + ltype=module.params["type"], + stackscript_id=module.params["stackscript_id"], ) module.exit_json(changed=True, instance=instance_json) - elif module.params['state'] == 'absent' and instance is not None: + elif module.params["state"] == "absent" and instance is not None: instance.delete() module.exit_json(changed=True, instance=instance._raw_json) - elif module.params['state'] == 'absent' and instance is None: + elif module.params["state"] == "absent" and instance is None: module.exit_json(changed=False, instance={}) diff --git a/library/scaleway_compute.py b/library/scaleway_compute.py index 793a6cef..623a62c2 100644 --- a/library/scaleway_compute.py +++ b/library/scaleway_compute.py @@ -8,14 +8,9 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"} -ANSIBLE_METADATA = { - 'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community' -} - -DOCUMENTATION = ''' +DOCUMENTATION = """ --- module: scaleway_compute short_description: Scaleway compute management module @@ -120,9 +115,9 @@ options: - If no value provided, the default security group or current security group will be used required: false version_added: "2.8" -''' +""" -EXAMPLES = ''' +EXAMPLES = """ - name: Create a server scaleway_compute: name: foobar @@ -156,10 +151,10 @@ EXAMPLES = ''' organization: 951df375-e094-4d26-97c1-ba548eeb9c42 region: ams1 commercial_type: VC1S -''' +""" -RETURN = ''' -''' +RETURN = """ +""" import datetime import time @@ -167,19 +162,9 @@ import time from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.scaleway import SCALEWAY_LOCATION, Scaleway, scaleway_argument_spec -SCALEWAY_SERVER_STATES = ( - 'stopped', - 'stopping', - 'starting', - 'running', - 'locked' -) +SCALEWAY_SERVER_STATES = ("stopped", "stopping", "starting", "running", "locked") -SCALEWAY_TRANSITIONS_STATES = ( - "stopping", - "starting", - "pending" -) +SCALEWAY_TRANSITIONS_STATES = ("stopping", "starting", "pending") def check_image_id(compute_api, image_id): @@ -188,9 +173,11 @@ def check_image_id(compute_api, image_id): if response.ok and response.json: image_ids = [image["id"] for image in response.json["images"]] if image_id not in image_ids: - compute_api.module.fail_json(msg='Error in getting image %s on %s' % (image_id, compute_api.module.params.get('api_url'))) + compute_api.module.fail_json( + msg="Error in getting image %s on %s" % (image_id, compute_api.module.params.get("api_url")) + ) else: - compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get('api_url')) + compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get("api_url")) def fetch_state(compute_api, server): @@ -201,7 +188,7 @@ def fetch_state(compute_api, server): return "absent" if not response.ok: - msg = 'Error during state fetching: (%s) %s' % (response.status_code, response.json) + msg = "Error during state fetching: (%s) %s" % (response.status_code, response.json) compute_api.module.fail_json(msg=msg) try: @@ -243,7 +230,7 @@ def public_ip_payload(compute_api, public_ip): # We check that the IP we want to attach exists, if so its ID is returned response = compute_api.get("ips") if not response.ok: - msg = 'Error during public IP validation: (%s) %s' % (response.status_code, response.json) + msg = "Error during public IP validation: (%s) %s" % (response.status_code, response.json) compute_api.module.fail_json(msg=msg) ip_list = [] @@ -260,14 +247,15 @@ def public_ip_payload(compute_api, public_ip): def create_server(compute_api, server): compute_api.module.debug("Starting a create_server") target_server = None - data = {"enable_ipv6": server["enable_ipv6"], - "tags": server["tags"], - "commercial_type": server["commercial_type"], - "image": server["image"], - "dynamic_ip_required": server["dynamic_ip_required"], - "name": server["name"], - "organization": server["organization"] - } + data = { + "enable_ipv6": server["enable_ipv6"], + "tags": server["tags"], + "commercial_type": server["commercial_type"], + "image": server["image"], + "dynamic_ip_required": server["dynamic_ip_required"], + "name": server["name"], + "organization": server["organization"], + } if server["boot_type"]: data["boot_type"] = server["boot_type"] @@ -278,7 +266,7 @@ def create_server(compute_api, server): response = compute_api.post(path="servers", data=data) if not response.ok: - msg = 'Error during server creation: (%s) %s' % (response.status_code, response.json) + msg = "Error during server creation: (%s) %s" % (response.status_code, response.json) compute_api.module.fail_json(msg=msg) try: @@ -304,10 +292,9 @@ def start_server(compute_api, server): def perform_action(compute_api, server, action): - response = compute_api.post(path="servers/%s/action" % server["id"], - data={"action": action}) + response = compute_api.post(path="servers/%s/action" % server["id"], data={"action": action}) if not response.ok: - msg = 'Error during server %s: (%s) %s' % (action, response.status_code, response.json) + msg = "Error during server %s: (%s) %s" % (action, response.status_code, response.json) compute_api.module.fail_json(msg=msg) wait_to_complete_state_transition(compute_api=compute_api, server=server) @@ -319,7 +306,7 @@ def remove_server(compute_api, server): compute_api.module.debug("Starting remove server strategy") response = compute_api.delete(path="servers/%s" % server["id"]) if not response.ok: - msg = 'Error during server deletion: (%s) %s' % (response.status_code, response.json) + msg = "Error during server deletion: (%s) %s" % (response.status_code, response.json) compute_api.module.fail_json(msg=msg) wait_to_complete_state_transition(compute_api=compute_api, server=server) @@ -341,14 +328,17 @@ def present_strategy(compute_api, wished_server): else: target_server = query_results[0] - if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server, - wished_server=wished_server): + if server_attributes_should_be_changed( + compute_api=compute_api, target_server=target_server, wished_server=wished_server + ): changed = True if compute_api.module.check_mode: return changed, {"status": "Server %s attributes would be changed." % target_server["id"]} - target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server) + target_server = server_change_attributes( + compute_api=compute_api, target_server=target_server, wished_server=wished_server + ) return changed, target_server @@ -375,7 +365,7 @@ def absent_strategy(compute_api, wished_server): response = stop_server(compute_api=compute_api, server=target_server) if not response.ok: - err_msg = f'Error while stopping a server before removing it [{response.status_code}: {response.json}]' + err_msg = f"Error while stopping a server before removing it [{response.status_code}: {response.json}]" compute_api.module.fail_json(msg=err_msg) wait_to_complete_state_transition(compute_api=compute_api, server=target_server) @@ -383,7 +373,7 @@ def absent_strategy(compute_api, wished_server): response = remove_server(compute_api=compute_api, server=target_server) if not response.ok: - err_msg = f'Error while removing server [{response.status_code}: {response.json}]' + err_msg = f"Error while removing server [{response.status_code}: {response.json}]" compute_api.module.fail_json(msg=err_msg) return changed, {"status": "Server %s deleted" % target_server["id"]} @@ -403,14 +393,17 @@ def running_strategy(compute_api, wished_server): else: target_server = query_results[0] - if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server, - wished_server=wished_server): + if server_attributes_should_be_changed( + compute_api=compute_api, target_server=target_server, wished_server=wished_server + ): changed = True if compute_api.module.check_mode: return changed, {"status": "Server %s attributes would be changed before running it." % target_server["id"]} - target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server) + target_server = server_change_attributes( + compute_api=compute_api, target_server=target_server, wished_server=wished_server + ) current_state = fetch_state(compute_api=compute_api, server=target_server) if current_state not in ("running", "starting"): @@ -422,7 +415,7 @@ def running_strategy(compute_api, wished_server): response = start_server(compute_api=compute_api, server=target_server) if not response.ok: - msg = f'Error while running server [{response.status_code}: {response.json}]' + msg = f"Error while running server [{response.status_code}: {response.json}]" compute_api.module.fail_json(msg=msg) return changed, target_server @@ -435,7 +428,6 @@ def stop_strategy(compute_api, wished_server): changed = False if not query_results: - if compute_api.module.check_mode: return changed, {"status": "A server would be created before being stopped."} @@ -446,15 +438,19 @@ def stop_strategy(compute_api, wished_server): compute_api.module.debug("stop_strategy: Servers are found.") - if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server, - wished_server=wished_server): + if server_attributes_should_be_changed( + compute_api=compute_api, target_server=target_server, wished_server=wished_server + ): changed = True if compute_api.module.check_mode: return changed, { - "status": "Server %s attributes would be changed before stopping it." % target_server["id"]} + "status": "Server %s attributes would be changed before stopping it." % target_server["id"] + } - target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server) + target_server = server_change_attributes( + compute_api=compute_api, target_server=target_server, wished_server=wished_server + ) wait_to_complete_state_transition(compute_api=compute_api, server=target_server) @@ -472,7 +468,7 @@ def stop_strategy(compute_api, wished_server): compute_api.module.debug(response.ok) if not response.ok: - msg = f'Error while stopping server [{response.status_code}: {response.json}]' + msg = f"Error while stopping server [{response.status_code}: {response.json}]" compute_api.module.fail_json(msg=msg) return changed, target_server @@ -492,16 +488,19 @@ def restart_strategy(compute_api, wished_server): else: target_server = query_results[0] - if server_attributes_should_be_changed(compute_api=compute_api, - target_server=target_server, - wished_server=wished_server): + if server_attributes_should_be_changed( + compute_api=compute_api, target_server=target_server, wished_server=wished_server + ): changed = True if compute_api.module.check_mode: return changed, { - "status": "Server %s attributes would be changed before rebooting it." % target_server["id"]} + "status": "Server %s attributes would be changed before rebooting it." % target_server["id"] + } - target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server) + target_server = server_change_attributes( + compute_api=compute_api, target_server=target_server, wished_server=wished_server + ) changed = True if compute_api.module.check_mode: @@ -513,14 +512,14 @@ def restart_strategy(compute_api, wished_server): response = restart_server(compute_api=compute_api, server=target_server) wait_to_complete_state_transition(compute_api=compute_api, server=target_server) if not response.ok: - msg = f'Error while restarting server that was running [{response.status_code}: {response.json}].' + msg = f"Error while restarting server that was running [{response.status_code}: {response.json}]." compute_api.module.fail_json(msg=msg) if fetch_state(compute_api=compute_api, server=target_server) in ("stopped",): response = restart_server(compute_api=compute_api, server=target_server) wait_to_complete_state_transition(compute_api=compute_api, server=target_server) if not response.ok: - msg = f'Error while restarting server that was stopped [{response.status_code}: {response.json}].' + msg = f"Error while restarting server that was stopped [{response.status_code}: {response.json}]." compute_api.module.fail_json(msg=msg) return changed, target_server @@ -531,18 +530,17 @@ state_strategy = { "restarted": restart_strategy, "stopped": stop_strategy, "running": running_strategy, - "absent": absent_strategy + "absent": absent_strategy, } def find(compute_api, wished_server, per_page=1): compute_api.module.debug("Getting inside find") # Only the name attribute is accepted in the Compute query API - response = compute_api.get("servers", params={"name": wished_server["name"], - "per_page": per_page}) + response = compute_api.get("servers", params={"name": wished_server["name"], "per_page": per_page}) if not response.ok: - msg = 'Error during server search: (%s) %s' % (response.status_code, response.json) + msg = "Error during server search: (%s) %s" % (response.status_code, response.json) compute_api.module.fail_json(msg=msg) search_results = response.json["servers"] @@ -563,16 +561,22 @@ def server_attributes_should_be_changed(compute_api, target_server, wished_serve compute_api.module.debug("Checking if server attributes should be changed") compute_api.module.debug("Current Server: %s" % target_server) compute_api.module.debug("Wished Server: %s" % wished_server) - debug_dict = dict((x, (target_server[x], wished_server[x])) - for x in PATCH_MUTABLE_SERVER_ATTRIBUTES - if x in target_server and x in wished_server) + debug_dict = dict( + (x, (target_server[x], wished_server[x])) + for x in PATCH_MUTABLE_SERVER_ATTRIBUTES + if x in target_server and x in wished_server + ) compute_api.module.debug("Debug dict %s" % debug_dict) try: for key in PATCH_MUTABLE_SERVER_ATTRIBUTES: if key in target_server and key in wished_server: # When you are working with dict, only ID matter as we ask user to put only the resource ID in the playbook - if isinstance(target_server[key], dict) and wished_server[key] and "id" in target_server[key].keys( - ) and target_server[key]["id"] != wished_server[key]: + if ( + isinstance(target_server[key], dict) + and wished_server[key] + and "id" in target_server[key].keys() + and target_server[key]["id"] != wished_server[key] + ): return True # Handling other structure compare simply the two objects content elif not isinstance(target_server[key], dict) and target_server[key] != wished_server[key]: @@ -598,10 +602,9 @@ def server_change_attributes(compute_api, target_server, wished_server): elif not isinstance(target_server[key], dict): patch_payload[key] = wished_server[key] - response = compute_api.patch(path="servers/%s" % target_server["id"], - data=patch_payload) + response = compute_api.patch(path="servers/%s" % target_server["id"], data=patch_payload) if not response.ok: - msg = 'Error during server attributes patching: (%s) %s' % (response.status_code, response.json) + msg = "Error during server attributes patching: (%s) %s" % (response.status_code, response.json) compute_api.module.fail_json(msg=msg) try: @@ -625,9 +628,9 @@ def core(module): "boot_type": module.params["boot_type"], "tags": module.params["tags"], "organization": module.params["organization"], - "security_group": module.params["security_group"] + "security_group": module.params["security_group"], } - module.params['api_url'] = SCALEWAY_LOCATION[region]["api_endpoint"] + module.params["api_url"] = SCALEWAY_LOCATION[region]["api_endpoint"] compute_api = Scaleway(module=module) @@ -643,22 +646,24 @@ def core(module): def main(): argument_spec = scaleway_argument_spec() - argument_spec.update(dict( - image=dict(required=True), - name=dict(), - region=dict(required=True, choices=SCALEWAY_LOCATION.keys()), - commercial_type=dict(required=True), - enable_ipv6=dict(default=False, type="bool"), - boot_type=dict(choices=['bootscript', 'local']), - public_ip=dict(default="absent"), - state=dict(choices=state_strategy.keys(), default='present'), - tags=dict(type="list", default=[]), - organization=dict(required=True), - wait=dict(type="bool", default=False), - wait_timeout=dict(type="int", default=300), - wait_sleep_time=dict(type="int", default=3), - security_group=dict(), - )) + argument_spec.update( + dict( + image=dict(required=True), + name=dict(), + region=dict(required=True, choices=SCALEWAY_LOCATION.keys()), + commercial_type=dict(required=True), + enable_ipv6=dict(default=False, type="bool"), + boot_type=dict(choices=["bootscript", "local"]), + public_ip=dict(default="absent"), + state=dict(choices=state_strategy.keys(), default="present"), + tags=dict(type="list", default=[]), + organization=dict(required=True), + wait=dict(type="bool", default=False), + wait_timeout=dict(type="int", default=300), + wait_sleep_time=dict(type="int", default=3), + security_group=dict(), + ) + ) module = AnsibleModule( argument_spec=argument_spec, supports_check_mode=True, @@ -667,5 +672,5 @@ def main(): core(module) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/library/x25519_pubkey.py b/library/x25519_pubkey.py index d56c2fa6..9d7e9c0a 100755 --- a/library/x25519_pubkey.py +++ b/library/x25519_pubkey.py @@ -32,32 +32,30 @@ Returns: def run_module(): """ Main execution function for the x25519_pubkey Ansible module. - + Handles parameter validation, private key processing, public key derivation, and optional file output with idempotent behavior. """ module_args = { - 'private_key_b64': {'type': 'str', 'required': False}, - 'private_key_path': {'type': 'path', 'required': False}, - 'public_key_path': {'type': 'path', 'required': False}, + "private_key_b64": {"type": "str", "required": False}, + "private_key_path": {"type": "path", "required": False}, + "public_key_path": {"type": "path", "required": False}, } result = { - 'changed': False, - 'public_key': '', + "changed": False, + "public_key": "", } module = AnsibleModule( - argument_spec=module_args, - required_one_of=[['private_key_b64', 'private_key_path']], - supports_check_mode=True + argument_spec=module_args, required_one_of=[["private_key_b64", "private_key_path"]], supports_check_mode=True ) priv_b64 = None - if module.params['private_key_path']: + if module.params["private_key_path"]: try: - with open(module.params['private_key_path'], 'rb') as f: + with open(module.params["private_key_path"], "rb") as f: data = f.read() try: # First attempt: assume file contains base64 text data @@ -71,12 +69,14 @@ def run_module(): # whitespace-like bytes (0x09, 0x0A, etc.) that must be preserved # Stripping would corrupt the key and cause "got 31 bytes" errors if len(data) != 32: - module.fail_json(msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes") + module.fail_json( + msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes" + ) priv_b64 = base64.b64encode(data).decode() except OSError as e: module.fail_json(msg=f"Failed to read private key file: {e}") else: - priv_b64 = module.params['private_key_b64'] + priv_b64 = module.params["private_key_b64"] # Validate input parameters if not priv_b64: @@ -93,15 +93,12 @@ def run_module(): try: priv_key = x25519.X25519PrivateKey.from_private_bytes(priv_raw) pub_key = priv_key.public_key() - pub_raw = pub_key.public_bytes( - encoding=serialization.Encoding.Raw, - format=serialization.PublicFormat.Raw - ) + pub_raw = pub_key.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw) pub_b64 = base64.b64encode(pub_raw).decode() - result['public_key'] = pub_b64 + result["public_key"] = pub_b64 - if module.params['public_key_path']: - pub_path = module.params['public_key_path'] + if module.params["public_key_path"]: + pub_path = module.params["public_key_path"] existing = None try: @@ -112,13 +109,13 @@ def run_module(): if existing != pub_b64: try: - with open(pub_path, 'w') as f: + with open(pub_path, "w") as f: f.write(pub_b64) - result['changed'] = True + result["changed"] = True except OSError as e: module.fail_json(msg=f"Failed to write public key file: {e}") - result['public_key_path'] = pub_path + result["public_key_path"] = pub_path except Exception as e: module.fail_json(msg=f"Failed to derive public key: {e}") @@ -131,5 +128,5 @@ def main(): run_module() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/track-test-effectiveness.py b/scripts/track-test-effectiveness.py index f9cab5b8..a58afeca 100755 --- a/scripts/track-test-effectiveness.py +++ b/scripts/track-test-effectiveness.py @@ -3,6 +3,7 @@ Track test effectiveness by analyzing CI failures and correlating with issues/PRs This helps identify which tests actually catch bugs vs just failing randomly """ + import json import subprocess from collections import defaultdict @@ -12,7 +13,7 @@ from pathlib import Path def get_github_api_data(endpoint): """Fetch data from GitHub API""" - cmd = ['gh', 'api', endpoint] + cmd = ["gh", "api", endpoint] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: print(f"Error fetching {endpoint}: {result.stderr}") @@ -25,40 +26,38 @@ def analyze_workflow_runs(repo_owner, repo_name, days_back=30): since = (datetime.now() - timedelta(days=days_back)).isoformat() # Get workflow runs - runs = get_github_api_data( - f'/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure' - ) + runs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure") if not runs: return {} test_failures = defaultdict(list) - for run in runs.get('workflow_runs', []): + for run in runs.get("workflow_runs", []): # Get jobs for this run - jobs = get_github_api_data( - f'/repos/{repo_owner}/{repo_name}/actions/runs/{run["id"]}/jobs' - ) + jobs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs/{run['id']}/jobs") if not jobs: continue - for job in jobs.get('jobs', []): - if job['conclusion'] == 'failure': + for job in jobs.get("jobs", []): + if job["conclusion"] == "failure": # Try to extract which test failed from logs - logs_url = job.get('logs_url') + logs_url = job.get("logs_url") if logs_url: # Parse logs to find test failures - test_name = extract_failed_test(job['name'], run['id']) + test_name = extract_failed_test(job["name"], run["id"]) if test_name: - test_failures[test_name].append({ - 'run_id': run['id'], - 'run_number': run['run_number'], - 'date': run['created_at'], - 'branch': run['head_branch'], - 'commit': run['head_sha'][:7], - 'pr': extract_pr_number(run) - }) + test_failures[test_name].append( + { + "run_id": run["id"], + "run_number": run["run_number"], + "date": run["created_at"], + "branch": run["head_branch"], + "commit": run["head_sha"][:7], + "pr": extract_pr_number(run), + } + ) return test_failures @@ -67,47 +66,44 @@ def extract_failed_test(job_name, run_id): """Extract test name from job - this is simplified""" # Map job names to test categories job_to_tests = { - 'Basic sanity tests': 'test_basic_sanity', - 'Ansible syntax check': 'ansible_syntax', - 'Docker build test': 'docker_tests', - 'Configuration generation test': 'config_generation', - 'Ansible dry-run validation': 'ansible_dry_run' + "Basic sanity tests": "test_basic_sanity", + "Ansible syntax check": "ansible_syntax", + "Docker build test": "docker_tests", + "Configuration generation test": "config_generation", + "Ansible dry-run validation": "ansible_dry_run", } return job_to_tests.get(job_name, job_name) def extract_pr_number(run): """Extract PR number from workflow run""" - for pr in run.get('pull_requests', []): - return pr['number'] + for pr in run.get("pull_requests", []): + return pr["number"] return None def correlate_with_issues(repo_owner, repo_name, test_failures): """Correlate test failures with issues/PRs that fixed them""" - correlations = defaultdict(lambda: {'caught_bugs': 0, 'false_positives': 0}) + correlations = defaultdict(lambda: {"caught_bugs": 0, "false_positives": 0}) for test_name, failures in test_failures.items(): for failure in failures: - if failure['pr']: + if failure["pr"]: # Check if PR was merged (indicating it fixed a real issue) - pr = get_github_api_data( - f'/repos/{repo_owner}/{repo_name}/pulls/{failure["pr"]}' - ) + pr = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/pulls/{failure['pr']}") - if pr and pr.get('merged'): + if pr and pr.get("merged"): # Check PR title/body for bug indicators - title = pr.get('title', '').lower() - body = pr.get('body', '').lower() + title = pr.get("title", "").lower() + body = pr.get("body", "").lower() - bug_keywords = ['fix', 'bug', 'error', 'issue', 'broken', 'fail'] - is_bug_fix = any(keyword in title or keyword in body - for keyword in bug_keywords) + bug_keywords = ["fix", "bug", "error", "issue", "broken", "fail"] + is_bug_fix = any(keyword in title or keyword in body for keyword in bug_keywords) if is_bug_fix: - correlations[test_name]['caught_bugs'] += 1 + correlations[test_name]["caught_bugs"] += 1 else: - correlations[test_name]['false_positives'] += 1 + correlations[test_name]["false_positives"] += 1 return correlations @@ -133,8 +129,8 @@ def generate_effectiveness_report(test_failures, correlations): scores = [] for test_name, failures in test_failures.items(): failure_count = len(failures) - caught = correlations[test_name]['caught_bugs'] - false_pos = correlations[test_name]['false_positives'] + caught = correlations[test_name]["caught_bugs"] + false_pos = correlations[test_name]["false_positives"] # Calculate effectiveness (bugs caught / total failures) if failure_count > 0: @@ -159,12 +155,12 @@ def generate_effectiveness_report(test_failures, correlations): elif effectiveness > 0.8: report.append(f"- ✅ `{test_name}` is highly effective ({effectiveness:.0%})") - return '\n'.join(report) + return "\n".join(report) def save_metrics(test_failures, correlations): """Save metrics to JSON for historical tracking""" - metrics_file = Path('.metrics/test-effectiveness.json') + metrics_file = Path(".metrics/test-effectiveness.json") metrics_file.parent.mkdir(exist_ok=True) # Load existing metrics @@ -176,38 +172,34 @@ def save_metrics(test_failures, correlations): # Add current metrics current = { - 'date': datetime.now().isoformat(), - 'test_failures': { - test: len(failures) for test, failures in test_failures.items() - }, - 'effectiveness': { + "date": datetime.now().isoformat(), + "test_failures": {test: len(failures) for test, failures in test_failures.items()}, + "effectiveness": { test: { - 'caught_bugs': data['caught_bugs'], - 'false_positives': data['false_positives'], - 'score': data['caught_bugs'] / (data['caught_bugs'] + data['false_positives']) - if (data['caught_bugs'] + data['false_positives']) > 0 else 0 + "caught_bugs": data["caught_bugs"], + "false_positives": data["false_positives"], + "score": data["caught_bugs"] / (data["caught_bugs"] + data["false_positives"]) + if (data["caught_bugs"] + data["false_positives"]) > 0 + else 0, } for test, data in correlations.items() - } + }, } historical.append(current) # Keep last 12 months of data cutoff = datetime.now() - timedelta(days=365) - historical = [ - h for h in historical - if datetime.fromisoformat(h['date']) > cutoff - ] + historical = [h for h in historical if datetime.fromisoformat(h["date"]) > cutoff] - with open(metrics_file, 'w') as f: + with open(metrics_file, "w") as f: json.dump(historical, f, indent=2) -if __name__ == '__main__': +if __name__ == "__main__": # Configure these for your repo - REPO_OWNER = 'trailofbits' - REPO_NAME = 'algo' + REPO_OWNER = "trailofbits" + REPO_NAME = "algo" print("Analyzing test effectiveness...") @@ -223,9 +215,9 @@ if __name__ == '__main__': print("\n" + report) # Save report - report_file = Path('.metrics/test-effectiveness-report.md') + report_file = Path(".metrics/test-effectiveness-report.md") report_file.parent.mkdir(exist_ok=True) - with open(report_file, 'w') as f: + with open(report_file, "w") as f: f.write(report) print(f"\nReport saved to: {report_file}") diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index a6cb5a0e..9a543f24 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,4 +1,5 @@ """Test fixtures for Algo unit tests""" + from pathlib import Path import yaml @@ -6,7 +7,7 @@ import yaml def load_test_variables(): """Load test variables from YAML fixture""" - fixture_path = Path(__file__).parent / 'test_variables.yml' + fixture_path = Path(__file__).parent / "test_variables.yml" with open(fixture_path) as f: return yaml.safe_load(f) diff --git a/tests/integration/ansible-service-wrapper.py b/tests/integration/ansible-service-wrapper.py index e7400946..ed58405c 100644 --- a/tests/integration/ansible-service-wrapper.py +++ b/tests/integration/ansible-service-wrapper.py @@ -2,49 +2,56 @@ """ Wrapper for Ansible's service module that always succeeds for known services """ + import json import sys # Parse module arguments args = json.loads(sys.stdin.read()) -module_args = args.get('ANSIBLE_MODULE_ARGS', {}) +module_args = args.get("ANSIBLE_MODULE_ARGS", {}) -service_name = module_args.get('name', '') -state = module_args.get('state', 'started') +service_name = module_args.get("name", "") +state = module_args.get("state", "started") # Known services that should always succeed known_services = [ - 'netfilter-persistent', 'iptables', 'wg-quick@wg0', 'strongswan-starter', - 'ipsec', 'apparmor', 'unattended-upgrades', 'systemd-networkd', - 'systemd-resolved', 'rsyslog', 'ipfw', 'cron' + "netfilter-persistent", + "iptables", + "wg-quick@wg0", + "strongswan-starter", + "ipsec", + "apparmor", + "unattended-upgrades", + "systemd-networkd", + "systemd-resolved", + "rsyslog", + "ipfw", + "cron", ] # Check if it's a known service service_found = False for svc in known_services: - if service_name == svc or service_name.startswith(svc + '.'): + if service_name == svc or service_name.startswith(svc + "."): service_found = True break if service_found: # Return success result = { - 'changed': True if state in ['started', 'stopped', 'restarted', 'reloaded'] else False, - 'name': service_name, - 'state': state, - 'status': { - 'LoadState': 'loaded', - 'ActiveState': 'active' if state != 'stopped' else 'inactive', - 'SubState': 'running' if state != 'stopped' else 'dead' - } + "changed": True if state in ["started", "stopped", "restarted", "reloaded"] else False, + "name": service_name, + "state": state, + "status": { + "LoadState": "loaded", + "ActiveState": "active" if state != "stopped" else "inactive", + "SubState": "running" if state != "stopped" else "dead", + }, } print(json.dumps(result)) sys.exit(0) else: # Service not found - error = { - 'failed': True, - 'msg': f'Could not find the requested service {service_name}: ' - } + error = {"failed": True, "msg": f"Could not find the requested service {service_name}: "} print(json.dumps(error)) sys.exit(1) diff --git a/tests/integration/mock_modules/apt.py b/tests/integration/mock_modules/apt.py index 5bcc8e23..97503633 100644 --- a/tests/integration/mock_modules/apt.py +++ b/tests/integration/mock_modules/apt.py @@ -9,44 +9,44 @@ from ansible.module_utils.basic import AnsibleModule def main(): module = AnsibleModule( argument_spec={ - 'name': {'type': 'list', 'aliases': ['pkg', 'package']}, - 'state': {'type': 'str', 'default': 'present', 'choices': ['present', 'absent', 'latest', 'build-dep', 'fixed']}, - 'update_cache': {'type': 'bool', 'default': False}, - 'cache_valid_time': {'type': 'int', 'default': 0}, - 'install_recommends': {'type': 'bool'}, - 'force': {'type': 'bool', 'default': False}, - 'allow_unauthenticated': {'type': 'bool', 'default': False}, - 'allow_downgrade': {'type': 'bool', 'default': False}, - 'allow_change_held_packages': {'type': 'bool', 'default': False}, - 'dpkg_options': {'type': 'str', 'default': 'force-confdef,force-confold'}, - 'autoremove': {'type': 'bool', 'default': False}, - 'purge': {'type': 'bool', 'default': False}, - 'force_apt_get': {'type': 'bool', 'default': False}, + "name": {"type": "list", "aliases": ["pkg", "package"]}, + "state": { + "type": "str", + "default": "present", + "choices": ["present", "absent", "latest", "build-dep", "fixed"], + }, + "update_cache": {"type": "bool", "default": False}, + "cache_valid_time": {"type": "int", "default": 0}, + "install_recommends": {"type": "bool"}, + "force": {"type": "bool", "default": False}, + "allow_unauthenticated": {"type": "bool", "default": False}, + "allow_downgrade": {"type": "bool", "default": False}, + "allow_change_held_packages": {"type": "bool", "default": False}, + "dpkg_options": {"type": "str", "default": "force-confdef,force-confold"}, + "autoremove": {"type": "bool", "default": False}, + "purge": {"type": "bool", "default": False}, + "force_apt_get": {"type": "bool", "default": False}, }, - supports_check_mode=True + supports_check_mode=True, ) - name = module.params['name'] - state = module.params['state'] - update_cache = module.params['update_cache'] + name = module.params["name"] + state = module.params["state"] + update_cache = module.params["update_cache"] - result = { - 'changed': False, - 'cache_updated': False, - 'cache_update_time': 0 - } + result = {"changed": False, "cache_updated": False, "cache_update_time": 0} # Log the operation - with open('/var/log/mock-apt-module.log', 'a') as f: + with open("/var/log/mock-apt-module.log", "a") as f: f.write(f"apt module called: name={name}, state={state}, update_cache={update_cache}\n") # Handle cache update if update_cache: # In Docker, apt-get update was already run in entrypoint # Just pretend it succeeded - result['cache_updated'] = True - result['cache_update_time'] = 1754231778 # Fixed timestamp - result['changed'] = True + result["cache_updated"] = True + result["cache_update_time"] = 1754231778 # Fixed timestamp + result["changed"] = True # Handle package installation/removal if name: @@ -56,40 +56,41 @@ def main(): installed_packages = [] for pkg in packages: # Use dpkg to check if package is installed - check_cmd = ['dpkg', '-s', pkg] + check_cmd = ["dpkg", "-s", pkg] rc = subprocess.run(check_cmd, capture_output=True) if rc.returncode == 0: installed_packages.append(pkg) - if state in ['present', 'latest']: + if state in ["present", "latest"]: # Check if we need to install anything missing_packages = [p for p in packages if p not in installed_packages] if missing_packages: # Log what we would install - with open('/var/log/mock-apt-module.log', 'a') as f: + with open("/var/log/mock-apt-module.log", "a") as f: f.write(f"Would install packages: {missing_packages}\n") # For our test purposes, these packages are pre-installed in Docker # Just report success - result['changed'] = True - result['stdout'] = f"Mock: Packages {missing_packages} are already available" - result['stderr'] = "" + result["changed"] = True + result["stdout"] = f"Mock: Packages {missing_packages} are already available" + result["stderr"] = "" else: - result['stdout'] = "All packages are already installed" + result["stdout"] = "All packages are already installed" - elif state == 'absent': + elif state == "absent": # Check if we need to remove anything present_packages = [p for p in packages if p in installed_packages] if present_packages: - result['changed'] = True - result['stdout'] = f"Mock: Would remove packages {present_packages}" + result["changed"] = True + result["stdout"] = f"Mock: Would remove packages {present_packages}" else: - result['stdout'] = "No packages to remove" + result["stdout"] = "No packages to remove" # Always report success for our testing module.exit_json(**result) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/tests/integration/mock_modules/command.py b/tests/integration/mock_modules/command.py index 8928d3b5..5505acd9 100644 --- a/tests/integration/mock_modules/command.py +++ b/tests/integration/mock_modules/command.py @@ -9,79 +9,72 @@ from ansible.module_utils.basic import AnsibleModule def main(): module = AnsibleModule( argument_spec={ - '_raw_params': {'type': 'str'}, - 'cmd': {'type': 'str'}, - 'creates': {'type': 'path'}, - 'removes': {'type': 'path'}, - 'chdir': {'type': 'path'}, - 'executable': {'type': 'path'}, - 'warn': {'type': 'bool', 'default': False}, - 'stdin': {'type': 'str'}, - 'stdin_add_newline': {'type': 'bool', 'default': True}, - 'strip_empty_ends': {'type': 'bool', 'default': True}, - '_uses_shell': {'type': 'bool', 'default': False}, + "_raw_params": {"type": "str"}, + "cmd": {"type": "str"}, + "creates": {"type": "path"}, + "removes": {"type": "path"}, + "chdir": {"type": "path"}, + "executable": {"type": "path"}, + "warn": {"type": "bool", "default": False}, + "stdin": {"type": "str"}, + "stdin_add_newline": {"type": "bool", "default": True}, + "strip_empty_ends": {"type": "bool", "default": True}, + "_uses_shell": {"type": "bool", "default": False}, }, - supports_check_mode=True + supports_check_mode=True, ) # Get the command - raw_params = module.params.get('_raw_params') - cmd = module.params.get('cmd') or raw_params + raw_params = module.params.get("_raw_params") + cmd = module.params.get("cmd") or raw_params if not cmd: module.fail_json(msg="no command given") - result = { - 'changed': False, - 'cmd': cmd, - 'rc': 0, - 'stdout': '', - 'stderr': '', - 'stdout_lines': [], - 'stderr_lines': [] - } + result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []} # Log the operation - with open('/var/log/mock-command-module.log', 'a') as f: + with open("/var/log/mock-command-module.log", "a") as f: f.write(f"command module called: cmd={cmd}\n") # Handle specific commands - if 'apparmor_status' in cmd: + if "apparmor_status" in cmd: # Pretend apparmor is not installed/active - result['rc'] = 127 - result['stderr'] = "apparmor_status: command not found" - result['msg'] = "[Errno 2] No such file or directory: b'apparmor_status'" - module.fail_json(msg=result['msg'], **result) - elif 'netplan apply' in cmd: + result["rc"] = 127 + result["stderr"] = "apparmor_status: command not found" + result["msg"] = "[Errno 2] No such file or directory: b'apparmor_status'" + module.fail_json(msg=result["msg"], **result) + elif "netplan apply" in cmd: # Pretend netplan succeeded - result['stdout'] = "Mock: netplan configuration applied" - result['changed'] = True - elif 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd: + result["stdout"] = "Mock: netplan configuration applied" + result["changed"] = True + elif "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd: # Routing cache flush - result['stdout'] = "1" - result['changed'] = True + result["stdout"] = "1" + result["changed"] = True else: # For other commands, try to run them try: - proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get('chdir')) - result['rc'] = proc.returncode - result['stdout'] = proc.stdout - result['stderr'] = proc.stderr - result['stdout_lines'] = proc.stdout.splitlines() - result['stderr_lines'] = proc.stderr.splitlines() - result['changed'] = True + proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get("chdir")) + result["rc"] = proc.returncode + result["stdout"] = proc.stdout + result["stderr"] = proc.stderr + result["stdout_lines"] = proc.stdout.splitlines() + result["stderr_lines"] = proc.stderr.splitlines() + result["changed"] = True except Exception as e: - result['rc'] = 1 - result['stderr'] = str(e) - result['msg'] = str(e) - module.fail_json(msg=result['msg'], **result) + result["rc"] = 1 + result["stderr"] = str(e) + result["msg"] = str(e) + module.fail_json(msg=result["msg"], **result) - if result['rc'] == 0: + if result["rc"] == 0: module.exit_json(**result) else: - if 'msg' not in result: - result['msg'] = f"Command failed with return code {result['rc']}" - module.fail_json(msg=result['msg'], **result) + if "msg" not in result: + result["msg"] = f"Command failed with return code {result['rc']}" + module.fail_json(msg=result["msg"], **result) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/tests/integration/mock_modules/shell.py b/tests/integration/mock_modules/shell.py index 71ea5b14..34949c21 100644 --- a/tests/integration/mock_modules/shell.py +++ b/tests/integration/mock_modules/shell.py @@ -9,73 +9,71 @@ from ansible.module_utils.basic import AnsibleModule def main(): module = AnsibleModule( argument_spec={ - '_raw_params': {'type': 'str'}, - 'cmd': {'type': 'str'}, - 'creates': {'type': 'path'}, - 'removes': {'type': 'path'}, - 'chdir': {'type': 'path'}, - 'executable': {'type': 'path', 'default': '/bin/sh'}, - 'warn': {'type': 'bool', 'default': False}, - 'stdin': {'type': 'str'}, - 'stdin_add_newline': {'type': 'bool', 'default': True}, + "_raw_params": {"type": "str"}, + "cmd": {"type": "str"}, + "creates": {"type": "path"}, + "removes": {"type": "path"}, + "chdir": {"type": "path"}, + "executable": {"type": "path", "default": "/bin/sh"}, + "warn": {"type": "bool", "default": False}, + "stdin": {"type": "str"}, + "stdin_add_newline": {"type": "bool", "default": True}, }, - supports_check_mode=True + supports_check_mode=True, ) # Get the command - raw_params = module.params.get('_raw_params') - cmd = module.params.get('cmd') or raw_params + raw_params = module.params.get("_raw_params") + cmd = module.params.get("cmd") or raw_params if not cmd: module.fail_json(msg="no command given") - result = { - 'changed': False, - 'cmd': cmd, - 'rc': 0, - 'stdout': '', - 'stderr': '', - 'stdout_lines': [], - 'stderr_lines': [] - } + result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []} # Log the operation - with open('/var/log/mock-shell-module.log', 'a') as f: + with open("/var/log/mock-shell-module.log", "a") as f: f.write(f"shell module called: cmd={cmd}\n") # Handle specific commands - if 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd: + if "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd: # Routing cache flush - just pretend it worked - result['stdout'] = "" - result['changed'] = True - elif 'ifconfig lo100' in cmd: + result["stdout"] = "" + result["changed"] = True + elif "ifconfig lo100" in cmd: # BSD loopback commands - simulate success - result['stdout'] = "0" - result['changed'] = True + result["stdout"] = "0" + result["changed"] = True else: # For other commands, try to run them try: - proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, - executable=module.params.get('executable'), - cwd=module.params.get('chdir')) - result['rc'] = proc.returncode - result['stdout'] = proc.stdout - result['stderr'] = proc.stderr - result['stdout_lines'] = proc.stdout.splitlines() - result['stderr_lines'] = proc.stderr.splitlines() - result['changed'] = True + proc = subprocess.run( + cmd, + shell=True, + capture_output=True, + text=True, + executable=module.params.get("executable"), + cwd=module.params.get("chdir"), + ) + result["rc"] = proc.returncode + result["stdout"] = proc.stdout + result["stderr"] = proc.stderr + result["stdout_lines"] = proc.stdout.splitlines() + result["stderr_lines"] = proc.stderr.splitlines() + result["changed"] = True except Exception as e: - result['rc'] = 1 - result['stderr'] = str(e) - result['msg'] = str(e) - module.fail_json(msg=result['msg'], **result) + result["rc"] = 1 + result["stderr"] = str(e) + result["msg"] = str(e) + module.fail_json(msg=result["msg"], **result) - if result['rc'] == 0: + if result["rc"] == 0: module.exit_json(**result) else: - if 'msg' not in result: - result['msg'] = f"Command failed with return code {result['rc']}" - module.fail_json(msg=result['msg'], **result) + if "msg" not in result: + result["msg"] = f"Command failed with return code {result['rc']}" + module.fail_json(msg=result["msg"], **result) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/tests/test_cloud_init_template.py b/tests/test_cloud_init_template.py index 047f88ef..806fece6 100644 --- a/tests/test_cloud_init_template.py +++ b/tests/test_cloud_init_template.py @@ -24,6 +24,7 @@ import yaml PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) + def create_expected_cloud_init(): """ Create the expected cloud-init content that should be generated @@ -74,6 +75,7 @@ runcmd: - systemctl restart sshd.service """ + class TestCloudInitTemplate: """Test class for cloud-init template validation.""" @@ -98,10 +100,7 @@ class TestCloudInitTemplate: parsed = self.test_yaml_validity() - required_sections = [ - 'package_update', 'package_upgrade', 'packages', - 'users', 'write_files', 'runcmd' - ] + required_sections = ["package_update", "package_upgrade", "packages", "users", "write_files", "runcmd"] missing = [section for section in required_sections if section not in parsed] assert not missing, f"Missing required sections: {missing}" @@ -114,35 +113,30 @@ class TestCloudInitTemplate: parsed = self.test_yaml_validity() - write_files = parsed.get('write_files', []) + write_files = parsed.get("write_files", []) assert write_files, "write_files section should be present" # Find sshd_config file sshd_config = None for file_entry in write_files: - if file_entry.get('path') == '/etc/ssh/sshd_config': + if file_entry.get("path") == "/etc/ssh/sshd_config": sshd_config = file_entry break assert sshd_config, "sshd_config file should be in write_files" - content = sshd_config.get('content', '') + content = sshd_config.get("content", "") assert content, "sshd_config should have content" # Check required SSH configurations - required_configs = [ - 'Port 4160', - 'AllowGroups algo', - 'PermitRootLogin no', - 'PasswordAuthentication no' - ] + required_configs = ["Port 4160", "AllowGroups algo", "PermitRootLogin no", "PasswordAuthentication no"] missing = [config for config in required_configs if config not in content] assert not missing, f"Missing SSH configurations: {missing}" # Verify proper formatting - first line should be Port directive - lines = content.strip().split('\n') - assert lines[0].strip() == 'Port 4160', f"First line should be 'Port 4160', got: {repr(lines[0])}" + lines = content.strip().split("\n") + assert lines[0].strip() == "Port 4160", f"First line should be 'Port 4160', got: {repr(lines[0])}" print("✅ SSH configuration correct") @@ -152,26 +146,26 @@ class TestCloudInitTemplate: parsed = self.test_yaml_validity() - users = parsed.get('users', []) + users = parsed.get("users", []) assert users, "users section should be present" # Find algo user algo_user = None for user in users: - if isinstance(user, dict) and user.get('name') == 'algo': + if isinstance(user, dict) and user.get("name") == "algo": algo_user = user break assert algo_user, "algo user should be defined" # Check required user properties - required_props = ['sudo', 'groups', 'shell', 'ssh_authorized_keys'] + required_props = ["sudo", "groups", "shell", "ssh_authorized_keys"] missing = [prop for prop in required_props if prop not in algo_user] assert not missing, f"algo user missing properties: {missing}" # Verify sudo configuration - sudo_config = algo_user.get('sudo', '') - assert 'NOPASSWD:ALL' in sudo_config, f"sudo config should allow passwordless access: {sudo_config}" + sudo_config = algo_user.get("sudo", "") + assert "NOPASSWD:ALL" in sudo_config, f"sudo config should allow passwordless access: {sudo_config}" print("✅ User creation correct") @@ -181,13 +175,13 @@ class TestCloudInitTemplate: parsed = self.test_yaml_validity() - runcmd = parsed.get('runcmd', []) + runcmd = parsed.get("runcmd", []) assert runcmd, "runcmd section should be present" # Check for SSH restart command ssh_restart_found = False for cmd in runcmd: - if 'systemctl restart sshd' in str(cmd): + if "systemctl restart sshd" in str(cmd): ssh_restart_found = True break @@ -202,18 +196,18 @@ class TestCloudInitTemplate: cloud_init_content = create_expected_cloud_init() # Extract the sshd_config content lines - lines = cloud_init_content.split('\n') + lines = cloud_init_content.split("\n") in_sshd_content = False sshd_lines = [] for line in lines: - if 'content: |' in line: + if "content: |" in line: in_sshd_content = True continue elif in_sshd_content: - if line.strip() == '' and len(sshd_lines) > 0: + if line.strip() == "" and len(sshd_lines) > 0: break - if line.startswith('runcmd:'): + if line.startswith("runcmd:"): break sshd_lines.append(line) @@ -225,11 +219,13 @@ class TestCloudInitTemplate: for line in non_empty_lines: # Each line should start with exactly 6 spaces - assert line.startswith(' ') and not line.startswith(' '), \ + assert line.startswith(" ") and not line.startswith(" "), ( f"Line should have exactly 6 spaces indentation: {repr(line)}" + ) print("✅ Indentation is consistent") + def run_tests(): """Run all tests manually (for non-pytest usage).""" print("🚀 Cloud-init template validation tests") @@ -258,6 +254,7 @@ def run_tests(): print(f"❌ Unexpected error: {e}") return False + if __name__ == "__main__": success = run_tests() sys.exit(0 if success else 1) diff --git a/tests/test_package_preinstall.py b/tests/test_package_preinstall.py index 8474f3b2..e6912079 100644 --- a/tests/test_package_preinstall.py +++ b/tests/test_package_preinstall.py @@ -30,49 +30,49 @@ packages: rendered = self.packages_template.render({}) # Should only have sudo package - self.assertIn('- sudo', rendered) - self.assertNotIn('- git', rendered) - self.assertNotIn('- screen', rendered) - self.assertNotIn('- apparmor-utils', rendered) + self.assertIn("- sudo", rendered) + self.assertNotIn("- git", rendered) + self.assertNotIn("- screen", rendered) + self.assertNotIn("- apparmor-utils", rendered) def test_preinstall_enabled(self): """Test that package pre-installation works when enabled.""" # Test with pre-installation enabled - rendered = self.packages_template.render({'performance_preinstall_packages': True}) + rendered = self.packages_template.render({"performance_preinstall_packages": True}) # Should have sudo and all universal packages - self.assertIn('- sudo', rendered) - self.assertIn('- git', rendered) - self.assertIn('- screen', rendered) - self.assertIn('- apparmor-utils', rendered) - self.assertIn('- uuid-runtime', rendered) - self.assertIn('- coreutils', rendered) - self.assertIn('- iptables-persistent', rendered) - self.assertIn('- cgroup-tools', rendered) + self.assertIn("- sudo", rendered) + self.assertIn("- git", rendered) + self.assertIn("- screen", rendered) + self.assertIn("- apparmor-utils", rendered) + self.assertIn("- uuid-runtime", rendered) + self.assertIn("- coreutils", rendered) + self.assertIn("- iptables-persistent", rendered) + self.assertIn("- cgroup-tools", rendered) def test_preinstall_disabled_explicitly(self): """Test that package pre-installation is disabled when set to false.""" # Test with pre-installation explicitly disabled - rendered = self.packages_template.render({'performance_preinstall_packages': False}) + rendered = self.packages_template.render({"performance_preinstall_packages": False}) # Should only have sudo package - self.assertIn('- sudo', rendered) - self.assertNotIn('- git', rendered) - self.assertNotIn('- screen', rendered) - self.assertNotIn('- apparmor-utils', rendered) + self.assertIn("- sudo", rendered) + self.assertNotIn("- git", rendered) + self.assertNotIn("- screen", rendered) + self.assertNotIn("- apparmor-utils", rendered) def test_package_count(self): """Test that the correct number of packages are included.""" # Default: should only have sudo (1 package) rendered_default = self.packages_template.render({}) - lines_default = [line.strip() for line in rendered_default.split('\n') if line.strip().startswith('- ')] + lines_default = [line.strip() for line in rendered_default.split("\n") if line.strip().startswith("- ")] self.assertEqual(len(lines_default), 1) # Enabled: should have sudo + 7 universal packages (8 total) - rendered_enabled = self.packages_template.render({'performance_preinstall_packages': True}) - lines_enabled = [line.strip() for line in rendered_enabled.split('\n') if line.strip().startswith('- ')] + rendered_enabled = self.packages_template.render({"performance_preinstall_packages": True}) + lines_enabled = [line.strip() for line in rendered_enabled.split("\n") if line.strip().startswith("- ")] self.assertEqual(len(lines_enabled), 8) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_basic_sanity.py b/tests/unit/test_basic_sanity.py index d3675eb6..e2a7b916 100644 --- a/tests/unit/test_basic_sanity.py +++ b/tests/unit/test_basic_sanity.py @@ -2,6 +2,7 @@ """ Basic sanity tests for Algo VPN that don't require deployment """ + import os import subprocess import sys @@ -44,11 +45,7 @@ def test_config_file_valid(): def test_ansible_syntax(): """Check that main playbook has valid syntax""" - result = subprocess.run( - ["ansible-playbook", "main.yml", "--syntax-check"], - capture_output=True, - text=True - ) + result = subprocess.run(["ansible-playbook", "main.yml", "--syntax-check"], capture_output=True, text=True) assert result.returncode == 0, f"Ansible syntax check failed:\n{result.stderr}" print("✓ Ansible playbook syntax is valid") @@ -60,11 +57,7 @@ def test_shellcheck(): for script in shell_scripts: if os.path.exists(script): - result = subprocess.run( - ["shellcheck", script], - capture_output=True, - text=True - ) + result = subprocess.run(["shellcheck", script], capture_output=True, text=True) assert result.returncode == 0, f"Shellcheck failed for {script}:\n{result.stdout}" print(f"✓ {script} passed shellcheck") @@ -87,7 +80,7 @@ def test_cloud_init_header_format(): assert os.path.exists(cloud_init_file), f"{cloud_init_file} not found" with open(cloud_init_file) as f: - first_line = f.readline().rstrip('\n\r') + first_line = f.readline().rstrip("\n\r") # The first line MUST be exactly "#cloud-config" (no space after #) # This regression was introduced in PR #14775 and broke DigitalOcean deployments diff --git a/tests/unit/test_cloud_provider_configs.py b/tests/unit/test_cloud_provider_configs.py index b48ee115..68738c1f 100644 --- a/tests/unit/test_cloud_provider_configs.py +++ b/tests/unit/test_cloud_provider_configs.py @@ -4,31 +4,30 @@ Test cloud provider instance type configurations Focused on validating that configured instance types are current/valid Based on issues #14730 - Hetzner changed from cx11 to cx22 """ + import sys def test_hetzner_server_types(): """Test Hetzner server type configurations (issue #14730)""" # Hetzner deprecated cx11 and cpx11 - smallest is now cx22 - deprecated_types = ['cx11', 'cpx11'] - current_types = ['cx22', 'cpx22', 'cx32', 'cpx32', 'cx42', 'cpx42'] + deprecated_types = ["cx11", "cpx11"] + current_types = ["cx22", "cpx22", "cx32", "cpx32", "cx42", "cpx42"] # Test that we're not using deprecated types in any configs test_config = { - 'cloud_providers': { - 'hetzner': { - 'size': 'cx22', # Should be cx22, not cx11 - 'image': 'ubuntu-22.04', - 'location': 'hel1' + "cloud_providers": { + "hetzner": { + "size": "cx22", # Should be cx22, not cx11 + "image": "ubuntu-22.04", + "location": "hel1", } } } - hetzner = test_config['cloud_providers']['hetzner'] - assert hetzner['size'] not in deprecated_types, \ - f"Using deprecated Hetzner type: {hetzner['size']}" - assert hetzner['size'] in current_types, \ - f"Unknown Hetzner type: {hetzner['size']}" + hetzner = test_config["cloud_providers"]["hetzner"] + assert hetzner["size"] not in deprecated_types, f"Using deprecated Hetzner type: {hetzner['size']}" + assert hetzner["size"] in current_types, f"Unknown Hetzner type: {hetzner['size']}" print("✓ Hetzner server types test passed") @@ -36,10 +35,10 @@ def test_hetzner_server_types(): def test_digitalocean_instance_types(): """Test DigitalOcean droplet size naming""" # DigitalOcean uses format like s-1vcpu-1gb - valid_sizes = ['s-1vcpu-1gb', 's-2vcpu-2gb', 's-2vcpu-4gb', 's-4vcpu-8gb'] - deprecated_sizes = ['512mb', '1gb', '2gb'] # Old naming scheme + valid_sizes = ["s-1vcpu-1gb", "s-2vcpu-2gb", "s-2vcpu-4gb", "s-4vcpu-8gb"] + deprecated_sizes = ["512mb", "1gb", "2gb"] # Old naming scheme - test_size = 's-2vcpu-2gb' + test_size = "s-2vcpu-2gb" assert test_size in valid_sizes, f"Invalid DO size: {test_size}" assert test_size not in deprecated_sizes, f"Using deprecated DO size: {test_size}" @@ -49,10 +48,10 @@ def test_digitalocean_instance_types(): def test_aws_instance_types(): """Test AWS EC2 instance type naming""" # Common valid instance types - valid_types = ['t2.micro', 't3.micro', 't3.small', 't3.medium', 'm5.large'] - deprecated_types = ['t1.micro', 'm1.small'] # Very old types + valid_types = ["t2.micro", "t3.micro", "t3.small", "t3.medium", "m5.large"] + deprecated_types = ["t1.micro", "m1.small"] # Very old types - test_type = 't3.micro' + test_type = "t3.micro" assert test_type in valid_types, f"Unknown EC2 type: {test_type}" assert test_type not in deprecated_types, f"Using deprecated EC2 type: {test_type}" @@ -62,9 +61,10 @@ def test_aws_instance_types(): def test_vultr_instance_types(): """Test Vultr instance type naming""" # Vultr uses format like vc2-1c-1gb - test_type = 'vc2-1c-1gb' - assert any(test_type.startswith(prefix) for prefix in ['vc2-', 'vhf-', 'vhp-']), \ + test_type = "vc2-1c-1gb" + assert any(test_type.startswith(prefix) for prefix in ["vc2-", "vhf-", "vhp-"]), ( f"Invalid Vultr type format: {test_type}" + ) print("✓ Vultr instance types test passed") diff --git a/tests/unit/test_config_validation.py b/tests/unit/test_config_validation.py index 39d4fd7f..65eddd20 100644 --- a/tests/unit/test_config_validation.py +++ b/tests/unit/test_config_validation.py @@ -2,6 +2,7 @@ """ Test configuration file validation without deployment """ + import configparser import os import re @@ -28,14 +29,14 @@ Endpoint = 192.168.1.1:51820 config = configparser.ConfigParser() config.read_string(sample_config) - assert 'Interface' in config, "Missing [Interface] section" - assert 'Peer' in config, "Missing [Peer] section" + assert "Interface" in config, "Missing [Interface] section" + assert "Peer" in config, "Missing [Peer] section" # Validate required fields - assert config['Interface'].get('PrivateKey'), "Missing PrivateKey" - assert config['Interface'].get('Address'), "Missing Address" - assert config['Peer'].get('PublicKey'), "Missing PublicKey" - assert config['Peer'].get('AllowedIPs'), "Missing AllowedIPs" + assert config["Interface"].get("PrivateKey"), "Missing PrivateKey" + assert config["Interface"].get("Address"), "Missing Address" + assert config["Peer"].get("PublicKey"), "Missing PublicKey" + assert config["Peer"].get("AllowedIPs"), "Missing AllowedIPs" print("✓ WireGuard config format validation passed") @@ -43,7 +44,7 @@ Endpoint = 192.168.1.1:51820 def test_base64_key_format(): """Test that keys are in valid base64 format""" # Base64 keys can have variable length, just check format - key_pattern = re.compile(r'^[A-Za-z0-9+/]+=*$') + key_pattern = re.compile(r"^[A-Za-z0-9+/]+=*$") test_keys = [ "aGVsbG8gd29ybGQgdGhpcyBpcyBub3QgYSByZWFsIGtleQo=", @@ -58,8 +59,8 @@ def test_base64_key_format(): def test_ip_address_format(): """Test IP address and CIDR notation validation""" - ip_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$') - endpoint_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$') + ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$") + endpoint_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$") # Test CIDR notation assert ip_pattern.match("10.19.49.2/32"), "Invalid CIDR notation" @@ -74,11 +75,7 @@ def test_ip_address_format(): def test_mobile_config_xml(): """Test that mobile config files would be valid XML""" # First check if xmllint is available - xmllint_check = subprocess.run( - ['which', 'xmllint'], - capture_output=True, - text=True - ) + xmllint_check = subprocess.run(["which", "xmllint"], capture_output=True, text=True) if xmllint_check.returncode != 0: print("⚠ Skipping XML validation test (xmllint not installed)") @@ -99,17 +96,13 @@ def test_mobile_config_xml(): """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.mobileconfig', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".mobileconfig", delete=False) as f: f.write(sample_mobileconfig) temp_file = f.name try: # Use xmllint to validate - result = subprocess.run( - ['xmllint', '--noout', temp_file], - capture_output=True, - text=True - ) + result = subprocess.run(["xmllint", "--noout", temp_file], capture_output=True, text=True) assert result.returncode == 0, f"XML validation failed: {result.stderr}" print("✓ Mobile config XML validation passed") diff --git a/tests/unit/test_docker_localhost_deployment.py b/tests/unit/test_docker_localhost_deployment.py index 8cf67d73..8e3aa9bc 100755 --- a/tests/unit/test_docker_localhost_deployment.py +++ b/tests/unit/test_docker_localhost_deployment.py @@ -3,6 +3,7 @@ Simplified Docker-based localhost deployment tests Verifies services can start and config files exist in expected locations """ + import os import subprocess import sys @@ -11,7 +12,7 @@ import sys def check_docker_available(): """Check if Docker is available""" try: - result = subprocess.run(['docker', '--version'], capture_output=True, text=True) + result = subprocess.run(["docker", "--version"], capture_output=True, text=True) return result.returncode == 0 except FileNotFoundError: return False @@ -31,8 +32,8 @@ AllowedIPs = 10.19.49.2/32,fd9d:bc11:4020::2/128 """ # Just validate the format - required_sections = ['[Interface]', '[Peer]'] - required_fields = ['PrivateKey', 'Address', 'PublicKey', 'AllowedIPs'] + required_sections = ["[Interface]", "[Peer]"] + required_fields = ["PrivateKey", "Address", "PublicKey", "AllowedIPs"] for section in required_sections: if section not in config: @@ -68,15 +69,15 @@ conn ikev2-pubkey """ # Validate format - if 'config setup' not in config: + if "config setup" not in config: print("✗ Missing 'config setup' section") return False - if 'conn %default' not in config: + if "conn %default" not in config: print("✗ Missing 'conn %default' section") return False - if 'keyexchange=ikev2' not in config: + if "keyexchange=ikev2" not in config: print("✗ Missing IKEv2 configuration") return False @@ -87,19 +88,19 @@ conn ikev2-pubkey def test_docker_algo_image(): """Test that the Algo Docker image can be built""" # Check if Dockerfile exists - if not os.path.exists('Dockerfile'): + if not os.path.exists("Dockerfile"): print("✗ Dockerfile not found") return False # Read Dockerfile and validate basic structure - with open('Dockerfile') as f: + with open("Dockerfile") as f: dockerfile_content = f.read() required_elements = [ - 'FROM', # Base image - 'RUN', # Build commands - 'COPY', # Copy Algo files - 'python' # Python dependency + "FROM", # Base image + "RUN", # Build commands + "COPY", # Copy Algo files + "python", # Python dependency ] missing = [] @@ -115,16 +116,14 @@ def test_docker_algo_image(): return True - - def test_localhost_deployment_requirements(): """Test that localhost deployment requirements are met""" requirements = { - 'Python 3.8+': sys.version_info >= (3, 8), - 'Ansible installed': subprocess.run(['which', 'ansible'], capture_output=True).returncode == 0, - 'Main playbook exists': os.path.exists('main.yml'), - 'Project config exists': os.path.exists('pyproject.toml'), - 'Config template exists': os.path.exists('config.cfg.example') or os.path.exists('config.cfg'), + "Python 3.8+": sys.version_info >= (3, 8), + "Ansible installed": subprocess.run(["which", "ansible"], capture_output=True).returncode == 0, + "Main playbook exists": os.path.exists("main.yml"), + "Project config exists": os.path.exists("pyproject.toml"), + "Config template exists": os.path.exists("config.cfg.example") or os.path.exists("config.cfg"), } all_met = True @@ -138,10 +137,6 @@ def test_localhost_deployment_requirements(): return all_met - - - - if __name__ == "__main__": print("Running Docker localhost deployment tests...") print("=" * 50) diff --git a/tests/unit/test_generated_configs.py b/tests/unit/test_generated_configs.py index 829a6255..a3345e20 100644 --- a/tests/unit/test_generated_configs.py +++ b/tests/unit/test_generated_configs.py @@ -3,6 +3,7 @@ Test that generated configuration files have valid syntax This validates WireGuard, StrongSwan, SSH, and other configs """ + import re import subprocess import sys @@ -11,7 +12,7 @@ import sys def check_command_available(cmd): """Check if a command is available on the system""" try: - subprocess.run([cmd, '--version'], capture_output=True, check=False) + subprocess.run([cmd, "--version"], capture_output=True, check=False) return True except FileNotFoundError: return False @@ -37,51 +38,50 @@ PersistentKeepalive = 25 errors = [] # Check for required sections - if '[Interface]' not in sample_config: + if "[Interface]" not in sample_config: errors.append("Missing [Interface] section") - if '[Peer]' not in sample_config: + if "[Peer]" not in sample_config: errors.append("Missing [Peer] section") # Validate Interface section - interface_match = re.search(r'\[Interface\](.*?)\[Peer\]', sample_config, re.DOTALL) + interface_match = re.search(r"\[Interface\](.*?)\[Peer\]", sample_config, re.DOTALL) if interface_match: interface_section = interface_match.group(1) # Check required fields - if not re.search(r'Address\s*=', interface_section): + if not re.search(r"Address\s*=", interface_section): errors.append("Missing Address in Interface section") - if not re.search(r'PrivateKey\s*=', interface_section): + if not re.search(r"PrivateKey\s*=", interface_section): errors.append("Missing PrivateKey in Interface section") # Validate IP addresses - address_match = re.search(r'Address\s*=\s*([^\n]+)', interface_section) + address_match = re.search(r"Address\s*=\s*([^\n]+)", interface_section) if address_match: - addresses = address_match.group(1).split(',') + addresses = address_match.group(1).split(",") for addr in addresses: addr = addr.strip() # Basic IP validation - if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', addr) and \ - not re.match(r'^[0-9a-fA-F:]+/\d+$', addr): + if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", addr) and not re.match(r"^[0-9a-fA-F:]+/\d+$", addr): errors.append(f"Invalid IP address format: {addr}") # Validate Peer section - peer_match = re.search(r'\[Peer\](.*)', sample_config, re.DOTALL) + peer_match = re.search(r"\[Peer\](.*)", sample_config, re.DOTALL) if peer_match: peer_section = peer_match.group(1) # Check required fields - if not re.search(r'PublicKey\s*=', peer_section): + if not re.search(r"PublicKey\s*=", peer_section): errors.append("Missing PublicKey in Peer section") - if not re.search(r'AllowedIPs\s*=', peer_section): + if not re.search(r"AllowedIPs\s*=", peer_section): errors.append("Missing AllowedIPs in Peer section") - if not re.search(r'Endpoint\s*=', peer_section): + if not re.search(r"Endpoint\s*=", peer_section): errors.append("Missing Endpoint in Peer section") # Validate endpoint format - endpoint_match = re.search(r'Endpoint\s*=\s*([^\n]+)', peer_section) + endpoint_match = re.search(r"Endpoint\s*=\s*([^\n]+)", peer_section) if endpoint_match: endpoint = endpoint_match.group(1).strip() - if not re.match(r'^[\d\.\:]+:\d+$', endpoint): + if not re.match(r"^[\d\.\:]+:\d+$", endpoint): errors.append(f"Invalid Endpoint format: {endpoint}") if errors: @@ -132,33 +132,32 @@ conn ikev2-pubkey errors = [] # Check for required sections - if 'config setup' not in sample_config: + if "config setup" not in sample_config: errors.append("Missing 'config setup' section") - if 'conn %default' not in sample_config: + if "conn %default" not in sample_config: errors.append("Missing 'conn %default' section") # Validate connection settings - conn_pattern = re.compile(r'conn\s+(\S+)') + conn_pattern = re.compile(r"conn\s+(\S+)") connections = conn_pattern.findall(sample_config) if len(connections) < 2: # Should have at least %default and one other errors.append("Not enough connection definitions") # Check for required parameters in connections - required_params = ['keyexchange', 'left', 'right'] + required_params = ["keyexchange", "left", "right"] for param in required_params: - if f'{param}=' not in sample_config: + if f"{param}=" not in sample_config: errors.append(f"Missing required parameter: {param}") # Validate IP subnet formats - subnet_pattern = re.compile(r'(left|right)subnet\s*=\s*([^\n]+)') + subnet_pattern = re.compile(r"(left|right)subnet\s*=\s*([^\n]+)") for match in subnet_pattern.finditer(sample_config): - subnets = match.group(2).split(',') + subnets = match.group(2).split(",") for subnet in subnets: subnet = subnet.strip() - if subnet != '0.0.0.0/0' and subnet != '::/0': - if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', subnet) and \ - not re.match(r'^[0-9a-fA-F:]+/\d+$', subnet): + if subnet != "0.0.0.0/0" and subnet != "::/0": + if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", subnet) and not re.match(r"^[0-9a-fA-F:]+/\d+$", subnet): errors.append(f"Invalid subnet format: {subnet}") if errors: @@ -188,21 +187,21 @@ def test_ssh_config_syntax(): errors = [] # Parse SSH config format - lines = sample_config.strip().split('\n') + lines = sample_config.strip().split("\n") current_host = None for line in lines: line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue - if line.startswith('Host '): + if line.startswith("Host "): current_host = line.split()[1] - elif current_host and ' ' in line: + elif current_host and " " in line: key, value = line.split(None, 1) # Validate common SSH options - if key == 'Port': + if key == "Port": try: port = int(value) if not 1 <= port <= 65535: @@ -210,7 +209,7 @@ def test_ssh_config_syntax(): except ValueError: errors.append(f"Port must be a number: {value}") - elif key == 'LocalForward': + elif key == "LocalForward": # Format: LocalForward [bind_address:]port host:hostport parts = value.split() if len(parts) != 2: @@ -256,35 +255,35 @@ COMMIT errors = [] # Check table definitions - tables = re.findall(r'\*(\w+)', sample_rules) - if 'filter' not in tables: + tables = re.findall(r"\*(\w+)", sample_rules) + if "filter" not in tables: errors.append("Missing *filter table") - if 'nat' not in tables: + if "nat" not in tables: errors.append("Missing *nat table") # Check for COMMIT statements - commit_count = sample_rules.count('COMMIT') + commit_count = sample_rules.count("COMMIT") if commit_count != len(tables): errors.append(f"Number of COMMIT statements ({commit_count}) doesn't match tables ({len(tables)})") # Validate chain policies - chain_pattern = re.compile(r'^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]', re.MULTILINE) + chain_pattern = re.compile(r"^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]", re.MULTILINE) chains = chain_pattern.findall(sample_rules) - required_chains = [('INPUT', 'DROP'), ('FORWARD', 'DROP'), ('OUTPUT', 'ACCEPT')] + required_chains = [("INPUT", "DROP"), ("FORWARD", "DROP"), ("OUTPUT", "ACCEPT")] for chain, _policy in required_chains: if not any(c[0] == chain for c in chains): errors.append(f"Missing required chain: {chain}") # Validate rule syntax - rule_pattern = re.compile(r'^-[AI]\s+(\w+)', re.MULTILINE) + rule_pattern = re.compile(r"^-[AI]\s+(\w+)", re.MULTILINE) rules = rule_pattern.findall(sample_rules) if len(rules) < 5: errors.append("Insufficient firewall rules") # Check for essential security rules - if '-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT' not in sample_rules: + if "-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT" not in sample_rules: errors.append("Missing stateful connection tracking rule") if errors: @@ -320,27 +319,26 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts errors = [] # Parse config - for line in sample_config.strip().split('\n'): + for line in sample_config.strip().split("\n"): line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue # Most dnsmasq options are key=value or just key - if '=' in line: - key, value = line.split('=', 1) + if "=" in line: + key, value = line.split("=", 1) # Validate specific options - if key == 'interface': - if not re.match(r'^[a-zA-Z0-9\-_]+$', value): + if key == "interface": + if not re.match(r"^[a-zA-Z0-9\-_]+$", value): errors.append(f"Invalid interface name: {value}") - elif key == 'server': + elif key == "server": # Basic IP validation - if not re.match(r'^\d+\.\d+\.\d+\.\d+$', value) and \ - not re.match(r'^[0-9a-fA-F:]+$', value): + if not re.match(r"^\d+\.\d+\.\d+\.\d+$", value) and not re.match(r"^[0-9a-fA-F:]+$", value): errors.append(f"Invalid DNS server IP: {value}") - elif key == 'cache-size': + elif key == "cache-size": try: size = int(value) if size < 0: @@ -349,9 +347,9 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts errors.append(f"Cache size must be a number: {value}") # Check for required options - required = ['interface', 'server'] + required = ["interface", "server"] for req in required: - if f'{req}=' not in sample_config: + if f"{req}=" not in sample_config: errors.append(f"Missing required option: {req}") if errors: diff --git a/tests/unit/test_iptables_rules.py b/tests/unit/test_iptables_rules.py index af0063f8..ec561b5a 100644 --- a/tests/unit/test_iptables_rules.py +++ b/tests/unit/test_iptables_rules.py @@ -14,203 +14,203 @@ from jinja2 import Environment, FileSystemLoader def load_template(template_name): """Load a Jinja2 template from the roles/common/templates directory.""" - template_dir = Path(__file__).parent.parent.parent / 'roles' / 'common' / 'templates' + template_dir = Path(__file__).parent.parent.parent / "roles" / "common" / "templates" env = Environment(loader=FileSystemLoader(str(template_dir))) return env.get_template(template_name) def test_wireguard_nat_rules_ipv4(): """Test that WireGuard traffic gets proper NAT rules without policy matching.""" - template = load_template('rules.v4.j2') + template = load_template("rules.v4.j2") # Test with WireGuard enabled result = template.render( ipsec_enabled=False, wireguard_enabled=True, - wireguard_network_ipv4='10.49.0.0/16', + wireguard_network_ipv4="10.49.0.0/16", wireguard_port=51820, wireguard_port_avoid=53, wireguard_port_actual=51820, - ansible_default_ipv4={'interface': 'eth0'}, + ansible_default_ipv4={"interface": "eth0"}, snat_aipv4=None, BetweenClients_DROP=True, block_smb=True, block_netbios=True, - local_service_ip='10.49.0.1', + local_service_ip="10.49.0.1", ansible_ssh_port=22, - reduce_mtu=0 + reduce_mtu=0, ) # Verify NAT rule exists with output interface and without policy matching - assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result + assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result # Verify no policy matching in WireGuard NAT rules - assert '-A POSTROUTING -s 10.49.0.0/16 -m policy' not in result + assert "-A POSTROUTING -s 10.49.0.0/16 -m policy" not in result def test_ipsec_nat_rules_ipv4(): """Test that IPsec traffic gets proper NAT rules without policy matching.""" - template = load_template('rules.v4.j2') + template = load_template("rules.v4.j2") # Test with IPsec enabled result = template.render( ipsec_enabled=True, wireguard_enabled=False, - strongswan_network='10.48.0.0/16', - strongswan_network_ipv6='2001:db8::/48', - ansible_default_ipv4={'interface': 'eth0'}, + strongswan_network="10.48.0.0/16", + strongswan_network_ipv6="2001:db8::/48", + ansible_default_ipv4={"interface": "eth0"}, snat_aipv4=None, BetweenClients_DROP=True, block_smb=True, block_netbios=True, - local_service_ip='10.48.0.1', + local_service_ip="10.48.0.1", ansible_ssh_port=22, - reduce_mtu=0 + reduce_mtu=0, ) # Verify NAT rule exists with output interface and without policy matching - assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result + assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result # Verify no policy matching in IPsec NAT rules (this was the bug) - assert '-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none' not in result + assert "-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none" not in result def test_both_vpns_nat_rules_ipv4(): """Test NAT rules when both VPN types are enabled.""" - template = load_template('rules.v4.j2') + template = load_template("rules.v4.j2") result = template.render( ipsec_enabled=True, wireguard_enabled=True, - strongswan_network='10.48.0.0/16', - wireguard_network_ipv4='10.49.0.0/16', - strongswan_network_ipv6='2001:db8::/48', - wireguard_network_ipv6='2001:db8:a160::/48', + strongswan_network="10.48.0.0/16", + wireguard_network_ipv4="10.49.0.0/16", + strongswan_network_ipv6="2001:db8::/48", + wireguard_network_ipv6="2001:db8:a160::/48", wireguard_port=51820, wireguard_port_avoid=53, wireguard_port_actual=51820, - ansible_default_ipv4={'interface': 'eth0'}, + ansible_default_ipv4={"interface": "eth0"}, snat_aipv4=None, BetweenClients_DROP=True, block_smb=True, block_netbios=True, - local_service_ip='10.49.0.1', + local_service_ip="10.49.0.1", ansible_ssh_port=22, - reduce_mtu=0 + reduce_mtu=0, ) # Both should have NAT rules with output interface - assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result - assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result + assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result + assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result # Neither should have policy matching - assert '-m policy --pol none' not in result + assert "-m policy --pol none" not in result def test_alternative_ingress_snat(): """Test that alternative ingress IP uses SNAT instead of MASQUERADE.""" - template = load_template('rules.v4.j2') + template = load_template("rules.v4.j2") result = template.render( ipsec_enabled=True, wireguard_enabled=True, - strongswan_network='10.48.0.0/16', - wireguard_network_ipv4='10.49.0.0/16', - strongswan_network_ipv6='2001:db8::/48', - wireguard_network_ipv6='2001:db8:a160::/48', + strongswan_network="10.48.0.0/16", + wireguard_network_ipv4="10.49.0.0/16", + strongswan_network_ipv6="2001:db8::/48", + wireguard_network_ipv6="2001:db8:a160::/48", wireguard_port=51820, wireguard_port_avoid=53, wireguard_port_actual=51820, - ansible_default_ipv4={'interface': 'eth0'}, - snat_aipv4='192.168.1.100', # Alternative ingress IP + ansible_default_ipv4={"interface": "eth0"}, + snat_aipv4="192.168.1.100", # Alternative ingress IP BetweenClients_DROP=True, block_smb=True, block_netbios=True, - local_service_ip='10.49.0.1', + local_service_ip="10.49.0.1", ansible_ssh_port=22, - reduce_mtu=0 + reduce_mtu=0, ) # Should use SNAT with specific IP and output interface instead of MASQUERADE - assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100' in result - assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100' in result - assert 'MASQUERADE' not in result + assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result + assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result + assert "MASQUERADE" not in result def test_ipsec_forward_rule_has_policy_match(): """Test that IPsec FORWARD rules still use policy matching (this is correct).""" - template = load_template('rules.v4.j2') + template = load_template("rules.v4.j2") result = template.render( ipsec_enabled=True, wireguard_enabled=False, - strongswan_network='10.48.0.0/16', - strongswan_network_ipv6='2001:db8::/48', - ansible_default_ipv4={'interface': 'eth0'}, + strongswan_network="10.48.0.0/16", + strongswan_network_ipv6="2001:db8::/48", + ansible_default_ipv4={"interface": "eth0"}, snat_aipv4=None, BetweenClients_DROP=True, block_smb=True, block_netbios=True, - local_service_ip='10.48.0.1', + local_service_ip="10.48.0.1", ansible_ssh_port=22, - reduce_mtu=0 + reduce_mtu=0, ) # FORWARD rule should have policy match (this is correct and should stay) - assert '-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT' in result + assert "-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT" in result def test_wireguard_forward_rule_no_policy_match(): """Test that WireGuard FORWARD rules don't use policy matching.""" - template = load_template('rules.v4.j2') + template = load_template("rules.v4.j2") result = template.render( ipsec_enabled=False, wireguard_enabled=True, - wireguard_network_ipv4='10.49.0.0/16', + wireguard_network_ipv4="10.49.0.0/16", wireguard_port=51820, wireguard_port_avoid=53, wireguard_port_actual=51820, - ansible_default_ipv4={'interface': 'eth0'}, + ansible_default_ipv4={"interface": "eth0"}, snat_aipv4=None, BetweenClients_DROP=True, block_smb=True, block_netbios=True, - local_service_ip='10.49.0.1', + local_service_ip="10.49.0.1", ansible_ssh_port=22, - reduce_mtu=0 + reduce_mtu=0, ) # WireGuard FORWARD rule should NOT have any policy match - assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT' in result - assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy' not in result + assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT" in result + assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy" not in result def test_output_interface_in_nat_rules(): """Test that output interface is specified in NAT rules.""" - template = load_template('rules.v4.j2') - + template = load_template("rules.v4.j2") + result = template.render( snat_aipv4=False, wireguard_enabled=True, ipsec_enabled=True, - wireguard_network_ipv4='10.49.0.0/16', - strongswan_network='10.48.0.0/16', - ansible_default_ipv4={'interface': 'eth0', 'address': '10.0.0.1'}, - ansible_default_ipv6={'interface': 'eth0', 'address': 'fd9d:bc11:4020::1'}, + wireguard_network_ipv4="10.49.0.0/16", + strongswan_network="10.48.0.0/16", + ansible_default_ipv4={"interface": "eth0", "address": "10.0.0.1"}, + ansible_default_ipv6={"interface": "eth0", "address": "fd9d:bc11:4020::1"}, wireguard_port_actual=51820, wireguard_port_avoid=53, wireguard_port=51820, ansible_ssh_port=22, - reduce_mtu=0 + reduce_mtu=0, ) - + # Check that output interface is specified for both VPNs - assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result - assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result - + assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result + assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result + # Ensure we don't have rules without output interface - assert '-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE' not in result - assert '-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE' not in result + assert "-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE" not in result + assert "-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE" not in result -if __name__ == '__main__': - pytest.main([__file__, '-v']) +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_lightsail_boto3_fix.py b/tests/unit/test_lightsail_boto3_fix.py index d8dbbd4d..2f4216e6 100644 --- a/tests/unit/test_lightsail_boto3_fix.py +++ b/tests/unit/test_lightsail_boto3_fix.py @@ -12,7 +12,7 @@ import unittest from unittest.mock import MagicMock, patch # Add the library directory to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../library')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../library")) class TestLightsailBoto3Fix(unittest.TestCase): @@ -22,15 +22,15 @@ class TestLightsailBoto3Fix(unittest.TestCase): """Set up test fixtures.""" # Mock the ansible module_utils since we're testing outside of Ansible self.mock_modules = { - 'ansible.module_utils.basic': MagicMock(), - 'ansible.module_utils.ec2': MagicMock(), - 'ansible.module_utils.aws.core': MagicMock(), + "ansible.module_utils.basic": MagicMock(), + "ansible.module_utils.ec2": MagicMock(), + "ansible.module_utils.aws.core": MagicMock(), } # Apply mocks self.patches = [] for module_name, mock_module in self.mock_modules.items(): - patcher = patch.dict('sys.modules', {module_name: mock_module}) + patcher = patch.dict("sys.modules", {module_name: mock_module}) patcher.start() self.patches.append(patcher) @@ -45,7 +45,7 @@ class TestLightsailBoto3Fix(unittest.TestCase): # Import the module spec = importlib.util.spec_from_file_location( "lightsail_region_facts", - os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py') + os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"), ) module = importlib.util.module_from_spec(spec) @@ -54,7 +54,7 @@ class TestLightsailBoto3Fix(unittest.TestCase): # Verify the module loaded self.assertIsNotNone(module) - self.assertTrue(hasattr(module, 'main')) + self.assertTrue(hasattr(module, "main")) except Exception as e: self.fail(f"Failed to import lightsail_region_facts: {e}") @@ -62,15 +62,13 @@ class TestLightsailBoto3Fix(unittest.TestCase): def test_get_aws_connection_info_called_without_boto3(self): """Test that get_aws_connection_info is called without boto3 parameter.""" # Mock get_aws_connection_info to track calls - mock_get_aws_connection_info = MagicMock( - return_value=('us-west-2', None, {}) - ) + mock_get_aws_connection_info = MagicMock(return_value=("us-west-2", None, {})) - with patch('ansible.module_utils.ec2.get_aws_connection_info', mock_get_aws_connection_info): + with patch("ansible.module_utils.ec2.get_aws_connection_info", mock_get_aws_connection_info): # Import the module spec = importlib.util.spec_from_file_location( "lightsail_region_facts", - os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py') + os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"), ) module = importlib.util.module_from_spec(spec) @@ -79,7 +77,7 @@ class TestLightsailBoto3Fix(unittest.TestCase): mock_ansible_module.params = {} mock_ansible_module.check_mode = False - with patch('ansible.module_utils.basic.AnsibleModule', return_value=mock_ansible_module): + with patch("ansible.module_utils.basic.AnsibleModule", return_value=mock_ansible_module): # Execute the module try: spec.loader.exec_module(module) @@ -100,28 +98,35 @@ class TestLightsailBoto3Fix(unittest.TestCase): if call_args: # Check positional arguments if call_args[0]: # args - self.assertTrue(len(call_args[0]) <= 1, - "get_aws_connection_info should be called with at most 1 positional arg (module)") + self.assertTrue( + len(call_args[0]) <= 1, + "get_aws_connection_info should be called with at most 1 positional arg (module)", + ) # Check keyword arguments if call_args[1]: # kwargs - self.assertNotIn('boto3', call_args[1], - "get_aws_connection_info should not be called with boto3 parameter") + self.assertNotIn( + "boto3", call_args[1], "get_aws_connection_info should not be called with boto3 parameter" + ) def test_no_boto3_parameter_in_source(self): """Verify that boto3 parameter is not present in the source code.""" - lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py') + lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py") with open(lightsail_path) as f: content = f.read() # Check that boto3=True is not in the file - self.assertNotIn('boto3=True', content, - "boto3=True parameter should not be present in lightsail_region_facts.py") + self.assertNotIn( + "boto3=True", content, "boto3=True parameter should not be present in lightsail_region_facts.py" + ) # Check that boto3 parameter is not used with get_aws_connection_info - self.assertNotIn('get_aws_connection_info(module, boto3', content, - "get_aws_connection_info should not be called with boto3 parameter") + self.assertNotIn( + "get_aws_connection_info(module, boto3", + content, + "get_aws_connection_info should not be called with boto3 parameter", + ) def test_regression_issue_14822(self): """ @@ -132,26 +137,28 @@ class TestLightsailBoto3Fix(unittest.TestCase): # The boto3 parameter was deprecated and removed in amazon.aws collection # that comes with Ansible 11.x - lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py') + lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py") with open(lightsail_path) as f: lines = f.readlines() # Find the line that calls get_aws_connection_info for line_num, line in enumerate(lines, 1): - if 'get_aws_connection_info' in line and 'region' in line: + if "get_aws_connection_info" in line and "region" in line: # This should be around line 85 # Verify it doesn't have boto3=True - self.assertNotIn('boto3', line, - f"Line {line_num} should not contain boto3 parameter") + self.assertNotIn("boto3", line, f"Line {line_num} should not contain boto3 parameter") # Verify the correct format - self.assertIn('get_aws_connection_info(module)', line, - f"Line {line_num} should call get_aws_connection_info(module) without boto3") + self.assertIn( + "get_aws_connection_info(module)", + line, + f"Line {line_num} should call get_aws_connection_info(module) without boto3", + ) break else: self.fail("Could not find get_aws_connection_info call in lightsail_region_facts.py") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_openssl_compatibility.py b/tests/unit/test_openssl_compatibility.py index 4f70ccf0..62e937cc 100644 --- a/tests/unit/test_openssl_compatibility.py +++ b/tests/unit/test_openssl_compatibility.py @@ -5,6 +5,7 @@ Hybrid approach: validates actual certificates when available, else tests templa Based on issues #14755, #14718 - Apple device compatibility Issues #75, #153 - Security enhancements (name constraints, EKU restrictions) """ + import glob import os import re @@ -22,7 +23,7 @@ def find_generated_certificates(): config_patterns = [ "configs/*/ipsec/.pki/cacert.pem", "../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory - "../../configs/*/ipsec/.pki/cacert.pem" # Alternative path + "../../configs/*/ipsec/.pki/cacert.pem", # Alternative path ] for pattern in config_patterns: @@ -30,26 +31,23 @@ def find_generated_certificates(): if ca_certs: base_path = os.path.dirname(ca_certs[0]) return { - 'ca_cert': ca_certs[0], - 'base_path': base_path, - 'server_certs': glob.glob(f"{base_path}/certs/*.crt"), - 'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12") + "ca_cert": ca_certs[0], + "base_path": base_path, + "server_certs": glob.glob(f"{base_path}/certs/*.crt"), + "p12_files": glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12"), } return None + def test_openssl_version_detection(): """Test that we can detect OpenSSL version for compatibility checks""" - result = subprocess.run( - ['openssl', 'version'], - capture_output=True, - text=True - ) + result = subprocess.run(["openssl", "version"], capture_output=True, text=True) assert result.returncode == 0, "Failed to get OpenSSL version" # Parse version - e.g., "OpenSSL 3.0.2 15 Mar 2022" - version_match = re.search(r'OpenSSL\s+(\d+)\.(\d+)\.(\d+)', result.stdout) + version_match = re.search(r"OpenSSL\s+(\d+)\.(\d+)\.(\d+)", result.stdout) assert version_match, f"Can't parse OpenSSL version: {result.stdout}" major = int(version_match.group(1)) @@ -62,7 +60,7 @@ def test_openssl_version_detection(): def validate_ca_certificate_real(cert_files): """Validate actual Ansible-generated CA certificate""" # Read the actual CA certificate generated by Ansible - with open(cert_files['ca_cert'], 'rb') as f: + with open(cert_files["ca_cert"], "rb") as f: cert_data = f.read() certificate = x509.load_pem_x509_certificate(cert_data) @@ -89,30 +87,34 @@ def validate_ca_certificate_real(cert_files): assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints" # Verify public domains are excluded - excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees - if isinstance(constraint, x509.DNSName)] + excluded_dns = [ + constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.DNSName) + ] public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] for domain in public_domains: assert domain in excluded_dns, f"CA should exclude public domain {domain}" # Verify private IP ranges are excluded (Issue #75) - excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees - if isinstance(constraint, x509.IPAddress)] + excluded_ips = [ + constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.IPAddress) + ] assert len(excluded_ips) > 0, "CA should exclude private IP ranges" # Verify email domains are also excluded (Issue #153) - excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees - if isinstance(constraint, x509.RFC822Name)] + excluded_emails = [ + constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.RFC822Name) + ] email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] for domain in email_domains: assert domain in excluded_emails, f"CA should exclude email domain {domain}" print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}") + def validate_ca_certificate_config(): """Validate CA certificate configuration in Ansible files (CI mode)""" # Check that the Ansible task file has proper CA certificate configuration - openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') + openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml") if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return @@ -122,15 +124,15 @@ def validate_ca_certificate_config(): # Verify key security configurations are present security_checks = [ - ('name_constraints_permitted', 'Name constraints should be configured'), - ('name_constraints_excluded', 'Excluded name constraints should be configured'), - ('extended_key_usage', 'Extended Key Usage should be configured'), - ('1.3.6.1.5.5.7.3.17', 'IPsec End Entity OID should be present'), - ('serverAuth', 'Server authentication EKU should be present'), - ('clientAuth', 'Client authentication EKU should be present'), - ('basic_constraints', 'Basic constraints should be configured'), - ('CA:TRUE', 'CA certificate should be marked as CA'), - ('pathlen:0', 'Path length constraint should be set') + ("name_constraints_permitted", "Name constraints should be configured"), + ("name_constraints_excluded", "Excluded name constraints should be configured"), + ("extended_key_usage", "Extended Key Usage should be configured"), + ("1.3.6.1.5.5.7.3.17", "IPsec End Entity OID should be present"), + ("serverAuth", "Server authentication EKU should be present"), + ("clientAuth", "Client authentication EKU should be present"), + ("basic_constraints", "Basic constraints should be configured"), + ("CA:TRUE", "CA certificate should be marked as CA"), + ("pathlen:0", "Path length constraint should be set"), ] for check, message in security_checks: @@ -140,7 +142,9 @@ def validate_ca_certificate_config(): public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] for domain in public_domains: # Handle both double quotes and single quotes in YAML - assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, f"Public domain {domain} should be excluded" + assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, ( + f"Public domain {domain} should be excluded" + ) # Verify private IP ranges are excluded private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"] @@ -151,13 +155,16 @@ def validate_ca_certificate_config(): email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"] for domain in email_domains: # Handle both double quotes and single quotes in YAML - assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, f"Email domain {domain} should be excluded" + assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, ( + f"Email domain {domain} should be excluded" + ) # Verify IPv6 constraints are present (Issue #153) assert "IP:::/0" in content, "IPv6 all addresses should be excluded" print("✓ CA certificate configuration has proper security constraints") + def test_ca_certificate(): """Test CA certificate - uses real certs if available, else validates config (Issue #75, #153)""" cert_files = find_generated_certificates() @@ -172,14 +179,18 @@ def validate_server_certificates_real(cert_files): # Filter to only actual server certificates (not client certs) # Server certificates contain IP addresses in the filename import re - server_certs = [f for f in cert_files['server_certs'] - if not f.endswith('/cacert.pem') and re.search(r'\d+\.\d+\.\d+\.\d+\.crt$', f)] + + server_certs = [ + f + for f in cert_files["server_certs"] + if not f.endswith("/cacert.pem") and re.search(r"\d+\.\d+\.\d+\.\d+\.crt$", f) + ] if not server_certs: print("⚠ No server certificates found") return for server_cert_path in server_certs: - with open(server_cert_path, 'rb') as f: + with open(server_cert_path, "rb") as f: cert_data = f.read() certificate = x509.load_pem_x509_certificate(cert_data) @@ -193,7 +204,9 @@ def validate_server_certificates_real(cert_files): assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU" assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU" # Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153) - assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation" + assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, ( + "Server cert should NOT have clientAuth EKU for role separation" + ) # Check SAN extension exists (required for Apple devices) try: @@ -204,9 +217,10 @@ def validate_server_certificates_real(cert_files): print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}") + def validate_server_certificates_config(): """Validate server certificate configuration in Ansible files (CI mode)""" - openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') + openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml") if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return @@ -215,7 +229,7 @@ def validate_server_certificates_config(): content = f.read() # Look for server certificate CSR section - server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL) + server_csr_section = re.search(r"Create CSRs for server certificate.*?register: server_csr", content, re.DOTALL) if not server_csr_section: print("⚠ Could not find server certificate CSR section") return @@ -224,11 +238,11 @@ def validate_server_certificates_config(): # Check server certificate CSR configuration server_checks = [ - ('subject_alt_name', 'Server certificates should have SAN extension'), - ('serverAuth', 'Server certificates should have serverAuth EKU'), - ('1.3.6.1.5.5.7.3.17', 'Server certificates should have IPsec End Entity EKU'), - ('digitalSignature', 'Server certificates should have digital signature usage'), - ('keyEncipherment', 'Server certificates should have key encipherment usage') + ("subject_alt_name", "Server certificates should have SAN extension"), + ("serverAuth", "Server certificates should have serverAuth EKU"), + ("1.3.6.1.5.5.7.3.17", "Server certificates should have IPsec End Entity EKU"), + ("digitalSignature", "Server certificates should have digital signature usage"), + ("keyEncipherment", "Server certificates should have key encipherment usage"), ] for check, message in server_checks: @@ -236,15 +250,20 @@ def validate_server_certificates_config(): # Security check: Server certificates should NOT have clientAuth (Issue #153) # Look for clientAuth in extended_key_usage section, not in comments - eku_lines = [line for line in server_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'clientAuth' in line)] - has_client_auth = any('clientAuth' in line for line in eku_lines if line.strip().startswith('- ')) + eku_lines = [ + line + for line in server_section.split("\n") + if "extended_key_usage:" in line or (line.strip().startswith("- ") and "clientAuth" in line) + ] + has_client_auth = any("clientAuth" in line for line in eku_lines if line.strip().startswith("- ")) assert not has_client_auth, "Server certificates should NOT have clientAuth EKU for role separation" # Verify SAN extension is configured for Apple compatibility - assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility" + assert "subjectAltName" in server_section, "Server certificates missing SAN configuration for Apple compatibility" print("✓ Server certificate configuration has proper EKU and SAN settings") + def test_server_certificates(): """Test server certificates - uses real certs if available, else validates config""" cert_files = find_generated_certificates() @@ -258,18 +277,18 @@ def validate_client_certificates_real(cert_files): """Validate actual Ansible-generated client certificates""" # Find client certificates (not CA cert, not server cert with IP/DNS name) client_certs = [] - for cert_path in cert_files['server_certs']: - if 'cacert.pem' in cert_path: + for cert_path in cert_files["server_certs"]: + if "cacert.pem" in cert_path: continue - with open(cert_path, 'rb') as f: + with open(cert_path, "rb") as f: cert_data = f.read() certificate = x509.load_pem_x509_certificate(cert_data) # Check if this looks like a client cert vs server cert cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value # Server certs typically have IP addresses or domain names as CN - if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4): + if not (cn.replace(".", "").isdigit() or "." in cn and len(cn.split(".")) == 4): client_certs.append((cert_path, certificate)) if not client_certs: @@ -287,7 +306,9 @@ def validate_client_certificates_real(cert_files): assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU" # Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153) - assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation" + assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, ( + "Client cert must NOT have serverAuth EKU to prevent server impersonation" + ) # Check SAN extension for email try: @@ -299,9 +320,10 @@ def validate_client_certificates_real(cert_files): print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}") + def validate_client_certificates_config(): """Validate client certificate configuration in Ansible files (CI mode)""" - openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') + openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml") if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return @@ -310,7 +332,9 @@ def validate_client_certificates_config(): content = f.read() # Look for client certificate CSR section - client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL) + client_csr_section = re.search( + r"Create CSRs for client certificates.*?register: client_csr_jobs", content, re.DOTALL + ) if not client_csr_section: print("⚠ Could not find client certificate CSR section") return @@ -319,11 +343,11 @@ def validate_client_certificates_config(): # Check client certificate configuration client_checks = [ - ('clientAuth', 'Client certificates should have clientAuth EKU'), - ('1.3.6.1.5.5.7.3.17', 'Client certificates should have IPsec End Entity EKU'), - ('digitalSignature', 'Client certificates should have digital signature usage'), - ('keyEncipherment', 'Client certificates should have key encipherment usage'), - ('email:', 'Client certificates should have email SAN') + ("clientAuth", "Client certificates should have clientAuth EKU"), + ("1.3.6.1.5.5.7.3.17", "Client certificates should have IPsec End Entity EKU"), + ("digitalSignature", "Client certificates should have digital signature usage"), + ("keyEncipherment", "Client certificates should have key encipherment usage"), + ("email:", "Client certificates should have email SAN"), ] for check, message in client_checks: @@ -331,15 +355,22 @@ def validate_client_certificates_config(): # Security check: Client certificates should NOT have serverAuth (Issue #153) # Look for serverAuth in extended_key_usage section, not in comments - eku_lines = [line for line in client_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'serverAuth' in line)] - has_server_auth = any('serverAuth' in line for line in eku_lines if line.strip().startswith('- ')) + eku_lines = [ + line + for line in client_section.split("\n") + if "extended_key_usage:" in line or (line.strip().startswith("- ") and "serverAuth" in line) + ] + has_server_auth = any("serverAuth" in line for line in eku_lines if line.strip().startswith("- ")) assert not has_server_auth, "Client certificates must NOT have serverAuth EKU to prevent server impersonation" # Verify client certificates use unique email domains (Issue #153) - assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment" + assert "openssl_constraint_random_id" in client_section, ( + "Client certificates should use unique email domain per deployment" + ) print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)") + def test_client_certificates(): """Test client certificates - uses real certs if available, else validates config (Issue #75, #153)""" cert_files = find_generated_certificates() @@ -351,24 +382,33 @@ def test_client_certificates(): def validate_pkcs12_files_real(cert_files): """Validate actual Ansible-generated PKCS#12 files""" - if not cert_files.get('p12_files'): + if not cert_files.get("p12_files"): print("⚠ No PKCS#12 files found") return major, minor = test_openssl_version_detection() - for p12_file in cert_files['p12_files']: + for p12_file in cert_files["p12_files"]: assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}" # Test that PKCS#12 file can be read (validates format) - legacy_flag = ['-legacy'] if major >= 3 else [] + legacy_flag = ["-legacy"] if major >= 3 else [] - result = subprocess.run([ - 'openssl', 'pkcs12', '-info', - '-in', p12_file, - '-passin', 'pass:', # Try empty password first - '-noout' - ] + legacy_flag, capture_output=True, text=True) + result = subprocess.run( + [ + "openssl", + "pkcs12", + "-info", + "-in", + p12_file, + "-passin", + "pass:", # Try empty password first + "-noout", + ] + + legacy_flag, + capture_output=True, + text=True, + ) # PKCS#12 files should be readable (even if password-protected) # We're just testing format validity, not trying to extract contents @@ -378,9 +418,10 @@ def validate_pkcs12_files_real(cert_files): print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}") + def validate_pkcs12_files_config(): """Validate PKCS#12 file configuration in Ansible files (CI mode)""" - openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') + openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml") if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return @@ -390,13 +431,13 @@ def validate_pkcs12_files_config(): # Check PKCS#12 generation configuration p12_checks = [ - ('openssl_pkcs12', 'PKCS#12 generation should be configured'), - ('encryption_level', 'PKCS#12 encryption level should be configured'), - ('compatibility2022', 'PKCS#12 should use Apple-compatible encryption'), - ('friendly_name', 'PKCS#12 should have friendly names'), - ('other_certificates', 'PKCS#12 should include CA certificate for full chain'), - ('passphrase', 'PKCS#12 files should be password protected'), - ('mode: "0600"', 'PKCS#12 files should have secure permissions') + ("openssl_pkcs12", "PKCS#12 generation should be configured"), + ("encryption_level", "PKCS#12 encryption level should be configured"), + ("compatibility2022", "PKCS#12 should use Apple-compatible encryption"), + ("friendly_name", "PKCS#12 should have friendly names"), + ("other_certificates", "PKCS#12 should include CA certificate for full chain"), + ("passphrase", "PKCS#12 files should be password protected"), + ('mode: "0600"', "PKCS#12 files should have secure permissions"), ] for check, message in p12_checks: @@ -404,6 +445,7 @@ def validate_pkcs12_files_config(): print("✓ PKCS#12 configuration has proper Apple device compatibility settings") + def test_pkcs12_files(): """Test PKCS#12 files - uses real files if available, else validates config (Issue #14755, #14718)""" cert_files = find_generated_certificates() @@ -416,19 +458,19 @@ def test_pkcs12_files(): def validate_certificate_chain_real(cert_files): """Validate actual Ansible-generated certificate chain""" # Load CA certificate - with open(cert_files['ca_cert'], 'rb') as f: + with open(cert_files["ca_cert"], "rb") as f: ca_cert_data = f.read() ca_certificate = x509.load_pem_x509_certificate(ca_cert_data) # Test that all other certificates are signed by the CA - other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']] + other_certs = [f for f in cert_files["server_certs"] if f != cert_files["ca_cert"]] if not other_certs: print("⚠ No client/server certificates found to validate") return for cert_path in other_certs: - with open(cert_path, 'rb') as f: + with open(cert_path, "rb") as f: cert_data = f.read() certificate = x509.load_pem_x509_certificate(cert_data) @@ -437,6 +479,7 @@ def validate_certificate_chain_real(cert_files): # Verify certificate is currently valid (not expired) from datetime import datetime + now = datetime.now(UTC) assert certificate.not_valid_before_utc <= now, f"Certificate {cert_path} not yet valid" assert certificate.not_valid_after_utc >= now, f"Certificate {cert_path} has expired" @@ -445,9 +488,10 @@ def validate_certificate_chain_real(cert_files): print("✓ All real certificates properly signed by CA") + def validate_certificate_chain_config(): """Validate certificate chain configuration in Ansible files (CI mode)""" - openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml') + openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml") if not openssl_task_file: print("⚠ Could not find openssl.yml task file") return @@ -457,15 +501,18 @@ def validate_certificate_chain_config(): # Check certificate signing configuration chain_checks = [ - ('provider: ownca', 'Certificates should be signed by own CA'), - ('ownca_path', 'CA certificate path should be specified'), - ('ownca_privatekey_path', 'CA private key path should be specified'), - ('ownca_privatekey_passphrase', 'CA private key should be password protected'), - ('certificate_validity_days: 3650', 'Certificate validity should be configurable (default 10 years)'), - ('ownca_not_after: "+{{ certificate_validity_days }}d"', 'Certificates should use configurable validity period'), - ('ownca_not_before: "-1d"', 'Certificates should have backdated start time'), - ('curve: secp384r1', 'Should use strong elliptic curve cryptography'), - ('type: ECC', 'Should use elliptic curve keys for better security') + ("provider: ownca", "Certificates should be signed by own CA"), + ("ownca_path", "CA certificate path should be specified"), + ("ownca_privatekey_path", "CA private key path should be specified"), + ("ownca_privatekey_passphrase", "CA private key should be password protected"), + ("certificate_validity_days: 3650", "Certificate validity should be configurable (default 10 years)"), + ( + 'ownca_not_after: "+{{ certificate_validity_days }}d"', + "Certificates should use configurable validity period", + ), + ('ownca_not_before: "-1d"', "Certificates should have backdated start time"), + ("curve: secp384r1", "Should use strong elliptic curve cryptography"), + ("type: ECC", "Should use elliptic curve keys for better security"), ] for check, message in chain_checks: @@ -473,6 +520,7 @@ def validate_certificate_chain_config(): print("✓ Certificate chain configuration properly set up for CA signing") + def test_certificate_chain(): """Test certificate chain - uses real certs if available, else validates config""" cert_files = find_generated_certificates() @@ -499,6 +547,7 @@ def find_ansible_file(relative_path): return None + if __name__ == "__main__": tests = [ test_openssl_version_detection, diff --git a/tests/unit/test_strongswan_templates.py b/tests/unit/test_strongswan_templates.py index 861555fc..943e7592 100644 --- a/tests/unit/test_strongswan_templates.py +++ b/tests/unit/test_strongswan_templates.py @@ -3,6 +3,7 @@ Enhanced tests for StrongSwan templates. Tests all strongswan role templates with various configurations. """ + import os import sys import uuid @@ -21,7 +22,7 @@ def mock_to_uuid(value): def mock_bool(value): """Mock the bool filter""" - return str(value).lower() in ('true', '1', 'yes', 'on') + return str(value).lower() in ("true", "1", "yes", "on") def mock_version(version_string, comparison): @@ -33,67 +34,67 @@ def mock_version(version_string, comparison): def mock_b64encode(value): """Mock base64 encoding""" import base64 + if isinstance(value, str): - value = value.encode('utf-8') - return base64.b64encode(value).decode('ascii') + value = value.encode("utf-8") + return base64.b64encode(value).decode("ascii") def mock_b64decode(value): """Mock base64 decoding""" import base64 - return base64.b64decode(value).decode('utf-8') + + return base64.b64decode(value).decode("utf-8") -def get_strongswan_test_variables(scenario='default'): +def get_strongswan_test_variables(scenario="default"): """Get test variables for StrongSwan templates with different scenarios.""" base_vars = load_test_variables() # Add StrongSwan specific variables strongswan_vars = { - 'ipsec_config_path': '/etc/ipsec.d', - 'ipsec_pki_path': '/etc/ipsec.d', - 'strongswan_enabled': True, - 'strongswan_network': '10.19.48.0/24', - 'strongswan_network_ipv6': 'fd9d:bc11:4021::/64', - 'strongswan_log_level': '2', - 'openssl_constraint_random_id': 'test-' + str(uuid.uuid4()), - 'subjectAltName': 'IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a', - 'subjectAltName_type': 'IP', - 'subjectAltName_client': 'IP:10.0.0.1', - 'ansible_default_ipv6': { - 'address': '2600:3c01::f03c:91ff:fedf:3b2a' - }, - 'openssl_version': '3.0.0', - 'p12_export_password': 'test-password', - 'ike_lifetime': '24h', - 'ipsec_lifetime': '8h', - 'ike_dpd': '30s', - 'ipsec_dead_peer_detection': True, - 'rekey_margin': '3m', - 'rekeymargin': '3m', - 'dpddelay': '35s', - 'keyexchange': 'ikev2', - 'ike_cipher': 'aes128gcm16-prfsha512-ecp256', - 'esp_cipher': 'aes128gcm16-ecp256', - 'leftsourceip': '10.19.48.1', - 'leftsubnet': '0.0.0.0/0,::/0', - 'rightsourceip': '10.19.48.2/24,fd9d:bc11:4021::2/64', + "ipsec_config_path": "/etc/ipsec.d", + "ipsec_pki_path": "/etc/ipsec.d", + "strongswan_enabled": True, + "strongswan_network": "10.19.48.0/24", + "strongswan_network_ipv6": "fd9d:bc11:4021::/64", + "strongswan_log_level": "2", + "openssl_constraint_random_id": "test-" + str(uuid.uuid4()), + "subjectAltName": "IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a", + "subjectAltName_type": "IP", + "subjectAltName_client": "IP:10.0.0.1", + "ansible_default_ipv6": {"address": "2600:3c01::f03c:91ff:fedf:3b2a"}, + "openssl_version": "3.0.0", + "p12_export_password": "test-password", + "ike_lifetime": "24h", + "ipsec_lifetime": "8h", + "ike_dpd": "30s", + "ipsec_dead_peer_detection": True, + "rekey_margin": "3m", + "rekeymargin": "3m", + "dpddelay": "35s", + "keyexchange": "ikev2", + "ike_cipher": "aes128gcm16-prfsha512-ecp256", + "esp_cipher": "aes128gcm16-ecp256", + "leftsourceip": "10.19.48.1", + "leftsubnet": "0.0.0.0/0,::/0", + "rightsourceip": "10.19.48.2/24,fd9d:bc11:4021::2/64", } # Merge with base variables test_vars = {**base_vars, **strongswan_vars} # Apply scenario-specific overrides - if scenario == 'ipv4_only': - test_vars['ipv6_support'] = False - test_vars['subjectAltName'] = 'IP:10.0.0.1' - test_vars['ansible_default_ipv6'] = None - elif scenario == 'dns_hostname': - test_vars['IP_subject_alt_name'] = 'vpn.example.com' - test_vars['subjectAltName'] = 'DNS:vpn.example.com' - test_vars['subjectAltName_type'] = 'DNS' - elif scenario == 'openssl_legacy': - test_vars['openssl_version'] = '1.1.1' + if scenario == "ipv4_only": + test_vars["ipv6_support"] = False + test_vars["subjectAltName"] = "IP:10.0.0.1" + test_vars["ansible_default_ipv6"] = None + elif scenario == "dns_hostname": + test_vars["IP_subject_alt_name"] = "vpn.example.com" + test_vars["subjectAltName"] = "DNS:vpn.example.com" + test_vars["subjectAltName_type"] = "DNS" + elif scenario == "openssl_legacy": + test_vars["openssl_version"] = "1.1.1" return test_vars @@ -101,16 +102,16 @@ def get_strongswan_test_variables(scenario='default'): def test_strongswan_templates(): """Test all StrongSwan templates with various configurations.""" templates = [ - 'roles/strongswan/templates/ipsec.conf.j2', - 'roles/strongswan/templates/ipsec.secrets.j2', - 'roles/strongswan/templates/strongswan.conf.j2', - 'roles/strongswan/templates/charon.conf.j2', - 'roles/strongswan/templates/client_ipsec.conf.j2', - 'roles/strongswan/templates/client_ipsec.secrets.j2', - 'roles/strongswan/templates/100-CustomLimitations.conf.j2', + "roles/strongswan/templates/ipsec.conf.j2", + "roles/strongswan/templates/ipsec.secrets.j2", + "roles/strongswan/templates/strongswan.conf.j2", + "roles/strongswan/templates/charon.conf.j2", + "roles/strongswan/templates/client_ipsec.conf.j2", + "roles/strongswan/templates/client_ipsec.secrets.j2", + "roles/strongswan/templates/100-CustomLimitations.conf.j2", ] - scenarios = ['default', 'ipv4_only', 'dns_hostname', 'openssl_legacy'] + scenarios = ["default", "ipv4_only", "dns_hostname", "openssl_legacy"] errors = [] tested = 0 @@ -127,21 +128,18 @@ def test_strongswan_templates(): test_vars = get_strongswan_test_variables(scenario) try: - env = Environment( - loader=FileSystemLoader(template_dir), - undefined=StrictUndefined - ) + env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined) # Add mock filters - env.filters['to_uuid'] = mock_to_uuid - env.filters['bool'] = mock_bool - env.filters['b64encode'] = mock_b64encode - env.filters['b64decode'] = mock_b64decode - env.tests['version'] = mock_version + env.filters["to_uuid"] = mock_to_uuid + env.filters["bool"] = mock_bool + env.filters["b64encode"] = mock_b64encode + env.filters["b64decode"] = mock_b64decode + env.tests["version"] = mock_version # For client templates, add item context - if 'client' in template_name: - test_vars['item'] = 'testuser' + if "client" in template_name: + test_vars["item"] = "testuser" template = env.get_template(template_name) output = template.render(**test_vars) @@ -150,16 +148,16 @@ def test_strongswan_templates(): assert len(output) > 0, f"Empty output from {template_path} ({scenario})" # Specific validations based on template - if 'ipsec.conf' in template_name and 'client' not in template_name: - assert 'conn' in output, "Missing connection definition" - if scenario != 'ipv4_only' and test_vars.get('ipv6_support'): - assert '::/0' in output or 'fd9d:bc11' in output, "Missing IPv6 configuration" + if "ipsec.conf" in template_name and "client" not in template_name: + assert "conn" in output, "Missing connection definition" + if scenario != "ipv4_only" and test_vars.get("ipv6_support"): + assert "::/0" in output or "fd9d:bc11" in output, "Missing IPv6 configuration" - if 'ipsec.secrets' in template_name: - assert 'PSK' in output or 'ECDSA' in output, "Missing authentication method" + if "ipsec.secrets" in template_name: + assert "PSK" in output or "ECDSA" in output, "Missing authentication method" - if 'strongswan.conf' in template_name: - assert 'charon' in output, "Missing charon configuration" + if "strongswan.conf" in template_name: + assert "charon" in output, "Missing charon configuration" print(f" ✅ {template_name} ({scenario})") @@ -182,7 +180,7 @@ def test_openssl_template_constraints(): # This tests the actual openssl.yml task file to ensure our fix works import yaml - openssl_path = 'roles/strongswan/tasks/openssl.yml' + openssl_path = "roles/strongswan/tasks/openssl.yml" if not os.path.exists(openssl_path): print("⚠️ OpenSSL tasks file not found") return True @@ -194,22 +192,23 @@ def test_openssl_template_constraints(): # Find the CA CSR task ca_csr_task = None for task in content: - if isinstance(task, dict) and task.get('name', '').startswith('Create certificate signing request'): + if isinstance(task, dict) and task.get("name", "").startswith("Create certificate signing request"): ca_csr_task = task break if ca_csr_task: # Check that name_constraints_permitted is properly formatted - csr_module = ca_csr_task.get('community.crypto.openssl_csr_pipe', {}) - constraints = csr_module.get('name_constraints_permitted', '') + csr_module = ca_csr_task.get("community.crypto.openssl_csr_pipe", {}) + constraints = csr_module.get("name_constraints_permitted", "") # The constraints should be a Jinja2 template without inline comments - if '#' in str(constraints): + if "#" in str(constraints): # Check if the # is within {{ }} import re - jinja_blocks = re.findall(r'\{\{.*?\}\}', str(constraints), re.DOTALL) + + jinja_blocks = re.findall(r"\{\{.*?\}\}", str(constraints), re.DOTALL) for block in jinja_blocks: - if '#' in block: + if "#" in block: print("❌ Found inline comment in Jinja2 expression") return False @@ -223,7 +222,7 @@ def test_openssl_template_constraints(): def test_mobileconfig_template(): """Test the mobileconfig template with various scenarios.""" - template_path = 'roles/strongswan/templates/mobileconfig.j2' + template_path = "roles/strongswan/templates/mobileconfig.j2" if not os.path.exists(template_path): print("⚠️ Mobileconfig template not found") @@ -237,20 +236,20 @@ def test_mobileconfig_template(): test_cases = [ { - 'name': 'iPhone with cellular on-demand', - 'algo_ondemand_cellular': 'true', - 'algo_ondemand_wifi': 'false', + "name": "iPhone with cellular on-demand", + "algo_ondemand_cellular": "true", + "algo_ondemand_wifi": "false", }, { - 'name': 'iPad with WiFi on-demand', - 'algo_ondemand_cellular': 'false', - 'algo_ondemand_wifi': 'true', - 'algo_ondemand_wifi_exclude': 'MyHomeNetwork,OfficeWiFi', + "name": "iPad with WiFi on-demand", + "algo_ondemand_cellular": "false", + "algo_ondemand_wifi": "true", + "algo_ondemand_wifi_exclude": "MyHomeNetwork,OfficeWiFi", }, { - 'name': 'Mac without on-demand', - 'algo_ondemand_cellular': 'false', - 'algo_ondemand_wifi': 'false', + "name": "Mac without on-demand", + "algo_ondemand_cellular": "false", + "algo_ondemand_wifi": "false", }, ] @@ -258,43 +257,41 @@ def test_mobileconfig_template(): for test_case in test_cases: test_vars = get_strongswan_test_variables() test_vars.update(test_case) + # Mock Ansible task result format for item class MockTaskResult: def __init__(self, content): self.stdout = content - test_vars['item'] = ('testuser', MockTaskResult('TU9DS19QS0NTMTJfQ09OVEVOVA==')) # Tuple with mock result - test_vars['PayloadContentCA_base64'] = 'TU9DS19DQV9DRVJUX0JBU0U2NA==' # Valid base64 - test_vars['PayloadContentUser_base64'] = 'TU9DS19VU0VSX0NFUlRfQkFTRTY0' # Valid base64 - test_vars['pkcs12_PayloadCertificateUUID'] = str(uuid.uuid4()) - test_vars['PayloadContent'] = 'TU9DS19QS0NTMTJfQ09OVEVOVA==' # Valid base64 for PKCS12 - test_vars['algo_server_name'] = 'test-algo-vpn' - test_vars['VPN_PayloadIdentifier'] = str(uuid.uuid4()) - test_vars['CA_PayloadIdentifier'] = str(uuid.uuid4()) - test_vars['PayloadContentCA'] = 'TU9DS19DQV9DRVJUX0NPTlRFTlQ=' # Valid base64 + test_vars["item"] = ("testuser", MockTaskResult("TU9DS19QS0NTMTJfQ09OVEVOVA==")) # Tuple with mock result + test_vars["PayloadContentCA_base64"] = "TU9DS19DQV9DRVJUX0JBU0U2NA==" # Valid base64 + test_vars["PayloadContentUser_base64"] = "TU9DS19VU0VSX0NFUlRfQkFTRTY0" # Valid base64 + test_vars["pkcs12_PayloadCertificateUUID"] = str(uuid.uuid4()) + test_vars["PayloadContent"] = "TU9DS19QS0NTMTJfQ09OVEVOVA==" # Valid base64 for PKCS12 + test_vars["algo_server_name"] = "test-algo-vpn" + test_vars["VPN_PayloadIdentifier"] = str(uuid.uuid4()) + test_vars["CA_PayloadIdentifier"] = str(uuid.uuid4()) + test_vars["PayloadContentCA"] = "TU9DS19DQV9DRVJUX0NPTlRFTlQ=" # Valid base64 try: - env = Environment( - loader=FileSystemLoader('roles/strongswan/templates'), - undefined=StrictUndefined - ) + env = Environment(loader=FileSystemLoader("roles/strongswan/templates"), undefined=StrictUndefined) # Add mock filters - env.filters['to_uuid'] = mock_to_uuid - env.filters['b64encode'] = mock_b64encode - env.filters['b64decode'] = mock_b64decode + env.filters["to_uuid"] = mock_to_uuid + env.filters["b64encode"] = mock_b64encode + env.filters["b64decode"] = mock_b64decode - template = env.get_template('mobileconfig.j2') + template = env.get_template("mobileconfig.j2") output = template.render(**test_vars) # Validate output - assert '= 12, f"Password too short: {pwd}" - assert ' ' not in pwd, f"Password contains spaces: {pwd}" + assert " " not in pwd, f"Password contains spaces: {pwd}" for pwd in invalid_passwords: issues = [] if len(pwd) < 12: issues.append("too short") - if ' ' in pwd: + if " " in pwd: issues.append("contains spaces") if not pwd: issues.append("empty") @@ -137,8 +126,8 @@ def test_ca_password_handling(): def test_user_config_generation(): """Test that user configs would be generated correctly""" - users = ['alice', 'bob', 'charlie'] - server_name = 'test-server' + users = ["alice", "bob", "charlie"] + server_name = "test-server" # Simulate config file structure for user in users: @@ -168,7 +157,7 @@ users: """ config = yaml.safe_load(test_config) - users = config.get('users', []) + users = config.get("users", []) # Check for duplicates unique_users = list(set(users)) @@ -182,7 +171,7 @@ users: duplicates.append(user) seen.add(user) - assert 'alice' in duplicates, "Duplicate 'alice' not detected" + assert "alice" in duplicates, "Duplicate 'alice' not detected" print("✓ Duplicate user handling test passed") diff --git a/tests/unit/test_wireguard_key_generation.py b/tests/unit/test_wireguard_key_generation.py index f75e3b10..1b7019cb 100644 --- a/tests/unit/test_wireguard_key_generation.py +++ b/tests/unit/test_wireguard_key_generation.py @@ -3,6 +3,7 @@ Test WireGuard key generation - focused on x25519_pubkey module integration Addresses test gap identified in tests/README.md line 63-67: WireGuard private/public key generation """ + import base64 import os import subprocess @@ -10,13 +11,13 @@ import sys import tempfile # Add library directory to path to import our custom module -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "library")) def test_wireguard_tools_available(): """Test that WireGuard tools are available for validation""" try: - result = subprocess.run(['wg', '--version'], capture_output=True, text=True) + result = subprocess.run(["wg", "--version"], capture_output=True, text=True) assert result.returncode == 0, "WireGuard tools not available" print(f"✓ WireGuard tools available: {result.stdout.strip()}") return True @@ -29,6 +30,7 @@ def test_x25519_module_import(): """Test that our custom x25519_pubkey module can be imported and used""" try: import x25519_pubkey # noqa: F401 + print("✓ x25519_pubkey module imports successfully") return True except ImportError as e: @@ -37,16 +39,17 @@ def test_x25519_module_import(): def generate_test_private_key(): """Generate a test private key using the same method as Algo""" - with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file: + with tempfile.NamedTemporaryFile(suffix=".raw", delete=False) as temp_file: raw_key_path = temp_file.name try: # Generate 32 random bytes for X25519 private key (same as community.crypto does) import secrets + raw_data = secrets.token_bytes(32) # Write raw key to file (like community.crypto openssl_privatekey with format: raw) - with open(raw_key_path, 'wb') as f: + with open(raw_key_path, "wb") as f: f.write(raw_data) assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}" @@ -83,7 +86,7 @@ def test_x25519_pubkey_from_raw_file(): def exit_json(self, **kwargs): self.result = kwargs - with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub: + with tempfile.NamedTemporaryFile(suffix=".pub", delete=False) as temp_pub: public_key_path = temp_pub.name try: @@ -95,11 +98,9 @@ def test_x25519_pubkey_from_raw_file(): try: # Mock the module call - mock_module = MockModule({ - 'private_key_path': raw_key_path, - 'public_key_path': public_key_path, - 'private_key_b64': None - }) + mock_module = MockModule( + {"private_key_path": raw_key_path, "public_key_path": public_key_path, "private_key_b64": None} + ) x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module @@ -107,8 +108,8 @@ def test_x25519_pubkey_from_raw_file(): run_module() # Check the result - assert 'public_key' in mock_module.result - assert mock_module.result['changed'] + assert "public_key" in mock_module.result + assert mock_module.result["changed"] assert os.path.exists(public_key_path) with open(public_key_path) as f: @@ -160,11 +161,7 @@ def test_x25519_pubkey_from_b64_string(): original_AnsibleModule = x25519_pubkey.AnsibleModule try: - mock_module = MockModule({ - 'private_key_b64': b64_key, - 'private_key_path': None, - 'public_key_path': None - }) + mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None}) x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module @@ -172,8 +169,8 @@ def test_x25519_pubkey_from_b64_string(): run_module() # Check the result - assert 'public_key' in mock_module.result - derived_pubkey = mock_module.result['public_key'] + assert "public_key" in mock_module.result + derived_pubkey = mock_module.result["public_key"] # Validate base64 format try: @@ -222,21 +219,17 @@ def test_wireguard_validation(): original_AnsibleModule = x25519_pubkey.AnsibleModule try: - mock_module = MockModule({ - 'private_key_b64': b64_key, - 'private_key_path': None, - 'public_key_path': None - }) + mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None}) x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module run_module() - derived_pubkey = mock_module.result['public_key'] + derived_pubkey = mock_module.result["public_key"] finally: x25519_pubkey.AnsibleModule = original_AnsibleModule - with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config: + with tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False) as temp_config: # Create a WireGuard config using our keys wg_config = f"""[Interface] PrivateKey = {b64_key} @@ -251,16 +244,12 @@ AllowedIPs = 10.19.49.2/32 try: # Test that WireGuard can parse our config - result = subprocess.run([ - 'wg-quick', 'strip', config_path - ], capture_output=True, text=True) + result = subprocess.run(["wg-quick", "strip", config_path], capture_output=True, text=True) assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}" # Test key derivation with wg pubkey command - wg_result = subprocess.run([ - 'wg', 'pubkey' - ], input=b64_key, capture_output=True, text=True) + wg_result = subprocess.run(["wg", "pubkey"], input=b64_key, capture_output=True, text=True) if wg_result.returncode == 0: wg_derived = wg_result.stdout.strip() @@ -286,8 +275,8 @@ def test_key_consistency(): raw_key_path, b64_key = generate_test_private_key() try: - def derive_pubkey_from_same_key(): + def derive_pubkey_from_same_key(): class MockModule: def __init__(self, params): self.params = params @@ -305,16 +294,18 @@ def test_key_consistency(): original_AnsibleModule = x25519_pubkey.AnsibleModule try: - mock_module = MockModule({ - 'private_key_b64': b64_key, # SAME key each time - 'private_key_path': None, - 'public_key_path': None - }) + mock_module = MockModule( + { + "private_key_b64": b64_key, # SAME key each time + "private_key_path": None, + "public_key_path": None, + } + ) x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module run_module() - return mock_module.result['public_key'] + return mock_module.result["public_key"] finally: x25519_pubkey.AnsibleModule = original_AnsibleModule diff --git a/tests/validate_jinja2_templates.py b/tests/validate_jinja2_templates.py index b5a578a1..1adfebdd 100755 --- a/tests/validate_jinja2_templates.py +++ b/tests/validate_jinja2_templates.py @@ -14,13 +14,13 @@ from pathlib import Path from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateSyntaxError, meta -def find_jinja2_templates(root_dir: str = '.') -> list[Path]: +def find_jinja2_templates(root_dir: str = ".") -> list[Path]: """Find all Jinja2 template files in the project.""" templates = [] - patterns = ['**/*.j2', '**/*.jinja2', '**/*.yml.j2', '**/*.conf.j2'] + patterns = ["**/*.j2", "**/*.jinja2", "**/*.yml.j2", "**/*.conf.j2"] # Skip these directories - skip_dirs = {'.git', '.venv', 'venv', '.env', 'configs', '__pycache__', '.cache'} + skip_dirs = {".git", ".venv", "venv", ".env", "configs", "__pycache__", ".cache"} for pattern in patterns: for path in Path(root_dir).glob(pattern): @@ -39,25 +39,25 @@ def check_inline_comments_in_expressions(template_content: str, template_path: P errors = [] # Pattern to find Jinja2 expressions - jinja_pattern = re.compile(r'\{\{.*?\}\}|\{%.*?%\}', re.DOTALL) + jinja_pattern = re.compile(r"\{\{.*?\}\}|\{%.*?%\}", re.DOTALL) for match in jinja_pattern.finditer(template_content): expression = match.group() - lines = expression.split('\n') + lines = expression.split("\n") for i, line in enumerate(lines): # Check for # that's not in a string # Simple heuristic: if # appears after non-whitespace and not in quotes - if '#' in line: + if "#" in line: # Remove quoted strings to avoid false positives - cleaned = re.sub(r'"[^"]*"', '', line) - cleaned = re.sub(r"'[^']*'", '', cleaned) + cleaned = re.sub(r'"[^"]*"', "", line) + cleaned = re.sub(r"'[^']*'", "", cleaned) - if '#' in cleaned: + if "#" in cleaned: # Check if it's likely a comment (has text after it) - hash_pos = cleaned.index('#') - if hash_pos > 0 and cleaned[hash_pos-1:hash_pos] != '\\': - line_num = template_content[:match.start()].count('\n') + i + 1 + hash_pos = cleaned.index("#") + if hash_pos > 0 and cleaned[hash_pos - 1 : hash_pos] != "\\": + line_num = template_content[: match.start()].count("\n") + i + 1 errors.append( f"{template_path}:{line_num}: Inline comment (#) found in Jinja2 expression. " f"Move comments outside the expression." @@ -83,11 +83,24 @@ def check_undefined_variables(template_path: Path) -> list[str]: # Common Ansible variables that are always available ansible_builtins = { - 'ansible_default_ipv4', 'ansible_default_ipv6', 'ansible_hostname', - 'ansible_distribution', 'ansible_distribution_version', 'ansible_facts', - 'inventory_hostname', 'hostvars', 'groups', 'group_names', - 'play_hosts', 'ansible_version', 'ansible_user', 'ansible_host', - 'item', 'ansible_loop', 'ansible_index', 'lookup' + "ansible_default_ipv4", + "ansible_default_ipv6", + "ansible_hostname", + "ansible_distribution", + "ansible_distribution_version", + "ansible_facts", + "inventory_hostname", + "hostvars", + "groups", + "group_names", + "play_hosts", + "ansible_version", + "ansible_user", + "ansible_host", + "item", + "ansible_loop", + "ansible_index", + "lookup", } # Filter out known Ansible variables @@ -95,9 +108,7 @@ def check_undefined_variables(template_path: Path) -> list[str]: # Only report if there are truly unknown variables if unknown_vars and len(unknown_vars) < 20: # Avoid noise from templates with many vars - errors.append( - f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}" - ) + errors.append(f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}") except Exception: # Don't report parse errors here, they're handled elsewhere @@ -116,9 +127,9 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]: # Skip full parsing for templates that use Ansible-specific features heavily # We still check for inline comments but skip full template parsing ansible_specific_templates = { - 'dnscrypt-proxy.toml.j2', # Uses |bool filter - 'mobileconfig.j2', # Uses |to_uuid filter and complex item structures - 'vpn-dict.j2', # Uses |to_uuid filter + "dnscrypt-proxy.toml.j2", # Uses |bool filter + "mobileconfig.j2", # Uses |to_uuid filter and complex item structures + "vpn-dict.j2", # Uses |to_uuid filter } if template_path.name in ansible_specific_templates: @@ -139,18 +150,15 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]: errors.extend(check_inline_comments_in_expressions(template_content, template_path)) # Try to parse the template - env = Environment( - loader=FileSystemLoader(template_path.parent), - undefined=StrictUndefined - ) + env = Environment(loader=FileSystemLoader(template_path.parent), undefined=StrictUndefined) # Add mock Ansible filters to avoid syntax errors - env.filters['bool'] = lambda x: x - env.filters['to_uuid'] = lambda x: x - env.filters['b64encode'] = lambda x: x - env.filters['b64decode'] = lambda x: x - env.filters['regex_replace'] = lambda x, y, z: x - env.filters['default'] = lambda x, d: x if x else d + env.filters["bool"] = lambda x: x + env.filters["to_uuid"] = lambda x: x + env.filters["b64encode"] = lambda x: x + env.filters["b64decode"] = lambda x: x + env.filters["regex_replace"] = lambda x, y, z: x + env.filters["default"] = lambda x, d: x if x else d # This will raise TemplateSyntaxError if there's a syntax problem env.get_template(template_path.name) @@ -178,18 +186,20 @@ def check_common_antipatterns(template_path: Path) -> list[str]: content = f.read() # Check for missing spaces around filters - if re.search(r'\{\{[^}]+\|[^ ]', content): + if re.search(r"\{\{[^}]+\|[^ ]", content): warnings.append(f"{template_path}: Missing space after filter pipe (|)") # Check for deprecated 'when' in Jinja2 (should use if) - if re.search(r'\{%\s*when\s+', content): + if re.search(r"\{%\s*when\s+", content): warnings.append(f"{template_path}: Use 'if' instead of 'when' in Jinja2 templates") # Check for extremely long expressions (harder to debug) - for match in re.finditer(r'\{\{(.+?)\}\}', content, re.DOTALL): + for match in re.finditer(r"\{\{(.+?)\}\}", content, re.DOTALL): if len(match.group(1)) > 200: - line_num = content[:match.start()].count('\n') + 1 - warnings.append(f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up") + line_num = content[: match.start()].count("\n") + 1 + warnings.append( + f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up" + ) except Exception: pass # Ignore errors in anti-pattern checking