mirror of
https://github.com/trailofbits/algo.git
synced 2025-09-03 10:33:13 +02:00
Apply Python linting and formatting
- Run ruff check --fix to fix linting issues - Run ruff format to ensure consistent formatting - All tests still pass after formatting changes 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
51847f3fbf
commit
15be88d28b
28 changed files with 1178 additions and 1199 deletions
|
@ -9,11 +9,9 @@ import time
|
|||
from ansible.module_utils.basic import AnsibleModule, env_fallback
|
||||
from ansible.module_utils.digital_ocean import DigitalOceanHelper
|
||||
|
||||
ANSIBLE_METADATA = {'metadata_version': '1.1',
|
||||
'status': ['preview'],
|
||||
'supported_by': 'community'}
|
||||
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
|
||||
|
||||
DOCUMENTATION = '''
|
||||
DOCUMENTATION = """
|
||||
---
|
||||
module: digital_ocean_floating_ip
|
||||
short_description: Manage DigitalOcean Floating IPs
|
||||
|
@ -44,10 +42,10 @@ notes:
|
|||
- Version 2 of DigitalOcean API is used.
|
||||
requirements:
|
||||
- "python >= 2.6"
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
EXAMPLES = '''
|
||||
EXAMPLES = """
|
||||
- name: "Create a Floating IP in region lon1"
|
||||
digital_ocean_floating_ip:
|
||||
state: present
|
||||
|
@ -63,10 +61,10 @@ EXAMPLES = '''
|
|||
state: absent
|
||||
ip: "1.2.3.4"
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
RETURN = '''
|
||||
RETURN = """
|
||||
# Digital Ocean API info https://developers.digitalocean.com/documentation/v2/#floating-ips
|
||||
data:
|
||||
description: a DigitalOcean Floating IP resource
|
||||
|
@ -106,11 +104,10 @@ data:
|
|||
"region_slug": "nyc3"
|
||||
}
|
||||
}
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class Response:
|
||||
|
||||
def __init__(self, resp, info):
|
||||
self.body = None
|
||||
if resp:
|
||||
|
@ -132,36 +129,37 @@ class Response:
|
|||
def status_code(self):
|
||||
return self.info["status"]
|
||||
|
||||
|
||||
def wait_action(module, rest, ip, action_id, timeout=10):
|
||||
end_time = time.time() + 10
|
||||
while time.time() < end_time:
|
||||
response = rest.get(f'floating_ips/{ip}/actions/{action_id}')
|
||||
response = rest.get(f"floating_ips/{ip}/actions/{action_id}")
|
||||
# status_code = response.status_code # TODO: check status_code == 200?
|
||||
status = response.json['action']['status']
|
||||
if status == 'completed':
|
||||
status = response.json["action"]["status"]
|
||||
if status == "completed":
|
||||
return True
|
||||
elif status == 'errored':
|
||||
module.fail_json(msg=f'Floating ip action error [ip: {ip}: action: {action_id}]', data=json)
|
||||
elif status == "errored":
|
||||
module.fail_json(msg=f"Floating ip action error [ip: {ip}: action: {action_id}]", data=json)
|
||||
|
||||
module.fail_json(msg=f'Floating ip action timeout [ip: {ip}: action: {action_id}]', data=json)
|
||||
module.fail_json(msg=f"Floating ip action timeout [ip: {ip}: action: {action_id}]", data=json)
|
||||
|
||||
|
||||
def core(module):
|
||||
# api_token = module.params['oauth_token'] # unused for now
|
||||
state = module.params['state']
|
||||
ip = module.params['ip']
|
||||
droplet_id = module.params['droplet_id']
|
||||
state = module.params["state"]
|
||||
ip = module.params["ip"]
|
||||
droplet_id = module.params["droplet_id"]
|
||||
|
||||
rest = DigitalOceanHelper(module)
|
||||
|
||||
if state in ('present'):
|
||||
if droplet_id is not None and module.params['ip'] is not None:
|
||||
if state in ("present"):
|
||||
if droplet_id is not None and module.params["ip"] is not None:
|
||||
# Lets try to associate the ip to the specified droplet
|
||||
associate_floating_ips(module, rest)
|
||||
else:
|
||||
create_floating_ips(module, rest)
|
||||
|
||||
elif state in ('absent'):
|
||||
elif state in ("absent"):
|
||||
response = rest.delete(f"floating_ips/{ip}")
|
||||
status_code = response.status_code
|
||||
json_data = response.json
|
||||
|
@ -174,65 +172,68 @@ def core(module):
|
|||
|
||||
|
||||
def get_floating_ip_details(module, rest):
|
||||
ip = module.params['ip']
|
||||
ip = module.params["ip"]
|
||||
|
||||
response = rest.get(f"floating_ips/{ip}")
|
||||
status_code = response.status_code
|
||||
json_data = response.json
|
||||
if status_code == 200:
|
||||
return json_data['floating_ip']
|
||||
return json_data["floating_ip"]
|
||||
else:
|
||||
module.fail_json(msg="Error assigning floating ip [{}: {}]".format(
|
||||
status_code, json_data["message"]), region=module.params['region'])
|
||||
module.fail_json(
|
||||
msg="Error assigning floating ip [{}: {}]".format(status_code, json_data["message"]),
|
||||
region=module.params["region"],
|
||||
)
|
||||
|
||||
|
||||
def assign_floating_id_to_droplet(module, rest):
|
||||
ip = module.params['ip']
|
||||
ip = module.params["ip"]
|
||||
|
||||
payload = {
|
||||
"type": "assign",
|
||||
"droplet_id": module.params['droplet_id'],
|
||||
"droplet_id": module.params["droplet_id"],
|
||||
}
|
||||
|
||||
response = rest.post(f"floating_ips/{ip}/actions", data=payload)
|
||||
status_code = response.status_code
|
||||
json_data = response.json
|
||||
if status_code == 201:
|
||||
wait_action(module, rest, ip, json_data['action']['id'])
|
||||
wait_action(module, rest, ip, json_data["action"]["id"])
|
||||
|
||||
module.exit_json(changed=True, data=json_data)
|
||||
else:
|
||||
module.fail_json(msg="Error creating floating ip [{}: {}]".format(
|
||||
status_code, json_data["message"]), region=module.params['region'])
|
||||
module.fail_json(
|
||||
msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
|
||||
region=module.params["region"],
|
||||
)
|
||||
|
||||
|
||||
def associate_floating_ips(module, rest):
|
||||
floating_ip = get_floating_ip_details(module, rest)
|
||||
droplet = floating_ip['droplet']
|
||||
droplet = floating_ip["droplet"]
|
||||
|
||||
# TODO: If already assigned to a droplet verify if is one of the specified as valid
|
||||
if droplet is not None and str(droplet['id']) in [module.params['droplet_id']]:
|
||||
if droplet is not None and str(droplet["id"]) in [module.params["droplet_id"]]:
|
||||
module.exit_json(changed=False)
|
||||
else:
|
||||
assign_floating_id_to_droplet(module, rest)
|
||||
|
||||
|
||||
def create_floating_ips(module, rest):
|
||||
payload = {
|
||||
}
|
||||
payload = {}
|
||||
floating_ip_data = None
|
||||
|
||||
if module.params['region'] is not None:
|
||||
payload["region"] = module.params['region']
|
||||
if module.params["region"] is not None:
|
||||
payload["region"] = module.params["region"]
|
||||
|
||||
if module.params['droplet_id'] is not None:
|
||||
payload["droplet_id"] = module.params['droplet_id']
|
||||
if module.params["droplet_id"] is not None:
|
||||
payload["droplet_id"] = module.params["droplet_id"]
|
||||
|
||||
floating_ips = rest.get_paginated_data(base_url='floating_ips?', data_key_name='floating_ips')
|
||||
floating_ips = rest.get_paginated_data(base_url="floating_ips?", data_key_name="floating_ips")
|
||||
|
||||
for floating_ip in floating_ips:
|
||||
if floating_ip['droplet'] and floating_ip['droplet']['id'] == module.params['droplet_id']:
|
||||
floating_ip_data = {'floating_ip': floating_ip}
|
||||
if floating_ip["droplet"] and floating_ip["droplet"]["id"] == module.params["droplet_id"]:
|
||||
floating_ip_data = {"floating_ip": floating_ip}
|
||||
|
||||
if floating_ip_data:
|
||||
module.exit_json(changed=False, data=floating_ip_data)
|
||||
|
@ -244,36 +245,34 @@ def create_floating_ips(module, rest):
|
|||
if status_code == 202:
|
||||
module.exit_json(changed=True, data=json_data)
|
||||
else:
|
||||
module.fail_json(msg="Error creating floating ip [{}: {}]".format(
|
||||
status_code, json_data["message"]), region=module.params['region'])
|
||||
module.fail_json(
|
||||
msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
|
||||
region=module.params["region"],
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
module = AnsibleModule(
|
||||
argument_spec={
|
||||
'state': {'choices': ['present', 'absent'], 'default': 'present'},
|
||||
'ip': {'aliases': ['id'], 'required': False},
|
||||
'region': {'required': False},
|
||||
'droplet_id': {'required': False, 'type': 'int'},
|
||||
'oauth_token': {
|
||||
'no_log': True,
|
||||
"state": {"choices": ["present", "absent"], "default": "present"},
|
||||
"ip": {"aliases": ["id"], "required": False},
|
||||
"region": {"required": False},
|
||||
"droplet_id": {"required": False, "type": "int"},
|
||||
"oauth_token": {
|
||||
"no_log": True,
|
||||
# Support environment variable for DigitalOcean OAuth Token
|
||||
'fallback': (env_fallback, ['DO_API_TOKEN', 'DO_API_KEY', 'DO_OAUTH_TOKEN']),
|
||||
'required': True,
|
||||
"fallback": (env_fallback, ["DO_API_TOKEN", "DO_API_KEY", "DO_OAUTH_TOKEN"]),
|
||||
"required": True,
|
||||
},
|
||||
'validate_certs': {'type': 'bool', 'default': True},
|
||||
'timeout': {'type': 'int', 'default': 30},
|
||||
"validate_certs": {"type": "bool", "default": True},
|
||||
"timeout": {"type": "int", "default": 30},
|
||||
},
|
||||
required_if=[
|
||||
('state', 'delete', ['ip'])
|
||||
],
|
||||
mutually_exclusive=[
|
||||
['region', 'droplet_id']
|
||||
],
|
||||
required_if=[("state", "delete", ["ip"])],
|
||||
mutually_exclusive=[["region", "droplet_id"]],
|
||||
)
|
||||
|
||||
core(module)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
|
||||
|
||||
import json
|
||||
|
||||
from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
|
||||
|
@ -10,7 +9,7 @@ from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
|
|||
# Documentation
|
||||
################################################################################
|
||||
|
||||
ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported_by': 'community'}
|
||||
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
|
||||
|
||||
################################################################################
|
||||
# Main
|
||||
|
@ -18,20 +17,24 @@ ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported
|
|||
|
||||
|
||||
def main():
|
||||
module = GcpModule(argument_spec={'filters': {'type': 'list', 'elements': 'str'}, 'scope': {'required': True, 'type': 'str'}})
|
||||
module = GcpModule(
|
||||
argument_spec={"filters": {"type": "list", "elements": "str"}, "scope": {"required": True, "type": "str"}}
|
||||
)
|
||||
|
||||
if module._name == 'gcp_compute_image_facts':
|
||||
module.deprecate("The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version='2.13')
|
||||
if module._name == "gcp_compute_image_facts":
|
||||
module.deprecate(
|
||||
"The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version="2.13"
|
||||
)
|
||||
|
||||
if not module.params['scopes']:
|
||||
module.params['scopes'] = ['https://www.googleapis.com/auth/compute']
|
||||
if not module.params["scopes"]:
|
||||
module.params["scopes"] = ["https://www.googleapis.com/auth/compute"]
|
||||
|
||||
items = fetch_list(module, collection(module), query_options(module.params['filters']))
|
||||
if items.get('items'):
|
||||
items = items.get('items')
|
||||
items = fetch_list(module, collection(module), query_options(module.params["filters"]))
|
||||
if items.get("items"):
|
||||
items = items.get("items")
|
||||
else:
|
||||
items = []
|
||||
return_value = {'resources': items}
|
||||
return_value = {"resources": items}
|
||||
module.exit_json(**return_value)
|
||||
|
||||
|
||||
|
@ -40,14 +43,14 @@ def collection(module):
|
|||
|
||||
|
||||
def fetch_list(module, link, query):
|
||||
auth = GcpSession(module, 'compute')
|
||||
response = auth.get(link, params={'filter': query})
|
||||
auth = GcpSession(module, "compute")
|
||||
response = auth.get(link, params={"filter": query})
|
||||
return return_if_object(module, response)
|
||||
|
||||
|
||||
def query_options(filters):
|
||||
if not filters:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if len(filters) == 1:
|
||||
return filters[0]
|
||||
|
@ -55,12 +58,12 @@ def query_options(filters):
|
|||
queries = []
|
||||
for f in filters:
|
||||
# For multiple queries, all queries should have ()
|
||||
if f[0] != '(' and f[-1] != ')':
|
||||
queries.append("({})".format(''.join(f)))
|
||||
if f[0] != "(" and f[-1] != ")":
|
||||
queries.append("({})".format("".join(f)))
|
||||
else:
|
||||
queries.append(f)
|
||||
|
||||
return ' '.join(queries)
|
||||
return " ".join(queries)
|
||||
|
||||
|
||||
def return_if_object(module, response):
|
||||
|
@ -75,11 +78,11 @@ def return_if_object(module, response):
|
|||
try:
|
||||
module.raise_for_status(response)
|
||||
result = response.json()
|
||||
except getattr(json.decoder, 'JSONDecodeError', ValueError) as inst:
|
||||
except getattr(json.decoder, "JSONDecodeError", ValueError) as inst:
|
||||
module.fail_json(msg=f"Invalid JSON response with error: {inst}")
|
||||
|
||||
if navigate_hash(result, ['error', 'errors']):
|
||||
module.fail_json(msg=navigate_hash(result, ['error', 'errors']))
|
||||
if navigate_hash(result, ["error", "errors"]):
|
||||
module.fail_json(msg=navigate_hash(result, ["error", "errors"]))
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -3,12 +3,9 @@
|
|||
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
|
||||
|
||||
|
||||
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
|
||||
|
||||
ANSIBLE_METADATA = {'metadata_version': '1.1',
|
||||
'status': ['preview'],
|
||||
'supported_by': 'community'}
|
||||
|
||||
DOCUMENTATION = '''
|
||||
DOCUMENTATION = """
|
||||
---
|
||||
module: lightsail_region_facts
|
||||
short_description: Gather facts about AWS Lightsail regions.
|
||||
|
@ -24,15 +21,15 @@ requirements:
|
|||
extends_documentation_fragment:
|
||||
- aws
|
||||
- ec2
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
EXAMPLES = '''
|
||||
EXAMPLES = """
|
||||
# Gather facts about all regions
|
||||
- lightsail_region_facts:
|
||||
'''
|
||||
"""
|
||||
|
||||
RETURN = '''
|
||||
RETURN = """
|
||||
regions:
|
||||
returned: on success
|
||||
description: >
|
||||
|
@ -46,12 +43,13 @@ regions:
|
|||
"displayName": "Virginia",
|
||||
"name": "us-east-1"
|
||||
}]"
|
||||
'''
|
||||
"""
|
||||
|
||||
import traceback
|
||||
|
||||
try:
|
||||
import botocore
|
||||
|
||||
HAS_BOTOCORE = True
|
||||
except ImportError:
|
||||
HAS_BOTOCORE = False
|
||||
|
@ -86,18 +84,19 @@ def main():
|
|||
|
||||
client = None
|
||||
try:
|
||||
client = boto3_conn(module, conn_type='client', resource='lightsail',
|
||||
region=region, endpoint=ec2_url, **aws_connect_kwargs)
|
||||
client = boto3_conn(
|
||||
module, conn_type="client", resource="lightsail", region=region, endpoint=ec2_url, **aws_connect_kwargs
|
||||
)
|
||||
except (botocore.exceptions.ClientError, botocore.exceptions.ValidationError) as e:
|
||||
module.fail_json(msg='Failed while connecting to the lightsail service: %s' % e, exception=traceback.format_exc())
|
||||
module.fail_json(
|
||||
msg="Failed while connecting to the lightsail service: %s" % e, exception=traceback.format_exc()
|
||||
)
|
||||
|
||||
response = client.get_regions(
|
||||
includeAvailabilityZones=False
|
||||
)
|
||||
response = client.get_regions(includeAvailabilityZones=False)
|
||||
module.exit_json(changed=False, data=response)
|
||||
except (botocore.exceptions.ClientError, Exception) as e:
|
||||
module.fail_json(msg=str(e), exception=traceback.format_exc())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -9,6 +9,7 @@ from ansible.module_utils.linode import get_user_agent
|
|||
LINODE_IMP_ERR = None
|
||||
try:
|
||||
from linode_api4 import LinodeClient, StackScript
|
||||
|
||||
HAS_LINODE_DEPENDENCY = True
|
||||
except ImportError:
|
||||
LINODE_IMP_ERR = traceback.format_exc()
|
||||
|
@ -21,57 +22,47 @@ def create_stackscript(module, client, **kwargs):
|
|||
response = client.linode.stackscript_create(**kwargs)
|
||||
return response._raw_json
|
||||
except Exception as exception:
|
||||
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
|
||||
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
|
||||
|
||||
|
||||
def stackscript_available(module, client):
|
||||
"""Try to retrieve a stackscript."""
|
||||
try:
|
||||
label = module.params['label']
|
||||
desc = module.params['description']
|
||||
label = module.params["label"]
|
||||
desc = module.params["description"]
|
||||
|
||||
result = client.linode.stackscripts(StackScript.label == label,
|
||||
StackScript.description == desc,
|
||||
mine_only=True
|
||||
)
|
||||
result = client.linode.stackscripts(StackScript.label == label, StackScript.description == desc, mine_only=True)
|
||||
return result[0]
|
||||
except IndexError:
|
||||
return None
|
||||
except Exception as exception:
|
||||
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
|
||||
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
|
||||
|
||||
|
||||
def initialise_module():
|
||||
"""Initialise the module parameter specification."""
|
||||
return AnsibleModule(
|
||||
argument_spec=dict(
|
||||
label=dict(type='str', required=True),
|
||||
state=dict(
|
||||
type='str',
|
||||
required=True,
|
||||
choices=['present', 'absent']
|
||||
),
|
||||
label=dict(type="str", required=True),
|
||||
state=dict(type="str", required=True, choices=["present", "absent"]),
|
||||
access_token=dict(
|
||||
type='str',
|
||||
type="str",
|
||||
required=True,
|
||||
no_log=True,
|
||||
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']),
|
||||
fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
|
||||
),
|
||||
script=dict(type='str', required=True),
|
||||
images=dict(type='list', required=True),
|
||||
description=dict(type='str', required=False),
|
||||
public=dict(type='bool', required=False, default=False),
|
||||
script=dict(type="str", required=True),
|
||||
images=dict(type="list", required=True),
|
||||
description=dict(type="str", required=False),
|
||||
public=dict(type="bool", required=False, default=False),
|
||||
),
|
||||
supports_check_mode=False
|
||||
supports_check_mode=False,
|
||||
)
|
||||
|
||||
|
||||
def build_client(module):
|
||||
"""Build a LinodeClient."""
|
||||
return LinodeClient(
|
||||
module.params['access_token'],
|
||||
user_agent=get_user_agent('linode_v4_module')
|
||||
)
|
||||
return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -79,30 +70,31 @@ def main():
|
|||
module = initialise_module()
|
||||
|
||||
if not HAS_LINODE_DEPENDENCY:
|
||||
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR)
|
||||
module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
|
||||
|
||||
client = build_client(module)
|
||||
stackscript = stackscript_available(module, client)
|
||||
|
||||
if module.params['state'] == 'present' and stackscript is not None:
|
||||
if module.params["state"] == "present" and stackscript is not None:
|
||||
module.exit_json(changed=False, stackscript=stackscript._raw_json)
|
||||
|
||||
elif module.params['state'] == 'present' and stackscript is None:
|
||||
elif module.params["state"] == "present" and stackscript is None:
|
||||
stackscript_json = create_stackscript(
|
||||
module, client,
|
||||
label=module.params['label'],
|
||||
script=module.params['script'],
|
||||
images=module.params['images'],
|
||||
desc=module.params['description'],
|
||||
public=module.params['public'],
|
||||
module,
|
||||
client,
|
||||
label=module.params["label"],
|
||||
script=module.params["script"],
|
||||
images=module.params["images"],
|
||||
desc=module.params["description"],
|
||||
public=module.params["public"],
|
||||
)
|
||||
module.exit_json(changed=True, stackscript=stackscript_json)
|
||||
|
||||
elif module.params['state'] == 'absent' and stackscript is not None:
|
||||
elif module.params["state"] == "absent" and stackscript is not None:
|
||||
stackscript.delete()
|
||||
module.exit_json(changed=True, stackscript=stackscript._raw_json)
|
||||
|
||||
elif module.params['state'] == 'absent' and stackscript is None:
|
||||
elif module.params["state"] == "absent" and stackscript is None:
|
||||
module.exit_json(changed=False, stackscript={})
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from ansible.module_utils.linode import get_user_agent
|
|||
LINODE_IMP_ERR = None
|
||||
try:
|
||||
from linode_api4 import Instance, LinodeClient
|
||||
|
||||
HAS_LINODE_DEPENDENCY = True
|
||||
except ImportError:
|
||||
LINODE_IMP_ERR = traceback.format_exc()
|
||||
|
@ -21,82 +22,72 @@ except ImportError:
|
|||
|
||||
def create_linode(module, client, **kwargs):
|
||||
"""Creates a Linode instance and handles return format."""
|
||||
if kwargs['root_pass'] is None:
|
||||
kwargs.pop('root_pass')
|
||||
if kwargs["root_pass"] is None:
|
||||
kwargs.pop("root_pass")
|
||||
|
||||
try:
|
||||
response = client.linode.instance_create(**kwargs)
|
||||
except Exception as exception:
|
||||
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
|
||||
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
|
||||
|
||||
try:
|
||||
if isinstance(response, tuple):
|
||||
instance, root_pass = response
|
||||
instance_json = instance._raw_json
|
||||
instance_json.update({'root_pass': root_pass})
|
||||
instance_json.update({"root_pass": root_pass})
|
||||
return instance_json
|
||||
else:
|
||||
return response._raw_json
|
||||
except TypeError:
|
||||
module.fail_json(msg='Unable to parse Linode instance creation'
|
||||
' response. Please raise a bug against this'
|
||||
' module on https://github.com/ansible/ansible/issues'
|
||||
)
|
||||
module.fail_json(
|
||||
msg="Unable to parse Linode instance creation"
|
||||
" response. Please raise a bug against this"
|
||||
" module on https://github.com/ansible/ansible/issues"
|
||||
)
|
||||
|
||||
|
||||
def maybe_instance_from_label(module, client):
|
||||
"""Try to retrieve an instance based on a label."""
|
||||
try:
|
||||
label = module.params['label']
|
||||
label = module.params["label"]
|
||||
result = client.linode.instances(Instance.label == label)
|
||||
return result[0]
|
||||
except IndexError:
|
||||
return None
|
||||
except Exception as exception:
|
||||
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
|
||||
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
|
||||
|
||||
|
||||
def initialise_module():
|
||||
"""Initialise the module parameter specification."""
|
||||
return AnsibleModule(
|
||||
argument_spec=dict(
|
||||
label=dict(type='str', required=True),
|
||||
state=dict(
|
||||
type='str',
|
||||
required=True,
|
||||
choices=['present', 'absent']
|
||||
),
|
||||
label=dict(type="str", required=True),
|
||||
state=dict(type="str", required=True, choices=["present", "absent"]),
|
||||
access_token=dict(
|
||||
type='str',
|
||||
type="str",
|
||||
required=True,
|
||||
no_log=True,
|
||||
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']),
|
||||
fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
|
||||
),
|
||||
authorized_keys=dict(type='list', required=False),
|
||||
group=dict(type='str', required=False),
|
||||
image=dict(type='str', required=False),
|
||||
region=dict(type='str', required=False),
|
||||
root_pass=dict(type='str', required=False, no_log=True),
|
||||
tags=dict(type='list', required=False),
|
||||
type=dict(type='str', required=False),
|
||||
stackscript_id=dict(type='int', required=False),
|
||||
authorized_keys=dict(type="list", required=False),
|
||||
group=dict(type="str", required=False),
|
||||
image=dict(type="str", required=False),
|
||||
region=dict(type="str", required=False),
|
||||
root_pass=dict(type="str", required=False, no_log=True),
|
||||
tags=dict(type="list", required=False),
|
||||
type=dict(type="str", required=False),
|
||||
stackscript_id=dict(type="int", required=False),
|
||||
),
|
||||
supports_check_mode=False,
|
||||
required_one_of=(
|
||||
['state', 'label'],
|
||||
),
|
||||
required_together=(
|
||||
['region', 'image', 'type'],
|
||||
)
|
||||
required_one_of=(["state", "label"],),
|
||||
required_together=(["region", "image", "type"],),
|
||||
)
|
||||
|
||||
|
||||
def build_client(module):
|
||||
"""Build a LinodeClient."""
|
||||
return LinodeClient(
|
||||
module.params['access_token'],
|
||||
user_agent=get_user_agent('linode_v4_module')
|
||||
)
|
||||
return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -104,34 +95,35 @@ def main():
|
|||
module = initialise_module()
|
||||
|
||||
if not HAS_LINODE_DEPENDENCY:
|
||||
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR)
|
||||
module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
|
||||
|
||||
client = build_client(module)
|
||||
instance = maybe_instance_from_label(module, client)
|
||||
|
||||
if module.params['state'] == 'present' and instance is not None:
|
||||
if module.params["state"] == "present" and instance is not None:
|
||||
module.exit_json(changed=False, instance=instance._raw_json)
|
||||
|
||||
elif module.params['state'] == 'present' and instance is None:
|
||||
elif module.params["state"] == "present" and instance is None:
|
||||
instance_json = create_linode(
|
||||
module, client,
|
||||
authorized_keys=module.params['authorized_keys'],
|
||||
group=module.params['group'],
|
||||
image=module.params['image'],
|
||||
label=module.params['label'],
|
||||
region=module.params['region'],
|
||||
root_pass=module.params['root_pass'],
|
||||
tags=module.params['tags'],
|
||||
ltype=module.params['type'],
|
||||
stackscript_id=module.params['stackscript_id'],
|
||||
module,
|
||||
client,
|
||||
authorized_keys=module.params["authorized_keys"],
|
||||
group=module.params["group"],
|
||||
image=module.params["image"],
|
||||
label=module.params["label"],
|
||||
region=module.params["region"],
|
||||
root_pass=module.params["root_pass"],
|
||||
tags=module.params["tags"],
|
||||
ltype=module.params["type"],
|
||||
stackscript_id=module.params["stackscript_id"],
|
||||
)
|
||||
module.exit_json(changed=True, instance=instance_json)
|
||||
|
||||
elif module.params['state'] == 'absent' and instance is not None:
|
||||
elif module.params["state"] == "absent" and instance is not None:
|
||||
instance.delete()
|
||||
module.exit_json(changed=True, instance=instance._raw_json)
|
||||
|
||||
elif module.params['state'] == 'absent' and instance is None:
|
||||
elif module.params["state"] == "absent" and instance is None:
|
||||
module.exit_json(changed=False, instance={})
|
||||
|
||||
|
||||
|
|
|
@ -8,14 +8,9 @@
|
|||
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
|
||||
|
||||
|
||||
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
|
||||
|
||||
ANSIBLE_METADATA = {
|
||||
'metadata_version': '1.1',
|
||||
'status': ['preview'],
|
||||
'supported_by': 'community'
|
||||
}
|
||||
|
||||
DOCUMENTATION = '''
|
||||
DOCUMENTATION = """
|
||||
---
|
||||
module: scaleway_compute
|
||||
short_description: Scaleway compute management module
|
||||
|
@ -120,9 +115,9 @@ options:
|
|||
- If no value provided, the default security group or current security group will be used
|
||||
required: false
|
||||
version_added: "2.8"
|
||||
'''
|
||||
"""
|
||||
|
||||
EXAMPLES = '''
|
||||
EXAMPLES = """
|
||||
- name: Create a server
|
||||
scaleway_compute:
|
||||
name: foobar
|
||||
|
@ -156,10 +151,10 @@ EXAMPLES = '''
|
|||
organization: 951df375-e094-4d26-97c1-ba548eeb9c42
|
||||
region: ams1
|
||||
commercial_type: VC1S
|
||||
'''
|
||||
"""
|
||||
|
||||
RETURN = '''
|
||||
'''
|
||||
RETURN = """
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import time
|
||||
|
@ -167,19 +162,9 @@ import time
|
|||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.scaleway import SCALEWAY_LOCATION, Scaleway, scaleway_argument_spec
|
||||
|
||||
SCALEWAY_SERVER_STATES = (
|
||||
'stopped',
|
||||
'stopping',
|
||||
'starting',
|
||||
'running',
|
||||
'locked'
|
||||
)
|
||||
SCALEWAY_SERVER_STATES = ("stopped", "stopping", "starting", "running", "locked")
|
||||
|
||||
SCALEWAY_TRANSITIONS_STATES = (
|
||||
"stopping",
|
||||
"starting",
|
||||
"pending"
|
||||
)
|
||||
SCALEWAY_TRANSITIONS_STATES = ("stopping", "starting", "pending")
|
||||
|
||||
|
||||
def check_image_id(compute_api, image_id):
|
||||
|
@ -188,9 +173,11 @@ def check_image_id(compute_api, image_id):
|
|||
if response.ok and response.json:
|
||||
image_ids = [image["id"] for image in response.json["images"]]
|
||||
if image_id not in image_ids:
|
||||
compute_api.module.fail_json(msg='Error in getting image %s on %s' % (image_id, compute_api.module.params.get('api_url')))
|
||||
compute_api.module.fail_json(
|
||||
msg="Error in getting image %s on %s" % (image_id, compute_api.module.params.get("api_url"))
|
||||
)
|
||||
else:
|
||||
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get('api_url'))
|
||||
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get("api_url"))
|
||||
|
||||
|
||||
def fetch_state(compute_api, server):
|
||||
|
@ -201,7 +188,7 @@ def fetch_state(compute_api, server):
|
|||
return "absent"
|
||||
|
||||
if not response.ok:
|
||||
msg = 'Error during state fetching: (%s) %s' % (response.status_code, response.json)
|
||||
msg = "Error during state fetching: (%s) %s" % (response.status_code, response.json)
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
try:
|
||||
|
@ -243,7 +230,7 @@ def public_ip_payload(compute_api, public_ip):
|
|||
# We check that the IP we want to attach exists, if so its ID is returned
|
||||
response = compute_api.get("ips")
|
||||
if not response.ok:
|
||||
msg = 'Error during public IP validation: (%s) %s' % (response.status_code, response.json)
|
||||
msg = "Error during public IP validation: (%s) %s" % (response.status_code, response.json)
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
ip_list = []
|
||||
|
@ -260,14 +247,15 @@ def public_ip_payload(compute_api, public_ip):
|
|||
def create_server(compute_api, server):
|
||||
compute_api.module.debug("Starting a create_server")
|
||||
target_server = None
|
||||
data = {"enable_ipv6": server["enable_ipv6"],
|
||||
"tags": server["tags"],
|
||||
"commercial_type": server["commercial_type"],
|
||||
"image": server["image"],
|
||||
"dynamic_ip_required": server["dynamic_ip_required"],
|
||||
"name": server["name"],
|
||||
"organization": server["organization"]
|
||||
}
|
||||
data = {
|
||||
"enable_ipv6": server["enable_ipv6"],
|
||||
"tags": server["tags"],
|
||||
"commercial_type": server["commercial_type"],
|
||||
"image": server["image"],
|
||||
"dynamic_ip_required": server["dynamic_ip_required"],
|
||||
"name": server["name"],
|
||||
"organization": server["organization"],
|
||||
}
|
||||
|
||||
if server["boot_type"]:
|
||||
data["boot_type"] = server["boot_type"]
|
||||
|
@ -278,7 +266,7 @@ def create_server(compute_api, server):
|
|||
response = compute_api.post(path="servers", data=data)
|
||||
|
||||
if not response.ok:
|
||||
msg = 'Error during server creation: (%s) %s' % (response.status_code, response.json)
|
||||
msg = "Error during server creation: (%s) %s" % (response.status_code, response.json)
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
try:
|
||||
|
@ -304,10 +292,9 @@ def start_server(compute_api, server):
|
|||
|
||||
|
||||
def perform_action(compute_api, server, action):
|
||||
response = compute_api.post(path="servers/%s/action" % server["id"],
|
||||
data={"action": action})
|
||||
response = compute_api.post(path="servers/%s/action" % server["id"], data={"action": action})
|
||||
if not response.ok:
|
||||
msg = 'Error during server %s: (%s) %s' % (action, response.status_code, response.json)
|
||||
msg = "Error during server %s: (%s) %s" % (action, response.status_code, response.json)
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
wait_to_complete_state_transition(compute_api=compute_api, server=server)
|
||||
|
@ -319,7 +306,7 @@ def remove_server(compute_api, server):
|
|||
compute_api.module.debug("Starting remove server strategy")
|
||||
response = compute_api.delete(path="servers/%s" % server["id"])
|
||||
if not response.ok:
|
||||
msg = 'Error during server deletion: (%s) %s' % (response.status_code, response.json)
|
||||
msg = "Error during server deletion: (%s) %s" % (response.status_code, response.json)
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
wait_to_complete_state_transition(compute_api=compute_api, server=server)
|
||||
|
@ -341,14 +328,17 @@ def present_strategy(compute_api, wished_server):
|
|||
else:
|
||||
target_server = query_results[0]
|
||||
|
||||
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
|
||||
wished_server=wished_server):
|
||||
if server_attributes_should_be_changed(
|
||||
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||
):
|
||||
changed = True
|
||||
|
||||
if compute_api.module.check_mode:
|
||||
return changed, {"status": "Server %s attributes would be changed." % target_server["id"]}
|
||||
|
||||
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
|
||||
target_server = server_change_attributes(
|
||||
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||
)
|
||||
|
||||
return changed, target_server
|
||||
|
||||
|
@ -375,7 +365,7 @@ def absent_strategy(compute_api, wished_server):
|
|||
response = stop_server(compute_api=compute_api, server=target_server)
|
||||
|
||||
if not response.ok:
|
||||
err_msg = f'Error while stopping a server before removing it [{response.status_code}: {response.json}]'
|
||||
err_msg = f"Error while stopping a server before removing it [{response.status_code}: {response.json}]"
|
||||
compute_api.module.fail_json(msg=err_msg)
|
||||
|
||||
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
||||
|
@ -383,7 +373,7 @@ def absent_strategy(compute_api, wished_server):
|
|||
response = remove_server(compute_api=compute_api, server=target_server)
|
||||
|
||||
if not response.ok:
|
||||
err_msg = f'Error while removing server [{response.status_code}: {response.json}]'
|
||||
err_msg = f"Error while removing server [{response.status_code}: {response.json}]"
|
||||
compute_api.module.fail_json(msg=err_msg)
|
||||
|
||||
return changed, {"status": "Server %s deleted" % target_server["id"]}
|
||||
|
@ -403,14 +393,17 @@ def running_strategy(compute_api, wished_server):
|
|||
else:
|
||||
target_server = query_results[0]
|
||||
|
||||
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
|
||||
wished_server=wished_server):
|
||||
if server_attributes_should_be_changed(
|
||||
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||
):
|
||||
changed = True
|
||||
|
||||
if compute_api.module.check_mode:
|
||||
return changed, {"status": "Server %s attributes would be changed before running it." % target_server["id"]}
|
||||
|
||||
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
|
||||
target_server = server_change_attributes(
|
||||
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||
)
|
||||
|
||||
current_state = fetch_state(compute_api=compute_api, server=target_server)
|
||||
if current_state not in ("running", "starting"):
|
||||
|
@ -422,7 +415,7 @@ def running_strategy(compute_api, wished_server):
|
|||
|
||||
response = start_server(compute_api=compute_api, server=target_server)
|
||||
if not response.ok:
|
||||
msg = f'Error while running server [{response.status_code}: {response.json}]'
|
||||
msg = f"Error while running server [{response.status_code}: {response.json}]"
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
return changed, target_server
|
||||
|
@ -435,7 +428,6 @@ def stop_strategy(compute_api, wished_server):
|
|||
changed = False
|
||||
|
||||
if not query_results:
|
||||
|
||||
if compute_api.module.check_mode:
|
||||
return changed, {"status": "A server would be created before being stopped."}
|
||||
|
||||
|
@ -446,15 +438,19 @@ def stop_strategy(compute_api, wished_server):
|
|||
|
||||
compute_api.module.debug("stop_strategy: Servers are found.")
|
||||
|
||||
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
|
||||
wished_server=wished_server):
|
||||
if server_attributes_should_be_changed(
|
||||
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||
):
|
||||
changed = True
|
||||
|
||||
if compute_api.module.check_mode:
|
||||
return changed, {
|
||||
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]}
|
||||
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]
|
||||
}
|
||||
|
||||
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
|
||||
target_server = server_change_attributes(
|
||||
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||
)
|
||||
|
||||
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
||||
|
||||
|
@ -472,7 +468,7 @@ def stop_strategy(compute_api, wished_server):
|
|||
compute_api.module.debug(response.ok)
|
||||
|
||||
if not response.ok:
|
||||
msg = f'Error while stopping server [{response.status_code}: {response.json}]'
|
||||
msg = f"Error while stopping server [{response.status_code}: {response.json}]"
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
return changed, target_server
|
||||
|
@ -492,16 +488,19 @@ def restart_strategy(compute_api, wished_server):
|
|||
else:
|
||||
target_server = query_results[0]
|
||||
|
||||
if server_attributes_should_be_changed(compute_api=compute_api,
|
||||
target_server=target_server,
|
||||
wished_server=wished_server):
|
||||
if server_attributes_should_be_changed(
|
||||
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||
):
|
||||
changed = True
|
||||
|
||||
if compute_api.module.check_mode:
|
||||
return changed, {
|
||||
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]}
|
||||
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]
|
||||
}
|
||||
|
||||
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
|
||||
target_server = server_change_attributes(
|
||||
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||
)
|
||||
|
||||
changed = True
|
||||
if compute_api.module.check_mode:
|
||||
|
@ -513,14 +512,14 @@ def restart_strategy(compute_api, wished_server):
|
|||
response = restart_server(compute_api=compute_api, server=target_server)
|
||||
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
||||
if not response.ok:
|
||||
msg = f'Error while restarting server that was running [{response.status_code}: {response.json}].'
|
||||
msg = f"Error while restarting server that was running [{response.status_code}: {response.json}]."
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
if fetch_state(compute_api=compute_api, server=target_server) in ("stopped",):
|
||||
response = restart_server(compute_api=compute_api, server=target_server)
|
||||
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
||||
if not response.ok:
|
||||
msg = f'Error while restarting server that was stopped [{response.status_code}: {response.json}].'
|
||||
msg = f"Error while restarting server that was stopped [{response.status_code}: {response.json}]."
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
return changed, target_server
|
||||
|
@ -531,18 +530,17 @@ state_strategy = {
|
|||
"restarted": restart_strategy,
|
||||
"stopped": stop_strategy,
|
||||
"running": running_strategy,
|
||||
"absent": absent_strategy
|
||||
"absent": absent_strategy,
|
||||
}
|
||||
|
||||
|
||||
def find(compute_api, wished_server, per_page=1):
|
||||
compute_api.module.debug("Getting inside find")
|
||||
# Only the name attribute is accepted in the Compute query API
|
||||
response = compute_api.get("servers", params={"name": wished_server["name"],
|
||||
"per_page": per_page})
|
||||
response = compute_api.get("servers", params={"name": wished_server["name"], "per_page": per_page})
|
||||
|
||||
if not response.ok:
|
||||
msg = 'Error during server search: (%s) %s' % (response.status_code, response.json)
|
||||
msg = "Error during server search: (%s) %s" % (response.status_code, response.json)
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
search_results = response.json["servers"]
|
||||
|
@ -563,16 +561,22 @@ def server_attributes_should_be_changed(compute_api, target_server, wished_serve
|
|||
compute_api.module.debug("Checking if server attributes should be changed")
|
||||
compute_api.module.debug("Current Server: %s" % target_server)
|
||||
compute_api.module.debug("Wished Server: %s" % wished_server)
|
||||
debug_dict = dict((x, (target_server[x], wished_server[x]))
|
||||
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
|
||||
if x in target_server and x in wished_server)
|
||||
debug_dict = dict(
|
||||
(x, (target_server[x], wished_server[x]))
|
||||
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
|
||||
if x in target_server and x in wished_server
|
||||
)
|
||||
compute_api.module.debug("Debug dict %s" % debug_dict)
|
||||
try:
|
||||
for key in PATCH_MUTABLE_SERVER_ATTRIBUTES:
|
||||
if key in target_server and key in wished_server:
|
||||
# When you are working with dict, only ID matter as we ask user to put only the resource ID in the playbook
|
||||
if isinstance(target_server[key], dict) and wished_server[key] and "id" in target_server[key].keys(
|
||||
) and target_server[key]["id"] != wished_server[key]:
|
||||
if (
|
||||
isinstance(target_server[key], dict)
|
||||
and wished_server[key]
|
||||
and "id" in target_server[key].keys()
|
||||
and target_server[key]["id"] != wished_server[key]
|
||||
):
|
||||
return True
|
||||
# Handling other structure compare simply the two objects content
|
||||
elif not isinstance(target_server[key], dict) and target_server[key] != wished_server[key]:
|
||||
|
@ -598,10 +602,9 @@ def server_change_attributes(compute_api, target_server, wished_server):
|
|||
elif not isinstance(target_server[key], dict):
|
||||
patch_payload[key] = wished_server[key]
|
||||
|
||||
response = compute_api.patch(path="servers/%s" % target_server["id"],
|
||||
data=patch_payload)
|
||||
response = compute_api.patch(path="servers/%s" % target_server["id"], data=patch_payload)
|
||||
if not response.ok:
|
||||
msg = 'Error during server attributes patching: (%s) %s' % (response.status_code, response.json)
|
||||
msg = "Error during server attributes patching: (%s) %s" % (response.status_code, response.json)
|
||||
compute_api.module.fail_json(msg=msg)
|
||||
|
||||
try:
|
||||
|
@ -625,9 +628,9 @@ def core(module):
|
|||
"boot_type": module.params["boot_type"],
|
||||
"tags": module.params["tags"],
|
||||
"organization": module.params["organization"],
|
||||
"security_group": module.params["security_group"]
|
||||
"security_group": module.params["security_group"],
|
||||
}
|
||||
module.params['api_url'] = SCALEWAY_LOCATION[region]["api_endpoint"]
|
||||
module.params["api_url"] = SCALEWAY_LOCATION[region]["api_endpoint"]
|
||||
|
||||
compute_api = Scaleway(module=module)
|
||||
|
||||
|
@ -643,22 +646,24 @@ def core(module):
|
|||
|
||||
def main():
|
||||
argument_spec = scaleway_argument_spec()
|
||||
argument_spec.update(dict(
|
||||
image=dict(required=True),
|
||||
name=dict(),
|
||||
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
|
||||
commercial_type=dict(required=True),
|
||||
enable_ipv6=dict(default=False, type="bool"),
|
||||
boot_type=dict(choices=['bootscript', 'local']),
|
||||
public_ip=dict(default="absent"),
|
||||
state=dict(choices=state_strategy.keys(), default='present'),
|
||||
tags=dict(type="list", default=[]),
|
||||
organization=dict(required=True),
|
||||
wait=dict(type="bool", default=False),
|
||||
wait_timeout=dict(type="int", default=300),
|
||||
wait_sleep_time=dict(type="int", default=3),
|
||||
security_group=dict(),
|
||||
))
|
||||
argument_spec.update(
|
||||
dict(
|
||||
image=dict(required=True),
|
||||
name=dict(),
|
||||
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
|
||||
commercial_type=dict(required=True),
|
||||
enable_ipv6=dict(default=False, type="bool"),
|
||||
boot_type=dict(choices=["bootscript", "local"]),
|
||||
public_ip=dict(default="absent"),
|
||||
state=dict(choices=state_strategy.keys(), default="present"),
|
||||
tags=dict(type="list", default=[]),
|
||||
organization=dict(required=True),
|
||||
wait=dict(type="bool", default=False),
|
||||
wait_timeout=dict(type="int", default=300),
|
||||
wait_sleep_time=dict(type="int", default=3),
|
||||
security_group=dict(),
|
||||
)
|
||||
)
|
||||
module = AnsibleModule(
|
||||
argument_spec=argument_spec,
|
||||
supports_check_mode=True,
|
||||
|
@ -667,5 +672,5 @@ def main():
|
|||
core(module)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -32,32 +32,30 @@ Returns:
|
|||
def run_module():
|
||||
"""
|
||||
Main execution function for the x25519_pubkey Ansible module.
|
||||
|
||||
|
||||
Handles parameter validation, private key processing, public key derivation,
|
||||
and optional file output with idempotent behavior.
|
||||
"""
|
||||
module_args = {
|
||||
'private_key_b64': {'type': 'str', 'required': False},
|
||||
'private_key_path': {'type': 'path', 'required': False},
|
||||
'public_key_path': {'type': 'path', 'required': False},
|
||||
"private_key_b64": {"type": "str", "required": False},
|
||||
"private_key_path": {"type": "path", "required": False},
|
||||
"public_key_path": {"type": "path", "required": False},
|
||||
}
|
||||
|
||||
result = {
|
||||
'changed': False,
|
||||
'public_key': '',
|
||||
"changed": False,
|
||||
"public_key": "",
|
||||
}
|
||||
|
||||
module = AnsibleModule(
|
||||
argument_spec=module_args,
|
||||
required_one_of=[['private_key_b64', 'private_key_path']],
|
||||
supports_check_mode=True
|
||||
argument_spec=module_args, required_one_of=[["private_key_b64", "private_key_path"]], supports_check_mode=True
|
||||
)
|
||||
|
||||
priv_b64 = None
|
||||
|
||||
if module.params['private_key_path']:
|
||||
if module.params["private_key_path"]:
|
||||
try:
|
||||
with open(module.params['private_key_path'], 'rb') as f:
|
||||
with open(module.params["private_key_path"], "rb") as f:
|
||||
data = f.read()
|
||||
try:
|
||||
# First attempt: assume file contains base64 text data
|
||||
|
@ -71,12 +69,14 @@ def run_module():
|
|||
# whitespace-like bytes (0x09, 0x0A, etc.) that must be preserved
|
||||
# Stripping would corrupt the key and cause "got 31 bytes" errors
|
||||
if len(data) != 32:
|
||||
module.fail_json(msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes")
|
||||
module.fail_json(
|
||||
msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes"
|
||||
)
|
||||
priv_b64 = base64.b64encode(data).decode()
|
||||
except OSError as e:
|
||||
module.fail_json(msg=f"Failed to read private key file: {e}")
|
||||
else:
|
||||
priv_b64 = module.params['private_key_b64']
|
||||
priv_b64 = module.params["private_key_b64"]
|
||||
|
||||
# Validate input parameters
|
||||
if not priv_b64:
|
||||
|
@ -93,15 +93,12 @@ def run_module():
|
|||
try:
|
||||
priv_key = x25519.X25519PrivateKey.from_private_bytes(priv_raw)
|
||||
pub_key = priv_key.public_key()
|
||||
pub_raw = pub_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw
|
||||
)
|
||||
pub_raw = pub_key.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
|
||||
pub_b64 = base64.b64encode(pub_raw).decode()
|
||||
result['public_key'] = pub_b64
|
||||
result["public_key"] = pub_b64
|
||||
|
||||
if module.params['public_key_path']:
|
||||
pub_path = module.params['public_key_path']
|
||||
if module.params["public_key_path"]:
|
||||
pub_path = module.params["public_key_path"]
|
||||
existing = None
|
||||
|
||||
try:
|
||||
|
@ -112,13 +109,13 @@ def run_module():
|
|||
|
||||
if existing != pub_b64:
|
||||
try:
|
||||
with open(pub_path, 'w') as f:
|
||||
with open(pub_path, "w") as f:
|
||||
f.write(pub_b64)
|
||||
result['changed'] = True
|
||||
result["changed"] = True
|
||||
except OSError as e:
|
||||
module.fail_json(msg=f"Failed to write public key file: {e}")
|
||||
|
||||
result['public_key_path'] = pub_path
|
||||
result["public_key_path"] = pub_path
|
||||
|
||||
except Exception as e:
|
||||
module.fail_json(msg=f"Failed to derive public key: {e}")
|
||||
|
@ -131,5 +128,5 @@ def main():
|
|||
run_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
Track test effectiveness by analyzing CI failures and correlating with issues/PRs
|
||||
This helps identify which tests actually catch bugs vs just failing randomly
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
|
@ -12,7 +13,7 @@ from pathlib import Path
|
|||
|
||||
def get_github_api_data(endpoint):
|
||||
"""Fetch data from GitHub API"""
|
||||
cmd = ['gh', 'api', endpoint]
|
||||
cmd = ["gh", "api", endpoint]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"Error fetching {endpoint}: {result.stderr}")
|
||||
|
@ -25,40 +26,38 @@ def analyze_workflow_runs(repo_owner, repo_name, days_back=30):
|
|||
since = (datetime.now() - timedelta(days=days_back)).isoformat()
|
||||
|
||||
# Get workflow runs
|
||||
runs = get_github_api_data(
|
||||
f'/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure'
|
||||
)
|
||||
runs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure")
|
||||
|
||||
if not runs:
|
||||
return {}
|
||||
|
||||
test_failures = defaultdict(list)
|
||||
|
||||
for run in runs.get('workflow_runs', []):
|
||||
for run in runs.get("workflow_runs", []):
|
||||
# Get jobs for this run
|
||||
jobs = get_github_api_data(
|
||||
f'/repos/{repo_owner}/{repo_name}/actions/runs/{run["id"]}/jobs'
|
||||
)
|
||||
jobs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs/{run['id']}/jobs")
|
||||
|
||||
if not jobs:
|
||||
continue
|
||||
|
||||
for job in jobs.get('jobs', []):
|
||||
if job['conclusion'] == 'failure':
|
||||
for job in jobs.get("jobs", []):
|
||||
if job["conclusion"] == "failure":
|
||||
# Try to extract which test failed from logs
|
||||
logs_url = job.get('logs_url')
|
||||
logs_url = job.get("logs_url")
|
||||
if logs_url:
|
||||
# Parse logs to find test failures
|
||||
test_name = extract_failed_test(job['name'], run['id'])
|
||||
test_name = extract_failed_test(job["name"], run["id"])
|
||||
if test_name:
|
||||
test_failures[test_name].append({
|
||||
'run_id': run['id'],
|
||||
'run_number': run['run_number'],
|
||||
'date': run['created_at'],
|
||||
'branch': run['head_branch'],
|
||||
'commit': run['head_sha'][:7],
|
||||
'pr': extract_pr_number(run)
|
||||
})
|
||||
test_failures[test_name].append(
|
||||
{
|
||||
"run_id": run["id"],
|
||||
"run_number": run["run_number"],
|
||||
"date": run["created_at"],
|
||||
"branch": run["head_branch"],
|
||||
"commit": run["head_sha"][:7],
|
||||
"pr": extract_pr_number(run),
|
||||
}
|
||||
)
|
||||
|
||||
return test_failures
|
||||
|
||||
|
@ -67,47 +66,44 @@ def extract_failed_test(job_name, run_id):
|
|||
"""Extract test name from job - this is simplified"""
|
||||
# Map job names to test categories
|
||||
job_to_tests = {
|
||||
'Basic sanity tests': 'test_basic_sanity',
|
||||
'Ansible syntax check': 'ansible_syntax',
|
||||
'Docker build test': 'docker_tests',
|
||||
'Configuration generation test': 'config_generation',
|
||||
'Ansible dry-run validation': 'ansible_dry_run'
|
||||
"Basic sanity tests": "test_basic_sanity",
|
||||
"Ansible syntax check": "ansible_syntax",
|
||||
"Docker build test": "docker_tests",
|
||||
"Configuration generation test": "config_generation",
|
||||
"Ansible dry-run validation": "ansible_dry_run",
|
||||
}
|
||||
return job_to_tests.get(job_name, job_name)
|
||||
|
||||
|
||||
def extract_pr_number(run):
|
||||
"""Extract PR number from workflow run"""
|
||||
for pr in run.get('pull_requests', []):
|
||||
return pr['number']
|
||||
for pr in run.get("pull_requests", []):
|
||||
return pr["number"]
|
||||
return None
|
||||
|
||||
|
||||
def correlate_with_issues(repo_owner, repo_name, test_failures):
|
||||
"""Correlate test failures with issues/PRs that fixed them"""
|
||||
correlations = defaultdict(lambda: {'caught_bugs': 0, 'false_positives': 0})
|
||||
correlations = defaultdict(lambda: {"caught_bugs": 0, "false_positives": 0})
|
||||
|
||||
for test_name, failures in test_failures.items():
|
||||
for failure in failures:
|
||||
if failure['pr']:
|
||||
if failure["pr"]:
|
||||
# Check if PR was merged (indicating it fixed a real issue)
|
||||
pr = get_github_api_data(
|
||||
f'/repos/{repo_owner}/{repo_name}/pulls/{failure["pr"]}'
|
||||
)
|
||||
pr = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/pulls/{failure['pr']}")
|
||||
|
||||
if pr and pr.get('merged'):
|
||||
if pr and pr.get("merged"):
|
||||
# Check PR title/body for bug indicators
|
||||
title = pr.get('title', '').lower()
|
||||
body = pr.get('body', '').lower()
|
||||
title = pr.get("title", "").lower()
|
||||
body = pr.get("body", "").lower()
|
||||
|
||||
bug_keywords = ['fix', 'bug', 'error', 'issue', 'broken', 'fail']
|
||||
is_bug_fix = any(keyword in title or keyword in body
|
||||
for keyword in bug_keywords)
|
||||
bug_keywords = ["fix", "bug", "error", "issue", "broken", "fail"]
|
||||
is_bug_fix = any(keyword in title or keyword in body for keyword in bug_keywords)
|
||||
|
||||
if is_bug_fix:
|
||||
correlations[test_name]['caught_bugs'] += 1
|
||||
correlations[test_name]["caught_bugs"] += 1
|
||||
else:
|
||||
correlations[test_name]['false_positives'] += 1
|
||||
correlations[test_name]["false_positives"] += 1
|
||||
|
||||
return correlations
|
||||
|
||||
|
@ -133,8 +129,8 @@ def generate_effectiveness_report(test_failures, correlations):
|
|||
scores = []
|
||||
for test_name, failures in test_failures.items():
|
||||
failure_count = len(failures)
|
||||
caught = correlations[test_name]['caught_bugs']
|
||||
false_pos = correlations[test_name]['false_positives']
|
||||
caught = correlations[test_name]["caught_bugs"]
|
||||
false_pos = correlations[test_name]["false_positives"]
|
||||
|
||||
# Calculate effectiveness (bugs caught / total failures)
|
||||
if failure_count > 0:
|
||||
|
@ -159,12 +155,12 @@ def generate_effectiveness_report(test_failures, correlations):
|
|||
elif effectiveness > 0.8:
|
||||
report.append(f"- ✅ `{test_name}` is highly effective ({effectiveness:.0%})")
|
||||
|
||||
return '\n'.join(report)
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
def save_metrics(test_failures, correlations):
|
||||
"""Save metrics to JSON for historical tracking"""
|
||||
metrics_file = Path('.metrics/test-effectiveness.json')
|
||||
metrics_file = Path(".metrics/test-effectiveness.json")
|
||||
metrics_file.parent.mkdir(exist_ok=True)
|
||||
|
||||
# Load existing metrics
|
||||
|
@ -176,38 +172,34 @@ def save_metrics(test_failures, correlations):
|
|||
|
||||
# Add current metrics
|
||||
current = {
|
||||
'date': datetime.now().isoformat(),
|
||||
'test_failures': {
|
||||
test: len(failures) for test, failures in test_failures.items()
|
||||
},
|
||||
'effectiveness': {
|
||||
"date": datetime.now().isoformat(),
|
||||
"test_failures": {test: len(failures) for test, failures in test_failures.items()},
|
||||
"effectiveness": {
|
||||
test: {
|
||||
'caught_bugs': data['caught_bugs'],
|
||||
'false_positives': data['false_positives'],
|
||||
'score': data['caught_bugs'] / (data['caught_bugs'] + data['false_positives'])
|
||||
if (data['caught_bugs'] + data['false_positives']) > 0 else 0
|
||||
"caught_bugs": data["caught_bugs"],
|
||||
"false_positives": data["false_positives"],
|
||||
"score": data["caught_bugs"] / (data["caught_bugs"] + data["false_positives"])
|
||||
if (data["caught_bugs"] + data["false_positives"]) > 0
|
||||
else 0,
|
||||
}
|
||||
for test, data in correlations.items()
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
historical.append(current)
|
||||
|
||||
# Keep last 12 months of data
|
||||
cutoff = datetime.now() - timedelta(days=365)
|
||||
historical = [
|
||||
h for h in historical
|
||||
if datetime.fromisoformat(h['date']) > cutoff
|
||||
]
|
||||
historical = [h for h in historical if datetime.fromisoformat(h["date"]) > cutoff]
|
||||
|
||||
with open(metrics_file, 'w') as f:
|
||||
with open(metrics_file, "w") as f:
|
||||
json.dump(historical, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# Configure these for your repo
|
||||
REPO_OWNER = 'trailofbits'
|
||||
REPO_NAME = 'algo'
|
||||
REPO_OWNER = "trailofbits"
|
||||
REPO_NAME = "algo"
|
||||
|
||||
print("Analyzing test effectiveness...")
|
||||
|
||||
|
@ -223,9 +215,9 @@ if __name__ == '__main__':
|
|||
print("\n" + report)
|
||||
|
||||
# Save report
|
||||
report_file = Path('.metrics/test-effectiveness-report.md')
|
||||
report_file = Path(".metrics/test-effectiveness-report.md")
|
||||
report_file.parent.mkdir(exist_ok=True)
|
||||
with open(report_file, 'w') as f:
|
||||
with open(report_file, "w") as f:
|
||||
f.write(report)
|
||||
print(f"\nReport saved to: {report_file}")
|
||||
|
||||
|
|
3
tests/fixtures/__init__.py
vendored
3
tests/fixtures/__init__.py
vendored
|
@ -1,4 +1,5 @@
|
|||
"""Test fixtures for Algo unit tests"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
@ -6,7 +7,7 @@ import yaml
|
|||
|
||||
def load_test_variables():
|
||||
"""Load test variables from YAML fixture"""
|
||||
fixture_path = Path(__file__).parent / 'test_variables.yml'
|
||||
fixture_path = Path(__file__).parent / "test_variables.yml"
|
||||
with open(fixture_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
|
|
@ -2,49 +2,56 @@
|
|||
"""
|
||||
Wrapper for Ansible's service module that always succeeds for known services
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
# Parse module arguments
|
||||
args = json.loads(sys.stdin.read())
|
||||
module_args = args.get('ANSIBLE_MODULE_ARGS', {})
|
||||
module_args = args.get("ANSIBLE_MODULE_ARGS", {})
|
||||
|
||||
service_name = module_args.get('name', '')
|
||||
state = module_args.get('state', 'started')
|
||||
service_name = module_args.get("name", "")
|
||||
state = module_args.get("state", "started")
|
||||
|
||||
# Known services that should always succeed
|
||||
known_services = [
|
||||
'netfilter-persistent', 'iptables', 'wg-quick@wg0', 'strongswan-starter',
|
||||
'ipsec', 'apparmor', 'unattended-upgrades', 'systemd-networkd',
|
||||
'systemd-resolved', 'rsyslog', 'ipfw', 'cron'
|
||||
"netfilter-persistent",
|
||||
"iptables",
|
||||
"wg-quick@wg0",
|
||||
"strongswan-starter",
|
||||
"ipsec",
|
||||
"apparmor",
|
||||
"unattended-upgrades",
|
||||
"systemd-networkd",
|
||||
"systemd-resolved",
|
||||
"rsyslog",
|
||||
"ipfw",
|
||||
"cron",
|
||||
]
|
||||
|
||||
# Check if it's a known service
|
||||
service_found = False
|
||||
for svc in known_services:
|
||||
if service_name == svc or service_name.startswith(svc + '.'):
|
||||
if service_name == svc or service_name.startswith(svc + "."):
|
||||
service_found = True
|
||||
break
|
||||
|
||||
if service_found:
|
||||
# Return success
|
||||
result = {
|
||||
'changed': True if state in ['started', 'stopped', 'restarted', 'reloaded'] else False,
|
||||
'name': service_name,
|
||||
'state': state,
|
||||
'status': {
|
||||
'LoadState': 'loaded',
|
||||
'ActiveState': 'active' if state != 'stopped' else 'inactive',
|
||||
'SubState': 'running' if state != 'stopped' else 'dead'
|
||||
}
|
||||
"changed": True if state in ["started", "stopped", "restarted", "reloaded"] else False,
|
||||
"name": service_name,
|
||||
"state": state,
|
||||
"status": {
|
||||
"LoadState": "loaded",
|
||||
"ActiveState": "active" if state != "stopped" else "inactive",
|
||||
"SubState": "running" if state != "stopped" else "dead",
|
||||
},
|
||||
}
|
||||
print(json.dumps(result))
|
||||
sys.exit(0)
|
||||
else:
|
||||
# Service not found
|
||||
error = {
|
||||
'failed': True,
|
||||
'msg': f'Could not find the requested service {service_name}: '
|
||||
}
|
||||
error = {"failed": True, "msg": f"Could not find the requested service {service_name}: "}
|
||||
print(json.dumps(error))
|
||||
sys.exit(1)
|
||||
|
|
|
@ -9,44 +9,44 @@ from ansible.module_utils.basic import AnsibleModule
|
|||
def main():
|
||||
module = AnsibleModule(
|
||||
argument_spec={
|
||||
'name': {'type': 'list', 'aliases': ['pkg', 'package']},
|
||||
'state': {'type': 'str', 'default': 'present', 'choices': ['present', 'absent', 'latest', 'build-dep', 'fixed']},
|
||||
'update_cache': {'type': 'bool', 'default': False},
|
||||
'cache_valid_time': {'type': 'int', 'default': 0},
|
||||
'install_recommends': {'type': 'bool'},
|
||||
'force': {'type': 'bool', 'default': False},
|
||||
'allow_unauthenticated': {'type': 'bool', 'default': False},
|
||||
'allow_downgrade': {'type': 'bool', 'default': False},
|
||||
'allow_change_held_packages': {'type': 'bool', 'default': False},
|
||||
'dpkg_options': {'type': 'str', 'default': 'force-confdef,force-confold'},
|
||||
'autoremove': {'type': 'bool', 'default': False},
|
||||
'purge': {'type': 'bool', 'default': False},
|
||||
'force_apt_get': {'type': 'bool', 'default': False},
|
||||
"name": {"type": "list", "aliases": ["pkg", "package"]},
|
||||
"state": {
|
||||
"type": "str",
|
||||
"default": "present",
|
||||
"choices": ["present", "absent", "latest", "build-dep", "fixed"],
|
||||
},
|
||||
"update_cache": {"type": "bool", "default": False},
|
||||
"cache_valid_time": {"type": "int", "default": 0},
|
||||
"install_recommends": {"type": "bool"},
|
||||
"force": {"type": "bool", "default": False},
|
||||
"allow_unauthenticated": {"type": "bool", "default": False},
|
||||
"allow_downgrade": {"type": "bool", "default": False},
|
||||
"allow_change_held_packages": {"type": "bool", "default": False},
|
||||
"dpkg_options": {"type": "str", "default": "force-confdef,force-confold"},
|
||||
"autoremove": {"type": "bool", "default": False},
|
||||
"purge": {"type": "bool", "default": False},
|
||||
"force_apt_get": {"type": "bool", "default": False},
|
||||
},
|
||||
supports_check_mode=True
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
name = module.params['name']
|
||||
state = module.params['state']
|
||||
update_cache = module.params['update_cache']
|
||||
name = module.params["name"]
|
||||
state = module.params["state"]
|
||||
update_cache = module.params["update_cache"]
|
||||
|
||||
result = {
|
||||
'changed': False,
|
||||
'cache_updated': False,
|
||||
'cache_update_time': 0
|
||||
}
|
||||
result = {"changed": False, "cache_updated": False, "cache_update_time": 0}
|
||||
|
||||
# Log the operation
|
||||
with open('/var/log/mock-apt-module.log', 'a') as f:
|
||||
with open("/var/log/mock-apt-module.log", "a") as f:
|
||||
f.write(f"apt module called: name={name}, state={state}, update_cache={update_cache}\n")
|
||||
|
||||
# Handle cache update
|
||||
if update_cache:
|
||||
# In Docker, apt-get update was already run in entrypoint
|
||||
# Just pretend it succeeded
|
||||
result['cache_updated'] = True
|
||||
result['cache_update_time'] = 1754231778 # Fixed timestamp
|
||||
result['changed'] = True
|
||||
result["cache_updated"] = True
|
||||
result["cache_update_time"] = 1754231778 # Fixed timestamp
|
||||
result["changed"] = True
|
||||
|
||||
# Handle package installation/removal
|
||||
if name:
|
||||
|
@ -56,40 +56,41 @@ def main():
|
|||
installed_packages = []
|
||||
for pkg in packages:
|
||||
# Use dpkg to check if package is installed
|
||||
check_cmd = ['dpkg', '-s', pkg]
|
||||
check_cmd = ["dpkg", "-s", pkg]
|
||||
rc = subprocess.run(check_cmd, capture_output=True)
|
||||
if rc.returncode == 0:
|
||||
installed_packages.append(pkg)
|
||||
|
||||
if state in ['present', 'latest']:
|
||||
if state in ["present", "latest"]:
|
||||
# Check if we need to install anything
|
||||
missing_packages = [p for p in packages if p not in installed_packages]
|
||||
|
||||
if missing_packages:
|
||||
# Log what we would install
|
||||
with open('/var/log/mock-apt-module.log', 'a') as f:
|
||||
with open("/var/log/mock-apt-module.log", "a") as f:
|
||||
f.write(f"Would install packages: {missing_packages}\n")
|
||||
|
||||
# For our test purposes, these packages are pre-installed in Docker
|
||||
# Just report success
|
||||
result['changed'] = True
|
||||
result['stdout'] = f"Mock: Packages {missing_packages} are already available"
|
||||
result['stderr'] = ""
|
||||
result["changed"] = True
|
||||
result["stdout"] = f"Mock: Packages {missing_packages} are already available"
|
||||
result["stderr"] = ""
|
||||
else:
|
||||
result['stdout'] = "All packages are already installed"
|
||||
result["stdout"] = "All packages are already installed"
|
||||
|
||||
elif state == 'absent':
|
||||
elif state == "absent":
|
||||
# Check if we need to remove anything
|
||||
present_packages = [p for p in packages if p in installed_packages]
|
||||
|
||||
if present_packages:
|
||||
result['changed'] = True
|
||||
result['stdout'] = f"Mock: Would remove packages {present_packages}"
|
||||
result["changed"] = True
|
||||
result["stdout"] = f"Mock: Would remove packages {present_packages}"
|
||||
else:
|
||||
result['stdout'] = "No packages to remove"
|
||||
result["stdout"] = "No packages to remove"
|
||||
|
||||
# Always report success for our testing
|
||||
module.exit_json(**result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -9,79 +9,72 @@ from ansible.module_utils.basic import AnsibleModule
|
|||
def main():
|
||||
module = AnsibleModule(
|
||||
argument_spec={
|
||||
'_raw_params': {'type': 'str'},
|
||||
'cmd': {'type': 'str'},
|
||||
'creates': {'type': 'path'},
|
||||
'removes': {'type': 'path'},
|
||||
'chdir': {'type': 'path'},
|
||||
'executable': {'type': 'path'},
|
||||
'warn': {'type': 'bool', 'default': False},
|
||||
'stdin': {'type': 'str'},
|
||||
'stdin_add_newline': {'type': 'bool', 'default': True},
|
||||
'strip_empty_ends': {'type': 'bool', 'default': True},
|
||||
'_uses_shell': {'type': 'bool', 'default': False},
|
||||
"_raw_params": {"type": "str"},
|
||||
"cmd": {"type": "str"},
|
||||
"creates": {"type": "path"},
|
||||
"removes": {"type": "path"},
|
||||
"chdir": {"type": "path"},
|
||||
"executable": {"type": "path"},
|
||||
"warn": {"type": "bool", "default": False},
|
||||
"stdin": {"type": "str"},
|
||||
"stdin_add_newline": {"type": "bool", "default": True},
|
||||
"strip_empty_ends": {"type": "bool", "default": True},
|
||||
"_uses_shell": {"type": "bool", "default": False},
|
||||
},
|
||||
supports_check_mode=True
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
# Get the command
|
||||
raw_params = module.params.get('_raw_params')
|
||||
cmd = module.params.get('cmd') or raw_params
|
||||
raw_params = module.params.get("_raw_params")
|
||||
cmd = module.params.get("cmd") or raw_params
|
||||
|
||||
if not cmd:
|
||||
module.fail_json(msg="no command given")
|
||||
|
||||
result = {
|
||||
'changed': False,
|
||||
'cmd': cmd,
|
||||
'rc': 0,
|
||||
'stdout': '',
|
||||
'stderr': '',
|
||||
'stdout_lines': [],
|
||||
'stderr_lines': []
|
||||
}
|
||||
result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
|
||||
|
||||
# Log the operation
|
||||
with open('/var/log/mock-command-module.log', 'a') as f:
|
||||
with open("/var/log/mock-command-module.log", "a") as f:
|
||||
f.write(f"command module called: cmd={cmd}\n")
|
||||
|
||||
# Handle specific commands
|
||||
if 'apparmor_status' in cmd:
|
||||
if "apparmor_status" in cmd:
|
||||
# Pretend apparmor is not installed/active
|
||||
result['rc'] = 127
|
||||
result['stderr'] = "apparmor_status: command not found"
|
||||
result['msg'] = "[Errno 2] No such file or directory: b'apparmor_status'"
|
||||
module.fail_json(msg=result['msg'], **result)
|
||||
elif 'netplan apply' in cmd:
|
||||
result["rc"] = 127
|
||||
result["stderr"] = "apparmor_status: command not found"
|
||||
result["msg"] = "[Errno 2] No such file or directory: b'apparmor_status'"
|
||||
module.fail_json(msg=result["msg"], **result)
|
||||
elif "netplan apply" in cmd:
|
||||
# Pretend netplan succeeded
|
||||
result['stdout'] = "Mock: netplan configuration applied"
|
||||
result['changed'] = True
|
||||
elif 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd:
|
||||
result["stdout"] = "Mock: netplan configuration applied"
|
||||
result["changed"] = True
|
||||
elif "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
|
||||
# Routing cache flush
|
||||
result['stdout'] = "1"
|
||||
result['changed'] = True
|
||||
result["stdout"] = "1"
|
||||
result["changed"] = True
|
||||
else:
|
||||
# For other commands, try to run them
|
||||
try:
|
||||
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get('chdir'))
|
||||
result['rc'] = proc.returncode
|
||||
result['stdout'] = proc.stdout
|
||||
result['stderr'] = proc.stderr
|
||||
result['stdout_lines'] = proc.stdout.splitlines()
|
||||
result['stderr_lines'] = proc.stderr.splitlines()
|
||||
result['changed'] = True
|
||||
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get("chdir"))
|
||||
result["rc"] = proc.returncode
|
||||
result["stdout"] = proc.stdout
|
||||
result["stderr"] = proc.stderr
|
||||
result["stdout_lines"] = proc.stdout.splitlines()
|
||||
result["stderr_lines"] = proc.stderr.splitlines()
|
||||
result["changed"] = True
|
||||
except Exception as e:
|
||||
result['rc'] = 1
|
||||
result['stderr'] = str(e)
|
||||
result['msg'] = str(e)
|
||||
module.fail_json(msg=result['msg'], **result)
|
||||
result["rc"] = 1
|
||||
result["stderr"] = str(e)
|
||||
result["msg"] = str(e)
|
||||
module.fail_json(msg=result["msg"], **result)
|
||||
|
||||
if result['rc'] == 0:
|
||||
if result["rc"] == 0:
|
||||
module.exit_json(**result)
|
||||
else:
|
||||
if 'msg' not in result:
|
||||
result['msg'] = f"Command failed with return code {result['rc']}"
|
||||
module.fail_json(msg=result['msg'], **result)
|
||||
if "msg" not in result:
|
||||
result["msg"] = f"Command failed with return code {result['rc']}"
|
||||
module.fail_json(msg=result["msg"], **result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -9,73 +9,71 @@ from ansible.module_utils.basic import AnsibleModule
|
|||
def main():
|
||||
module = AnsibleModule(
|
||||
argument_spec={
|
||||
'_raw_params': {'type': 'str'},
|
||||
'cmd': {'type': 'str'},
|
||||
'creates': {'type': 'path'},
|
||||
'removes': {'type': 'path'},
|
||||
'chdir': {'type': 'path'},
|
||||
'executable': {'type': 'path', 'default': '/bin/sh'},
|
||||
'warn': {'type': 'bool', 'default': False},
|
||||
'stdin': {'type': 'str'},
|
||||
'stdin_add_newline': {'type': 'bool', 'default': True},
|
||||
"_raw_params": {"type": "str"},
|
||||
"cmd": {"type": "str"},
|
||||
"creates": {"type": "path"},
|
||||
"removes": {"type": "path"},
|
||||
"chdir": {"type": "path"},
|
||||
"executable": {"type": "path", "default": "/bin/sh"},
|
||||
"warn": {"type": "bool", "default": False},
|
||||
"stdin": {"type": "str"},
|
||||
"stdin_add_newline": {"type": "bool", "default": True},
|
||||
},
|
||||
supports_check_mode=True
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
# Get the command
|
||||
raw_params = module.params.get('_raw_params')
|
||||
cmd = module.params.get('cmd') or raw_params
|
||||
raw_params = module.params.get("_raw_params")
|
||||
cmd = module.params.get("cmd") or raw_params
|
||||
|
||||
if not cmd:
|
||||
module.fail_json(msg="no command given")
|
||||
|
||||
result = {
|
||||
'changed': False,
|
||||
'cmd': cmd,
|
||||
'rc': 0,
|
||||
'stdout': '',
|
||||
'stderr': '',
|
||||
'stdout_lines': [],
|
||||
'stderr_lines': []
|
||||
}
|
||||
result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
|
||||
|
||||
# Log the operation
|
||||
with open('/var/log/mock-shell-module.log', 'a') as f:
|
||||
with open("/var/log/mock-shell-module.log", "a") as f:
|
||||
f.write(f"shell module called: cmd={cmd}\n")
|
||||
|
||||
# Handle specific commands
|
||||
if 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd:
|
||||
if "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
|
||||
# Routing cache flush - just pretend it worked
|
||||
result['stdout'] = ""
|
||||
result['changed'] = True
|
||||
elif 'ifconfig lo100' in cmd:
|
||||
result["stdout"] = ""
|
||||
result["changed"] = True
|
||||
elif "ifconfig lo100" in cmd:
|
||||
# BSD loopback commands - simulate success
|
||||
result['stdout'] = "0"
|
||||
result['changed'] = True
|
||||
result["stdout"] = "0"
|
||||
result["changed"] = True
|
||||
else:
|
||||
# For other commands, try to run them
|
||||
try:
|
||||
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True,
|
||||
executable=module.params.get('executable'),
|
||||
cwd=module.params.get('chdir'))
|
||||
result['rc'] = proc.returncode
|
||||
result['stdout'] = proc.stdout
|
||||
result['stderr'] = proc.stderr
|
||||
result['stdout_lines'] = proc.stdout.splitlines()
|
||||
result['stderr_lines'] = proc.stderr.splitlines()
|
||||
result['changed'] = True
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
executable=module.params.get("executable"),
|
||||
cwd=module.params.get("chdir"),
|
||||
)
|
||||
result["rc"] = proc.returncode
|
||||
result["stdout"] = proc.stdout
|
||||
result["stderr"] = proc.stderr
|
||||
result["stdout_lines"] = proc.stdout.splitlines()
|
||||
result["stderr_lines"] = proc.stderr.splitlines()
|
||||
result["changed"] = True
|
||||
except Exception as e:
|
||||
result['rc'] = 1
|
||||
result['stderr'] = str(e)
|
||||
result['msg'] = str(e)
|
||||
module.fail_json(msg=result['msg'], **result)
|
||||
result["rc"] = 1
|
||||
result["stderr"] = str(e)
|
||||
result["msg"] = str(e)
|
||||
module.fail_json(msg=result["msg"], **result)
|
||||
|
||||
if result['rc'] == 0:
|
||||
if result["rc"] == 0:
|
||||
module.exit_json(**result)
|
||||
else:
|
||||
if 'msg' not in result:
|
||||
result['msg'] = f"Command failed with return code {result['rc']}"
|
||||
module.fail_json(msg=result['msg'], **result)
|
||||
if "msg" not in result:
|
||||
result["msg"] = f"Command failed with return code {result['rc']}"
|
||||
module.fail_json(msg=result["msg"], **result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -24,6 +24,7 @@ import yaml
|
|||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def create_expected_cloud_init():
|
||||
"""
|
||||
Create the expected cloud-init content that should be generated
|
||||
|
@ -74,6 +75,7 @@ runcmd:
|
|||
- systemctl restart sshd.service
|
||||
"""
|
||||
|
||||
|
||||
class TestCloudInitTemplate:
|
||||
"""Test class for cloud-init template validation."""
|
||||
|
||||
|
@ -98,10 +100,7 @@ class TestCloudInitTemplate:
|
|||
|
||||
parsed = self.test_yaml_validity()
|
||||
|
||||
required_sections = [
|
||||
'package_update', 'package_upgrade', 'packages',
|
||||
'users', 'write_files', 'runcmd'
|
||||
]
|
||||
required_sections = ["package_update", "package_upgrade", "packages", "users", "write_files", "runcmd"]
|
||||
|
||||
missing = [section for section in required_sections if section not in parsed]
|
||||
assert not missing, f"Missing required sections: {missing}"
|
||||
|
@ -114,35 +113,30 @@ class TestCloudInitTemplate:
|
|||
|
||||
parsed = self.test_yaml_validity()
|
||||
|
||||
write_files = parsed.get('write_files', [])
|
||||
write_files = parsed.get("write_files", [])
|
||||
assert write_files, "write_files section should be present"
|
||||
|
||||
# Find sshd_config file
|
||||
sshd_config = None
|
||||
for file_entry in write_files:
|
||||
if file_entry.get('path') == '/etc/ssh/sshd_config':
|
||||
if file_entry.get("path") == "/etc/ssh/sshd_config":
|
||||
sshd_config = file_entry
|
||||
break
|
||||
|
||||
assert sshd_config, "sshd_config file should be in write_files"
|
||||
|
||||
content = sshd_config.get('content', '')
|
||||
content = sshd_config.get("content", "")
|
||||
assert content, "sshd_config should have content"
|
||||
|
||||
# Check required SSH configurations
|
||||
required_configs = [
|
||||
'Port 4160',
|
||||
'AllowGroups algo',
|
||||
'PermitRootLogin no',
|
||||
'PasswordAuthentication no'
|
||||
]
|
||||
required_configs = ["Port 4160", "AllowGroups algo", "PermitRootLogin no", "PasswordAuthentication no"]
|
||||
|
||||
missing = [config for config in required_configs if config not in content]
|
||||
assert not missing, f"Missing SSH configurations: {missing}"
|
||||
|
||||
# Verify proper formatting - first line should be Port directive
|
||||
lines = content.strip().split('\n')
|
||||
assert lines[0].strip() == 'Port 4160', f"First line should be 'Port 4160', got: {repr(lines[0])}"
|
||||
lines = content.strip().split("\n")
|
||||
assert lines[0].strip() == "Port 4160", f"First line should be 'Port 4160', got: {repr(lines[0])}"
|
||||
|
||||
print("✅ SSH configuration correct")
|
||||
|
||||
|
@ -152,26 +146,26 @@ class TestCloudInitTemplate:
|
|||
|
||||
parsed = self.test_yaml_validity()
|
||||
|
||||
users = parsed.get('users', [])
|
||||
users = parsed.get("users", [])
|
||||
assert users, "users section should be present"
|
||||
|
||||
# Find algo user
|
||||
algo_user = None
|
||||
for user in users:
|
||||
if isinstance(user, dict) and user.get('name') == 'algo':
|
||||
if isinstance(user, dict) and user.get("name") == "algo":
|
||||
algo_user = user
|
||||
break
|
||||
|
||||
assert algo_user, "algo user should be defined"
|
||||
|
||||
# Check required user properties
|
||||
required_props = ['sudo', 'groups', 'shell', 'ssh_authorized_keys']
|
||||
required_props = ["sudo", "groups", "shell", "ssh_authorized_keys"]
|
||||
missing = [prop for prop in required_props if prop not in algo_user]
|
||||
assert not missing, f"algo user missing properties: {missing}"
|
||||
|
||||
# Verify sudo configuration
|
||||
sudo_config = algo_user.get('sudo', '')
|
||||
assert 'NOPASSWD:ALL' in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
|
||||
sudo_config = algo_user.get("sudo", "")
|
||||
assert "NOPASSWD:ALL" in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
|
||||
|
||||
print("✅ User creation correct")
|
||||
|
||||
|
@ -181,13 +175,13 @@ class TestCloudInitTemplate:
|
|||
|
||||
parsed = self.test_yaml_validity()
|
||||
|
||||
runcmd = parsed.get('runcmd', [])
|
||||
runcmd = parsed.get("runcmd", [])
|
||||
assert runcmd, "runcmd section should be present"
|
||||
|
||||
# Check for SSH restart command
|
||||
ssh_restart_found = False
|
||||
for cmd in runcmd:
|
||||
if 'systemctl restart sshd' in str(cmd):
|
||||
if "systemctl restart sshd" in str(cmd):
|
||||
ssh_restart_found = True
|
||||
break
|
||||
|
||||
|
@ -202,18 +196,18 @@ class TestCloudInitTemplate:
|
|||
cloud_init_content = create_expected_cloud_init()
|
||||
|
||||
# Extract the sshd_config content lines
|
||||
lines = cloud_init_content.split('\n')
|
||||
lines = cloud_init_content.split("\n")
|
||||
in_sshd_content = False
|
||||
sshd_lines = []
|
||||
|
||||
for line in lines:
|
||||
if 'content: |' in line:
|
||||
if "content: |" in line:
|
||||
in_sshd_content = True
|
||||
continue
|
||||
elif in_sshd_content:
|
||||
if line.strip() == '' and len(sshd_lines) > 0:
|
||||
if line.strip() == "" and len(sshd_lines) > 0:
|
||||
break
|
||||
if line.startswith('runcmd:'):
|
||||
if line.startswith("runcmd:"):
|
||||
break
|
||||
sshd_lines.append(line)
|
||||
|
||||
|
@ -225,11 +219,13 @@ class TestCloudInitTemplate:
|
|||
|
||||
for line in non_empty_lines:
|
||||
# Each line should start with exactly 6 spaces
|
||||
assert line.startswith(' ') and not line.startswith(' '), \
|
||||
assert line.startswith(" ") and not line.startswith(" "), (
|
||||
f"Line should have exactly 6 spaces indentation: {repr(line)}"
|
||||
)
|
||||
|
||||
print("✅ Indentation is consistent")
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests manually (for non-pytest usage)."""
|
||||
print("🚀 Cloud-init template validation tests")
|
||||
|
@ -258,6 +254,7 @@ def run_tests():
|
|||
print(f"❌ Unexpected error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
|
@ -30,49 +30,49 @@ packages:
|
|||
rendered = self.packages_template.render({})
|
||||
|
||||
# Should only have sudo package
|
||||
self.assertIn('- sudo', rendered)
|
||||
self.assertNotIn('- git', rendered)
|
||||
self.assertNotIn('- screen', rendered)
|
||||
self.assertNotIn('- apparmor-utils', rendered)
|
||||
self.assertIn("- sudo", rendered)
|
||||
self.assertNotIn("- git", rendered)
|
||||
self.assertNotIn("- screen", rendered)
|
||||
self.assertNotIn("- apparmor-utils", rendered)
|
||||
|
||||
def test_preinstall_enabled(self):
|
||||
"""Test that package pre-installation works when enabled."""
|
||||
# Test with pre-installation enabled
|
||||
rendered = self.packages_template.render({'performance_preinstall_packages': True})
|
||||
rendered = self.packages_template.render({"performance_preinstall_packages": True})
|
||||
|
||||
# Should have sudo and all universal packages
|
||||
self.assertIn('- sudo', rendered)
|
||||
self.assertIn('- git', rendered)
|
||||
self.assertIn('- screen', rendered)
|
||||
self.assertIn('- apparmor-utils', rendered)
|
||||
self.assertIn('- uuid-runtime', rendered)
|
||||
self.assertIn('- coreutils', rendered)
|
||||
self.assertIn('- iptables-persistent', rendered)
|
||||
self.assertIn('- cgroup-tools', rendered)
|
||||
self.assertIn("- sudo", rendered)
|
||||
self.assertIn("- git", rendered)
|
||||
self.assertIn("- screen", rendered)
|
||||
self.assertIn("- apparmor-utils", rendered)
|
||||
self.assertIn("- uuid-runtime", rendered)
|
||||
self.assertIn("- coreutils", rendered)
|
||||
self.assertIn("- iptables-persistent", rendered)
|
||||
self.assertIn("- cgroup-tools", rendered)
|
||||
|
||||
def test_preinstall_disabled_explicitly(self):
|
||||
"""Test that package pre-installation is disabled when set to false."""
|
||||
# Test with pre-installation explicitly disabled
|
||||
rendered = self.packages_template.render({'performance_preinstall_packages': False})
|
||||
rendered = self.packages_template.render({"performance_preinstall_packages": False})
|
||||
|
||||
# Should only have sudo package
|
||||
self.assertIn('- sudo', rendered)
|
||||
self.assertNotIn('- git', rendered)
|
||||
self.assertNotIn('- screen', rendered)
|
||||
self.assertNotIn('- apparmor-utils', rendered)
|
||||
self.assertIn("- sudo", rendered)
|
||||
self.assertNotIn("- git", rendered)
|
||||
self.assertNotIn("- screen", rendered)
|
||||
self.assertNotIn("- apparmor-utils", rendered)
|
||||
|
||||
def test_package_count(self):
|
||||
"""Test that the correct number of packages are included."""
|
||||
# Default: should only have sudo (1 package)
|
||||
rendered_default = self.packages_template.render({})
|
||||
lines_default = [line.strip() for line in rendered_default.split('\n') if line.strip().startswith('- ')]
|
||||
lines_default = [line.strip() for line in rendered_default.split("\n") if line.strip().startswith("- ")]
|
||||
self.assertEqual(len(lines_default), 1)
|
||||
|
||||
# Enabled: should have sudo + 7 universal packages (8 total)
|
||||
rendered_enabled = self.packages_template.render({'performance_preinstall_packages': True})
|
||||
lines_enabled = [line.strip() for line in rendered_enabled.split('\n') if line.strip().startswith('- ')]
|
||||
rendered_enabled = self.packages_template.render({"performance_preinstall_packages": True})
|
||||
lines_enabled = [line.strip() for line in rendered_enabled.split("\n") if line.strip().startswith("- ")]
|
||||
self.assertEqual(len(lines_enabled), 8)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
"""
|
||||
Basic sanity tests for Algo VPN that don't require deployment
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -44,11 +45,7 @@ def test_config_file_valid():
|
|||
|
||||
def test_ansible_syntax():
|
||||
"""Check that main playbook has valid syntax"""
|
||||
result = subprocess.run(
|
||||
["ansible-playbook", "main.yml", "--syntax-check"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
result = subprocess.run(["ansible-playbook", "main.yml", "--syntax-check"], capture_output=True, text=True)
|
||||
|
||||
assert result.returncode == 0, f"Ansible syntax check failed:\n{result.stderr}"
|
||||
print("✓ Ansible playbook syntax is valid")
|
||||
|
@ -60,11 +57,7 @@ def test_shellcheck():
|
|||
|
||||
for script in shell_scripts:
|
||||
if os.path.exists(script):
|
||||
result = subprocess.run(
|
||||
["shellcheck", script],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
result = subprocess.run(["shellcheck", script], capture_output=True, text=True)
|
||||
assert result.returncode == 0, f"Shellcheck failed for {script}:\n{result.stdout}"
|
||||
print(f"✓ {script} passed shellcheck")
|
||||
|
||||
|
@ -87,7 +80,7 @@ def test_cloud_init_header_format():
|
|||
assert os.path.exists(cloud_init_file), f"{cloud_init_file} not found"
|
||||
|
||||
with open(cloud_init_file) as f:
|
||||
first_line = f.readline().rstrip('\n\r')
|
||||
first_line = f.readline().rstrip("\n\r")
|
||||
|
||||
# The first line MUST be exactly "#cloud-config" (no space after #)
|
||||
# This regression was introduced in PR #14775 and broke DigitalOcean deployments
|
||||
|
|
|
@ -4,31 +4,30 @@ Test cloud provider instance type configurations
|
|||
Focused on validating that configured instance types are current/valid
|
||||
Based on issues #14730 - Hetzner changed from cx11 to cx22
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
def test_hetzner_server_types():
|
||||
"""Test Hetzner server type configurations (issue #14730)"""
|
||||
# Hetzner deprecated cx11 and cpx11 - smallest is now cx22
|
||||
deprecated_types = ['cx11', 'cpx11']
|
||||
current_types = ['cx22', 'cpx22', 'cx32', 'cpx32', 'cx42', 'cpx42']
|
||||
deprecated_types = ["cx11", "cpx11"]
|
||||
current_types = ["cx22", "cpx22", "cx32", "cpx32", "cx42", "cpx42"]
|
||||
|
||||
# Test that we're not using deprecated types in any configs
|
||||
test_config = {
|
||||
'cloud_providers': {
|
||||
'hetzner': {
|
||||
'size': 'cx22', # Should be cx22, not cx11
|
||||
'image': 'ubuntu-22.04',
|
||||
'location': 'hel1'
|
||||
"cloud_providers": {
|
||||
"hetzner": {
|
||||
"size": "cx22", # Should be cx22, not cx11
|
||||
"image": "ubuntu-22.04",
|
||||
"location": "hel1",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hetzner = test_config['cloud_providers']['hetzner']
|
||||
assert hetzner['size'] not in deprecated_types, \
|
||||
f"Using deprecated Hetzner type: {hetzner['size']}"
|
||||
assert hetzner['size'] in current_types, \
|
||||
f"Unknown Hetzner type: {hetzner['size']}"
|
||||
hetzner = test_config["cloud_providers"]["hetzner"]
|
||||
assert hetzner["size"] not in deprecated_types, f"Using deprecated Hetzner type: {hetzner['size']}"
|
||||
assert hetzner["size"] in current_types, f"Unknown Hetzner type: {hetzner['size']}"
|
||||
|
||||
print("✓ Hetzner server types test passed")
|
||||
|
||||
|
@ -36,10 +35,10 @@ def test_hetzner_server_types():
|
|||
def test_digitalocean_instance_types():
|
||||
"""Test DigitalOcean droplet size naming"""
|
||||
# DigitalOcean uses format like s-1vcpu-1gb
|
||||
valid_sizes = ['s-1vcpu-1gb', 's-2vcpu-2gb', 's-2vcpu-4gb', 's-4vcpu-8gb']
|
||||
deprecated_sizes = ['512mb', '1gb', '2gb'] # Old naming scheme
|
||||
valid_sizes = ["s-1vcpu-1gb", "s-2vcpu-2gb", "s-2vcpu-4gb", "s-4vcpu-8gb"]
|
||||
deprecated_sizes = ["512mb", "1gb", "2gb"] # Old naming scheme
|
||||
|
||||
test_size = 's-2vcpu-2gb'
|
||||
test_size = "s-2vcpu-2gb"
|
||||
assert test_size in valid_sizes, f"Invalid DO size: {test_size}"
|
||||
assert test_size not in deprecated_sizes, f"Using deprecated DO size: {test_size}"
|
||||
|
||||
|
@ -49,10 +48,10 @@ def test_digitalocean_instance_types():
|
|||
def test_aws_instance_types():
|
||||
"""Test AWS EC2 instance type naming"""
|
||||
# Common valid instance types
|
||||
valid_types = ['t2.micro', 't3.micro', 't3.small', 't3.medium', 'm5.large']
|
||||
deprecated_types = ['t1.micro', 'm1.small'] # Very old types
|
||||
valid_types = ["t2.micro", "t3.micro", "t3.small", "t3.medium", "m5.large"]
|
||||
deprecated_types = ["t1.micro", "m1.small"] # Very old types
|
||||
|
||||
test_type = 't3.micro'
|
||||
test_type = "t3.micro"
|
||||
assert test_type in valid_types, f"Unknown EC2 type: {test_type}"
|
||||
assert test_type not in deprecated_types, f"Using deprecated EC2 type: {test_type}"
|
||||
|
||||
|
@ -62,9 +61,10 @@ def test_aws_instance_types():
|
|||
def test_vultr_instance_types():
|
||||
"""Test Vultr instance type naming"""
|
||||
# Vultr uses format like vc2-1c-1gb
|
||||
test_type = 'vc2-1c-1gb'
|
||||
assert any(test_type.startswith(prefix) for prefix in ['vc2-', 'vhf-', 'vhp-']), \
|
||||
test_type = "vc2-1c-1gb"
|
||||
assert any(test_type.startswith(prefix) for prefix in ["vc2-", "vhf-", "vhp-"]), (
|
||||
f"Invalid Vultr type format: {test_type}"
|
||||
)
|
||||
|
||||
print("✓ Vultr instance types test passed")
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
"""
|
||||
Test configuration file validation without deployment
|
||||
"""
|
||||
|
||||
import configparser
|
||||
import os
|
||||
import re
|
||||
|
@ -28,14 +29,14 @@ Endpoint = 192.168.1.1:51820
|
|||
config = configparser.ConfigParser()
|
||||
config.read_string(sample_config)
|
||||
|
||||
assert 'Interface' in config, "Missing [Interface] section"
|
||||
assert 'Peer' in config, "Missing [Peer] section"
|
||||
assert "Interface" in config, "Missing [Interface] section"
|
||||
assert "Peer" in config, "Missing [Peer] section"
|
||||
|
||||
# Validate required fields
|
||||
assert config['Interface'].get('PrivateKey'), "Missing PrivateKey"
|
||||
assert config['Interface'].get('Address'), "Missing Address"
|
||||
assert config['Peer'].get('PublicKey'), "Missing PublicKey"
|
||||
assert config['Peer'].get('AllowedIPs'), "Missing AllowedIPs"
|
||||
assert config["Interface"].get("PrivateKey"), "Missing PrivateKey"
|
||||
assert config["Interface"].get("Address"), "Missing Address"
|
||||
assert config["Peer"].get("PublicKey"), "Missing PublicKey"
|
||||
assert config["Peer"].get("AllowedIPs"), "Missing AllowedIPs"
|
||||
|
||||
print("✓ WireGuard config format validation passed")
|
||||
|
||||
|
@ -43,7 +44,7 @@ Endpoint = 192.168.1.1:51820
|
|||
def test_base64_key_format():
|
||||
"""Test that keys are in valid base64 format"""
|
||||
# Base64 keys can have variable length, just check format
|
||||
key_pattern = re.compile(r'^[A-Za-z0-9+/]+=*$')
|
||||
key_pattern = re.compile(r"^[A-Za-z0-9+/]+=*$")
|
||||
|
||||
test_keys = [
|
||||
"aGVsbG8gd29ybGQgdGhpcyBpcyBub3QgYSByZWFsIGtleQo=",
|
||||
|
@ -58,8 +59,8 @@ def test_base64_key_format():
|
|||
|
||||
def test_ip_address_format():
|
||||
"""Test IP address and CIDR notation validation"""
|
||||
ip_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$')
|
||||
endpoint_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$')
|
||||
ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$")
|
||||
endpoint_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$")
|
||||
|
||||
# Test CIDR notation
|
||||
assert ip_pattern.match("10.19.49.2/32"), "Invalid CIDR notation"
|
||||
|
@ -74,11 +75,7 @@ def test_ip_address_format():
|
|||
def test_mobile_config_xml():
|
||||
"""Test that mobile config files would be valid XML"""
|
||||
# First check if xmllint is available
|
||||
xmllint_check = subprocess.run(
|
||||
['which', 'xmllint'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
xmllint_check = subprocess.run(["which", "xmllint"], capture_output=True, text=True)
|
||||
|
||||
if xmllint_check.returncode != 0:
|
||||
print("⚠ Skipping XML validation test (xmllint not installed)")
|
||||
|
@ -99,17 +96,13 @@ def test_mobile_config_xml():
|
|||
</dict>
|
||||
</plist>"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mobileconfig', delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".mobileconfig", delete=False) as f:
|
||||
f.write(sample_mobileconfig)
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# Use xmllint to validate
|
||||
result = subprocess.run(
|
||||
['xmllint', '--noout', temp_file],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
result = subprocess.run(["xmllint", "--noout", temp_file], capture_output=True, text=True)
|
||||
|
||||
assert result.returncode == 0, f"XML validation failed: {result.stderr}"
|
||||
print("✓ Mobile config XML validation passed")
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
Simplified Docker-based localhost deployment tests
|
||||
Verifies services can start and config files exist in expected locations
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -11,7 +12,7 @@ import sys
|
|||
def check_docker_available():
|
||||
"""Check if Docker is available"""
|
||||
try:
|
||||
result = subprocess.run(['docker', '--version'], capture_output=True, text=True)
|
||||
result = subprocess.run(["docker", "--version"], capture_output=True, text=True)
|
||||
return result.returncode == 0
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
@ -31,8 +32,8 @@ AllowedIPs = 10.19.49.2/32,fd9d:bc11:4020::2/128
|
|||
"""
|
||||
|
||||
# Just validate the format
|
||||
required_sections = ['[Interface]', '[Peer]']
|
||||
required_fields = ['PrivateKey', 'Address', 'PublicKey', 'AllowedIPs']
|
||||
required_sections = ["[Interface]", "[Peer]"]
|
||||
required_fields = ["PrivateKey", "Address", "PublicKey", "AllowedIPs"]
|
||||
|
||||
for section in required_sections:
|
||||
if section not in config:
|
||||
|
@ -68,15 +69,15 @@ conn ikev2-pubkey
|
|||
"""
|
||||
|
||||
# Validate format
|
||||
if 'config setup' not in config:
|
||||
if "config setup" not in config:
|
||||
print("✗ Missing 'config setup' section")
|
||||
return False
|
||||
|
||||
if 'conn %default' not in config:
|
||||
if "conn %default" not in config:
|
||||
print("✗ Missing 'conn %default' section")
|
||||
return False
|
||||
|
||||
if 'keyexchange=ikev2' not in config:
|
||||
if "keyexchange=ikev2" not in config:
|
||||
print("✗ Missing IKEv2 configuration")
|
||||
return False
|
||||
|
||||
|
@ -87,19 +88,19 @@ conn ikev2-pubkey
|
|||
def test_docker_algo_image():
|
||||
"""Test that the Algo Docker image can be built"""
|
||||
# Check if Dockerfile exists
|
||||
if not os.path.exists('Dockerfile'):
|
||||
if not os.path.exists("Dockerfile"):
|
||||
print("✗ Dockerfile not found")
|
||||
return False
|
||||
|
||||
# Read Dockerfile and validate basic structure
|
||||
with open('Dockerfile') as f:
|
||||
with open("Dockerfile") as f:
|
||||
dockerfile_content = f.read()
|
||||
|
||||
required_elements = [
|
||||
'FROM', # Base image
|
||||
'RUN', # Build commands
|
||||
'COPY', # Copy Algo files
|
||||
'python' # Python dependency
|
||||
"FROM", # Base image
|
||||
"RUN", # Build commands
|
||||
"COPY", # Copy Algo files
|
||||
"python", # Python dependency
|
||||
]
|
||||
|
||||
missing = []
|
||||
|
@ -115,16 +116,14 @@ def test_docker_algo_image():
|
|||
return True
|
||||
|
||||
|
||||
|
||||
|
||||
def test_localhost_deployment_requirements():
|
||||
"""Test that localhost deployment requirements are met"""
|
||||
requirements = {
|
||||
'Python 3.8+': sys.version_info >= (3, 8),
|
||||
'Ansible installed': subprocess.run(['which', 'ansible'], capture_output=True).returncode == 0,
|
||||
'Main playbook exists': os.path.exists('main.yml'),
|
||||
'Project config exists': os.path.exists('pyproject.toml'),
|
||||
'Config template exists': os.path.exists('config.cfg.example') or os.path.exists('config.cfg'),
|
||||
"Python 3.8+": sys.version_info >= (3, 8),
|
||||
"Ansible installed": subprocess.run(["which", "ansible"], capture_output=True).returncode == 0,
|
||||
"Main playbook exists": os.path.exists("main.yml"),
|
||||
"Project config exists": os.path.exists("pyproject.toml"),
|
||||
"Config template exists": os.path.exists("config.cfg.example") or os.path.exists("config.cfg"),
|
||||
}
|
||||
|
||||
all_met = True
|
||||
|
@ -138,10 +137,6 @@ def test_localhost_deployment_requirements():
|
|||
return all_met
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running Docker localhost deployment tests...")
|
||||
print("=" * 50)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
Test that generated configuration files have valid syntax
|
||||
This validates WireGuard, StrongSwan, SSH, and other configs
|
||||
"""
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -11,7 +12,7 @@ import sys
|
|||
def check_command_available(cmd):
|
||||
"""Check if a command is available on the system"""
|
||||
try:
|
||||
subprocess.run([cmd, '--version'], capture_output=True, check=False)
|
||||
subprocess.run([cmd, "--version"], capture_output=True, check=False)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
@ -37,51 +38,50 @@ PersistentKeepalive = 25
|
|||
errors = []
|
||||
|
||||
# Check for required sections
|
||||
if '[Interface]' not in sample_config:
|
||||
if "[Interface]" not in sample_config:
|
||||
errors.append("Missing [Interface] section")
|
||||
if '[Peer]' not in sample_config:
|
||||
if "[Peer]" not in sample_config:
|
||||
errors.append("Missing [Peer] section")
|
||||
|
||||
# Validate Interface section
|
||||
interface_match = re.search(r'\[Interface\](.*?)\[Peer\]', sample_config, re.DOTALL)
|
||||
interface_match = re.search(r"\[Interface\](.*?)\[Peer\]", sample_config, re.DOTALL)
|
||||
if interface_match:
|
||||
interface_section = interface_match.group(1)
|
||||
|
||||
# Check required fields
|
||||
if not re.search(r'Address\s*=', interface_section):
|
||||
if not re.search(r"Address\s*=", interface_section):
|
||||
errors.append("Missing Address in Interface section")
|
||||
if not re.search(r'PrivateKey\s*=', interface_section):
|
||||
if not re.search(r"PrivateKey\s*=", interface_section):
|
||||
errors.append("Missing PrivateKey in Interface section")
|
||||
|
||||
# Validate IP addresses
|
||||
address_match = re.search(r'Address\s*=\s*([^\n]+)', interface_section)
|
||||
address_match = re.search(r"Address\s*=\s*([^\n]+)", interface_section)
|
||||
if address_match:
|
||||
addresses = address_match.group(1).split(',')
|
||||
addresses = address_match.group(1).split(",")
|
||||
for addr in addresses:
|
||||
addr = addr.strip()
|
||||
# Basic IP validation
|
||||
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', addr) and \
|
||||
not re.match(r'^[0-9a-fA-F:]+/\d+$', addr):
|
||||
if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", addr) and not re.match(r"^[0-9a-fA-F:]+/\d+$", addr):
|
||||
errors.append(f"Invalid IP address format: {addr}")
|
||||
|
||||
# Validate Peer section
|
||||
peer_match = re.search(r'\[Peer\](.*)', sample_config, re.DOTALL)
|
||||
peer_match = re.search(r"\[Peer\](.*)", sample_config, re.DOTALL)
|
||||
if peer_match:
|
||||
peer_section = peer_match.group(1)
|
||||
|
||||
# Check required fields
|
||||
if not re.search(r'PublicKey\s*=', peer_section):
|
||||
if not re.search(r"PublicKey\s*=", peer_section):
|
||||
errors.append("Missing PublicKey in Peer section")
|
||||
if not re.search(r'AllowedIPs\s*=', peer_section):
|
||||
if not re.search(r"AllowedIPs\s*=", peer_section):
|
||||
errors.append("Missing AllowedIPs in Peer section")
|
||||
if not re.search(r'Endpoint\s*=', peer_section):
|
||||
if not re.search(r"Endpoint\s*=", peer_section):
|
||||
errors.append("Missing Endpoint in Peer section")
|
||||
|
||||
# Validate endpoint format
|
||||
endpoint_match = re.search(r'Endpoint\s*=\s*([^\n]+)', peer_section)
|
||||
endpoint_match = re.search(r"Endpoint\s*=\s*([^\n]+)", peer_section)
|
||||
if endpoint_match:
|
||||
endpoint = endpoint_match.group(1).strip()
|
||||
if not re.match(r'^[\d\.\:]+:\d+$', endpoint):
|
||||
if not re.match(r"^[\d\.\:]+:\d+$", endpoint):
|
||||
errors.append(f"Invalid Endpoint format: {endpoint}")
|
||||
|
||||
if errors:
|
||||
|
@ -132,33 +132,32 @@ conn ikev2-pubkey
|
|||
errors = []
|
||||
|
||||
# Check for required sections
|
||||
if 'config setup' not in sample_config:
|
||||
if "config setup" not in sample_config:
|
||||
errors.append("Missing 'config setup' section")
|
||||
if 'conn %default' not in sample_config:
|
||||
if "conn %default" not in sample_config:
|
||||
errors.append("Missing 'conn %default' section")
|
||||
|
||||
# Validate connection settings
|
||||
conn_pattern = re.compile(r'conn\s+(\S+)')
|
||||
conn_pattern = re.compile(r"conn\s+(\S+)")
|
||||
connections = conn_pattern.findall(sample_config)
|
||||
|
||||
if len(connections) < 2: # Should have at least %default and one other
|
||||
errors.append("Not enough connection definitions")
|
||||
|
||||
# Check for required parameters in connections
|
||||
required_params = ['keyexchange', 'left', 'right']
|
||||
required_params = ["keyexchange", "left", "right"]
|
||||
for param in required_params:
|
||||
if f'{param}=' not in sample_config:
|
||||
if f"{param}=" not in sample_config:
|
||||
errors.append(f"Missing required parameter: {param}")
|
||||
|
||||
# Validate IP subnet formats
|
||||
subnet_pattern = re.compile(r'(left|right)subnet\s*=\s*([^\n]+)')
|
||||
subnet_pattern = re.compile(r"(left|right)subnet\s*=\s*([^\n]+)")
|
||||
for match in subnet_pattern.finditer(sample_config):
|
||||
subnets = match.group(2).split(',')
|
||||
subnets = match.group(2).split(",")
|
||||
for subnet in subnets:
|
||||
subnet = subnet.strip()
|
||||
if subnet != '0.0.0.0/0' and subnet != '::/0':
|
||||
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', subnet) and \
|
||||
not re.match(r'^[0-9a-fA-F:]+/\d+$', subnet):
|
||||
if subnet != "0.0.0.0/0" and subnet != "::/0":
|
||||
if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", subnet) and not re.match(r"^[0-9a-fA-F:]+/\d+$", subnet):
|
||||
errors.append(f"Invalid subnet format: {subnet}")
|
||||
|
||||
if errors:
|
||||
|
@ -188,21 +187,21 @@ def test_ssh_config_syntax():
|
|||
errors = []
|
||||
|
||||
# Parse SSH config format
|
||||
lines = sample_config.strip().split('\n')
|
||||
lines = sample_config.strip().split("\n")
|
||||
current_host = None
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
if line.startswith('Host '):
|
||||
if line.startswith("Host "):
|
||||
current_host = line.split()[1]
|
||||
elif current_host and ' ' in line:
|
||||
elif current_host and " " in line:
|
||||
key, value = line.split(None, 1)
|
||||
|
||||
# Validate common SSH options
|
||||
if key == 'Port':
|
||||
if key == "Port":
|
||||
try:
|
||||
port = int(value)
|
||||
if not 1 <= port <= 65535:
|
||||
|
@ -210,7 +209,7 @@ def test_ssh_config_syntax():
|
|||
except ValueError:
|
||||
errors.append(f"Port must be a number: {value}")
|
||||
|
||||
elif key == 'LocalForward':
|
||||
elif key == "LocalForward":
|
||||
# Format: LocalForward [bind_address:]port host:hostport
|
||||
parts = value.split()
|
||||
if len(parts) != 2:
|
||||
|
@ -256,35 +255,35 @@ COMMIT
|
|||
errors = []
|
||||
|
||||
# Check table definitions
|
||||
tables = re.findall(r'\*(\w+)', sample_rules)
|
||||
if 'filter' not in tables:
|
||||
tables = re.findall(r"\*(\w+)", sample_rules)
|
||||
if "filter" not in tables:
|
||||
errors.append("Missing *filter table")
|
||||
if 'nat' not in tables:
|
||||
if "nat" not in tables:
|
||||
errors.append("Missing *nat table")
|
||||
|
||||
# Check for COMMIT statements
|
||||
commit_count = sample_rules.count('COMMIT')
|
||||
commit_count = sample_rules.count("COMMIT")
|
||||
if commit_count != len(tables):
|
||||
errors.append(f"Number of COMMIT statements ({commit_count}) doesn't match tables ({len(tables)})")
|
||||
|
||||
# Validate chain policies
|
||||
chain_pattern = re.compile(r'^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]', re.MULTILINE)
|
||||
chain_pattern = re.compile(r"^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]", re.MULTILINE)
|
||||
chains = chain_pattern.findall(sample_rules)
|
||||
|
||||
required_chains = [('INPUT', 'DROP'), ('FORWARD', 'DROP'), ('OUTPUT', 'ACCEPT')]
|
||||
required_chains = [("INPUT", "DROP"), ("FORWARD", "DROP"), ("OUTPUT", "ACCEPT")]
|
||||
for chain, _policy in required_chains:
|
||||
if not any(c[0] == chain for c in chains):
|
||||
errors.append(f"Missing required chain: {chain}")
|
||||
|
||||
# Validate rule syntax
|
||||
rule_pattern = re.compile(r'^-[AI]\s+(\w+)', re.MULTILINE)
|
||||
rule_pattern = re.compile(r"^-[AI]\s+(\w+)", re.MULTILINE)
|
||||
rules = rule_pattern.findall(sample_rules)
|
||||
|
||||
if len(rules) < 5:
|
||||
errors.append("Insufficient firewall rules")
|
||||
|
||||
# Check for essential security rules
|
||||
if '-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT' not in sample_rules:
|
||||
if "-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT" not in sample_rules:
|
||||
errors.append("Missing stateful connection tracking rule")
|
||||
|
||||
if errors:
|
||||
|
@ -320,27 +319,26 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
|
|||
errors = []
|
||||
|
||||
# Parse config
|
||||
for line in sample_config.strip().split('\n'):
|
||||
for line in sample_config.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Most dnsmasq options are key=value or just key
|
||||
if '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
if "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
|
||||
# Validate specific options
|
||||
if key == 'interface':
|
||||
if not re.match(r'^[a-zA-Z0-9\-_]+$', value):
|
||||
if key == "interface":
|
||||
if not re.match(r"^[a-zA-Z0-9\-_]+$", value):
|
||||
errors.append(f"Invalid interface name: {value}")
|
||||
|
||||
elif key == 'server':
|
||||
elif key == "server":
|
||||
# Basic IP validation
|
||||
if not re.match(r'^\d+\.\d+\.\d+\.\d+$', value) and \
|
||||
not re.match(r'^[0-9a-fA-F:]+$', value):
|
||||
if not re.match(r"^\d+\.\d+\.\d+\.\d+$", value) and not re.match(r"^[0-9a-fA-F:]+$", value):
|
||||
errors.append(f"Invalid DNS server IP: {value}")
|
||||
|
||||
elif key == 'cache-size':
|
||||
elif key == "cache-size":
|
||||
try:
|
||||
size = int(value)
|
||||
if size < 0:
|
||||
|
@ -349,9 +347,9 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
|
|||
errors.append(f"Cache size must be a number: {value}")
|
||||
|
||||
# Check for required options
|
||||
required = ['interface', 'server']
|
||||
required = ["interface", "server"]
|
||||
for req in required:
|
||||
if f'{req}=' not in sample_config:
|
||||
if f"{req}=" not in sample_config:
|
||||
errors.append(f"Missing required option: {req}")
|
||||
|
||||
if errors:
|
||||
|
|
|
@ -14,203 +14,203 @@ from jinja2 import Environment, FileSystemLoader
|
|||
|
||||
def load_template(template_name):
|
||||
"""Load a Jinja2 template from the roles/common/templates directory."""
|
||||
template_dir = Path(__file__).parent.parent.parent / 'roles' / 'common' / 'templates'
|
||||
template_dir = Path(__file__).parent.parent.parent / "roles" / "common" / "templates"
|
||||
env = Environment(loader=FileSystemLoader(str(template_dir)))
|
||||
return env.get_template(template_name)
|
||||
|
||||
|
||||
def test_wireguard_nat_rules_ipv4():
|
||||
"""Test that WireGuard traffic gets proper NAT rules without policy matching."""
|
||||
template = load_template('rules.v4.j2')
|
||||
template = load_template("rules.v4.j2")
|
||||
|
||||
# Test with WireGuard enabled
|
||||
result = template.render(
|
||||
ipsec_enabled=False,
|
||||
wireguard_enabled=True,
|
||||
wireguard_network_ipv4='10.49.0.0/16',
|
||||
wireguard_network_ipv4="10.49.0.0/16",
|
||||
wireguard_port=51820,
|
||||
wireguard_port_avoid=53,
|
||||
wireguard_port_actual=51820,
|
||||
ansible_default_ipv4={'interface': 'eth0'},
|
||||
ansible_default_ipv4={"interface": "eth0"},
|
||||
snat_aipv4=None,
|
||||
BetweenClients_DROP=True,
|
||||
block_smb=True,
|
||||
block_netbios=True,
|
||||
local_service_ip='10.49.0.1',
|
||||
local_service_ip="10.49.0.1",
|
||||
ansible_ssh_port=22,
|
||||
reduce_mtu=0
|
||||
reduce_mtu=0,
|
||||
)
|
||||
|
||||
# Verify NAT rule exists with output interface and without policy matching
|
||||
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
|
||||
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||
# Verify no policy matching in WireGuard NAT rules
|
||||
assert '-A POSTROUTING -s 10.49.0.0/16 -m policy' not in result
|
||||
assert "-A POSTROUTING -s 10.49.0.0/16 -m policy" not in result
|
||||
|
||||
|
||||
def test_ipsec_nat_rules_ipv4():
|
||||
"""Test that IPsec traffic gets proper NAT rules without policy matching."""
|
||||
template = load_template('rules.v4.j2')
|
||||
template = load_template("rules.v4.j2")
|
||||
|
||||
# Test with IPsec enabled
|
||||
result = template.render(
|
||||
ipsec_enabled=True,
|
||||
wireguard_enabled=False,
|
||||
strongswan_network='10.48.0.0/16',
|
||||
strongswan_network_ipv6='2001:db8::/48',
|
||||
ansible_default_ipv4={'interface': 'eth0'},
|
||||
strongswan_network="10.48.0.0/16",
|
||||
strongswan_network_ipv6="2001:db8::/48",
|
||||
ansible_default_ipv4={"interface": "eth0"},
|
||||
snat_aipv4=None,
|
||||
BetweenClients_DROP=True,
|
||||
block_smb=True,
|
||||
block_netbios=True,
|
||||
local_service_ip='10.48.0.1',
|
||||
local_service_ip="10.48.0.1",
|
||||
ansible_ssh_port=22,
|
||||
reduce_mtu=0
|
||||
reduce_mtu=0,
|
||||
)
|
||||
|
||||
# Verify NAT rule exists with output interface and without policy matching
|
||||
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
|
||||
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||
# Verify no policy matching in IPsec NAT rules (this was the bug)
|
||||
assert '-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none' not in result
|
||||
assert "-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none" not in result
|
||||
|
||||
|
||||
def test_both_vpns_nat_rules_ipv4():
|
||||
"""Test NAT rules when both VPN types are enabled."""
|
||||
template = load_template('rules.v4.j2')
|
||||
template = load_template("rules.v4.j2")
|
||||
|
||||
result = template.render(
|
||||
ipsec_enabled=True,
|
||||
wireguard_enabled=True,
|
||||
strongswan_network='10.48.0.0/16',
|
||||
wireguard_network_ipv4='10.49.0.0/16',
|
||||
strongswan_network_ipv6='2001:db8::/48',
|
||||
wireguard_network_ipv6='2001:db8:a160::/48',
|
||||
strongswan_network="10.48.0.0/16",
|
||||
wireguard_network_ipv4="10.49.0.0/16",
|
||||
strongswan_network_ipv6="2001:db8::/48",
|
||||
wireguard_network_ipv6="2001:db8:a160::/48",
|
||||
wireguard_port=51820,
|
||||
wireguard_port_avoid=53,
|
||||
wireguard_port_actual=51820,
|
||||
ansible_default_ipv4={'interface': 'eth0'},
|
||||
ansible_default_ipv4={"interface": "eth0"},
|
||||
snat_aipv4=None,
|
||||
BetweenClients_DROP=True,
|
||||
block_smb=True,
|
||||
block_netbios=True,
|
||||
local_service_ip='10.49.0.1',
|
||||
local_service_ip="10.49.0.1",
|
||||
ansible_ssh_port=22,
|
||||
reduce_mtu=0
|
||||
reduce_mtu=0,
|
||||
)
|
||||
|
||||
# Both should have NAT rules with output interface
|
||||
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
|
||||
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
|
||||
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||
|
||||
# Neither should have policy matching
|
||||
assert '-m policy --pol none' not in result
|
||||
assert "-m policy --pol none" not in result
|
||||
|
||||
|
||||
def test_alternative_ingress_snat():
|
||||
"""Test that alternative ingress IP uses SNAT instead of MASQUERADE."""
|
||||
template = load_template('rules.v4.j2')
|
||||
template = load_template("rules.v4.j2")
|
||||
|
||||
result = template.render(
|
||||
ipsec_enabled=True,
|
||||
wireguard_enabled=True,
|
||||
strongswan_network='10.48.0.0/16',
|
||||
wireguard_network_ipv4='10.49.0.0/16',
|
||||
strongswan_network_ipv6='2001:db8::/48',
|
||||
wireguard_network_ipv6='2001:db8:a160::/48',
|
||||
strongswan_network="10.48.0.0/16",
|
||||
wireguard_network_ipv4="10.49.0.0/16",
|
||||
strongswan_network_ipv6="2001:db8::/48",
|
||||
wireguard_network_ipv6="2001:db8:a160::/48",
|
||||
wireguard_port=51820,
|
||||
wireguard_port_avoid=53,
|
||||
wireguard_port_actual=51820,
|
||||
ansible_default_ipv4={'interface': 'eth0'},
|
||||
snat_aipv4='192.168.1.100', # Alternative ingress IP
|
||||
ansible_default_ipv4={"interface": "eth0"},
|
||||
snat_aipv4="192.168.1.100", # Alternative ingress IP
|
||||
BetweenClients_DROP=True,
|
||||
block_smb=True,
|
||||
block_netbios=True,
|
||||
local_service_ip='10.49.0.1',
|
||||
local_service_ip="10.49.0.1",
|
||||
ansible_ssh_port=22,
|
||||
reduce_mtu=0
|
||||
reduce_mtu=0,
|
||||
)
|
||||
|
||||
# Should use SNAT with specific IP and output interface instead of MASQUERADE
|
||||
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100' in result
|
||||
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100' in result
|
||||
assert 'MASQUERADE' not in result
|
||||
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
|
||||
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
|
||||
assert "MASQUERADE" not in result
|
||||
|
||||
|
||||
def test_ipsec_forward_rule_has_policy_match():
|
||||
"""Test that IPsec FORWARD rules still use policy matching (this is correct)."""
|
||||
template = load_template('rules.v4.j2')
|
||||
template = load_template("rules.v4.j2")
|
||||
|
||||
result = template.render(
|
||||
ipsec_enabled=True,
|
||||
wireguard_enabled=False,
|
||||
strongswan_network='10.48.0.0/16',
|
||||
strongswan_network_ipv6='2001:db8::/48',
|
||||
ansible_default_ipv4={'interface': 'eth0'},
|
||||
strongswan_network="10.48.0.0/16",
|
||||
strongswan_network_ipv6="2001:db8::/48",
|
||||
ansible_default_ipv4={"interface": "eth0"},
|
||||
snat_aipv4=None,
|
||||
BetweenClients_DROP=True,
|
||||
block_smb=True,
|
||||
block_netbios=True,
|
||||
local_service_ip='10.48.0.1',
|
||||
local_service_ip="10.48.0.1",
|
||||
ansible_ssh_port=22,
|
||||
reduce_mtu=0
|
||||
reduce_mtu=0,
|
||||
)
|
||||
|
||||
# FORWARD rule should have policy match (this is correct and should stay)
|
||||
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT' in result
|
||||
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT" in result
|
||||
|
||||
|
||||
def test_wireguard_forward_rule_no_policy_match():
|
||||
"""Test that WireGuard FORWARD rules don't use policy matching."""
|
||||
template = load_template('rules.v4.j2')
|
||||
template = load_template("rules.v4.j2")
|
||||
|
||||
result = template.render(
|
||||
ipsec_enabled=False,
|
||||
wireguard_enabled=True,
|
||||
wireguard_network_ipv4='10.49.0.0/16',
|
||||
wireguard_network_ipv4="10.49.0.0/16",
|
||||
wireguard_port=51820,
|
||||
wireguard_port_avoid=53,
|
||||
wireguard_port_actual=51820,
|
||||
ansible_default_ipv4={'interface': 'eth0'},
|
||||
ansible_default_ipv4={"interface": "eth0"},
|
||||
snat_aipv4=None,
|
||||
BetweenClients_DROP=True,
|
||||
block_smb=True,
|
||||
block_netbios=True,
|
||||
local_service_ip='10.49.0.1',
|
||||
local_service_ip="10.49.0.1",
|
||||
ansible_ssh_port=22,
|
||||
reduce_mtu=0
|
||||
reduce_mtu=0,
|
||||
)
|
||||
|
||||
# WireGuard FORWARD rule should NOT have any policy match
|
||||
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT' in result
|
||||
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy' not in result
|
||||
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT" in result
|
||||
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy" not in result
|
||||
|
||||
|
||||
def test_output_interface_in_nat_rules():
|
||||
"""Test that output interface is specified in NAT rules."""
|
||||
template = load_template('rules.v4.j2')
|
||||
|
||||
template = load_template("rules.v4.j2")
|
||||
|
||||
result = template.render(
|
||||
snat_aipv4=False,
|
||||
wireguard_enabled=True,
|
||||
ipsec_enabled=True,
|
||||
wireguard_network_ipv4='10.49.0.0/16',
|
||||
strongswan_network='10.48.0.0/16',
|
||||
ansible_default_ipv4={'interface': 'eth0', 'address': '10.0.0.1'},
|
||||
ansible_default_ipv6={'interface': 'eth0', 'address': 'fd9d:bc11:4020::1'},
|
||||
wireguard_network_ipv4="10.49.0.0/16",
|
||||
strongswan_network="10.48.0.0/16",
|
||||
ansible_default_ipv4={"interface": "eth0", "address": "10.0.0.1"},
|
||||
ansible_default_ipv6={"interface": "eth0", "address": "fd9d:bc11:4020::1"},
|
||||
wireguard_port_actual=51820,
|
||||
wireguard_port_avoid=53,
|
||||
wireguard_port=51820,
|
||||
ansible_ssh_port=22,
|
||||
reduce_mtu=0
|
||||
reduce_mtu=0,
|
||||
)
|
||||
|
||||
|
||||
# Check that output interface is specified for both VPNs
|
||||
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
|
||||
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
|
||||
|
||||
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||
|
||||
# Ensure we don't have rules without output interface
|
||||
assert '-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE' not in result
|
||||
assert '-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE' not in result
|
||||
assert "-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE" not in result
|
||||
assert "-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE" not in result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
|
@ -12,7 +12,7 @@ import unittest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Add the library directory to the path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../library'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../library"))
|
||||
|
||||
|
||||
class TestLightsailBoto3Fix(unittest.TestCase):
|
||||
|
@ -22,15 +22,15 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
|||
"""Set up test fixtures."""
|
||||
# Mock the ansible module_utils since we're testing outside of Ansible
|
||||
self.mock_modules = {
|
||||
'ansible.module_utils.basic': MagicMock(),
|
||||
'ansible.module_utils.ec2': MagicMock(),
|
||||
'ansible.module_utils.aws.core': MagicMock(),
|
||||
"ansible.module_utils.basic": MagicMock(),
|
||||
"ansible.module_utils.ec2": MagicMock(),
|
||||
"ansible.module_utils.aws.core": MagicMock(),
|
||||
}
|
||||
|
||||
# Apply mocks
|
||||
self.patches = []
|
||||
for module_name, mock_module in self.mock_modules.items():
|
||||
patcher = patch.dict('sys.modules', {module_name: mock_module})
|
||||
patcher = patch.dict("sys.modules", {module_name: mock_module})
|
||||
patcher.start()
|
||||
self.patches.append(patcher)
|
||||
|
||||
|
@ -45,7 +45,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
|||
# Import the module
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"lightsail_region_facts",
|
||||
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
|
||||
os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
|
@ -54,7 +54,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
|||
|
||||
# Verify the module loaded
|
||||
self.assertIsNotNone(module)
|
||||
self.assertTrue(hasattr(module, 'main'))
|
||||
self.assertTrue(hasattr(module, "main"))
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Failed to import lightsail_region_facts: {e}")
|
||||
|
@ -62,15 +62,13 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
|||
def test_get_aws_connection_info_called_without_boto3(self):
|
||||
"""Test that get_aws_connection_info is called without boto3 parameter."""
|
||||
# Mock get_aws_connection_info to track calls
|
||||
mock_get_aws_connection_info = MagicMock(
|
||||
return_value=('us-west-2', None, {})
|
||||
)
|
||||
mock_get_aws_connection_info = MagicMock(return_value=("us-west-2", None, {}))
|
||||
|
||||
with patch('ansible.module_utils.ec2.get_aws_connection_info', mock_get_aws_connection_info):
|
||||
with patch("ansible.module_utils.ec2.get_aws_connection_info", mock_get_aws_connection_info):
|
||||
# Import the module
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"lightsail_region_facts",
|
||||
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
|
||||
os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
|
@ -79,7 +77,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
|||
mock_ansible_module.params = {}
|
||||
mock_ansible_module.check_mode = False
|
||||
|
||||
with patch('ansible.module_utils.basic.AnsibleModule', return_value=mock_ansible_module):
|
||||
with patch("ansible.module_utils.basic.AnsibleModule", return_value=mock_ansible_module):
|
||||
# Execute the module
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
|
@ -100,28 +98,35 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
|||
if call_args:
|
||||
# Check positional arguments
|
||||
if call_args[0]: # args
|
||||
self.assertTrue(len(call_args[0]) <= 1,
|
||||
"get_aws_connection_info should be called with at most 1 positional arg (module)")
|
||||
self.assertTrue(
|
||||
len(call_args[0]) <= 1,
|
||||
"get_aws_connection_info should be called with at most 1 positional arg (module)",
|
||||
)
|
||||
|
||||
# Check keyword arguments
|
||||
if call_args[1]: # kwargs
|
||||
self.assertNotIn('boto3', call_args[1],
|
||||
"get_aws_connection_info should not be called with boto3 parameter")
|
||||
self.assertNotIn(
|
||||
"boto3", call_args[1], "get_aws_connection_info should not be called with boto3 parameter"
|
||||
)
|
||||
|
||||
def test_no_boto3_parameter_in_source(self):
|
||||
"""Verify that boto3 parameter is not present in the source code."""
|
||||
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
|
||||
lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
|
||||
|
||||
with open(lightsail_path) as f:
|
||||
content = f.read()
|
||||
|
||||
# Check that boto3=True is not in the file
|
||||
self.assertNotIn('boto3=True', content,
|
||||
"boto3=True parameter should not be present in lightsail_region_facts.py")
|
||||
self.assertNotIn(
|
||||
"boto3=True", content, "boto3=True parameter should not be present in lightsail_region_facts.py"
|
||||
)
|
||||
|
||||
# Check that boto3 parameter is not used with get_aws_connection_info
|
||||
self.assertNotIn('get_aws_connection_info(module, boto3', content,
|
||||
"get_aws_connection_info should not be called with boto3 parameter")
|
||||
self.assertNotIn(
|
||||
"get_aws_connection_info(module, boto3",
|
||||
content,
|
||||
"get_aws_connection_info should not be called with boto3 parameter",
|
||||
)
|
||||
|
||||
def test_regression_issue_14822(self):
|
||||
"""
|
||||
|
@ -132,26 +137,28 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
|||
# The boto3 parameter was deprecated and removed in amazon.aws collection
|
||||
# that comes with Ansible 11.x
|
||||
|
||||
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
|
||||
lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
|
||||
|
||||
with open(lightsail_path) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Find the line that calls get_aws_connection_info
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if 'get_aws_connection_info' in line and 'region' in line:
|
||||
if "get_aws_connection_info" in line and "region" in line:
|
||||
# This should be around line 85
|
||||
# Verify it doesn't have boto3=True
|
||||
self.assertNotIn('boto3', line,
|
||||
f"Line {line_num} should not contain boto3 parameter")
|
||||
self.assertNotIn("boto3", line, f"Line {line_num} should not contain boto3 parameter")
|
||||
|
||||
# Verify the correct format
|
||||
self.assertIn('get_aws_connection_info(module)', line,
|
||||
f"Line {line_num} should call get_aws_connection_info(module) without boto3")
|
||||
self.assertIn(
|
||||
"get_aws_connection_info(module)",
|
||||
line,
|
||||
f"Line {line_num} should call get_aws_connection_info(module) without boto3",
|
||||
)
|
||||
break
|
||||
else:
|
||||
self.fail("Could not find get_aws_connection_info call in lightsail_region_facts.py")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -5,6 +5,7 @@ Hybrid approach: validates actual certificates when available, else tests templa
|
|||
Based on issues #14755, #14718 - Apple device compatibility
|
||||
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
|
@ -22,7 +23,7 @@ def find_generated_certificates():
|
|||
config_patterns = [
|
||||
"configs/*/ipsec/.pki/cacert.pem",
|
||||
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
|
||||
"../../configs/*/ipsec/.pki/cacert.pem" # Alternative path
|
||||
"../../configs/*/ipsec/.pki/cacert.pem", # Alternative path
|
||||
]
|
||||
|
||||
for pattern in config_patterns:
|
||||
|
@ -30,26 +31,23 @@ def find_generated_certificates():
|
|||
if ca_certs:
|
||||
base_path = os.path.dirname(ca_certs[0])
|
||||
return {
|
||||
'ca_cert': ca_certs[0],
|
||||
'base_path': base_path,
|
||||
'server_certs': glob.glob(f"{base_path}/certs/*.crt"),
|
||||
'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12")
|
||||
"ca_cert": ca_certs[0],
|
||||
"base_path": base_path,
|
||||
"server_certs": glob.glob(f"{base_path}/certs/*.crt"),
|
||||
"p12_files": glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12"),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def test_openssl_version_detection():
|
||||
"""Test that we can detect OpenSSL version for compatibility checks"""
|
||||
result = subprocess.run(
|
||||
['openssl', 'version'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
result = subprocess.run(["openssl", "version"], capture_output=True, text=True)
|
||||
|
||||
assert result.returncode == 0, "Failed to get OpenSSL version"
|
||||
|
||||
# Parse version - e.g., "OpenSSL 3.0.2 15 Mar 2022"
|
||||
version_match = re.search(r'OpenSSL\s+(\d+)\.(\d+)\.(\d+)', result.stdout)
|
||||
version_match = re.search(r"OpenSSL\s+(\d+)\.(\d+)\.(\d+)", result.stdout)
|
||||
assert version_match, f"Can't parse OpenSSL version: {result.stdout}"
|
||||
|
||||
major = int(version_match.group(1))
|
||||
|
@ -62,7 +60,7 @@ def test_openssl_version_detection():
|
|||
def validate_ca_certificate_real(cert_files):
|
||||
"""Validate actual Ansible-generated CA certificate"""
|
||||
# Read the actual CA certificate generated by Ansible
|
||||
with open(cert_files['ca_cert'], 'rb') as f:
|
||||
with open(cert_files["ca_cert"], "rb") as f:
|
||||
cert_data = f.read()
|
||||
|
||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||
|
@ -89,30 +87,34 @@ def validate_ca_certificate_real(cert_files):
|
|||
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
|
||||
|
||||
# Verify public domains are excluded
|
||||
excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees
|
||||
if isinstance(constraint, x509.DNSName)]
|
||||
excluded_dns = [
|
||||
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.DNSName)
|
||||
]
|
||||
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||
for domain in public_domains:
|
||||
assert domain in excluded_dns, f"CA should exclude public domain {domain}"
|
||||
|
||||
# Verify private IP ranges are excluded (Issue #75)
|
||||
excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees
|
||||
if isinstance(constraint, x509.IPAddress)]
|
||||
excluded_ips = [
|
||||
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.IPAddress)
|
||||
]
|
||||
assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
|
||||
|
||||
# Verify email domains are also excluded (Issue #153)
|
||||
excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees
|
||||
if isinstance(constraint, x509.RFC822Name)]
|
||||
excluded_emails = [
|
||||
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.RFC822Name)
|
||||
]
|
||||
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||
for domain in email_domains:
|
||||
assert domain in excluded_emails, f"CA should exclude email domain {domain}"
|
||||
|
||||
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
|
||||
|
||||
|
||||
def validate_ca_certificate_config():
|
||||
"""Validate CA certificate configuration in Ansible files (CI mode)"""
|
||||
# Check that the Ansible task file has proper CA certificate configuration
|
||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
||||
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||
if not openssl_task_file:
|
||||
print("⚠ Could not find openssl.yml task file")
|
||||
return
|
||||
|
@ -122,15 +124,15 @@ def validate_ca_certificate_config():
|
|||
|
||||
# Verify key security configurations are present
|
||||
security_checks = [
|
||||
('name_constraints_permitted', 'Name constraints should be configured'),
|
||||
('name_constraints_excluded', 'Excluded name constraints should be configured'),
|
||||
('extended_key_usage', 'Extended Key Usage should be configured'),
|
||||
('1.3.6.1.5.5.7.3.17', 'IPsec End Entity OID should be present'),
|
||||
('serverAuth', 'Server authentication EKU should be present'),
|
||||
('clientAuth', 'Client authentication EKU should be present'),
|
||||
('basic_constraints', 'Basic constraints should be configured'),
|
||||
('CA:TRUE', 'CA certificate should be marked as CA'),
|
||||
('pathlen:0', 'Path length constraint should be set')
|
||||
("name_constraints_permitted", "Name constraints should be configured"),
|
||||
("name_constraints_excluded", "Excluded name constraints should be configured"),
|
||||
("extended_key_usage", "Extended Key Usage should be configured"),
|
||||
("1.3.6.1.5.5.7.3.17", "IPsec End Entity OID should be present"),
|
||||
("serverAuth", "Server authentication EKU should be present"),
|
||||
("clientAuth", "Client authentication EKU should be present"),
|
||||
("basic_constraints", "Basic constraints should be configured"),
|
||||
("CA:TRUE", "CA certificate should be marked as CA"),
|
||||
("pathlen:0", "Path length constraint should be set"),
|
||||
]
|
||||
|
||||
for check, message in security_checks:
|
||||
|
@ -140,7 +142,9 @@ def validate_ca_certificate_config():
|
|||
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||
for domain in public_domains:
|
||||
# Handle both double quotes and single quotes in YAML
|
||||
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, f"Public domain {domain} should be excluded"
|
||||
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, (
|
||||
f"Public domain {domain} should be excluded"
|
||||
)
|
||||
|
||||
# Verify private IP ranges are excluded
|
||||
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
|
||||
|
@ -151,13 +155,16 @@ def validate_ca_certificate_config():
|
|||
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||
for domain in email_domains:
|
||||
# Handle both double quotes and single quotes in YAML
|
||||
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, f"Email domain {domain} should be excluded"
|
||||
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, (
|
||||
f"Email domain {domain} should be excluded"
|
||||
)
|
||||
|
||||
# Verify IPv6 constraints are present (Issue #153)
|
||||
assert "IP:::/0" in content, "IPv6 all addresses should be excluded"
|
||||
|
||||
print("✓ CA certificate configuration has proper security constraints")
|
||||
|
||||
|
||||
def test_ca_certificate():
|
||||
"""Test CA certificate - uses real certs if available, else validates config (Issue #75, #153)"""
|
||||
cert_files = find_generated_certificates()
|
||||
|
@ -172,14 +179,18 @@ def validate_server_certificates_real(cert_files):
|
|||
# Filter to only actual server certificates (not client certs)
|
||||
# Server certificates contain IP addresses in the filename
|
||||
import re
|
||||
server_certs = [f for f in cert_files['server_certs']
|
||||
if not f.endswith('/cacert.pem') and re.search(r'\d+\.\d+\.\d+\.\d+\.crt$', f)]
|
||||
|
||||
server_certs = [
|
||||
f
|
||||
for f in cert_files["server_certs"]
|
||||
if not f.endswith("/cacert.pem") and re.search(r"\d+\.\d+\.\d+\.\d+\.crt$", f)
|
||||
]
|
||||
if not server_certs:
|
||||
print("⚠ No server certificates found")
|
||||
return
|
||||
|
||||
for server_cert_path in server_certs:
|
||||
with open(server_cert_path, 'rb') as f:
|
||||
with open(server_cert_path, "rb") as f:
|
||||
cert_data = f.read()
|
||||
|
||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||
|
@ -193,7 +204,9 @@ def validate_server_certificates_real(cert_files):
|
|||
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
|
||||
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
|
||||
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
|
||||
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation"
|
||||
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, (
|
||||
"Server cert should NOT have clientAuth EKU for role separation"
|
||||
)
|
||||
|
||||
# Check SAN extension exists (required for Apple devices)
|
||||
try:
|
||||
|
@ -204,9 +217,10 @@ def validate_server_certificates_real(cert_files):
|
|||
|
||||
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
|
||||
|
||||
|
||||
def validate_server_certificates_config():
|
||||
"""Validate server certificate configuration in Ansible files (CI mode)"""
|
||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
||||
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||
if not openssl_task_file:
|
||||
print("⚠ Could not find openssl.yml task file")
|
||||
return
|
||||
|
@ -215,7 +229,7 @@ def validate_server_certificates_config():
|
|||
content = f.read()
|
||||
|
||||
# Look for server certificate CSR section
|
||||
server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL)
|
||||
server_csr_section = re.search(r"Create CSRs for server certificate.*?register: server_csr", content, re.DOTALL)
|
||||
if not server_csr_section:
|
||||
print("⚠ Could not find server certificate CSR section")
|
||||
return
|
||||
|
@ -224,11 +238,11 @@ def validate_server_certificates_config():
|
|||
|
||||
# Check server certificate CSR configuration
|
||||
server_checks = [
|
||||
('subject_alt_name', 'Server certificates should have SAN extension'),
|
||||
('serverAuth', 'Server certificates should have serverAuth EKU'),
|
||||
('1.3.6.1.5.5.7.3.17', 'Server certificates should have IPsec End Entity EKU'),
|
||||
('digitalSignature', 'Server certificates should have digital signature usage'),
|
||||
('keyEncipherment', 'Server certificates should have key encipherment usage')
|
||||
("subject_alt_name", "Server certificates should have SAN extension"),
|
||||
("serverAuth", "Server certificates should have serverAuth EKU"),
|
||||
("1.3.6.1.5.5.7.3.17", "Server certificates should have IPsec End Entity EKU"),
|
||||
("digitalSignature", "Server certificates should have digital signature usage"),
|
||||
("keyEncipherment", "Server certificates should have key encipherment usage"),
|
||||
]
|
||||
|
||||
for check, message in server_checks:
|
||||
|
@ -236,15 +250,20 @@ def validate_server_certificates_config():
|
|||
|
||||
# Security check: Server certificates should NOT have clientAuth (Issue #153)
|
||||
# Look for clientAuth in extended_key_usage section, not in comments
|
||||
eku_lines = [line for line in server_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'clientAuth' in line)]
|
||||
has_client_auth = any('clientAuth' in line for line in eku_lines if line.strip().startswith('- '))
|
||||
eku_lines = [
|
||||
line
|
||||
for line in server_section.split("\n")
|
||||
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "clientAuth" in line)
|
||||
]
|
||||
has_client_auth = any("clientAuth" in line for line in eku_lines if line.strip().startswith("- "))
|
||||
assert not has_client_auth, "Server certificates should NOT have clientAuth EKU for role separation"
|
||||
|
||||
# Verify SAN extension is configured for Apple compatibility
|
||||
assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility"
|
||||
assert "subjectAltName" in server_section, "Server certificates missing SAN configuration for Apple compatibility"
|
||||
|
||||
print("✓ Server certificate configuration has proper EKU and SAN settings")
|
||||
|
||||
|
||||
def test_server_certificates():
|
||||
"""Test server certificates - uses real certs if available, else validates config"""
|
||||
cert_files = find_generated_certificates()
|
||||
|
@ -258,18 +277,18 @@ def validate_client_certificates_real(cert_files):
|
|||
"""Validate actual Ansible-generated client certificates"""
|
||||
# Find client certificates (not CA cert, not server cert with IP/DNS name)
|
||||
client_certs = []
|
||||
for cert_path in cert_files['server_certs']:
|
||||
if 'cacert.pem' in cert_path:
|
||||
for cert_path in cert_files["server_certs"]:
|
||||
if "cacert.pem" in cert_path:
|
||||
continue
|
||||
|
||||
with open(cert_path, 'rb') as f:
|
||||
with open(cert_path, "rb") as f:
|
||||
cert_data = f.read()
|
||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||
|
||||
# Check if this looks like a client cert vs server cert
|
||||
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
||||
# Server certs typically have IP addresses or domain names as CN
|
||||
if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4):
|
||||
if not (cn.replace(".", "").isdigit() or "." in cn and len(cn.split(".")) == 4):
|
||||
client_certs.append((cert_path, certificate))
|
||||
|
||||
if not client_certs:
|
||||
|
@ -287,7 +306,9 @@ def validate_client_certificates_real(cert_files):
|
|||
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
|
||||
|
||||
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
|
||||
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation"
|
||||
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, (
|
||||
"Client cert must NOT have serverAuth EKU to prevent server impersonation"
|
||||
)
|
||||
|
||||
# Check SAN extension for email
|
||||
try:
|
||||
|
@ -299,9 +320,10 @@ def validate_client_certificates_real(cert_files):
|
|||
|
||||
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
|
||||
|
||||
|
||||
def validate_client_certificates_config():
|
||||
"""Validate client certificate configuration in Ansible files (CI mode)"""
|
||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
||||
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||
if not openssl_task_file:
|
||||
print("⚠ Could not find openssl.yml task file")
|
||||
return
|
||||
|
@ -310,7 +332,9 @@ def validate_client_certificates_config():
|
|||
content = f.read()
|
||||
|
||||
# Look for client certificate CSR section
|
||||
client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL)
|
||||
client_csr_section = re.search(
|
||||
r"Create CSRs for client certificates.*?register: client_csr_jobs", content, re.DOTALL
|
||||
)
|
||||
if not client_csr_section:
|
||||
print("⚠ Could not find client certificate CSR section")
|
||||
return
|
||||
|
@ -319,11 +343,11 @@ def validate_client_certificates_config():
|
|||
|
||||
# Check client certificate configuration
|
||||
client_checks = [
|
||||
('clientAuth', 'Client certificates should have clientAuth EKU'),
|
||||
('1.3.6.1.5.5.7.3.17', 'Client certificates should have IPsec End Entity EKU'),
|
||||
('digitalSignature', 'Client certificates should have digital signature usage'),
|
||||
('keyEncipherment', 'Client certificates should have key encipherment usage'),
|
||||
('email:', 'Client certificates should have email SAN')
|
||||
("clientAuth", "Client certificates should have clientAuth EKU"),
|
||||
("1.3.6.1.5.5.7.3.17", "Client certificates should have IPsec End Entity EKU"),
|
||||
("digitalSignature", "Client certificates should have digital signature usage"),
|
||||
("keyEncipherment", "Client certificates should have key encipherment usage"),
|
||||
("email:", "Client certificates should have email SAN"),
|
||||
]
|
||||
|
||||
for check, message in client_checks:
|
||||
|
@ -331,15 +355,22 @@ def validate_client_certificates_config():
|
|||
|
||||
# Security check: Client certificates should NOT have serverAuth (Issue #153)
|
||||
# Look for serverAuth in extended_key_usage section, not in comments
|
||||
eku_lines = [line for line in client_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'serverAuth' in line)]
|
||||
has_server_auth = any('serverAuth' in line for line in eku_lines if line.strip().startswith('- '))
|
||||
eku_lines = [
|
||||
line
|
||||
for line in client_section.split("\n")
|
||||
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "serverAuth" in line)
|
||||
]
|
||||
has_server_auth = any("serverAuth" in line for line in eku_lines if line.strip().startswith("- "))
|
||||
assert not has_server_auth, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
|
||||
|
||||
# Verify client certificates use unique email domains (Issue #153)
|
||||
assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment"
|
||||
assert "openssl_constraint_random_id" in client_section, (
|
||||
"Client certificates should use unique email domain per deployment"
|
||||
)
|
||||
|
||||
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
|
||||
|
||||
|
||||
def test_client_certificates():
|
||||
"""Test client certificates - uses real certs if available, else validates config (Issue #75, #153)"""
|
||||
cert_files = find_generated_certificates()
|
||||
|
@ -351,24 +382,33 @@ def test_client_certificates():
|
|||
|
||||
def validate_pkcs12_files_real(cert_files):
|
||||
"""Validate actual Ansible-generated PKCS#12 files"""
|
||||
if not cert_files.get('p12_files'):
|
||||
if not cert_files.get("p12_files"):
|
||||
print("⚠ No PKCS#12 files found")
|
||||
return
|
||||
|
||||
major, minor = test_openssl_version_detection()
|
||||
|
||||
for p12_file in cert_files['p12_files']:
|
||||
for p12_file in cert_files["p12_files"]:
|
||||
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
|
||||
|
||||
# Test that PKCS#12 file can be read (validates format)
|
||||
legacy_flag = ['-legacy'] if major >= 3 else []
|
||||
legacy_flag = ["-legacy"] if major >= 3 else []
|
||||
|
||||
result = subprocess.run([
|
||||
'openssl', 'pkcs12', '-info',
|
||||
'-in', p12_file,
|
||||
'-passin', 'pass:', # Try empty password first
|
||||
'-noout'
|
||||
] + legacy_flag, capture_output=True, text=True)
|
||||
result = subprocess.run(
|
||||
[
|
||||
"openssl",
|
||||
"pkcs12",
|
||||
"-info",
|
||||
"-in",
|
||||
p12_file,
|
||||
"-passin",
|
||||
"pass:", # Try empty password first
|
||||
"-noout",
|
||||
]
|
||||
+ legacy_flag,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# PKCS#12 files should be readable (even if password-protected)
|
||||
# We're just testing format validity, not trying to extract contents
|
||||
|
@ -378,9 +418,10 @@ def validate_pkcs12_files_real(cert_files):
|
|||
|
||||
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
|
||||
|
||||
|
||||
def validate_pkcs12_files_config():
|
||||
"""Validate PKCS#12 file configuration in Ansible files (CI mode)"""
|
||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
||||
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||
if not openssl_task_file:
|
||||
print("⚠ Could not find openssl.yml task file")
|
||||
return
|
||||
|
@ -390,13 +431,13 @@ def validate_pkcs12_files_config():
|
|||
|
||||
# Check PKCS#12 generation configuration
|
||||
p12_checks = [
|
||||
('openssl_pkcs12', 'PKCS#12 generation should be configured'),
|
||||
('encryption_level', 'PKCS#12 encryption level should be configured'),
|
||||
('compatibility2022', 'PKCS#12 should use Apple-compatible encryption'),
|
||||
('friendly_name', 'PKCS#12 should have friendly names'),
|
||||
('other_certificates', 'PKCS#12 should include CA certificate for full chain'),
|
||||
('passphrase', 'PKCS#12 files should be password protected'),
|
||||
('mode: "0600"', 'PKCS#12 files should have secure permissions')
|
||||
("openssl_pkcs12", "PKCS#12 generation should be configured"),
|
||||
("encryption_level", "PKCS#12 encryption level should be configured"),
|
||||
("compatibility2022", "PKCS#12 should use Apple-compatible encryption"),
|
||||
("friendly_name", "PKCS#12 should have friendly names"),
|
||||
("other_certificates", "PKCS#12 should include CA certificate for full chain"),
|
||||
("passphrase", "PKCS#12 files should be password protected"),
|
||||
('mode: "0600"', "PKCS#12 files should have secure permissions"),
|
||||
]
|
||||
|
||||
for check, message in p12_checks:
|
||||
|
@ -404,6 +445,7 @@ def validate_pkcs12_files_config():
|
|||
|
||||
print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
|
||||
|
||||
|
||||
def test_pkcs12_files():
|
||||
"""Test PKCS#12 files - uses real files if available, else validates config (Issue #14755, #14718)"""
|
||||
cert_files = find_generated_certificates()
|
||||
|
@ -416,19 +458,19 @@ def test_pkcs12_files():
|
|||
def validate_certificate_chain_real(cert_files):
|
||||
"""Validate actual Ansible-generated certificate chain"""
|
||||
# Load CA certificate
|
||||
with open(cert_files['ca_cert'], 'rb') as f:
|
||||
with open(cert_files["ca_cert"], "rb") as f:
|
||||
ca_cert_data = f.read()
|
||||
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
|
||||
|
||||
# Test that all other certificates are signed by the CA
|
||||
other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']]
|
||||
other_certs = [f for f in cert_files["server_certs"] if f != cert_files["ca_cert"]]
|
||||
|
||||
if not other_certs:
|
||||
print("⚠ No client/server certificates found to validate")
|
||||
return
|
||||
|
||||
for cert_path in other_certs:
|
||||
with open(cert_path, 'rb') as f:
|
||||
with open(cert_path, "rb") as f:
|
||||
cert_data = f.read()
|
||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||
|
||||
|
@ -437,6 +479,7 @@ def validate_certificate_chain_real(cert_files):
|
|||
|
||||
# Verify certificate is currently valid (not expired)
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now(UTC)
|
||||
assert certificate.not_valid_before_utc <= now, f"Certificate {cert_path} not yet valid"
|
||||
assert certificate.not_valid_after_utc >= now, f"Certificate {cert_path} has expired"
|
||||
|
@ -445,9 +488,10 @@ def validate_certificate_chain_real(cert_files):
|
|||
|
||||
print("✓ All real certificates properly signed by CA")
|
||||
|
||||
|
||||
def validate_certificate_chain_config():
|
||||
"""Validate certificate chain configuration in Ansible files (CI mode)"""
|
||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
||||
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||
if not openssl_task_file:
|
||||
print("⚠ Could not find openssl.yml task file")
|
||||
return
|
||||
|
@ -457,15 +501,18 @@ def validate_certificate_chain_config():
|
|||
|
||||
# Check certificate signing configuration
|
||||
chain_checks = [
|
||||
('provider: ownca', 'Certificates should be signed by own CA'),
|
||||
('ownca_path', 'CA certificate path should be specified'),
|
||||
('ownca_privatekey_path', 'CA private key path should be specified'),
|
||||
('ownca_privatekey_passphrase', 'CA private key should be password protected'),
|
||||
('certificate_validity_days: 3650', 'Certificate validity should be configurable (default 10 years)'),
|
||||
('ownca_not_after: "+{{ certificate_validity_days }}d"', 'Certificates should use configurable validity period'),
|
||||
('ownca_not_before: "-1d"', 'Certificates should have backdated start time'),
|
||||
('curve: secp384r1', 'Should use strong elliptic curve cryptography'),
|
||||
('type: ECC', 'Should use elliptic curve keys for better security')
|
||||
("provider: ownca", "Certificates should be signed by own CA"),
|
||||
("ownca_path", "CA certificate path should be specified"),
|
||||
("ownca_privatekey_path", "CA private key path should be specified"),
|
||||
("ownca_privatekey_passphrase", "CA private key should be password protected"),
|
||||
("certificate_validity_days: 3650", "Certificate validity should be configurable (default 10 years)"),
|
||||
(
|
||||
'ownca_not_after: "+{{ certificate_validity_days }}d"',
|
||||
"Certificates should use configurable validity period",
|
||||
),
|
||||
('ownca_not_before: "-1d"', "Certificates should have backdated start time"),
|
||||
("curve: secp384r1", "Should use strong elliptic curve cryptography"),
|
||||
("type: ECC", "Should use elliptic curve keys for better security"),
|
||||
]
|
||||
|
||||
for check, message in chain_checks:
|
||||
|
@ -473,6 +520,7 @@ def validate_certificate_chain_config():
|
|||
|
||||
print("✓ Certificate chain configuration properly set up for CA signing")
|
||||
|
||||
|
||||
def test_certificate_chain():
|
||||
"""Test certificate chain - uses real certs if available, else validates config"""
|
||||
cert_files = find_generated_certificates()
|
||||
|
@ -499,6 +547,7 @@ def find_ansible_file(relative_path):
|
|||
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [
|
||||
test_openssl_version_detection,
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
Enhanced tests for StrongSwan templates.
|
||||
Tests all strongswan role templates with various configurations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
@ -21,7 +22,7 @@ def mock_to_uuid(value):
|
|||
|
||||
def mock_bool(value):
|
||||
"""Mock the bool filter"""
|
||||
return str(value).lower() in ('true', '1', 'yes', 'on')
|
||||
return str(value).lower() in ("true", "1", "yes", "on")
|
||||
|
||||
|
||||
def mock_version(version_string, comparison):
|
||||
|
@ -33,67 +34,67 @@ def mock_version(version_string, comparison):
|
|||
def mock_b64encode(value):
|
||||
"""Mock base64 encoding"""
|
||||
import base64
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value.encode('utf-8')
|
||||
return base64.b64encode(value).decode('ascii')
|
||||
value = value.encode("utf-8")
|
||||
return base64.b64encode(value).decode("ascii")
|
||||
|
||||
|
||||
def mock_b64decode(value):
|
||||
"""Mock base64 decoding"""
|
||||
import base64
|
||||
return base64.b64decode(value).decode('utf-8')
|
||||
|
||||
return base64.b64decode(value).decode("utf-8")
|
||||
|
||||
|
||||
def get_strongswan_test_variables(scenario='default'):
|
||||
def get_strongswan_test_variables(scenario="default"):
|
||||
"""Get test variables for StrongSwan templates with different scenarios."""
|
||||
base_vars = load_test_variables()
|
||||
|
||||
# Add StrongSwan specific variables
|
||||
strongswan_vars = {
|
||||
'ipsec_config_path': '/etc/ipsec.d',
|
||||
'ipsec_pki_path': '/etc/ipsec.d',
|
||||
'strongswan_enabled': True,
|
||||
'strongswan_network': '10.19.48.0/24',
|
||||
'strongswan_network_ipv6': 'fd9d:bc11:4021::/64',
|
||||
'strongswan_log_level': '2',
|
||||
'openssl_constraint_random_id': 'test-' + str(uuid.uuid4()),
|
||||
'subjectAltName': 'IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a',
|
||||
'subjectAltName_type': 'IP',
|
||||
'subjectAltName_client': 'IP:10.0.0.1',
|
||||
'ansible_default_ipv6': {
|
||||
'address': '2600:3c01::f03c:91ff:fedf:3b2a'
|
||||
},
|
||||
'openssl_version': '3.0.0',
|
||||
'p12_export_password': 'test-password',
|
||||
'ike_lifetime': '24h',
|
||||
'ipsec_lifetime': '8h',
|
||||
'ike_dpd': '30s',
|
||||
'ipsec_dead_peer_detection': True,
|
||||
'rekey_margin': '3m',
|
||||
'rekeymargin': '3m',
|
||||
'dpddelay': '35s',
|
||||
'keyexchange': 'ikev2',
|
||||
'ike_cipher': 'aes128gcm16-prfsha512-ecp256',
|
||||
'esp_cipher': 'aes128gcm16-ecp256',
|
||||
'leftsourceip': '10.19.48.1',
|
||||
'leftsubnet': '0.0.0.0/0,::/0',
|
||||
'rightsourceip': '10.19.48.2/24,fd9d:bc11:4021::2/64',
|
||||
"ipsec_config_path": "/etc/ipsec.d",
|
||||
"ipsec_pki_path": "/etc/ipsec.d",
|
||||
"strongswan_enabled": True,
|
||||
"strongswan_network": "10.19.48.0/24",
|
||||
"strongswan_network_ipv6": "fd9d:bc11:4021::/64",
|
||||
"strongswan_log_level": "2",
|
||||
"openssl_constraint_random_id": "test-" + str(uuid.uuid4()),
|
||||
"subjectAltName": "IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a",
|
||||
"subjectAltName_type": "IP",
|
||||
"subjectAltName_client": "IP:10.0.0.1",
|
||||
"ansible_default_ipv6": {"address": "2600:3c01::f03c:91ff:fedf:3b2a"},
|
||||
"openssl_version": "3.0.0",
|
||||
"p12_export_password": "test-password",
|
||||
"ike_lifetime": "24h",
|
||||
"ipsec_lifetime": "8h",
|
||||
"ike_dpd": "30s",
|
||||
"ipsec_dead_peer_detection": True,
|
||||
"rekey_margin": "3m",
|
||||
"rekeymargin": "3m",
|
||||
"dpddelay": "35s",
|
||||
"keyexchange": "ikev2",
|
||||
"ike_cipher": "aes128gcm16-prfsha512-ecp256",
|
||||
"esp_cipher": "aes128gcm16-ecp256",
|
||||
"leftsourceip": "10.19.48.1",
|
||||
"leftsubnet": "0.0.0.0/0,::/0",
|
||||
"rightsourceip": "10.19.48.2/24,fd9d:bc11:4021::2/64",
|
||||
}
|
||||
|
||||
# Merge with base variables
|
||||
test_vars = {**base_vars, **strongswan_vars}
|
||||
|
||||
# Apply scenario-specific overrides
|
||||
if scenario == 'ipv4_only':
|
||||
test_vars['ipv6_support'] = False
|
||||
test_vars['subjectAltName'] = 'IP:10.0.0.1'
|
||||
test_vars['ansible_default_ipv6'] = None
|
||||
elif scenario == 'dns_hostname':
|
||||
test_vars['IP_subject_alt_name'] = 'vpn.example.com'
|
||||
test_vars['subjectAltName'] = 'DNS:vpn.example.com'
|
||||
test_vars['subjectAltName_type'] = 'DNS'
|
||||
elif scenario == 'openssl_legacy':
|
||||
test_vars['openssl_version'] = '1.1.1'
|
||||
if scenario == "ipv4_only":
|
||||
test_vars["ipv6_support"] = False
|
||||
test_vars["subjectAltName"] = "IP:10.0.0.1"
|
||||
test_vars["ansible_default_ipv6"] = None
|
||||
elif scenario == "dns_hostname":
|
||||
test_vars["IP_subject_alt_name"] = "vpn.example.com"
|
||||
test_vars["subjectAltName"] = "DNS:vpn.example.com"
|
||||
test_vars["subjectAltName_type"] = "DNS"
|
||||
elif scenario == "openssl_legacy":
|
||||
test_vars["openssl_version"] = "1.1.1"
|
||||
|
||||
return test_vars
|
||||
|
||||
|
@ -101,16 +102,16 @@ def get_strongswan_test_variables(scenario='default'):
|
|||
def test_strongswan_templates():
|
||||
"""Test all StrongSwan templates with various configurations."""
|
||||
templates = [
|
||||
'roles/strongswan/templates/ipsec.conf.j2',
|
||||
'roles/strongswan/templates/ipsec.secrets.j2',
|
||||
'roles/strongswan/templates/strongswan.conf.j2',
|
||||
'roles/strongswan/templates/charon.conf.j2',
|
||||
'roles/strongswan/templates/client_ipsec.conf.j2',
|
||||
'roles/strongswan/templates/client_ipsec.secrets.j2',
|
||||
'roles/strongswan/templates/100-CustomLimitations.conf.j2',
|
||||
"roles/strongswan/templates/ipsec.conf.j2",
|
||||
"roles/strongswan/templates/ipsec.secrets.j2",
|
||||
"roles/strongswan/templates/strongswan.conf.j2",
|
||||
"roles/strongswan/templates/charon.conf.j2",
|
||||
"roles/strongswan/templates/client_ipsec.conf.j2",
|
||||
"roles/strongswan/templates/client_ipsec.secrets.j2",
|
||||
"roles/strongswan/templates/100-CustomLimitations.conf.j2",
|
||||
]
|
||||
|
||||
scenarios = ['default', 'ipv4_only', 'dns_hostname', 'openssl_legacy']
|
||||
scenarios = ["default", "ipv4_only", "dns_hostname", "openssl_legacy"]
|
||||
errors = []
|
||||
tested = 0
|
||||
|
||||
|
@ -127,21 +128,18 @@ def test_strongswan_templates():
|
|||
test_vars = get_strongswan_test_variables(scenario)
|
||||
|
||||
try:
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(template_dir),
|
||||
undefined=StrictUndefined
|
||||
)
|
||||
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
|
||||
|
||||
# Add mock filters
|
||||
env.filters['to_uuid'] = mock_to_uuid
|
||||
env.filters['bool'] = mock_bool
|
||||
env.filters['b64encode'] = mock_b64encode
|
||||
env.filters['b64decode'] = mock_b64decode
|
||||
env.tests['version'] = mock_version
|
||||
env.filters["to_uuid"] = mock_to_uuid
|
||||
env.filters["bool"] = mock_bool
|
||||
env.filters["b64encode"] = mock_b64encode
|
||||
env.filters["b64decode"] = mock_b64decode
|
||||
env.tests["version"] = mock_version
|
||||
|
||||
# For client templates, add item context
|
||||
if 'client' in template_name:
|
||||
test_vars['item'] = 'testuser'
|
||||
if "client" in template_name:
|
||||
test_vars["item"] = "testuser"
|
||||
|
||||
template = env.get_template(template_name)
|
||||
output = template.render(**test_vars)
|
||||
|
@ -150,16 +148,16 @@ def test_strongswan_templates():
|
|||
assert len(output) > 0, f"Empty output from {template_path} ({scenario})"
|
||||
|
||||
# Specific validations based on template
|
||||
if 'ipsec.conf' in template_name and 'client' not in template_name:
|
||||
assert 'conn' in output, "Missing connection definition"
|
||||
if scenario != 'ipv4_only' and test_vars.get('ipv6_support'):
|
||||
assert '::/0' in output or 'fd9d:bc11' in output, "Missing IPv6 configuration"
|
||||
if "ipsec.conf" in template_name and "client" not in template_name:
|
||||
assert "conn" in output, "Missing connection definition"
|
||||
if scenario != "ipv4_only" and test_vars.get("ipv6_support"):
|
||||
assert "::/0" in output or "fd9d:bc11" in output, "Missing IPv6 configuration"
|
||||
|
||||
if 'ipsec.secrets' in template_name:
|
||||
assert 'PSK' in output or 'ECDSA' in output, "Missing authentication method"
|
||||
if "ipsec.secrets" in template_name:
|
||||
assert "PSK" in output or "ECDSA" in output, "Missing authentication method"
|
||||
|
||||
if 'strongswan.conf' in template_name:
|
||||
assert 'charon' in output, "Missing charon configuration"
|
||||
if "strongswan.conf" in template_name:
|
||||
assert "charon" in output, "Missing charon configuration"
|
||||
|
||||
print(f" ✅ {template_name} ({scenario})")
|
||||
|
||||
|
@ -182,7 +180,7 @@ def test_openssl_template_constraints():
|
|||
# This tests the actual openssl.yml task file to ensure our fix works
|
||||
import yaml
|
||||
|
||||
openssl_path = 'roles/strongswan/tasks/openssl.yml'
|
||||
openssl_path = "roles/strongswan/tasks/openssl.yml"
|
||||
if not os.path.exists(openssl_path):
|
||||
print("⚠️ OpenSSL tasks file not found")
|
||||
return True
|
||||
|
@ -194,22 +192,23 @@ def test_openssl_template_constraints():
|
|||
# Find the CA CSR task
|
||||
ca_csr_task = None
|
||||
for task in content:
|
||||
if isinstance(task, dict) and task.get('name', '').startswith('Create certificate signing request'):
|
||||
if isinstance(task, dict) and task.get("name", "").startswith("Create certificate signing request"):
|
||||
ca_csr_task = task
|
||||
break
|
||||
|
||||
if ca_csr_task:
|
||||
# Check that name_constraints_permitted is properly formatted
|
||||
csr_module = ca_csr_task.get('community.crypto.openssl_csr_pipe', {})
|
||||
constraints = csr_module.get('name_constraints_permitted', '')
|
||||
csr_module = ca_csr_task.get("community.crypto.openssl_csr_pipe", {})
|
||||
constraints = csr_module.get("name_constraints_permitted", "")
|
||||
|
||||
# The constraints should be a Jinja2 template without inline comments
|
||||
if '#' in str(constraints):
|
||||
if "#" in str(constraints):
|
||||
# Check if the # is within {{ }}
|
||||
import re
|
||||
jinja_blocks = re.findall(r'\{\{.*?\}\}', str(constraints), re.DOTALL)
|
||||
|
||||
jinja_blocks = re.findall(r"\{\{.*?\}\}", str(constraints), re.DOTALL)
|
||||
for block in jinja_blocks:
|
||||
if '#' in block:
|
||||
if "#" in block:
|
||||
print("❌ Found inline comment in Jinja2 expression")
|
||||
return False
|
||||
|
||||
|
@ -223,7 +222,7 @@ def test_openssl_template_constraints():
|
|||
|
||||
def test_mobileconfig_template():
|
||||
"""Test the mobileconfig template with various scenarios."""
|
||||
template_path = 'roles/strongswan/templates/mobileconfig.j2'
|
||||
template_path = "roles/strongswan/templates/mobileconfig.j2"
|
||||
|
||||
if not os.path.exists(template_path):
|
||||
print("⚠️ Mobileconfig template not found")
|
||||
|
@ -237,20 +236,20 @@ def test_mobileconfig_template():
|
|||
|
||||
test_cases = [
|
||||
{
|
||||
'name': 'iPhone with cellular on-demand',
|
||||
'algo_ondemand_cellular': 'true',
|
||||
'algo_ondemand_wifi': 'false',
|
||||
"name": "iPhone with cellular on-demand",
|
||||
"algo_ondemand_cellular": "true",
|
||||
"algo_ondemand_wifi": "false",
|
||||
},
|
||||
{
|
||||
'name': 'iPad with WiFi on-demand',
|
||||
'algo_ondemand_cellular': 'false',
|
||||
'algo_ondemand_wifi': 'true',
|
||||
'algo_ondemand_wifi_exclude': 'MyHomeNetwork,OfficeWiFi',
|
||||
"name": "iPad with WiFi on-demand",
|
||||
"algo_ondemand_cellular": "false",
|
||||
"algo_ondemand_wifi": "true",
|
||||
"algo_ondemand_wifi_exclude": "MyHomeNetwork,OfficeWiFi",
|
||||
},
|
||||
{
|
||||
'name': 'Mac without on-demand',
|
||||
'algo_ondemand_cellular': 'false',
|
||||
'algo_ondemand_wifi': 'false',
|
||||
"name": "Mac without on-demand",
|
||||
"algo_ondemand_cellular": "false",
|
||||
"algo_ondemand_wifi": "false",
|
||||
},
|
||||
]
|
||||
|
||||
|
@ -258,43 +257,41 @@ def test_mobileconfig_template():
|
|||
for test_case in test_cases:
|
||||
test_vars = get_strongswan_test_variables()
|
||||
test_vars.update(test_case)
|
||||
|
||||
# Mock Ansible task result format for item
|
||||
class MockTaskResult:
|
||||
def __init__(self, content):
|
||||
self.stdout = content
|
||||
|
||||
test_vars['item'] = ('testuser', MockTaskResult('TU9DS19QS0NTMTJfQ09OVEVOVA==')) # Tuple with mock result
|
||||
test_vars['PayloadContentCA_base64'] = 'TU9DS19DQV9DRVJUX0JBU0U2NA==' # Valid base64
|
||||
test_vars['PayloadContentUser_base64'] = 'TU9DS19VU0VSX0NFUlRfQkFTRTY0' # Valid base64
|
||||
test_vars['pkcs12_PayloadCertificateUUID'] = str(uuid.uuid4())
|
||||
test_vars['PayloadContent'] = 'TU9DS19QS0NTMTJfQ09OVEVOVA==' # Valid base64 for PKCS12
|
||||
test_vars['algo_server_name'] = 'test-algo-vpn'
|
||||
test_vars['VPN_PayloadIdentifier'] = str(uuid.uuid4())
|
||||
test_vars['CA_PayloadIdentifier'] = str(uuid.uuid4())
|
||||
test_vars['PayloadContentCA'] = 'TU9DS19DQV9DRVJUX0NPTlRFTlQ=' # Valid base64
|
||||
test_vars["item"] = ("testuser", MockTaskResult("TU9DS19QS0NTMTJfQ09OVEVOVA==")) # Tuple with mock result
|
||||
test_vars["PayloadContentCA_base64"] = "TU9DS19DQV9DRVJUX0JBU0U2NA==" # Valid base64
|
||||
test_vars["PayloadContentUser_base64"] = "TU9DS19VU0VSX0NFUlRfQkFTRTY0" # Valid base64
|
||||
test_vars["pkcs12_PayloadCertificateUUID"] = str(uuid.uuid4())
|
||||
test_vars["PayloadContent"] = "TU9DS19QS0NTMTJfQ09OVEVOVA==" # Valid base64 for PKCS12
|
||||
test_vars["algo_server_name"] = "test-algo-vpn"
|
||||
test_vars["VPN_PayloadIdentifier"] = str(uuid.uuid4())
|
||||
test_vars["CA_PayloadIdentifier"] = str(uuid.uuid4())
|
||||
test_vars["PayloadContentCA"] = "TU9DS19DQV9DRVJUX0NPTlRFTlQ=" # Valid base64
|
||||
|
||||
try:
|
||||
env = Environment(
|
||||
loader=FileSystemLoader('roles/strongswan/templates'),
|
||||
undefined=StrictUndefined
|
||||
)
|
||||
env = Environment(loader=FileSystemLoader("roles/strongswan/templates"), undefined=StrictUndefined)
|
||||
|
||||
# Add mock filters
|
||||
env.filters['to_uuid'] = mock_to_uuid
|
||||
env.filters['b64encode'] = mock_b64encode
|
||||
env.filters['b64decode'] = mock_b64decode
|
||||
env.filters["to_uuid"] = mock_to_uuid
|
||||
env.filters["b64encode"] = mock_b64encode
|
||||
env.filters["b64decode"] = mock_b64decode
|
||||
|
||||
template = env.get_template('mobileconfig.j2')
|
||||
template = env.get_template("mobileconfig.j2")
|
||||
output = template.render(**test_vars)
|
||||
|
||||
# Validate output
|
||||
assert '<?xml' in output, "Missing XML declaration"
|
||||
assert '<plist' in output, "Missing plist element"
|
||||
assert 'PayloadType' in output, "Missing PayloadType"
|
||||
assert "<?xml" in output, "Missing XML declaration"
|
||||
assert "<plist" in output, "Missing plist element"
|
||||
assert "PayloadType" in output, "Missing PayloadType"
|
||||
|
||||
# Check on-demand configuration
|
||||
if test_case.get('algo_ondemand_cellular') == 'true' or test_case.get('algo_ondemand_wifi') == 'true':
|
||||
assert 'OnDemandEnabled' in output, f"Missing OnDemand config for {test_case['name']}"
|
||||
if test_case.get("algo_ondemand_cellular") == "true" or test_case.get("algo_ondemand_wifi") == "true":
|
||||
assert "OnDemandEnabled" in output, f"Missing OnDemand config for {test_case['name']}"
|
||||
|
||||
print(f" ✅ Mobileconfig: {test_case['name']}")
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
Test that Ansible templates render correctly
|
||||
This catches undefined variables, syntax errors, and logic bugs
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
@ -22,20 +23,20 @@ def mock_to_uuid(value):
|
|||
|
||||
def mock_bool(value):
|
||||
"""Mock the bool filter"""
|
||||
return str(value).lower() in ('true', '1', 'yes', 'on')
|
||||
return str(value).lower() in ("true", "1", "yes", "on")
|
||||
|
||||
|
||||
def mock_lookup(type, path):
|
||||
"""Mock the lookup function"""
|
||||
# Return fake data for file lookups
|
||||
if type == 'file':
|
||||
if 'private' in path:
|
||||
return 'MOCK_PRIVATE_KEY_BASE64=='
|
||||
elif 'public' in path:
|
||||
return 'MOCK_PUBLIC_KEY_BASE64=='
|
||||
elif 'preshared' in path:
|
||||
return 'MOCK_PRESHARED_KEY_BASE64=='
|
||||
return 'MOCK_LOOKUP_DATA'
|
||||
if type == "file":
|
||||
if "private" in path:
|
||||
return "MOCK_PRIVATE_KEY_BASE64=="
|
||||
elif "public" in path:
|
||||
return "MOCK_PUBLIC_KEY_BASE64=="
|
||||
elif "preshared" in path:
|
||||
return "MOCK_PRESHARED_KEY_BASE64=="
|
||||
return "MOCK_LOOKUP_DATA"
|
||||
|
||||
|
||||
def get_test_variables():
|
||||
|
@ -47,8 +48,8 @@ def get_test_variables():
|
|||
def find_templates():
|
||||
"""Find all Jinja2 template files in the repo"""
|
||||
templates = []
|
||||
for pattern in ['**/*.j2', '**/*.jinja2', '**/*.yml.j2']:
|
||||
templates.extend(Path('.').glob(pattern))
|
||||
for pattern in ["**/*.j2", "**/*.jinja2", "**/*.yml.j2"]:
|
||||
templates.extend(Path(".").glob(pattern))
|
||||
return templates
|
||||
|
||||
|
||||
|
@ -57,10 +58,10 @@ def test_template_syntax():
|
|||
templates = find_templates()
|
||||
|
||||
# Skip some paths that aren't real templates
|
||||
skip_paths = ['.git/', 'venv/', '.venv/', '.env/', 'configs/']
|
||||
skip_paths = [".git/", "venv/", ".venv/", ".env/", "configs/"]
|
||||
|
||||
# Skip templates that use Ansible-specific filters
|
||||
skip_templates = ['vpn-dict.j2', 'mobileconfig.j2', 'dnscrypt-proxy.toml.j2']
|
||||
skip_templates = ["vpn-dict.j2", "mobileconfig.j2", "dnscrypt-proxy.toml.j2"]
|
||||
|
||||
errors = []
|
||||
skipped = 0
|
||||
|
@ -76,10 +77,7 @@ def test_template_syntax():
|
|||
|
||||
try:
|
||||
template_dir = template_path.parent
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(template_dir),
|
||||
undefined=StrictUndefined
|
||||
)
|
||||
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
|
||||
|
||||
# Just try to load the template - this checks syntax
|
||||
env.get_template(template_path.name)
|
||||
|
@ -103,13 +101,13 @@ def test_template_syntax():
|
|||
def test_critical_templates():
|
||||
"""Test that critical templates render with test data"""
|
||||
critical_templates = [
|
||||
'roles/wireguard/templates/client.conf.j2',
|
||||
'roles/strongswan/templates/ipsec.conf.j2',
|
||||
'roles/strongswan/templates/ipsec.secrets.j2',
|
||||
'roles/dns/templates/adblock.sh.j2',
|
||||
'roles/dns/templates/dnsmasq.conf.j2',
|
||||
'roles/common/templates/rules.v4.j2',
|
||||
'roles/common/templates/rules.v6.j2',
|
||||
"roles/wireguard/templates/client.conf.j2",
|
||||
"roles/strongswan/templates/ipsec.conf.j2",
|
||||
"roles/strongswan/templates/ipsec.secrets.j2",
|
||||
"roles/dns/templates/adblock.sh.j2",
|
||||
"roles/dns/templates/dnsmasq.conf.j2",
|
||||
"roles/common/templates/rules.v4.j2",
|
||||
"roles/common/templates/rules.v6.j2",
|
||||
]
|
||||
|
||||
test_vars = get_test_variables()
|
||||
|
@ -123,21 +121,18 @@ def test_critical_templates():
|
|||
template_dir = os.path.dirname(template_path)
|
||||
template_name = os.path.basename(template_path)
|
||||
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(template_dir),
|
||||
undefined=StrictUndefined
|
||||
)
|
||||
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
|
||||
|
||||
# Add mock functions
|
||||
env.globals['lookup'] = mock_lookup
|
||||
env.filters['to_uuid'] = mock_to_uuid
|
||||
env.filters['bool'] = mock_bool
|
||||
env.globals["lookup"] = mock_lookup
|
||||
env.filters["to_uuid"] = mock_to_uuid
|
||||
env.filters["bool"] = mock_bool
|
||||
|
||||
template = env.get_template(template_name)
|
||||
|
||||
# Add item context for templates that use loops
|
||||
if 'client' in template_name:
|
||||
test_vars['item'] = ('test-user', 'test-user')
|
||||
if "client" in template_name:
|
||||
test_vars["item"] = ("test-user", "test-user")
|
||||
|
||||
# Try to render
|
||||
output = template.render(**test_vars)
|
||||
|
@ -163,17 +158,17 @@ def test_variable_consistency():
|
|||
"""Check that commonly used variables are defined consistently"""
|
||||
# Variables that should be used consistently across templates
|
||||
common_vars = [
|
||||
'server_name',
|
||||
'IP_subject_alt_name',
|
||||
'wireguard_port',
|
||||
'wireguard_network',
|
||||
'dns_servers',
|
||||
'users',
|
||||
"server_name",
|
||||
"IP_subject_alt_name",
|
||||
"wireguard_port",
|
||||
"wireguard_network",
|
||||
"dns_servers",
|
||||
"users",
|
||||
]
|
||||
|
||||
# Check if main.yml defines these
|
||||
if os.path.exists('main.yml'):
|
||||
with open('main.yml') as f:
|
||||
if os.path.exists("main.yml"):
|
||||
with open("main.yml") as f:
|
||||
content = f.read()
|
||||
|
||||
missing = []
|
||||
|
@ -192,28 +187,19 @@ def test_wireguard_ipv6_endpoints():
|
|||
"""Test that WireGuard client configs properly format IPv6 endpoints"""
|
||||
test_cases = [
|
||||
# IPv4 address - should not be bracketed
|
||||
{
|
||||
'IP_subject_alt_name': '192.168.1.100',
|
||||
'expected_endpoint': 'Endpoint = 192.168.1.100:51820'
|
||||
},
|
||||
{"IP_subject_alt_name": "192.168.1.100", "expected_endpoint": "Endpoint = 192.168.1.100:51820"},
|
||||
# IPv6 address - should be bracketed
|
||||
{
|
||||
'IP_subject_alt_name': '2600:3c01::f03c:91ff:fedf:3b2a',
|
||||
'expected_endpoint': 'Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820'
|
||||
"IP_subject_alt_name": "2600:3c01::f03c:91ff:fedf:3b2a",
|
||||
"expected_endpoint": "Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820",
|
||||
},
|
||||
# Hostname - should not be bracketed
|
||||
{
|
||||
'IP_subject_alt_name': 'vpn.example.com',
|
||||
'expected_endpoint': 'Endpoint = vpn.example.com:51820'
|
||||
},
|
||||
{"IP_subject_alt_name": "vpn.example.com", "expected_endpoint": "Endpoint = vpn.example.com:51820"},
|
||||
# IPv6 with zone ID - should be bracketed
|
||||
{
|
||||
'IP_subject_alt_name': 'fe80::1%eth0',
|
||||
'expected_endpoint': 'Endpoint = [fe80::1%eth0]:51820'
|
||||
},
|
||||
{"IP_subject_alt_name": "fe80::1%eth0", "expected_endpoint": "Endpoint = [fe80::1%eth0]:51820"},
|
||||
]
|
||||
|
||||
template_path = 'roles/wireguard/templates/client.conf.j2'
|
||||
template_path = "roles/wireguard/templates/client.conf.j2"
|
||||
if not os.path.exists(template_path):
|
||||
print(f"⚠ Skipping IPv6 endpoint test - {template_path} not found")
|
||||
return
|
||||
|
@ -225,24 +211,23 @@ def test_wireguard_ipv6_endpoints():
|
|||
try:
|
||||
# Set up test variables
|
||||
test_vars = {**base_vars, **test_case}
|
||||
test_vars['item'] = ('test-user', 'test-user')
|
||||
test_vars["item"] = ("test-user", "test-user")
|
||||
|
||||
# Render template
|
||||
env = Environment(
|
||||
loader=FileSystemLoader('roles/wireguard/templates'),
|
||||
undefined=StrictUndefined
|
||||
)
|
||||
env.globals['lookup'] = mock_lookup
|
||||
env = Environment(loader=FileSystemLoader("roles/wireguard/templates"), undefined=StrictUndefined)
|
||||
env.globals["lookup"] = mock_lookup
|
||||
|
||||
template = env.get_template('client.conf.j2')
|
||||
template = env.get_template("client.conf.j2")
|
||||
output = template.render(**test_vars)
|
||||
|
||||
# Check if the expected endpoint format is in the output
|
||||
if test_case['expected_endpoint'] not in output:
|
||||
errors.append(f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output")
|
||||
if test_case["expected_endpoint"] not in output:
|
||||
errors.append(
|
||||
f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output"
|
||||
)
|
||||
# Print relevant part of output for debugging
|
||||
for line in output.split('\n'):
|
||||
if 'Endpoint' in line:
|
||||
for line in output.split("\n"):
|
||||
if "Endpoint" in line:
|
||||
errors.append(f" Found: {line.strip()}")
|
||||
|
||||
except Exception as e:
|
||||
|
@ -262,27 +247,27 @@ def test_template_conditionals():
|
|||
test_cases = [
|
||||
# WireGuard enabled, IPsec disabled
|
||||
{
|
||||
'wireguard_enabled': True,
|
||||
'ipsec_enabled': False,
|
||||
'dns_encryption': True,
|
||||
'dns_adblocking': True,
|
||||
'algo_ssh_tunneling': False,
|
||||
"wireguard_enabled": True,
|
||||
"ipsec_enabled": False,
|
||||
"dns_encryption": True,
|
||||
"dns_adblocking": True,
|
||||
"algo_ssh_tunneling": False,
|
||||
},
|
||||
# IPsec enabled, WireGuard disabled
|
||||
{
|
||||
'wireguard_enabled': False,
|
||||
'ipsec_enabled': True,
|
||||
'dns_encryption': False,
|
||||
'dns_adblocking': False,
|
||||
'algo_ssh_tunneling': True,
|
||||
"wireguard_enabled": False,
|
||||
"ipsec_enabled": True,
|
||||
"dns_encryption": False,
|
||||
"dns_adblocking": False,
|
||||
"algo_ssh_tunneling": True,
|
||||
},
|
||||
# Both enabled
|
||||
{
|
||||
'wireguard_enabled': True,
|
||||
'ipsec_enabled': True,
|
||||
'dns_encryption': True,
|
||||
'dns_adblocking': True,
|
||||
'algo_ssh_tunneling': True,
|
||||
"wireguard_enabled": True,
|
||||
"ipsec_enabled": True,
|
||||
"dns_encryption": True,
|
||||
"dns_adblocking": True,
|
||||
"algo_ssh_tunneling": True,
|
||||
},
|
||||
]
|
||||
|
||||
|
@ -294,7 +279,7 @@ def test_template_conditionals():
|
|||
|
||||
# Test a few templates that have conditionals
|
||||
conditional_templates = [
|
||||
'roles/common/templates/rules.v4.j2',
|
||||
"roles/common/templates/rules.v4.j2",
|
||||
]
|
||||
|
||||
for template_path in conditional_templates:
|
||||
|
@ -305,23 +290,19 @@ def test_template_conditionals():
|
|||
template_dir = os.path.dirname(template_path)
|
||||
template_name = os.path.basename(template_path)
|
||||
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(template_dir),
|
||||
undefined=StrictUndefined
|
||||
)
|
||||
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
|
||||
|
||||
# Add mock functions
|
||||
env.globals['lookup'] = mock_lookup
|
||||
env.filters['to_uuid'] = mock_to_uuid
|
||||
env.filters['bool'] = mock_bool
|
||||
env.globals["lookup"] = mock_lookup
|
||||
env.filters["to_uuid"] = mock_to_uuid
|
||||
env.filters["bool"] = mock_bool
|
||||
|
||||
template = env.get_template(template_name)
|
||||
output = template.render(**test_vars)
|
||||
|
||||
# Verify conditionals work
|
||||
if test_case.get('wireguard_enabled'):
|
||||
assert str(test_vars['wireguard_port']) in output, \
|
||||
f"WireGuard port missing when enabled (case {i})"
|
||||
if test_case.get("wireguard_enabled"):
|
||||
assert str(test_vars["wireguard_port"]) in output, f"WireGuard port missing when enabled (case {i})"
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Conditional test failed for {template_path} case {i}: {e}")
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
Test user management functionality without deployment
|
||||
Based on issues #14745, #14746, #14738, #14726
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
@ -23,15 +24,15 @@ users:
|
|||
"""
|
||||
|
||||
config = yaml.safe_load(test_config)
|
||||
users = config.get('users', [])
|
||||
users = config.get("users", [])
|
||||
|
||||
assert len(users) == 5, f"Expected 5 users, got {len(users)}"
|
||||
assert 'alice' in users, "Missing user 'alice'"
|
||||
assert 'user-with-dash' in users, "Dash in username not handled"
|
||||
assert 'user_with_underscore' in users, "Underscore in username not handled"
|
||||
assert "alice" in users, "Missing user 'alice'"
|
||||
assert "user-with-dash" in users, "Dash in username not handled"
|
||||
assert "user_with_underscore" in users, "Underscore in username not handled"
|
||||
|
||||
# Test that usernames are valid
|
||||
username_pattern = re.compile(r'^[a-zA-Z0-9_-]+$')
|
||||
username_pattern = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
for user in users:
|
||||
assert username_pattern.match(user), f"Invalid username format: {user}"
|
||||
|
||||
|
@ -42,35 +43,27 @@ def test_server_selection_format():
|
|||
"""Test server selection string parsing (issue #14727)"""
|
||||
# Test various server display formats
|
||||
test_cases = [
|
||||
{"display": "1. 192.168.1.100 (algo-server)", "expected_ip": "192.168.1.100", "expected_name": "algo-server"},
|
||||
{"display": "2. 10.0.0.1 (production-vpn)", "expected_ip": "10.0.0.1", "expected_name": "production-vpn"},
|
||||
{
|
||||
'display': '1. 192.168.1.100 (algo-server)',
|
||||
'expected_ip': '192.168.1.100',
|
||||
'expected_name': 'algo-server'
|
||||
"display": "3. vpn.example.com (example-server)",
|
||||
"expected_ip": "vpn.example.com",
|
||||
"expected_name": "example-server",
|
||||
},
|
||||
{
|
||||
'display': '2. 10.0.0.1 (production-vpn)',
|
||||
'expected_ip': '10.0.0.1',
|
||||
'expected_name': 'production-vpn'
|
||||
},
|
||||
{
|
||||
'display': '3. vpn.example.com (example-server)',
|
||||
'expected_ip': 'vpn.example.com',
|
||||
'expected_name': 'example-server'
|
||||
}
|
||||
]
|
||||
|
||||
# Pattern to extract IP and name from display string
|
||||
pattern = re.compile(r'^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$')
|
||||
pattern = re.compile(r"^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$")
|
||||
|
||||
for case in test_cases:
|
||||
match = pattern.match(case['display'])
|
||||
match = pattern.match(case["display"])
|
||||
assert match, f"Failed to parse: {case['display']}"
|
||||
|
||||
ip_or_host = match.group(1)
|
||||
name = match.group(2)
|
||||
|
||||
assert ip_or_host == case['expected_ip'], f"Wrong IP extracted: {ip_or_host}"
|
||||
assert name == case['expected_name'], f"Wrong name extracted: {name}"
|
||||
assert ip_or_host == case["expected_ip"], f"Wrong IP extracted: {ip_or_host}"
|
||||
assert name == case["expected_name"], f"Wrong name extracted: {name}"
|
||||
|
||||
print("✓ Server selection format test passed")
|
||||
|
||||
|
@ -78,12 +71,12 @@ def test_server_selection_format():
|
|||
def test_ssh_key_preservation():
|
||||
"""Test that SSH keys aren't regenerated unnecessarily"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ssh_key_path = os.path.join(tmpdir, 'test_key')
|
||||
ssh_key_path = os.path.join(tmpdir, "test_key")
|
||||
|
||||
# Simulate existing SSH key
|
||||
with open(ssh_key_path, 'w') as f:
|
||||
with open(ssh_key_path, "w") as f:
|
||||
f.write("EXISTING_SSH_KEY_CONTENT")
|
||||
with open(f"{ssh_key_path}.pub", 'w') as f:
|
||||
with open(f"{ssh_key_path}.pub", "w") as f:
|
||||
f.write("ssh-rsa EXISTING_PUBLIC_KEY")
|
||||
|
||||
# Record original content
|
||||
|
@ -105,11 +98,7 @@ def test_ssh_key_preservation():
|
|||
def test_ca_password_handling():
|
||||
"""Test CA password validation and handling"""
|
||||
# Test password requirements
|
||||
valid_passwords = [
|
||||
"SecurePassword123!",
|
||||
"Algo-VPN-2024",
|
||||
"Complex#Pass@Word999"
|
||||
]
|
||||
valid_passwords = ["SecurePassword123!", "Algo-VPN-2024", "Complex#Pass@Word999"]
|
||||
|
||||
invalid_passwords = [
|
||||
"", # Empty
|
||||
|
@ -120,13 +109,13 @@ def test_ca_password_handling():
|
|||
# Basic password validation
|
||||
for pwd in valid_passwords:
|
||||
assert len(pwd) >= 12, f"Password too short: {pwd}"
|
||||
assert ' ' not in pwd, f"Password contains spaces: {pwd}"
|
||||
assert " " not in pwd, f"Password contains spaces: {pwd}"
|
||||
|
||||
for pwd in invalid_passwords:
|
||||
issues = []
|
||||
if len(pwd) < 12:
|
||||
issues.append("too short")
|
||||
if ' ' in pwd:
|
||||
if " " in pwd:
|
||||
issues.append("contains spaces")
|
||||
if not pwd:
|
||||
issues.append("empty")
|
||||
|
@ -137,8 +126,8 @@ def test_ca_password_handling():
|
|||
|
||||
def test_user_config_generation():
|
||||
"""Test that user configs would be generated correctly"""
|
||||
users = ['alice', 'bob', 'charlie']
|
||||
server_name = 'test-server'
|
||||
users = ["alice", "bob", "charlie"]
|
||||
server_name = "test-server"
|
||||
|
||||
# Simulate config file structure
|
||||
for user in users:
|
||||
|
@ -168,7 +157,7 @@ users:
|
|||
"""
|
||||
|
||||
config = yaml.safe_load(test_config)
|
||||
users = config.get('users', [])
|
||||
users = config.get("users", [])
|
||||
|
||||
# Check for duplicates
|
||||
unique_users = list(set(users))
|
||||
|
@ -182,7 +171,7 @@ users:
|
|||
duplicates.append(user)
|
||||
seen.add(user)
|
||||
|
||||
assert 'alice' in duplicates, "Duplicate 'alice' not detected"
|
||||
assert "alice" in duplicates, "Duplicate 'alice' not detected"
|
||||
|
||||
print("✓ Duplicate user handling test passed")
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
Test WireGuard key generation - focused on x25519_pubkey module integration
|
||||
Addresses test gap identified in tests/README.md line 63-67: WireGuard private/public key generation
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
|
@ -10,13 +11,13 @@ import sys
|
|||
import tempfile
|
||||
|
||||
# Add library directory to path to import our custom module
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "library"))
|
||||
|
||||
|
||||
def test_wireguard_tools_available():
|
||||
"""Test that WireGuard tools are available for validation"""
|
||||
try:
|
||||
result = subprocess.run(['wg', '--version'], capture_output=True, text=True)
|
||||
result = subprocess.run(["wg", "--version"], capture_output=True, text=True)
|
||||
assert result.returncode == 0, "WireGuard tools not available"
|
||||
print(f"✓ WireGuard tools available: {result.stdout.strip()}")
|
||||
return True
|
||||
|
@ -29,6 +30,7 @@ def test_x25519_module_import():
|
|||
"""Test that our custom x25519_pubkey module can be imported and used"""
|
||||
try:
|
||||
import x25519_pubkey # noqa: F401
|
||||
|
||||
print("✓ x25519_pubkey module imports successfully")
|
||||
return True
|
||||
except ImportError as e:
|
||||
|
@ -37,16 +39,17 @@ def test_x25519_module_import():
|
|||
|
||||
def generate_test_private_key():
|
||||
"""Generate a test private key using the same method as Algo"""
|
||||
with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file:
|
||||
with tempfile.NamedTemporaryFile(suffix=".raw", delete=False) as temp_file:
|
||||
raw_key_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Generate 32 random bytes for X25519 private key (same as community.crypto does)
|
||||
import secrets
|
||||
|
||||
raw_data = secrets.token_bytes(32)
|
||||
|
||||
# Write raw key to file (like community.crypto openssl_privatekey with format: raw)
|
||||
with open(raw_key_path, 'wb') as f:
|
||||
with open(raw_key_path, "wb") as f:
|
||||
f.write(raw_data)
|
||||
|
||||
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
|
||||
|
@ -83,7 +86,7 @@ def test_x25519_pubkey_from_raw_file():
|
|||
def exit_json(self, **kwargs):
|
||||
self.result = kwargs
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pub", delete=False) as temp_pub:
|
||||
public_key_path = temp_pub.name
|
||||
|
||||
try:
|
||||
|
@ -95,11 +98,9 @@ def test_x25519_pubkey_from_raw_file():
|
|||
|
||||
try:
|
||||
# Mock the module call
|
||||
mock_module = MockModule({
|
||||
'private_key_path': raw_key_path,
|
||||
'public_key_path': public_key_path,
|
||||
'private_key_b64': None
|
||||
})
|
||||
mock_module = MockModule(
|
||||
{"private_key_path": raw_key_path, "public_key_path": public_key_path, "private_key_b64": None}
|
||||
)
|
||||
|
||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||
|
||||
|
@ -107,8 +108,8 @@ def test_x25519_pubkey_from_raw_file():
|
|||
run_module()
|
||||
|
||||
# Check the result
|
||||
assert 'public_key' in mock_module.result
|
||||
assert mock_module.result['changed']
|
||||
assert "public_key" in mock_module.result
|
||||
assert mock_module.result["changed"]
|
||||
assert os.path.exists(public_key_path)
|
||||
|
||||
with open(public_key_path) as f:
|
||||
|
@ -160,11 +161,7 @@ def test_x25519_pubkey_from_b64_string():
|
|||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||
|
||||
try:
|
||||
mock_module = MockModule({
|
||||
'private_key_b64': b64_key,
|
||||
'private_key_path': None,
|
||||
'public_key_path': None
|
||||
})
|
||||
mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
|
||||
|
||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||
|
||||
|
@ -172,8 +169,8 @@ def test_x25519_pubkey_from_b64_string():
|
|||
run_module()
|
||||
|
||||
# Check the result
|
||||
assert 'public_key' in mock_module.result
|
||||
derived_pubkey = mock_module.result['public_key']
|
||||
assert "public_key" in mock_module.result
|
||||
derived_pubkey = mock_module.result["public_key"]
|
||||
|
||||
# Validate base64 format
|
||||
try:
|
||||
|
@ -222,21 +219,17 @@ def test_wireguard_validation():
|
|||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||
|
||||
try:
|
||||
mock_module = MockModule({
|
||||
'private_key_b64': b64_key,
|
||||
'private_key_path': None,
|
||||
'public_key_path': None
|
||||
})
|
||||
mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
|
||||
|
||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||
run_module()
|
||||
|
||||
derived_pubkey = mock_module.result['public_key']
|
||||
derived_pubkey = mock_module.result["public_key"]
|
||||
|
||||
finally:
|
||||
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False) as temp_config:
|
||||
# Create a WireGuard config using our keys
|
||||
wg_config = f"""[Interface]
|
||||
PrivateKey = {b64_key}
|
||||
|
@ -251,16 +244,12 @@ AllowedIPs = 10.19.49.2/32
|
|||
|
||||
try:
|
||||
# Test that WireGuard can parse our config
|
||||
result = subprocess.run([
|
||||
'wg-quick', 'strip', config_path
|
||||
], capture_output=True, text=True)
|
||||
result = subprocess.run(["wg-quick", "strip", config_path], capture_output=True, text=True)
|
||||
|
||||
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
|
||||
|
||||
# Test key derivation with wg pubkey command
|
||||
wg_result = subprocess.run([
|
||||
'wg', 'pubkey'
|
||||
], input=b64_key, capture_output=True, text=True)
|
||||
wg_result = subprocess.run(["wg", "pubkey"], input=b64_key, capture_output=True, text=True)
|
||||
|
||||
if wg_result.returncode == 0:
|
||||
wg_derived = wg_result.stdout.strip()
|
||||
|
@ -286,8 +275,8 @@ def test_key_consistency():
|
|||
raw_key_path, b64_key = generate_test_private_key()
|
||||
|
||||
try:
|
||||
def derive_pubkey_from_same_key():
|
||||
|
||||
def derive_pubkey_from_same_key():
|
||||
class MockModule:
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
@ -305,16 +294,18 @@ def test_key_consistency():
|
|||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||
|
||||
try:
|
||||
mock_module = MockModule({
|
||||
'private_key_b64': b64_key, # SAME key each time
|
||||
'private_key_path': None,
|
||||
'public_key_path': None
|
||||
})
|
||||
mock_module = MockModule(
|
||||
{
|
||||
"private_key_b64": b64_key, # SAME key each time
|
||||
"private_key_path": None,
|
||||
"public_key_path": None,
|
||||
}
|
||||
)
|
||||
|
||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||
run_module()
|
||||
|
||||
return mock_module.result['public_key']
|
||||
return mock_module.result["public_key"]
|
||||
|
||||
finally:
|
||||
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
||||
|
|
|
@ -14,13 +14,13 @@ from pathlib import Path
|
|||
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateSyntaxError, meta
|
||||
|
||||
|
||||
def find_jinja2_templates(root_dir: str = '.') -> list[Path]:
|
||||
def find_jinja2_templates(root_dir: str = ".") -> list[Path]:
|
||||
"""Find all Jinja2 template files in the project."""
|
||||
templates = []
|
||||
patterns = ['**/*.j2', '**/*.jinja2', '**/*.yml.j2', '**/*.conf.j2']
|
||||
patterns = ["**/*.j2", "**/*.jinja2", "**/*.yml.j2", "**/*.conf.j2"]
|
||||
|
||||
# Skip these directories
|
||||
skip_dirs = {'.git', '.venv', 'venv', '.env', 'configs', '__pycache__', '.cache'}
|
||||
skip_dirs = {".git", ".venv", "venv", ".env", "configs", "__pycache__", ".cache"}
|
||||
|
||||
for pattern in patterns:
|
||||
for path in Path(root_dir).glob(pattern):
|
||||
|
@ -39,25 +39,25 @@ def check_inline_comments_in_expressions(template_content: str, template_path: P
|
|||
errors = []
|
||||
|
||||
# Pattern to find Jinja2 expressions
|
||||
jinja_pattern = re.compile(r'\{\{.*?\}\}|\{%.*?%\}', re.DOTALL)
|
||||
jinja_pattern = re.compile(r"\{\{.*?\}\}|\{%.*?%\}", re.DOTALL)
|
||||
|
||||
for match in jinja_pattern.finditer(template_content):
|
||||
expression = match.group()
|
||||
lines = expression.split('\n')
|
||||
lines = expression.split("\n")
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# Check for # that's not in a string
|
||||
# Simple heuristic: if # appears after non-whitespace and not in quotes
|
||||
if '#' in line:
|
||||
if "#" in line:
|
||||
# Remove quoted strings to avoid false positives
|
||||
cleaned = re.sub(r'"[^"]*"', '', line)
|
||||
cleaned = re.sub(r"'[^']*'", '', cleaned)
|
||||
cleaned = re.sub(r'"[^"]*"', "", line)
|
||||
cleaned = re.sub(r"'[^']*'", "", cleaned)
|
||||
|
||||
if '#' in cleaned:
|
||||
if "#" in cleaned:
|
||||
# Check if it's likely a comment (has text after it)
|
||||
hash_pos = cleaned.index('#')
|
||||
if hash_pos > 0 and cleaned[hash_pos-1:hash_pos] != '\\':
|
||||
line_num = template_content[:match.start()].count('\n') + i + 1
|
||||
hash_pos = cleaned.index("#")
|
||||
if hash_pos > 0 and cleaned[hash_pos - 1 : hash_pos] != "\\":
|
||||
line_num = template_content[: match.start()].count("\n") + i + 1
|
||||
errors.append(
|
||||
f"{template_path}:{line_num}: Inline comment (#) found in Jinja2 expression. "
|
||||
f"Move comments outside the expression."
|
||||
|
@ -83,11 +83,24 @@ def check_undefined_variables(template_path: Path) -> list[str]:
|
|||
|
||||
# Common Ansible variables that are always available
|
||||
ansible_builtins = {
|
||||
'ansible_default_ipv4', 'ansible_default_ipv6', 'ansible_hostname',
|
||||
'ansible_distribution', 'ansible_distribution_version', 'ansible_facts',
|
||||
'inventory_hostname', 'hostvars', 'groups', 'group_names',
|
||||
'play_hosts', 'ansible_version', 'ansible_user', 'ansible_host',
|
||||
'item', 'ansible_loop', 'ansible_index', 'lookup'
|
||||
"ansible_default_ipv4",
|
||||
"ansible_default_ipv6",
|
||||
"ansible_hostname",
|
||||
"ansible_distribution",
|
||||
"ansible_distribution_version",
|
||||
"ansible_facts",
|
||||
"inventory_hostname",
|
||||
"hostvars",
|
||||
"groups",
|
||||
"group_names",
|
||||
"play_hosts",
|
||||
"ansible_version",
|
||||
"ansible_user",
|
||||
"ansible_host",
|
||||
"item",
|
||||
"ansible_loop",
|
||||
"ansible_index",
|
||||
"lookup",
|
||||
}
|
||||
|
||||
# Filter out known Ansible variables
|
||||
|
@ -95,9 +108,7 @@ def check_undefined_variables(template_path: Path) -> list[str]:
|
|||
|
||||
# Only report if there are truly unknown variables
|
||||
if unknown_vars and len(unknown_vars) < 20: # Avoid noise from templates with many vars
|
||||
errors.append(
|
||||
f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}"
|
||||
)
|
||||
errors.append(f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}")
|
||||
|
||||
except Exception:
|
||||
# Don't report parse errors here, they're handled elsewhere
|
||||
|
@ -116,9 +127,9 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
|
|||
# Skip full parsing for templates that use Ansible-specific features heavily
|
||||
# We still check for inline comments but skip full template parsing
|
||||
ansible_specific_templates = {
|
||||
'dnscrypt-proxy.toml.j2', # Uses |bool filter
|
||||
'mobileconfig.j2', # Uses |to_uuid filter and complex item structures
|
||||
'vpn-dict.j2', # Uses |to_uuid filter
|
||||
"dnscrypt-proxy.toml.j2", # Uses |bool filter
|
||||
"mobileconfig.j2", # Uses |to_uuid filter and complex item structures
|
||||
"vpn-dict.j2", # Uses |to_uuid filter
|
||||
}
|
||||
|
||||
if template_path.name in ansible_specific_templates:
|
||||
|
@ -139,18 +150,15 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
|
|||
errors.extend(check_inline_comments_in_expressions(template_content, template_path))
|
||||
|
||||
# Try to parse the template
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(template_path.parent),
|
||||
undefined=StrictUndefined
|
||||
)
|
||||
env = Environment(loader=FileSystemLoader(template_path.parent), undefined=StrictUndefined)
|
||||
|
||||
# Add mock Ansible filters to avoid syntax errors
|
||||
env.filters['bool'] = lambda x: x
|
||||
env.filters['to_uuid'] = lambda x: x
|
||||
env.filters['b64encode'] = lambda x: x
|
||||
env.filters['b64decode'] = lambda x: x
|
||||
env.filters['regex_replace'] = lambda x, y, z: x
|
||||
env.filters['default'] = lambda x, d: x if x else d
|
||||
env.filters["bool"] = lambda x: x
|
||||
env.filters["to_uuid"] = lambda x: x
|
||||
env.filters["b64encode"] = lambda x: x
|
||||
env.filters["b64decode"] = lambda x: x
|
||||
env.filters["regex_replace"] = lambda x, y, z: x
|
||||
env.filters["default"] = lambda x, d: x if x else d
|
||||
|
||||
# This will raise TemplateSyntaxError if there's a syntax problem
|
||||
env.get_template(template_path.name)
|
||||
|
@ -178,18 +186,20 @@ def check_common_antipatterns(template_path: Path) -> list[str]:
|
|||
content = f.read()
|
||||
|
||||
# Check for missing spaces around filters
|
||||
if re.search(r'\{\{[^}]+\|[^ ]', content):
|
||||
if re.search(r"\{\{[^}]+\|[^ ]", content):
|
||||
warnings.append(f"{template_path}: Missing space after filter pipe (|)")
|
||||
|
||||
# Check for deprecated 'when' in Jinja2 (should use if)
|
||||
if re.search(r'\{%\s*when\s+', content):
|
||||
if re.search(r"\{%\s*when\s+", content):
|
||||
warnings.append(f"{template_path}: Use 'if' instead of 'when' in Jinja2 templates")
|
||||
|
||||
# Check for extremely long expressions (harder to debug)
|
||||
for match in re.finditer(r'\{\{(.+?)\}\}', content, re.DOTALL):
|
||||
for match in re.finditer(r"\{\{(.+?)\}\}", content, re.DOTALL):
|
||||
if len(match.group(1)) > 200:
|
||||
line_num = content[:match.start()].count('\n') + 1
|
||||
warnings.append(f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up")
|
||||
line_num = content[: match.start()].count("\n") + 1
|
||||
warnings.append(
|
||||
f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
pass # Ignore errors in anti-pattern checking
|
||||
|
|
Loading…
Add table
Reference in a new issue