mirror of
https://github.com/trailofbits/algo.git
synced 2025-09-05 19:43:22 +02:00
Apply Python linting and formatting
- Run ruff check --fix to fix linting issues - Run ruff format to ensure consistent formatting - All tests still pass after formatting changes 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
51847f3fbf
commit
15be88d28b
28 changed files with 1178 additions and 1199 deletions
|
@ -9,11 +9,9 @@ import time
|
||||||
from ansible.module_utils.basic import AnsibleModule, env_fallback
|
from ansible.module_utils.basic import AnsibleModule, env_fallback
|
||||||
from ansible.module_utils.digital_ocean import DigitalOceanHelper
|
from ansible.module_utils.digital_ocean import DigitalOceanHelper
|
||||||
|
|
||||||
ANSIBLE_METADATA = {'metadata_version': '1.1',
|
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
|
||||||
'status': ['preview'],
|
|
||||||
'supported_by': 'community'}
|
|
||||||
|
|
||||||
DOCUMENTATION = '''
|
DOCUMENTATION = """
|
||||||
---
|
---
|
||||||
module: digital_ocean_floating_ip
|
module: digital_ocean_floating_ip
|
||||||
short_description: Manage DigitalOcean Floating IPs
|
short_description: Manage DigitalOcean Floating IPs
|
||||||
|
@ -44,10 +42,10 @@ notes:
|
||||||
- Version 2 of DigitalOcean API is used.
|
- Version 2 of DigitalOcean API is used.
|
||||||
requirements:
|
requirements:
|
||||||
- "python >= 2.6"
|
- "python >= 2.6"
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
EXAMPLES = '''
|
EXAMPLES = """
|
||||||
- name: "Create a Floating IP in region lon1"
|
- name: "Create a Floating IP in region lon1"
|
||||||
digital_ocean_floating_ip:
|
digital_ocean_floating_ip:
|
||||||
state: present
|
state: present
|
||||||
|
@ -63,10 +61,10 @@ EXAMPLES = '''
|
||||||
state: absent
|
state: absent
|
||||||
ip: "1.2.3.4"
|
ip: "1.2.3.4"
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
RETURN = '''
|
RETURN = """
|
||||||
# Digital Ocean API info https://developers.digitalocean.com/documentation/v2/#floating-ips
|
# Digital Ocean API info https://developers.digitalocean.com/documentation/v2/#floating-ips
|
||||||
data:
|
data:
|
||||||
description: a DigitalOcean Floating IP resource
|
description: a DigitalOcean Floating IP resource
|
||||||
|
@ -106,11 +104,10 @@ data:
|
||||||
"region_slug": "nyc3"
|
"region_slug": "nyc3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Response:
|
class Response:
|
||||||
|
|
||||||
def __init__(self, resp, info):
|
def __init__(self, resp, info):
|
||||||
self.body = None
|
self.body = None
|
||||||
if resp:
|
if resp:
|
||||||
|
@ -132,36 +129,37 @@ class Response:
|
||||||
def status_code(self):
|
def status_code(self):
|
||||||
return self.info["status"]
|
return self.info["status"]
|
||||||
|
|
||||||
|
|
||||||
def wait_action(module, rest, ip, action_id, timeout=10):
|
def wait_action(module, rest, ip, action_id, timeout=10):
|
||||||
end_time = time.time() + 10
|
end_time = time.time() + 10
|
||||||
while time.time() < end_time:
|
while time.time() < end_time:
|
||||||
response = rest.get(f'floating_ips/{ip}/actions/{action_id}')
|
response = rest.get(f"floating_ips/{ip}/actions/{action_id}")
|
||||||
# status_code = response.status_code # TODO: check status_code == 200?
|
# status_code = response.status_code # TODO: check status_code == 200?
|
||||||
status = response.json['action']['status']
|
status = response.json["action"]["status"]
|
||||||
if status == 'completed':
|
if status == "completed":
|
||||||
return True
|
return True
|
||||||
elif status == 'errored':
|
elif status == "errored":
|
||||||
module.fail_json(msg=f'Floating ip action error [ip: {ip}: action: {action_id}]', data=json)
|
module.fail_json(msg=f"Floating ip action error [ip: {ip}: action: {action_id}]", data=json)
|
||||||
|
|
||||||
module.fail_json(msg=f'Floating ip action timeout [ip: {ip}: action: {action_id}]', data=json)
|
module.fail_json(msg=f"Floating ip action timeout [ip: {ip}: action: {action_id}]", data=json)
|
||||||
|
|
||||||
|
|
||||||
def core(module):
|
def core(module):
|
||||||
# api_token = module.params['oauth_token'] # unused for now
|
# api_token = module.params['oauth_token'] # unused for now
|
||||||
state = module.params['state']
|
state = module.params["state"]
|
||||||
ip = module.params['ip']
|
ip = module.params["ip"]
|
||||||
droplet_id = module.params['droplet_id']
|
droplet_id = module.params["droplet_id"]
|
||||||
|
|
||||||
rest = DigitalOceanHelper(module)
|
rest = DigitalOceanHelper(module)
|
||||||
|
|
||||||
if state in ('present'):
|
if state in ("present"):
|
||||||
if droplet_id is not None and module.params['ip'] is not None:
|
if droplet_id is not None and module.params["ip"] is not None:
|
||||||
# Lets try to associate the ip to the specified droplet
|
# Lets try to associate the ip to the specified droplet
|
||||||
associate_floating_ips(module, rest)
|
associate_floating_ips(module, rest)
|
||||||
else:
|
else:
|
||||||
create_floating_ips(module, rest)
|
create_floating_ips(module, rest)
|
||||||
|
|
||||||
elif state in ('absent'):
|
elif state in ("absent"):
|
||||||
response = rest.delete(f"floating_ips/{ip}")
|
response = rest.delete(f"floating_ips/{ip}")
|
||||||
status_code = response.status_code
|
status_code = response.status_code
|
||||||
json_data = response.json
|
json_data = response.json
|
||||||
|
@ -174,65 +172,68 @@ def core(module):
|
||||||
|
|
||||||
|
|
||||||
def get_floating_ip_details(module, rest):
|
def get_floating_ip_details(module, rest):
|
||||||
ip = module.params['ip']
|
ip = module.params["ip"]
|
||||||
|
|
||||||
response = rest.get(f"floating_ips/{ip}")
|
response = rest.get(f"floating_ips/{ip}")
|
||||||
status_code = response.status_code
|
status_code = response.status_code
|
||||||
json_data = response.json
|
json_data = response.json
|
||||||
if status_code == 200:
|
if status_code == 200:
|
||||||
return json_data['floating_ip']
|
return json_data["floating_ip"]
|
||||||
else:
|
else:
|
||||||
module.fail_json(msg="Error assigning floating ip [{}: {}]".format(
|
module.fail_json(
|
||||||
status_code, json_data["message"]), region=module.params['region'])
|
msg="Error assigning floating ip [{}: {}]".format(status_code, json_data["message"]),
|
||||||
|
region=module.params["region"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def assign_floating_id_to_droplet(module, rest):
|
def assign_floating_id_to_droplet(module, rest):
|
||||||
ip = module.params['ip']
|
ip = module.params["ip"]
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"type": "assign",
|
"type": "assign",
|
||||||
"droplet_id": module.params['droplet_id'],
|
"droplet_id": module.params["droplet_id"],
|
||||||
}
|
}
|
||||||
|
|
||||||
response = rest.post(f"floating_ips/{ip}/actions", data=payload)
|
response = rest.post(f"floating_ips/{ip}/actions", data=payload)
|
||||||
status_code = response.status_code
|
status_code = response.status_code
|
||||||
json_data = response.json
|
json_data = response.json
|
||||||
if status_code == 201:
|
if status_code == 201:
|
||||||
wait_action(module, rest, ip, json_data['action']['id'])
|
wait_action(module, rest, ip, json_data["action"]["id"])
|
||||||
|
|
||||||
module.exit_json(changed=True, data=json_data)
|
module.exit_json(changed=True, data=json_data)
|
||||||
else:
|
else:
|
||||||
module.fail_json(msg="Error creating floating ip [{}: {}]".format(
|
module.fail_json(
|
||||||
status_code, json_data["message"]), region=module.params['region'])
|
msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
|
||||||
|
region=module.params["region"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def associate_floating_ips(module, rest):
|
def associate_floating_ips(module, rest):
|
||||||
floating_ip = get_floating_ip_details(module, rest)
|
floating_ip = get_floating_ip_details(module, rest)
|
||||||
droplet = floating_ip['droplet']
|
droplet = floating_ip["droplet"]
|
||||||
|
|
||||||
# TODO: If already assigned to a droplet verify if is one of the specified as valid
|
# TODO: If already assigned to a droplet verify if is one of the specified as valid
|
||||||
if droplet is not None and str(droplet['id']) in [module.params['droplet_id']]:
|
if droplet is not None and str(droplet["id"]) in [module.params["droplet_id"]]:
|
||||||
module.exit_json(changed=False)
|
module.exit_json(changed=False)
|
||||||
else:
|
else:
|
||||||
assign_floating_id_to_droplet(module, rest)
|
assign_floating_id_to_droplet(module, rest)
|
||||||
|
|
||||||
|
|
||||||
def create_floating_ips(module, rest):
|
def create_floating_ips(module, rest):
|
||||||
payload = {
|
payload = {}
|
||||||
}
|
|
||||||
floating_ip_data = None
|
floating_ip_data = None
|
||||||
|
|
||||||
if module.params['region'] is not None:
|
if module.params["region"] is not None:
|
||||||
payload["region"] = module.params['region']
|
payload["region"] = module.params["region"]
|
||||||
|
|
||||||
if module.params['droplet_id'] is not None:
|
if module.params["droplet_id"] is not None:
|
||||||
payload["droplet_id"] = module.params['droplet_id']
|
payload["droplet_id"] = module.params["droplet_id"]
|
||||||
|
|
||||||
floating_ips = rest.get_paginated_data(base_url='floating_ips?', data_key_name='floating_ips')
|
floating_ips = rest.get_paginated_data(base_url="floating_ips?", data_key_name="floating_ips")
|
||||||
|
|
||||||
for floating_ip in floating_ips:
|
for floating_ip in floating_ips:
|
||||||
if floating_ip['droplet'] and floating_ip['droplet']['id'] == module.params['droplet_id']:
|
if floating_ip["droplet"] and floating_ip["droplet"]["id"] == module.params["droplet_id"]:
|
||||||
floating_ip_data = {'floating_ip': floating_ip}
|
floating_ip_data = {"floating_ip": floating_ip}
|
||||||
|
|
||||||
if floating_ip_data:
|
if floating_ip_data:
|
||||||
module.exit_json(changed=False, data=floating_ip_data)
|
module.exit_json(changed=False, data=floating_ip_data)
|
||||||
|
@ -244,36 +245,34 @@ def create_floating_ips(module, rest):
|
||||||
if status_code == 202:
|
if status_code == 202:
|
||||||
module.exit_json(changed=True, data=json_data)
|
module.exit_json(changed=True, data=json_data)
|
||||||
else:
|
else:
|
||||||
module.fail_json(msg="Error creating floating ip [{}: {}]".format(
|
module.fail_json(
|
||||||
status_code, json_data["message"]), region=module.params['region'])
|
msg="Error creating floating ip [{}: {}]".format(status_code, json_data["message"]),
|
||||||
|
region=module.params["region"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
module = AnsibleModule(
|
module = AnsibleModule(
|
||||||
argument_spec={
|
argument_spec={
|
||||||
'state': {'choices': ['present', 'absent'], 'default': 'present'},
|
"state": {"choices": ["present", "absent"], "default": "present"},
|
||||||
'ip': {'aliases': ['id'], 'required': False},
|
"ip": {"aliases": ["id"], "required": False},
|
||||||
'region': {'required': False},
|
"region": {"required": False},
|
||||||
'droplet_id': {'required': False, 'type': 'int'},
|
"droplet_id": {"required": False, "type": "int"},
|
||||||
'oauth_token': {
|
"oauth_token": {
|
||||||
'no_log': True,
|
"no_log": True,
|
||||||
# Support environment variable for DigitalOcean OAuth Token
|
# Support environment variable for DigitalOcean OAuth Token
|
||||||
'fallback': (env_fallback, ['DO_API_TOKEN', 'DO_API_KEY', 'DO_OAUTH_TOKEN']),
|
"fallback": (env_fallback, ["DO_API_TOKEN", "DO_API_KEY", "DO_OAUTH_TOKEN"]),
|
||||||
'required': True,
|
"required": True,
|
||||||
},
|
},
|
||||||
'validate_certs': {'type': 'bool', 'default': True},
|
"validate_certs": {"type": "bool", "default": True},
|
||||||
'timeout': {'type': 'int', 'default': 30},
|
"timeout": {"type": "int", "default": 30},
|
||||||
},
|
},
|
||||||
required_if=[
|
required_if=[("state", "delete", ["ip"])],
|
||||||
('state', 'delete', ['ip'])
|
mutually_exclusive=[["region", "droplet_id"]],
|
||||||
],
|
|
||||||
mutually_exclusive=[
|
|
||||||
['region', 'droplet_id']
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
core(module)
|
core(module)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
|
from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
|
||||||
|
@ -10,7 +9,7 @@ from ansible.module_utils.gcp_utils import GcpModule, GcpSession, navigate_hash
|
||||||
# Documentation
|
# Documentation
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported_by': 'community'}
|
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Main
|
# Main
|
||||||
|
@ -18,20 +17,24 @@ ANSIBLE_METADATA = {'metadata_version': '1.1', 'status': ["preview"], 'supported
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
module = GcpModule(argument_spec={'filters': {'type': 'list', 'elements': 'str'}, 'scope': {'required': True, 'type': 'str'}})
|
module = GcpModule(
|
||||||
|
argument_spec={"filters": {"type": "list", "elements": "str"}, "scope": {"required": True, "type": "str"}}
|
||||||
|
)
|
||||||
|
|
||||||
if module._name == 'gcp_compute_image_facts':
|
if module._name == "gcp_compute_image_facts":
|
||||||
module.deprecate("The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version='2.13')
|
module.deprecate(
|
||||||
|
"The 'gcp_compute_image_facts' module has been renamed to 'gcp_compute_regions_info'", version="2.13"
|
||||||
|
)
|
||||||
|
|
||||||
if not module.params['scopes']:
|
if not module.params["scopes"]:
|
||||||
module.params['scopes'] = ['https://www.googleapis.com/auth/compute']
|
module.params["scopes"] = ["https://www.googleapis.com/auth/compute"]
|
||||||
|
|
||||||
items = fetch_list(module, collection(module), query_options(module.params['filters']))
|
items = fetch_list(module, collection(module), query_options(module.params["filters"]))
|
||||||
if items.get('items'):
|
if items.get("items"):
|
||||||
items = items.get('items')
|
items = items.get("items")
|
||||||
else:
|
else:
|
||||||
items = []
|
items = []
|
||||||
return_value = {'resources': items}
|
return_value = {"resources": items}
|
||||||
module.exit_json(**return_value)
|
module.exit_json(**return_value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,14 +43,14 @@ def collection(module):
|
||||||
|
|
||||||
|
|
||||||
def fetch_list(module, link, query):
|
def fetch_list(module, link, query):
|
||||||
auth = GcpSession(module, 'compute')
|
auth = GcpSession(module, "compute")
|
||||||
response = auth.get(link, params={'filter': query})
|
response = auth.get(link, params={"filter": query})
|
||||||
return return_if_object(module, response)
|
return return_if_object(module, response)
|
||||||
|
|
||||||
|
|
||||||
def query_options(filters):
|
def query_options(filters):
|
||||||
if not filters:
|
if not filters:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
if len(filters) == 1:
|
if len(filters) == 1:
|
||||||
return filters[0]
|
return filters[0]
|
||||||
|
@ -55,12 +58,12 @@ def query_options(filters):
|
||||||
queries = []
|
queries = []
|
||||||
for f in filters:
|
for f in filters:
|
||||||
# For multiple queries, all queries should have ()
|
# For multiple queries, all queries should have ()
|
||||||
if f[0] != '(' and f[-1] != ')':
|
if f[0] != "(" and f[-1] != ")":
|
||||||
queries.append("({})".format(''.join(f)))
|
queries.append("({})".format("".join(f)))
|
||||||
else:
|
else:
|
||||||
queries.append(f)
|
queries.append(f)
|
||||||
|
|
||||||
return ' '.join(queries)
|
return " ".join(queries)
|
||||||
|
|
||||||
|
|
||||||
def return_if_object(module, response):
|
def return_if_object(module, response):
|
||||||
|
@ -75,11 +78,11 @@ def return_if_object(module, response):
|
||||||
try:
|
try:
|
||||||
module.raise_for_status(response)
|
module.raise_for_status(response)
|
||||||
result = response.json()
|
result = response.json()
|
||||||
except getattr(json.decoder, 'JSONDecodeError', ValueError) as inst:
|
except getattr(json.decoder, "JSONDecodeError", ValueError) as inst:
|
||||||
module.fail_json(msg=f"Invalid JSON response with error: {inst}")
|
module.fail_json(msg=f"Invalid JSON response with error: {inst}")
|
||||||
|
|
||||||
if navigate_hash(result, ['error', 'errors']):
|
if navigate_hash(result, ["error", "errors"]):
|
||||||
module.fail_json(msg=navigate_hash(result, ['error', 'errors']))
|
module.fail_json(msg=navigate_hash(result, ["error", "errors"]))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
@ -3,12 +3,9 @@
|
||||||
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
|
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
|
||||||
|
|
||||||
|
|
||||||
|
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
|
||||||
|
|
||||||
ANSIBLE_METADATA = {'metadata_version': '1.1',
|
DOCUMENTATION = """
|
||||||
'status': ['preview'],
|
|
||||||
'supported_by': 'community'}
|
|
||||||
|
|
||||||
DOCUMENTATION = '''
|
|
||||||
---
|
---
|
||||||
module: lightsail_region_facts
|
module: lightsail_region_facts
|
||||||
short_description: Gather facts about AWS Lightsail regions.
|
short_description: Gather facts about AWS Lightsail regions.
|
||||||
|
@ -24,15 +21,15 @@ requirements:
|
||||||
extends_documentation_fragment:
|
extends_documentation_fragment:
|
||||||
- aws
|
- aws
|
||||||
- ec2
|
- ec2
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
EXAMPLES = '''
|
EXAMPLES = """
|
||||||
# Gather facts about all regions
|
# Gather facts about all regions
|
||||||
- lightsail_region_facts:
|
- lightsail_region_facts:
|
||||||
'''
|
"""
|
||||||
|
|
||||||
RETURN = '''
|
RETURN = """
|
||||||
regions:
|
regions:
|
||||||
returned: on success
|
returned: on success
|
||||||
description: >
|
description: >
|
||||||
|
@ -46,12 +43,13 @@ regions:
|
||||||
"displayName": "Virginia",
|
"displayName": "Virginia",
|
||||||
"name": "us-east-1"
|
"name": "us-east-1"
|
||||||
}]"
|
}]"
|
||||||
'''
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import botocore
|
import botocore
|
||||||
|
|
||||||
HAS_BOTOCORE = True
|
HAS_BOTOCORE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_BOTOCORE = False
|
HAS_BOTOCORE = False
|
||||||
|
@ -86,18 +84,19 @@ def main():
|
||||||
|
|
||||||
client = None
|
client = None
|
||||||
try:
|
try:
|
||||||
client = boto3_conn(module, conn_type='client', resource='lightsail',
|
client = boto3_conn(
|
||||||
region=region, endpoint=ec2_url, **aws_connect_kwargs)
|
module, conn_type="client", resource="lightsail", region=region, endpoint=ec2_url, **aws_connect_kwargs
|
||||||
|
)
|
||||||
except (botocore.exceptions.ClientError, botocore.exceptions.ValidationError) as e:
|
except (botocore.exceptions.ClientError, botocore.exceptions.ValidationError) as e:
|
||||||
module.fail_json(msg='Failed while connecting to the lightsail service: %s' % e, exception=traceback.format_exc())
|
module.fail_json(
|
||||||
|
msg="Failed while connecting to the lightsail service: %s" % e, exception=traceback.format_exc()
|
||||||
|
)
|
||||||
|
|
||||||
response = client.get_regions(
|
response = client.get_regions(includeAvailabilityZones=False)
|
||||||
includeAvailabilityZones=False
|
|
||||||
)
|
|
||||||
module.exit_json(changed=False, data=response)
|
module.exit_json(changed=False, data=response)
|
||||||
except (botocore.exceptions.ClientError, Exception) as e:
|
except (botocore.exceptions.ClientError, Exception) as e:
|
||||||
module.fail_json(msg=str(e), exception=traceback.format_exc())
|
module.fail_json(msg=str(e), exception=traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ansible.module_utils.linode import get_user_agent
|
||||||
LINODE_IMP_ERR = None
|
LINODE_IMP_ERR = None
|
||||||
try:
|
try:
|
||||||
from linode_api4 import LinodeClient, StackScript
|
from linode_api4 import LinodeClient, StackScript
|
||||||
|
|
||||||
HAS_LINODE_DEPENDENCY = True
|
HAS_LINODE_DEPENDENCY = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LINODE_IMP_ERR = traceback.format_exc()
|
LINODE_IMP_ERR = traceback.format_exc()
|
||||||
|
@ -21,57 +22,47 @@ def create_stackscript(module, client, **kwargs):
|
||||||
response = client.linode.stackscript_create(**kwargs)
|
response = client.linode.stackscript_create(**kwargs)
|
||||||
return response._raw_json
|
return response._raw_json
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
|
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
|
||||||
|
|
||||||
|
|
||||||
def stackscript_available(module, client):
|
def stackscript_available(module, client):
|
||||||
"""Try to retrieve a stackscript."""
|
"""Try to retrieve a stackscript."""
|
||||||
try:
|
try:
|
||||||
label = module.params['label']
|
label = module.params["label"]
|
||||||
desc = module.params['description']
|
desc = module.params["description"]
|
||||||
|
|
||||||
result = client.linode.stackscripts(StackScript.label == label,
|
result = client.linode.stackscripts(StackScript.label == label, StackScript.description == desc, mine_only=True)
|
||||||
StackScript.description == desc,
|
|
||||||
mine_only=True
|
|
||||||
)
|
|
||||||
return result[0]
|
return result[0]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return None
|
return None
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
|
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
|
||||||
|
|
||||||
|
|
||||||
def initialise_module():
|
def initialise_module():
|
||||||
"""Initialise the module parameter specification."""
|
"""Initialise the module parameter specification."""
|
||||||
return AnsibleModule(
|
return AnsibleModule(
|
||||||
argument_spec=dict(
|
argument_spec=dict(
|
||||||
label=dict(type='str', required=True),
|
label=dict(type="str", required=True),
|
||||||
state=dict(
|
state=dict(type="str", required=True, choices=["present", "absent"]),
|
||||||
type='str',
|
|
||||||
required=True,
|
|
||||||
choices=['present', 'absent']
|
|
||||||
),
|
|
||||||
access_token=dict(
|
access_token=dict(
|
||||||
type='str',
|
type="str",
|
||||||
required=True,
|
required=True,
|
||||||
no_log=True,
|
no_log=True,
|
||||||
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']),
|
fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
|
||||||
),
|
),
|
||||||
script=dict(type='str', required=True),
|
script=dict(type="str", required=True),
|
||||||
images=dict(type='list', required=True),
|
images=dict(type="list", required=True),
|
||||||
description=dict(type='str', required=False),
|
description=dict(type="str", required=False),
|
||||||
public=dict(type='bool', required=False, default=False),
|
public=dict(type="bool", required=False, default=False),
|
||||||
),
|
),
|
||||||
supports_check_mode=False
|
supports_check_mode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_client(module):
|
def build_client(module):
|
||||||
"""Build a LinodeClient."""
|
"""Build a LinodeClient."""
|
||||||
return LinodeClient(
|
return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
|
||||||
module.params['access_token'],
|
|
||||||
user_agent=get_user_agent('linode_v4_module')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -79,30 +70,31 @@ def main():
|
||||||
module = initialise_module()
|
module = initialise_module()
|
||||||
|
|
||||||
if not HAS_LINODE_DEPENDENCY:
|
if not HAS_LINODE_DEPENDENCY:
|
||||||
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR)
|
module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
|
||||||
|
|
||||||
client = build_client(module)
|
client = build_client(module)
|
||||||
stackscript = stackscript_available(module, client)
|
stackscript = stackscript_available(module, client)
|
||||||
|
|
||||||
if module.params['state'] == 'present' and stackscript is not None:
|
if module.params["state"] == "present" and stackscript is not None:
|
||||||
module.exit_json(changed=False, stackscript=stackscript._raw_json)
|
module.exit_json(changed=False, stackscript=stackscript._raw_json)
|
||||||
|
|
||||||
elif module.params['state'] == 'present' and stackscript is None:
|
elif module.params["state"] == "present" and stackscript is None:
|
||||||
stackscript_json = create_stackscript(
|
stackscript_json = create_stackscript(
|
||||||
module, client,
|
module,
|
||||||
label=module.params['label'],
|
client,
|
||||||
script=module.params['script'],
|
label=module.params["label"],
|
||||||
images=module.params['images'],
|
script=module.params["script"],
|
||||||
desc=module.params['description'],
|
images=module.params["images"],
|
||||||
public=module.params['public'],
|
desc=module.params["description"],
|
||||||
|
public=module.params["public"],
|
||||||
)
|
)
|
||||||
module.exit_json(changed=True, stackscript=stackscript_json)
|
module.exit_json(changed=True, stackscript=stackscript_json)
|
||||||
|
|
||||||
elif module.params['state'] == 'absent' and stackscript is not None:
|
elif module.params["state"] == "absent" and stackscript is not None:
|
||||||
stackscript.delete()
|
stackscript.delete()
|
||||||
module.exit_json(changed=True, stackscript=stackscript._raw_json)
|
module.exit_json(changed=True, stackscript=stackscript._raw_json)
|
||||||
|
|
||||||
elif module.params['state'] == 'absent' and stackscript is None:
|
elif module.params["state"] == "absent" and stackscript is None:
|
||||||
module.exit_json(changed=False, stackscript={})
|
module.exit_json(changed=False, stackscript={})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ansible.module_utils.linode import get_user_agent
|
||||||
LINODE_IMP_ERR = None
|
LINODE_IMP_ERR = None
|
||||||
try:
|
try:
|
||||||
from linode_api4 import Instance, LinodeClient
|
from linode_api4 import Instance, LinodeClient
|
||||||
|
|
||||||
HAS_LINODE_DEPENDENCY = True
|
HAS_LINODE_DEPENDENCY = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LINODE_IMP_ERR = traceback.format_exc()
|
LINODE_IMP_ERR = traceback.format_exc()
|
||||||
|
@ -21,82 +22,72 @@ except ImportError:
|
||||||
|
|
||||||
def create_linode(module, client, **kwargs):
|
def create_linode(module, client, **kwargs):
|
||||||
"""Creates a Linode instance and handles return format."""
|
"""Creates a Linode instance and handles return format."""
|
||||||
if kwargs['root_pass'] is None:
|
if kwargs["root_pass"] is None:
|
||||||
kwargs.pop('root_pass')
|
kwargs.pop("root_pass")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = client.linode.instance_create(**kwargs)
|
response = client.linode.instance_create(**kwargs)
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
|
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(response, tuple):
|
if isinstance(response, tuple):
|
||||||
instance, root_pass = response
|
instance, root_pass = response
|
||||||
instance_json = instance._raw_json
|
instance_json = instance._raw_json
|
||||||
instance_json.update({'root_pass': root_pass})
|
instance_json.update({"root_pass": root_pass})
|
||||||
return instance_json
|
return instance_json
|
||||||
else:
|
else:
|
||||||
return response._raw_json
|
return response._raw_json
|
||||||
except TypeError:
|
except TypeError:
|
||||||
module.fail_json(msg='Unable to parse Linode instance creation'
|
module.fail_json(
|
||||||
' response. Please raise a bug against this'
|
msg="Unable to parse Linode instance creation"
|
||||||
' module on https://github.com/ansible/ansible/issues'
|
" response. Please raise a bug against this"
|
||||||
)
|
" module on https://github.com/ansible/ansible/issues"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def maybe_instance_from_label(module, client):
|
def maybe_instance_from_label(module, client):
|
||||||
"""Try to retrieve an instance based on a label."""
|
"""Try to retrieve an instance based on a label."""
|
||||||
try:
|
try:
|
||||||
label = module.params['label']
|
label = module.params["label"]
|
||||||
result = client.linode.instances(Instance.label == label)
|
result = client.linode.instances(Instance.label == label)
|
||||||
return result[0]
|
return result[0]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return None
|
return None
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
module.fail_json(msg='Unable to query the Linode API. Saw: %s' % exception)
|
module.fail_json(msg="Unable to query the Linode API. Saw: %s" % exception)
|
||||||
|
|
||||||
|
|
||||||
def initialise_module():
|
def initialise_module():
|
||||||
"""Initialise the module parameter specification."""
|
"""Initialise the module parameter specification."""
|
||||||
return AnsibleModule(
|
return AnsibleModule(
|
||||||
argument_spec=dict(
|
argument_spec=dict(
|
||||||
label=dict(type='str', required=True),
|
label=dict(type="str", required=True),
|
||||||
state=dict(
|
state=dict(type="str", required=True, choices=["present", "absent"]),
|
||||||
type='str',
|
|
||||||
required=True,
|
|
||||||
choices=['present', 'absent']
|
|
||||||
),
|
|
||||||
access_token=dict(
|
access_token=dict(
|
||||||
type='str',
|
type="str",
|
||||||
required=True,
|
required=True,
|
||||||
no_log=True,
|
no_log=True,
|
||||||
fallback=(env_fallback, ['LINODE_ACCESS_TOKEN']),
|
fallback=(env_fallback, ["LINODE_ACCESS_TOKEN"]),
|
||||||
),
|
),
|
||||||
authorized_keys=dict(type='list', required=False),
|
authorized_keys=dict(type="list", required=False),
|
||||||
group=dict(type='str', required=False),
|
group=dict(type="str", required=False),
|
||||||
image=dict(type='str', required=False),
|
image=dict(type="str", required=False),
|
||||||
region=dict(type='str', required=False),
|
region=dict(type="str", required=False),
|
||||||
root_pass=dict(type='str', required=False, no_log=True),
|
root_pass=dict(type="str", required=False, no_log=True),
|
||||||
tags=dict(type='list', required=False),
|
tags=dict(type="list", required=False),
|
||||||
type=dict(type='str', required=False),
|
type=dict(type="str", required=False),
|
||||||
stackscript_id=dict(type='int', required=False),
|
stackscript_id=dict(type="int", required=False),
|
||||||
),
|
),
|
||||||
supports_check_mode=False,
|
supports_check_mode=False,
|
||||||
required_one_of=(
|
required_one_of=(["state", "label"],),
|
||||||
['state', 'label'],
|
required_together=(["region", "image", "type"],),
|
||||||
),
|
|
||||||
required_together=(
|
|
||||||
['region', 'image', 'type'],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_client(module):
|
def build_client(module):
|
||||||
"""Build a LinodeClient."""
|
"""Build a LinodeClient."""
|
||||||
return LinodeClient(
|
return LinodeClient(module.params["access_token"], user_agent=get_user_agent("linode_v4_module"))
|
||||||
module.params['access_token'],
|
|
||||||
user_agent=get_user_agent('linode_v4_module')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -104,34 +95,35 @@ def main():
|
||||||
module = initialise_module()
|
module = initialise_module()
|
||||||
|
|
||||||
if not HAS_LINODE_DEPENDENCY:
|
if not HAS_LINODE_DEPENDENCY:
|
||||||
module.fail_json(msg=missing_required_lib('linode-api4'), exception=LINODE_IMP_ERR)
|
module.fail_json(msg=missing_required_lib("linode-api4"), exception=LINODE_IMP_ERR)
|
||||||
|
|
||||||
client = build_client(module)
|
client = build_client(module)
|
||||||
instance = maybe_instance_from_label(module, client)
|
instance = maybe_instance_from_label(module, client)
|
||||||
|
|
||||||
if module.params['state'] == 'present' and instance is not None:
|
if module.params["state"] == "present" and instance is not None:
|
||||||
module.exit_json(changed=False, instance=instance._raw_json)
|
module.exit_json(changed=False, instance=instance._raw_json)
|
||||||
|
|
||||||
elif module.params['state'] == 'present' and instance is None:
|
elif module.params["state"] == "present" and instance is None:
|
||||||
instance_json = create_linode(
|
instance_json = create_linode(
|
||||||
module, client,
|
module,
|
||||||
authorized_keys=module.params['authorized_keys'],
|
client,
|
||||||
group=module.params['group'],
|
authorized_keys=module.params["authorized_keys"],
|
||||||
image=module.params['image'],
|
group=module.params["group"],
|
||||||
label=module.params['label'],
|
image=module.params["image"],
|
||||||
region=module.params['region'],
|
label=module.params["label"],
|
||||||
root_pass=module.params['root_pass'],
|
region=module.params["region"],
|
||||||
tags=module.params['tags'],
|
root_pass=module.params["root_pass"],
|
||||||
ltype=module.params['type'],
|
tags=module.params["tags"],
|
||||||
stackscript_id=module.params['stackscript_id'],
|
ltype=module.params["type"],
|
||||||
|
stackscript_id=module.params["stackscript_id"],
|
||||||
)
|
)
|
||||||
module.exit_json(changed=True, instance=instance_json)
|
module.exit_json(changed=True, instance=instance_json)
|
||||||
|
|
||||||
elif module.params['state'] == 'absent' and instance is not None:
|
elif module.params["state"] == "absent" and instance is not None:
|
||||||
instance.delete()
|
instance.delete()
|
||||||
module.exit_json(changed=True, instance=instance._raw_json)
|
module.exit_json(changed=True, instance=instance._raw_json)
|
||||||
|
|
||||||
elif module.params['state'] == 'absent' and instance is None:
|
elif module.params["state"] == "absent" and instance is None:
|
||||||
module.exit_json(changed=False, instance={})
|
module.exit_json(changed=False, instance={})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,14 +8,9 @@
|
||||||
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
|
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
|
||||||
|
|
||||||
|
|
||||||
|
ANSIBLE_METADATA = {"metadata_version": "1.1", "status": ["preview"], "supported_by": "community"}
|
||||||
|
|
||||||
ANSIBLE_METADATA = {
|
DOCUMENTATION = """
|
||||||
'metadata_version': '1.1',
|
|
||||||
'status': ['preview'],
|
|
||||||
'supported_by': 'community'
|
|
||||||
}
|
|
||||||
|
|
||||||
DOCUMENTATION = '''
|
|
||||||
---
|
---
|
||||||
module: scaleway_compute
|
module: scaleway_compute
|
||||||
short_description: Scaleway compute management module
|
short_description: Scaleway compute management module
|
||||||
|
@ -120,9 +115,9 @@ options:
|
||||||
- If no value provided, the default security group or current security group will be used
|
- If no value provided, the default security group or current security group will be used
|
||||||
required: false
|
required: false
|
||||||
version_added: "2.8"
|
version_added: "2.8"
|
||||||
'''
|
"""
|
||||||
|
|
||||||
EXAMPLES = '''
|
EXAMPLES = """
|
||||||
- name: Create a server
|
- name: Create a server
|
||||||
scaleway_compute:
|
scaleway_compute:
|
||||||
name: foobar
|
name: foobar
|
||||||
|
@ -156,10 +151,10 @@ EXAMPLES = '''
|
||||||
organization: 951df375-e094-4d26-97c1-ba548eeb9c42
|
organization: 951df375-e094-4d26-97c1-ba548eeb9c42
|
||||||
region: ams1
|
region: ams1
|
||||||
commercial_type: VC1S
|
commercial_type: VC1S
|
||||||
'''
|
"""
|
||||||
|
|
||||||
RETURN = '''
|
RETURN = """
|
||||||
'''
|
"""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
|
@ -167,19 +162,9 @@ import time
|
||||||
from ansible.module_utils.basic import AnsibleModule
|
from ansible.module_utils.basic import AnsibleModule
|
||||||
from ansible.module_utils.scaleway import SCALEWAY_LOCATION, Scaleway, scaleway_argument_spec
|
from ansible.module_utils.scaleway import SCALEWAY_LOCATION, Scaleway, scaleway_argument_spec
|
||||||
|
|
||||||
SCALEWAY_SERVER_STATES = (
|
SCALEWAY_SERVER_STATES = ("stopped", "stopping", "starting", "running", "locked")
|
||||||
'stopped',
|
|
||||||
'stopping',
|
|
||||||
'starting',
|
|
||||||
'running',
|
|
||||||
'locked'
|
|
||||||
)
|
|
||||||
|
|
||||||
SCALEWAY_TRANSITIONS_STATES = (
|
SCALEWAY_TRANSITIONS_STATES = ("stopping", "starting", "pending")
|
||||||
"stopping",
|
|
||||||
"starting",
|
|
||||||
"pending"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_image_id(compute_api, image_id):
|
def check_image_id(compute_api, image_id):
|
||||||
|
@ -188,9 +173,11 @@ def check_image_id(compute_api, image_id):
|
||||||
if response.ok and response.json:
|
if response.ok and response.json:
|
||||||
image_ids = [image["id"] for image in response.json["images"]]
|
image_ids = [image["id"] for image in response.json["images"]]
|
||||||
if image_id not in image_ids:
|
if image_id not in image_ids:
|
||||||
compute_api.module.fail_json(msg='Error in getting image %s on %s' % (image_id, compute_api.module.params.get('api_url')))
|
compute_api.module.fail_json(
|
||||||
|
msg="Error in getting image %s on %s" % (image_id, compute_api.module.params.get("api_url"))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get('api_url'))
|
compute_api.module.fail_json(msg="Error in getting images from: %s" % compute_api.module.params.get("api_url"))
|
||||||
|
|
||||||
|
|
||||||
def fetch_state(compute_api, server):
|
def fetch_state(compute_api, server):
|
||||||
|
@ -201,7 +188,7 @@ def fetch_state(compute_api, server):
|
||||||
return "absent"
|
return "absent"
|
||||||
|
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = 'Error during state fetching: (%s) %s' % (response.status_code, response.json)
|
msg = "Error during state fetching: (%s) %s" % (response.status_code, response.json)
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -243,7 +230,7 @@ def public_ip_payload(compute_api, public_ip):
|
||||||
# We check that the IP we want to attach exists, if so its ID is returned
|
# We check that the IP we want to attach exists, if so its ID is returned
|
||||||
response = compute_api.get("ips")
|
response = compute_api.get("ips")
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = 'Error during public IP validation: (%s) %s' % (response.status_code, response.json)
|
msg = "Error during public IP validation: (%s) %s" % (response.status_code, response.json)
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
ip_list = []
|
ip_list = []
|
||||||
|
@ -260,14 +247,15 @@ def public_ip_payload(compute_api, public_ip):
|
||||||
def create_server(compute_api, server):
|
def create_server(compute_api, server):
|
||||||
compute_api.module.debug("Starting a create_server")
|
compute_api.module.debug("Starting a create_server")
|
||||||
target_server = None
|
target_server = None
|
||||||
data = {"enable_ipv6": server["enable_ipv6"],
|
data = {
|
||||||
"tags": server["tags"],
|
"enable_ipv6": server["enable_ipv6"],
|
||||||
"commercial_type": server["commercial_type"],
|
"tags": server["tags"],
|
||||||
"image": server["image"],
|
"commercial_type": server["commercial_type"],
|
||||||
"dynamic_ip_required": server["dynamic_ip_required"],
|
"image": server["image"],
|
||||||
"name": server["name"],
|
"dynamic_ip_required": server["dynamic_ip_required"],
|
||||||
"organization": server["organization"]
|
"name": server["name"],
|
||||||
}
|
"organization": server["organization"],
|
||||||
|
}
|
||||||
|
|
||||||
if server["boot_type"]:
|
if server["boot_type"]:
|
||||||
data["boot_type"] = server["boot_type"]
|
data["boot_type"] = server["boot_type"]
|
||||||
|
@ -278,7 +266,7 @@ def create_server(compute_api, server):
|
||||||
response = compute_api.post(path="servers", data=data)
|
response = compute_api.post(path="servers", data=data)
|
||||||
|
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = 'Error during server creation: (%s) %s' % (response.status_code, response.json)
|
msg = "Error during server creation: (%s) %s" % (response.status_code, response.json)
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -304,10 +292,9 @@ def start_server(compute_api, server):
|
||||||
|
|
||||||
|
|
||||||
def perform_action(compute_api, server, action):
|
def perform_action(compute_api, server, action):
|
||||||
response = compute_api.post(path="servers/%s/action" % server["id"],
|
response = compute_api.post(path="servers/%s/action" % server["id"], data={"action": action})
|
||||||
data={"action": action})
|
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = 'Error during server %s: (%s) %s' % (action, response.status_code, response.json)
|
msg = "Error during server %s: (%s) %s" % (action, response.status_code, response.json)
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
wait_to_complete_state_transition(compute_api=compute_api, server=server)
|
wait_to_complete_state_transition(compute_api=compute_api, server=server)
|
||||||
|
@ -319,7 +306,7 @@ def remove_server(compute_api, server):
|
||||||
compute_api.module.debug("Starting remove server strategy")
|
compute_api.module.debug("Starting remove server strategy")
|
||||||
response = compute_api.delete(path="servers/%s" % server["id"])
|
response = compute_api.delete(path="servers/%s" % server["id"])
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = 'Error during server deletion: (%s) %s' % (response.status_code, response.json)
|
msg = "Error during server deletion: (%s) %s" % (response.status_code, response.json)
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
wait_to_complete_state_transition(compute_api=compute_api, server=server)
|
wait_to_complete_state_transition(compute_api=compute_api, server=server)
|
||||||
|
@ -341,14 +328,17 @@ def present_strategy(compute_api, wished_server):
|
||||||
else:
|
else:
|
||||||
target_server = query_results[0]
|
target_server = query_results[0]
|
||||||
|
|
||||||
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
|
if server_attributes_should_be_changed(
|
||||||
wished_server=wished_server):
|
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||||
|
):
|
||||||
changed = True
|
changed = True
|
||||||
|
|
||||||
if compute_api.module.check_mode:
|
if compute_api.module.check_mode:
|
||||||
return changed, {"status": "Server %s attributes would be changed." % target_server["id"]}
|
return changed, {"status": "Server %s attributes would be changed." % target_server["id"]}
|
||||||
|
|
||||||
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
|
target_server = server_change_attributes(
|
||||||
|
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||||
|
)
|
||||||
|
|
||||||
return changed, target_server
|
return changed, target_server
|
||||||
|
|
||||||
|
@ -375,7 +365,7 @@ def absent_strategy(compute_api, wished_server):
|
||||||
response = stop_server(compute_api=compute_api, server=target_server)
|
response = stop_server(compute_api=compute_api, server=target_server)
|
||||||
|
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
err_msg = f'Error while stopping a server before removing it [{response.status_code}: {response.json}]'
|
err_msg = f"Error while stopping a server before removing it [{response.status_code}: {response.json}]"
|
||||||
compute_api.module.fail_json(msg=err_msg)
|
compute_api.module.fail_json(msg=err_msg)
|
||||||
|
|
||||||
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
||||||
|
@ -383,7 +373,7 @@ def absent_strategy(compute_api, wished_server):
|
||||||
response = remove_server(compute_api=compute_api, server=target_server)
|
response = remove_server(compute_api=compute_api, server=target_server)
|
||||||
|
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
err_msg = f'Error while removing server [{response.status_code}: {response.json}]'
|
err_msg = f"Error while removing server [{response.status_code}: {response.json}]"
|
||||||
compute_api.module.fail_json(msg=err_msg)
|
compute_api.module.fail_json(msg=err_msg)
|
||||||
|
|
||||||
return changed, {"status": "Server %s deleted" % target_server["id"]}
|
return changed, {"status": "Server %s deleted" % target_server["id"]}
|
||||||
|
@ -403,14 +393,17 @@ def running_strategy(compute_api, wished_server):
|
||||||
else:
|
else:
|
||||||
target_server = query_results[0]
|
target_server = query_results[0]
|
||||||
|
|
||||||
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
|
if server_attributes_should_be_changed(
|
||||||
wished_server=wished_server):
|
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||||
|
):
|
||||||
changed = True
|
changed = True
|
||||||
|
|
||||||
if compute_api.module.check_mode:
|
if compute_api.module.check_mode:
|
||||||
return changed, {"status": "Server %s attributes would be changed before running it." % target_server["id"]}
|
return changed, {"status": "Server %s attributes would be changed before running it." % target_server["id"]}
|
||||||
|
|
||||||
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
|
target_server = server_change_attributes(
|
||||||
|
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||||
|
)
|
||||||
|
|
||||||
current_state = fetch_state(compute_api=compute_api, server=target_server)
|
current_state = fetch_state(compute_api=compute_api, server=target_server)
|
||||||
if current_state not in ("running", "starting"):
|
if current_state not in ("running", "starting"):
|
||||||
|
@ -422,7 +415,7 @@ def running_strategy(compute_api, wished_server):
|
||||||
|
|
||||||
response = start_server(compute_api=compute_api, server=target_server)
|
response = start_server(compute_api=compute_api, server=target_server)
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = f'Error while running server [{response.status_code}: {response.json}]'
|
msg = f"Error while running server [{response.status_code}: {response.json}]"
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
return changed, target_server
|
return changed, target_server
|
||||||
|
@ -435,7 +428,6 @@ def stop_strategy(compute_api, wished_server):
|
||||||
changed = False
|
changed = False
|
||||||
|
|
||||||
if not query_results:
|
if not query_results:
|
||||||
|
|
||||||
if compute_api.module.check_mode:
|
if compute_api.module.check_mode:
|
||||||
return changed, {"status": "A server would be created before being stopped."}
|
return changed, {"status": "A server would be created before being stopped."}
|
||||||
|
|
||||||
|
@ -446,15 +438,19 @@ def stop_strategy(compute_api, wished_server):
|
||||||
|
|
||||||
compute_api.module.debug("stop_strategy: Servers are found.")
|
compute_api.module.debug("stop_strategy: Servers are found.")
|
||||||
|
|
||||||
if server_attributes_should_be_changed(compute_api=compute_api, target_server=target_server,
|
if server_attributes_should_be_changed(
|
||||||
wished_server=wished_server):
|
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||||
|
):
|
||||||
changed = True
|
changed = True
|
||||||
|
|
||||||
if compute_api.module.check_mode:
|
if compute_api.module.check_mode:
|
||||||
return changed, {
|
return changed, {
|
||||||
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]}
|
"status": "Server %s attributes would be changed before stopping it." % target_server["id"]
|
||||||
|
}
|
||||||
|
|
||||||
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
|
target_server = server_change_attributes(
|
||||||
|
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||||
|
)
|
||||||
|
|
||||||
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
||||||
|
|
||||||
|
@ -472,7 +468,7 @@ def stop_strategy(compute_api, wished_server):
|
||||||
compute_api.module.debug(response.ok)
|
compute_api.module.debug(response.ok)
|
||||||
|
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = f'Error while stopping server [{response.status_code}: {response.json}]'
|
msg = f"Error while stopping server [{response.status_code}: {response.json}]"
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
return changed, target_server
|
return changed, target_server
|
||||||
|
@ -492,16 +488,19 @@ def restart_strategy(compute_api, wished_server):
|
||||||
else:
|
else:
|
||||||
target_server = query_results[0]
|
target_server = query_results[0]
|
||||||
|
|
||||||
if server_attributes_should_be_changed(compute_api=compute_api,
|
if server_attributes_should_be_changed(
|
||||||
target_server=target_server,
|
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||||
wished_server=wished_server):
|
):
|
||||||
changed = True
|
changed = True
|
||||||
|
|
||||||
if compute_api.module.check_mode:
|
if compute_api.module.check_mode:
|
||||||
return changed, {
|
return changed, {
|
||||||
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]}
|
"status": "Server %s attributes would be changed before rebooting it." % target_server["id"]
|
||||||
|
}
|
||||||
|
|
||||||
target_server = server_change_attributes(compute_api=compute_api, target_server=target_server, wished_server=wished_server)
|
target_server = server_change_attributes(
|
||||||
|
compute_api=compute_api, target_server=target_server, wished_server=wished_server
|
||||||
|
)
|
||||||
|
|
||||||
changed = True
|
changed = True
|
||||||
if compute_api.module.check_mode:
|
if compute_api.module.check_mode:
|
||||||
|
@ -513,14 +512,14 @@ def restart_strategy(compute_api, wished_server):
|
||||||
response = restart_server(compute_api=compute_api, server=target_server)
|
response = restart_server(compute_api=compute_api, server=target_server)
|
||||||
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = f'Error while restarting server that was running [{response.status_code}: {response.json}].'
|
msg = f"Error while restarting server that was running [{response.status_code}: {response.json}]."
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
if fetch_state(compute_api=compute_api, server=target_server) in ("stopped",):
|
if fetch_state(compute_api=compute_api, server=target_server) in ("stopped",):
|
||||||
response = restart_server(compute_api=compute_api, server=target_server)
|
response = restart_server(compute_api=compute_api, server=target_server)
|
||||||
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
wait_to_complete_state_transition(compute_api=compute_api, server=target_server)
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = f'Error while restarting server that was stopped [{response.status_code}: {response.json}].'
|
msg = f"Error while restarting server that was stopped [{response.status_code}: {response.json}]."
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
return changed, target_server
|
return changed, target_server
|
||||||
|
@ -531,18 +530,17 @@ state_strategy = {
|
||||||
"restarted": restart_strategy,
|
"restarted": restart_strategy,
|
||||||
"stopped": stop_strategy,
|
"stopped": stop_strategy,
|
||||||
"running": running_strategy,
|
"running": running_strategy,
|
||||||
"absent": absent_strategy
|
"absent": absent_strategy,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def find(compute_api, wished_server, per_page=1):
|
def find(compute_api, wished_server, per_page=1):
|
||||||
compute_api.module.debug("Getting inside find")
|
compute_api.module.debug("Getting inside find")
|
||||||
# Only the name attribute is accepted in the Compute query API
|
# Only the name attribute is accepted in the Compute query API
|
||||||
response = compute_api.get("servers", params={"name": wished_server["name"],
|
response = compute_api.get("servers", params={"name": wished_server["name"], "per_page": per_page})
|
||||||
"per_page": per_page})
|
|
||||||
|
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = 'Error during server search: (%s) %s' % (response.status_code, response.json)
|
msg = "Error during server search: (%s) %s" % (response.status_code, response.json)
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
search_results = response.json["servers"]
|
search_results = response.json["servers"]
|
||||||
|
@ -563,16 +561,22 @@ def server_attributes_should_be_changed(compute_api, target_server, wished_serve
|
||||||
compute_api.module.debug("Checking if server attributes should be changed")
|
compute_api.module.debug("Checking if server attributes should be changed")
|
||||||
compute_api.module.debug("Current Server: %s" % target_server)
|
compute_api.module.debug("Current Server: %s" % target_server)
|
||||||
compute_api.module.debug("Wished Server: %s" % wished_server)
|
compute_api.module.debug("Wished Server: %s" % wished_server)
|
||||||
debug_dict = dict((x, (target_server[x], wished_server[x]))
|
debug_dict = dict(
|
||||||
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
|
(x, (target_server[x], wished_server[x]))
|
||||||
if x in target_server and x in wished_server)
|
for x in PATCH_MUTABLE_SERVER_ATTRIBUTES
|
||||||
|
if x in target_server and x in wished_server
|
||||||
|
)
|
||||||
compute_api.module.debug("Debug dict %s" % debug_dict)
|
compute_api.module.debug("Debug dict %s" % debug_dict)
|
||||||
try:
|
try:
|
||||||
for key in PATCH_MUTABLE_SERVER_ATTRIBUTES:
|
for key in PATCH_MUTABLE_SERVER_ATTRIBUTES:
|
||||||
if key in target_server and key in wished_server:
|
if key in target_server and key in wished_server:
|
||||||
# When you are working with dict, only ID matter as we ask user to put only the resource ID in the playbook
|
# When you are working with dict, only ID matter as we ask user to put only the resource ID in the playbook
|
||||||
if isinstance(target_server[key], dict) and wished_server[key] and "id" in target_server[key].keys(
|
if (
|
||||||
) and target_server[key]["id"] != wished_server[key]:
|
isinstance(target_server[key], dict)
|
||||||
|
and wished_server[key]
|
||||||
|
and "id" in target_server[key].keys()
|
||||||
|
and target_server[key]["id"] != wished_server[key]
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
# Handling other structure compare simply the two objects content
|
# Handling other structure compare simply the two objects content
|
||||||
elif not isinstance(target_server[key], dict) and target_server[key] != wished_server[key]:
|
elif not isinstance(target_server[key], dict) and target_server[key] != wished_server[key]:
|
||||||
|
@ -598,10 +602,9 @@ def server_change_attributes(compute_api, target_server, wished_server):
|
||||||
elif not isinstance(target_server[key], dict):
|
elif not isinstance(target_server[key], dict):
|
||||||
patch_payload[key] = wished_server[key]
|
patch_payload[key] = wished_server[key]
|
||||||
|
|
||||||
response = compute_api.patch(path="servers/%s" % target_server["id"],
|
response = compute_api.patch(path="servers/%s" % target_server["id"], data=patch_payload)
|
||||||
data=patch_payload)
|
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
msg = 'Error during server attributes patching: (%s) %s' % (response.status_code, response.json)
|
msg = "Error during server attributes patching: (%s) %s" % (response.status_code, response.json)
|
||||||
compute_api.module.fail_json(msg=msg)
|
compute_api.module.fail_json(msg=msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -625,9 +628,9 @@ def core(module):
|
||||||
"boot_type": module.params["boot_type"],
|
"boot_type": module.params["boot_type"],
|
||||||
"tags": module.params["tags"],
|
"tags": module.params["tags"],
|
||||||
"organization": module.params["organization"],
|
"organization": module.params["organization"],
|
||||||
"security_group": module.params["security_group"]
|
"security_group": module.params["security_group"],
|
||||||
}
|
}
|
||||||
module.params['api_url'] = SCALEWAY_LOCATION[region]["api_endpoint"]
|
module.params["api_url"] = SCALEWAY_LOCATION[region]["api_endpoint"]
|
||||||
|
|
||||||
compute_api = Scaleway(module=module)
|
compute_api = Scaleway(module=module)
|
||||||
|
|
||||||
|
@ -643,22 +646,24 @@ def core(module):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
argument_spec = scaleway_argument_spec()
|
argument_spec = scaleway_argument_spec()
|
||||||
argument_spec.update(dict(
|
argument_spec.update(
|
||||||
image=dict(required=True),
|
dict(
|
||||||
name=dict(),
|
image=dict(required=True),
|
||||||
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
|
name=dict(),
|
||||||
commercial_type=dict(required=True),
|
region=dict(required=True, choices=SCALEWAY_LOCATION.keys()),
|
||||||
enable_ipv6=dict(default=False, type="bool"),
|
commercial_type=dict(required=True),
|
||||||
boot_type=dict(choices=['bootscript', 'local']),
|
enable_ipv6=dict(default=False, type="bool"),
|
||||||
public_ip=dict(default="absent"),
|
boot_type=dict(choices=["bootscript", "local"]),
|
||||||
state=dict(choices=state_strategy.keys(), default='present'),
|
public_ip=dict(default="absent"),
|
||||||
tags=dict(type="list", default=[]),
|
state=dict(choices=state_strategy.keys(), default="present"),
|
||||||
organization=dict(required=True),
|
tags=dict(type="list", default=[]),
|
||||||
wait=dict(type="bool", default=False),
|
organization=dict(required=True),
|
||||||
wait_timeout=dict(type="int", default=300),
|
wait=dict(type="bool", default=False),
|
||||||
wait_sleep_time=dict(type="int", default=3),
|
wait_timeout=dict(type="int", default=300),
|
||||||
security_group=dict(),
|
wait_sleep_time=dict(type="int", default=3),
|
||||||
))
|
security_group=dict(),
|
||||||
|
)
|
||||||
|
)
|
||||||
module = AnsibleModule(
|
module = AnsibleModule(
|
||||||
argument_spec=argument_spec,
|
argument_spec=argument_spec,
|
||||||
supports_check_mode=True,
|
supports_check_mode=True,
|
||||||
|
@ -667,5 +672,5 @@ def main():
|
||||||
core(module)
|
core(module)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -32,32 +32,30 @@ Returns:
|
||||||
def run_module():
|
def run_module():
|
||||||
"""
|
"""
|
||||||
Main execution function for the x25519_pubkey Ansible module.
|
Main execution function for the x25519_pubkey Ansible module.
|
||||||
|
|
||||||
Handles parameter validation, private key processing, public key derivation,
|
Handles parameter validation, private key processing, public key derivation,
|
||||||
and optional file output with idempotent behavior.
|
and optional file output with idempotent behavior.
|
||||||
"""
|
"""
|
||||||
module_args = {
|
module_args = {
|
||||||
'private_key_b64': {'type': 'str', 'required': False},
|
"private_key_b64": {"type": "str", "required": False},
|
||||||
'private_key_path': {'type': 'path', 'required': False},
|
"private_key_path": {"type": "path", "required": False},
|
||||||
'public_key_path': {'type': 'path', 'required': False},
|
"public_key_path": {"type": "path", "required": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'changed': False,
|
"changed": False,
|
||||||
'public_key': '',
|
"public_key": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
module = AnsibleModule(
|
module = AnsibleModule(
|
||||||
argument_spec=module_args,
|
argument_spec=module_args, required_one_of=[["private_key_b64", "private_key_path"]], supports_check_mode=True
|
||||||
required_one_of=[['private_key_b64', 'private_key_path']],
|
|
||||||
supports_check_mode=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
priv_b64 = None
|
priv_b64 = None
|
||||||
|
|
||||||
if module.params['private_key_path']:
|
if module.params["private_key_path"]:
|
||||||
try:
|
try:
|
||||||
with open(module.params['private_key_path'], 'rb') as f:
|
with open(module.params["private_key_path"], "rb") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
try:
|
try:
|
||||||
# First attempt: assume file contains base64 text data
|
# First attempt: assume file contains base64 text data
|
||||||
|
@ -71,12 +69,14 @@ def run_module():
|
||||||
# whitespace-like bytes (0x09, 0x0A, etc.) that must be preserved
|
# whitespace-like bytes (0x09, 0x0A, etc.) that must be preserved
|
||||||
# Stripping would corrupt the key and cause "got 31 bytes" errors
|
# Stripping would corrupt the key and cause "got 31 bytes" errors
|
||||||
if len(data) != 32:
|
if len(data) != 32:
|
||||||
module.fail_json(msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes")
|
module.fail_json(
|
||||||
|
msg=f"Private key file must be either base64 or exactly 32 raw bytes, got {len(data)} bytes"
|
||||||
|
)
|
||||||
priv_b64 = base64.b64encode(data).decode()
|
priv_b64 = base64.b64encode(data).decode()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
module.fail_json(msg=f"Failed to read private key file: {e}")
|
module.fail_json(msg=f"Failed to read private key file: {e}")
|
||||||
else:
|
else:
|
||||||
priv_b64 = module.params['private_key_b64']
|
priv_b64 = module.params["private_key_b64"]
|
||||||
|
|
||||||
# Validate input parameters
|
# Validate input parameters
|
||||||
if not priv_b64:
|
if not priv_b64:
|
||||||
|
@ -93,15 +93,12 @@ def run_module():
|
||||||
try:
|
try:
|
||||||
priv_key = x25519.X25519PrivateKey.from_private_bytes(priv_raw)
|
priv_key = x25519.X25519PrivateKey.from_private_bytes(priv_raw)
|
||||||
pub_key = priv_key.public_key()
|
pub_key = priv_key.public_key()
|
||||||
pub_raw = pub_key.public_bytes(
|
pub_raw = pub_key.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
|
||||||
encoding=serialization.Encoding.Raw,
|
|
||||||
format=serialization.PublicFormat.Raw
|
|
||||||
)
|
|
||||||
pub_b64 = base64.b64encode(pub_raw).decode()
|
pub_b64 = base64.b64encode(pub_raw).decode()
|
||||||
result['public_key'] = pub_b64
|
result["public_key"] = pub_b64
|
||||||
|
|
||||||
if module.params['public_key_path']:
|
if module.params["public_key_path"]:
|
||||||
pub_path = module.params['public_key_path']
|
pub_path = module.params["public_key_path"]
|
||||||
existing = None
|
existing = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -112,13 +109,13 @@ def run_module():
|
||||||
|
|
||||||
if existing != pub_b64:
|
if existing != pub_b64:
|
||||||
try:
|
try:
|
||||||
with open(pub_path, 'w') as f:
|
with open(pub_path, "w") as f:
|
||||||
f.write(pub_b64)
|
f.write(pub_b64)
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
module.fail_json(msg=f"Failed to write public key file: {e}")
|
module.fail_json(msg=f"Failed to write public key file: {e}")
|
||||||
|
|
||||||
result['public_key_path'] = pub_path
|
result["public_key_path"] = pub_path
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
module.fail_json(msg=f"Failed to derive public key: {e}")
|
module.fail_json(msg=f"Failed to derive public key: {e}")
|
||||||
|
@ -131,5 +128,5 @@ def main():
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Track test effectiveness by analyzing CI failures and correlating with issues/PRs
|
Track test effectiveness by analyzing CI failures and correlating with issues/PRs
|
||||||
This helps identify which tests actually catch bugs vs just failing randomly
|
This helps identify which tests actually catch bugs vs just failing randomly
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
@ -12,7 +13,7 @@ from pathlib import Path
|
||||||
|
|
||||||
def get_github_api_data(endpoint):
|
def get_github_api_data(endpoint):
|
||||||
"""Fetch data from GitHub API"""
|
"""Fetch data from GitHub API"""
|
||||||
cmd = ['gh', 'api', endpoint]
|
cmd = ["gh", "api", endpoint]
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
print(f"Error fetching {endpoint}: {result.stderr}")
|
print(f"Error fetching {endpoint}: {result.stderr}")
|
||||||
|
@ -25,40 +26,38 @@ def analyze_workflow_runs(repo_owner, repo_name, days_back=30):
|
||||||
since = (datetime.now() - timedelta(days=days_back)).isoformat()
|
since = (datetime.now() - timedelta(days=days_back)).isoformat()
|
||||||
|
|
||||||
# Get workflow runs
|
# Get workflow runs
|
||||||
runs = get_github_api_data(
|
runs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure")
|
||||||
f'/repos/{repo_owner}/{repo_name}/actions/runs?created=>{since}&status=failure'
|
|
||||||
)
|
|
||||||
|
|
||||||
if not runs:
|
if not runs:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
test_failures = defaultdict(list)
|
test_failures = defaultdict(list)
|
||||||
|
|
||||||
for run in runs.get('workflow_runs', []):
|
for run in runs.get("workflow_runs", []):
|
||||||
# Get jobs for this run
|
# Get jobs for this run
|
||||||
jobs = get_github_api_data(
|
jobs = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/actions/runs/{run['id']}/jobs")
|
||||||
f'/repos/{repo_owner}/{repo_name}/actions/runs/{run["id"]}/jobs'
|
|
||||||
)
|
|
||||||
|
|
||||||
if not jobs:
|
if not jobs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for job in jobs.get('jobs', []):
|
for job in jobs.get("jobs", []):
|
||||||
if job['conclusion'] == 'failure':
|
if job["conclusion"] == "failure":
|
||||||
# Try to extract which test failed from logs
|
# Try to extract which test failed from logs
|
||||||
logs_url = job.get('logs_url')
|
logs_url = job.get("logs_url")
|
||||||
if logs_url:
|
if logs_url:
|
||||||
# Parse logs to find test failures
|
# Parse logs to find test failures
|
||||||
test_name = extract_failed_test(job['name'], run['id'])
|
test_name = extract_failed_test(job["name"], run["id"])
|
||||||
if test_name:
|
if test_name:
|
||||||
test_failures[test_name].append({
|
test_failures[test_name].append(
|
||||||
'run_id': run['id'],
|
{
|
||||||
'run_number': run['run_number'],
|
"run_id": run["id"],
|
||||||
'date': run['created_at'],
|
"run_number": run["run_number"],
|
||||||
'branch': run['head_branch'],
|
"date": run["created_at"],
|
||||||
'commit': run['head_sha'][:7],
|
"branch": run["head_branch"],
|
||||||
'pr': extract_pr_number(run)
|
"commit": run["head_sha"][:7],
|
||||||
})
|
"pr": extract_pr_number(run),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return test_failures
|
return test_failures
|
||||||
|
|
||||||
|
@ -67,47 +66,44 @@ def extract_failed_test(job_name, run_id):
|
||||||
"""Extract test name from job - this is simplified"""
|
"""Extract test name from job - this is simplified"""
|
||||||
# Map job names to test categories
|
# Map job names to test categories
|
||||||
job_to_tests = {
|
job_to_tests = {
|
||||||
'Basic sanity tests': 'test_basic_sanity',
|
"Basic sanity tests": "test_basic_sanity",
|
||||||
'Ansible syntax check': 'ansible_syntax',
|
"Ansible syntax check": "ansible_syntax",
|
||||||
'Docker build test': 'docker_tests',
|
"Docker build test": "docker_tests",
|
||||||
'Configuration generation test': 'config_generation',
|
"Configuration generation test": "config_generation",
|
||||||
'Ansible dry-run validation': 'ansible_dry_run'
|
"Ansible dry-run validation": "ansible_dry_run",
|
||||||
}
|
}
|
||||||
return job_to_tests.get(job_name, job_name)
|
return job_to_tests.get(job_name, job_name)
|
||||||
|
|
||||||
|
|
||||||
def extract_pr_number(run):
|
def extract_pr_number(run):
|
||||||
"""Extract PR number from workflow run"""
|
"""Extract PR number from workflow run"""
|
||||||
for pr in run.get('pull_requests', []):
|
for pr in run.get("pull_requests", []):
|
||||||
return pr['number']
|
return pr["number"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def correlate_with_issues(repo_owner, repo_name, test_failures):
|
def correlate_with_issues(repo_owner, repo_name, test_failures):
|
||||||
"""Correlate test failures with issues/PRs that fixed them"""
|
"""Correlate test failures with issues/PRs that fixed them"""
|
||||||
correlations = defaultdict(lambda: {'caught_bugs': 0, 'false_positives': 0})
|
correlations = defaultdict(lambda: {"caught_bugs": 0, "false_positives": 0})
|
||||||
|
|
||||||
for test_name, failures in test_failures.items():
|
for test_name, failures in test_failures.items():
|
||||||
for failure in failures:
|
for failure in failures:
|
||||||
if failure['pr']:
|
if failure["pr"]:
|
||||||
# Check if PR was merged (indicating it fixed a real issue)
|
# Check if PR was merged (indicating it fixed a real issue)
|
||||||
pr = get_github_api_data(
|
pr = get_github_api_data(f"/repos/{repo_owner}/{repo_name}/pulls/{failure['pr']}")
|
||||||
f'/repos/{repo_owner}/{repo_name}/pulls/{failure["pr"]}'
|
|
||||||
)
|
|
||||||
|
|
||||||
if pr and pr.get('merged'):
|
if pr and pr.get("merged"):
|
||||||
# Check PR title/body for bug indicators
|
# Check PR title/body for bug indicators
|
||||||
title = pr.get('title', '').lower()
|
title = pr.get("title", "").lower()
|
||||||
body = pr.get('body', '').lower()
|
body = pr.get("body", "").lower()
|
||||||
|
|
||||||
bug_keywords = ['fix', 'bug', 'error', 'issue', 'broken', 'fail']
|
bug_keywords = ["fix", "bug", "error", "issue", "broken", "fail"]
|
||||||
is_bug_fix = any(keyword in title or keyword in body
|
is_bug_fix = any(keyword in title or keyword in body for keyword in bug_keywords)
|
||||||
for keyword in bug_keywords)
|
|
||||||
|
|
||||||
if is_bug_fix:
|
if is_bug_fix:
|
||||||
correlations[test_name]['caught_bugs'] += 1
|
correlations[test_name]["caught_bugs"] += 1
|
||||||
else:
|
else:
|
||||||
correlations[test_name]['false_positives'] += 1
|
correlations[test_name]["false_positives"] += 1
|
||||||
|
|
||||||
return correlations
|
return correlations
|
||||||
|
|
||||||
|
@ -133,8 +129,8 @@ def generate_effectiveness_report(test_failures, correlations):
|
||||||
scores = []
|
scores = []
|
||||||
for test_name, failures in test_failures.items():
|
for test_name, failures in test_failures.items():
|
||||||
failure_count = len(failures)
|
failure_count = len(failures)
|
||||||
caught = correlations[test_name]['caught_bugs']
|
caught = correlations[test_name]["caught_bugs"]
|
||||||
false_pos = correlations[test_name]['false_positives']
|
false_pos = correlations[test_name]["false_positives"]
|
||||||
|
|
||||||
# Calculate effectiveness (bugs caught / total failures)
|
# Calculate effectiveness (bugs caught / total failures)
|
||||||
if failure_count > 0:
|
if failure_count > 0:
|
||||||
|
@ -159,12 +155,12 @@ def generate_effectiveness_report(test_failures, correlations):
|
||||||
elif effectiveness > 0.8:
|
elif effectiveness > 0.8:
|
||||||
report.append(f"- ✅ `{test_name}` is highly effective ({effectiveness:.0%})")
|
report.append(f"- ✅ `{test_name}` is highly effective ({effectiveness:.0%})")
|
||||||
|
|
||||||
return '\n'.join(report)
|
return "\n".join(report)
|
||||||
|
|
||||||
|
|
||||||
def save_metrics(test_failures, correlations):
|
def save_metrics(test_failures, correlations):
|
||||||
"""Save metrics to JSON for historical tracking"""
|
"""Save metrics to JSON for historical tracking"""
|
||||||
metrics_file = Path('.metrics/test-effectiveness.json')
|
metrics_file = Path(".metrics/test-effectiveness.json")
|
||||||
metrics_file.parent.mkdir(exist_ok=True)
|
metrics_file.parent.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# Load existing metrics
|
# Load existing metrics
|
||||||
|
@ -176,38 +172,34 @@ def save_metrics(test_failures, correlations):
|
||||||
|
|
||||||
# Add current metrics
|
# Add current metrics
|
||||||
current = {
|
current = {
|
||||||
'date': datetime.now().isoformat(),
|
"date": datetime.now().isoformat(),
|
||||||
'test_failures': {
|
"test_failures": {test: len(failures) for test, failures in test_failures.items()},
|
||||||
test: len(failures) for test, failures in test_failures.items()
|
"effectiveness": {
|
||||||
},
|
|
||||||
'effectiveness': {
|
|
||||||
test: {
|
test: {
|
||||||
'caught_bugs': data['caught_bugs'],
|
"caught_bugs": data["caught_bugs"],
|
||||||
'false_positives': data['false_positives'],
|
"false_positives": data["false_positives"],
|
||||||
'score': data['caught_bugs'] / (data['caught_bugs'] + data['false_positives'])
|
"score": data["caught_bugs"] / (data["caught_bugs"] + data["false_positives"])
|
||||||
if (data['caught_bugs'] + data['false_positives']) > 0 else 0
|
if (data["caught_bugs"] + data["false_positives"]) > 0
|
||||||
|
else 0,
|
||||||
}
|
}
|
||||||
for test, data in correlations.items()
|
for test, data in correlations.items()
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
historical.append(current)
|
historical.append(current)
|
||||||
|
|
||||||
# Keep last 12 months of data
|
# Keep last 12 months of data
|
||||||
cutoff = datetime.now() - timedelta(days=365)
|
cutoff = datetime.now() - timedelta(days=365)
|
||||||
historical = [
|
historical = [h for h in historical if datetime.fromisoformat(h["date"]) > cutoff]
|
||||||
h for h in historical
|
|
||||||
if datetime.fromisoformat(h['date']) > cutoff
|
|
||||||
]
|
|
||||||
|
|
||||||
with open(metrics_file, 'w') as f:
|
with open(metrics_file, "w") as f:
|
||||||
json.dump(historical, f, indent=2)
|
json.dump(historical, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# Configure these for your repo
|
# Configure these for your repo
|
||||||
REPO_OWNER = 'trailofbits'
|
REPO_OWNER = "trailofbits"
|
||||||
REPO_NAME = 'algo'
|
REPO_NAME = "algo"
|
||||||
|
|
||||||
print("Analyzing test effectiveness...")
|
print("Analyzing test effectiveness...")
|
||||||
|
|
||||||
|
@ -223,9 +215,9 @@ if __name__ == '__main__':
|
||||||
print("\n" + report)
|
print("\n" + report)
|
||||||
|
|
||||||
# Save report
|
# Save report
|
||||||
report_file = Path('.metrics/test-effectiveness-report.md')
|
report_file = Path(".metrics/test-effectiveness-report.md")
|
||||||
report_file.parent.mkdir(exist_ok=True)
|
report_file.parent.mkdir(exist_ok=True)
|
||||||
with open(report_file, 'w') as f:
|
with open(report_file, "w") as f:
|
||||||
f.write(report)
|
f.write(report)
|
||||||
print(f"\nReport saved to: {report_file}")
|
print(f"\nReport saved to: {report_file}")
|
||||||
|
|
||||||
|
|
3
tests/fixtures/__init__.py
vendored
3
tests/fixtures/__init__.py
vendored
|
@ -1,4 +1,5 @@
|
||||||
"""Test fixtures for Algo unit tests"""
|
"""Test fixtures for Algo unit tests"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -6,7 +7,7 @@ import yaml
|
||||||
|
|
||||||
def load_test_variables():
|
def load_test_variables():
|
||||||
"""Load test variables from YAML fixture"""
|
"""Load test variables from YAML fixture"""
|
||||||
fixture_path = Path(__file__).parent / 'test_variables.yml'
|
fixture_path = Path(__file__).parent / "test_variables.yml"
|
||||||
with open(fixture_path) as f:
|
with open(fixture_path) as f:
|
||||||
return yaml.safe_load(f)
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
|
@ -2,49 +2,56 @@
|
||||||
"""
|
"""
|
||||||
Wrapper for Ansible's service module that always succeeds for known services
|
Wrapper for Ansible's service module that always succeeds for known services
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Parse module arguments
|
# Parse module arguments
|
||||||
args = json.loads(sys.stdin.read())
|
args = json.loads(sys.stdin.read())
|
||||||
module_args = args.get('ANSIBLE_MODULE_ARGS', {})
|
module_args = args.get("ANSIBLE_MODULE_ARGS", {})
|
||||||
|
|
||||||
service_name = module_args.get('name', '')
|
service_name = module_args.get("name", "")
|
||||||
state = module_args.get('state', 'started')
|
state = module_args.get("state", "started")
|
||||||
|
|
||||||
# Known services that should always succeed
|
# Known services that should always succeed
|
||||||
known_services = [
|
known_services = [
|
||||||
'netfilter-persistent', 'iptables', 'wg-quick@wg0', 'strongswan-starter',
|
"netfilter-persistent",
|
||||||
'ipsec', 'apparmor', 'unattended-upgrades', 'systemd-networkd',
|
"iptables",
|
||||||
'systemd-resolved', 'rsyslog', 'ipfw', 'cron'
|
"wg-quick@wg0",
|
||||||
|
"strongswan-starter",
|
||||||
|
"ipsec",
|
||||||
|
"apparmor",
|
||||||
|
"unattended-upgrades",
|
||||||
|
"systemd-networkd",
|
||||||
|
"systemd-resolved",
|
||||||
|
"rsyslog",
|
||||||
|
"ipfw",
|
||||||
|
"cron",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check if it's a known service
|
# Check if it's a known service
|
||||||
service_found = False
|
service_found = False
|
||||||
for svc in known_services:
|
for svc in known_services:
|
||||||
if service_name == svc or service_name.startswith(svc + '.'):
|
if service_name == svc or service_name.startswith(svc + "."):
|
||||||
service_found = True
|
service_found = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if service_found:
|
if service_found:
|
||||||
# Return success
|
# Return success
|
||||||
result = {
|
result = {
|
||||||
'changed': True if state in ['started', 'stopped', 'restarted', 'reloaded'] else False,
|
"changed": True if state in ["started", "stopped", "restarted", "reloaded"] else False,
|
||||||
'name': service_name,
|
"name": service_name,
|
||||||
'state': state,
|
"state": state,
|
||||||
'status': {
|
"status": {
|
||||||
'LoadState': 'loaded',
|
"LoadState": "loaded",
|
||||||
'ActiveState': 'active' if state != 'stopped' else 'inactive',
|
"ActiveState": "active" if state != "stopped" else "inactive",
|
||||||
'SubState': 'running' if state != 'stopped' else 'dead'
|
"SubState": "running" if state != "stopped" else "dead",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
print(json.dumps(result))
|
print(json.dumps(result))
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
# Service not found
|
# Service not found
|
||||||
error = {
|
error = {"failed": True, "msg": f"Could not find the requested service {service_name}: "}
|
||||||
'failed': True,
|
|
||||||
'msg': f'Could not find the requested service {service_name}: '
|
|
||||||
}
|
|
||||||
print(json.dumps(error))
|
print(json.dumps(error))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
|
@ -9,44 +9,44 @@ from ansible.module_utils.basic import AnsibleModule
|
||||||
def main():
|
def main():
|
||||||
module = AnsibleModule(
|
module = AnsibleModule(
|
||||||
argument_spec={
|
argument_spec={
|
||||||
'name': {'type': 'list', 'aliases': ['pkg', 'package']},
|
"name": {"type": "list", "aliases": ["pkg", "package"]},
|
||||||
'state': {'type': 'str', 'default': 'present', 'choices': ['present', 'absent', 'latest', 'build-dep', 'fixed']},
|
"state": {
|
||||||
'update_cache': {'type': 'bool', 'default': False},
|
"type": "str",
|
||||||
'cache_valid_time': {'type': 'int', 'default': 0},
|
"default": "present",
|
||||||
'install_recommends': {'type': 'bool'},
|
"choices": ["present", "absent", "latest", "build-dep", "fixed"],
|
||||||
'force': {'type': 'bool', 'default': False},
|
},
|
||||||
'allow_unauthenticated': {'type': 'bool', 'default': False},
|
"update_cache": {"type": "bool", "default": False},
|
||||||
'allow_downgrade': {'type': 'bool', 'default': False},
|
"cache_valid_time": {"type": "int", "default": 0},
|
||||||
'allow_change_held_packages': {'type': 'bool', 'default': False},
|
"install_recommends": {"type": "bool"},
|
||||||
'dpkg_options': {'type': 'str', 'default': 'force-confdef,force-confold'},
|
"force": {"type": "bool", "default": False},
|
||||||
'autoremove': {'type': 'bool', 'default': False},
|
"allow_unauthenticated": {"type": "bool", "default": False},
|
||||||
'purge': {'type': 'bool', 'default': False},
|
"allow_downgrade": {"type": "bool", "default": False},
|
||||||
'force_apt_get': {'type': 'bool', 'default': False},
|
"allow_change_held_packages": {"type": "bool", "default": False},
|
||||||
|
"dpkg_options": {"type": "str", "default": "force-confdef,force-confold"},
|
||||||
|
"autoremove": {"type": "bool", "default": False},
|
||||||
|
"purge": {"type": "bool", "default": False},
|
||||||
|
"force_apt_get": {"type": "bool", "default": False},
|
||||||
},
|
},
|
||||||
supports_check_mode=True
|
supports_check_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
name = module.params['name']
|
name = module.params["name"]
|
||||||
state = module.params['state']
|
state = module.params["state"]
|
||||||
update_cache = module.params['update_cache']
|
update_cache = module.params["update_cache"]
|
||||||
|
|
||||||
result = {
|
result = {"changed": False, "cache_updated": False, "cache_update_time": 0}
|
||||||
'changed': False,
|
|
||||||
'cache_updated': False,
|
|
||||||
'cache_update_time': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log the operation
|
# Log the operation
|
||||||
with open('/var/log/mock-apt-module.log', 'a') as f:
|
with open("/var/log/mock-apt-module.log", "a") as f:
|
||||||
f.write(f"apt module called: name={name}, state={state}, update_cache={update_cache}\n")
|
f.write(f"apt module called: name={name}, state={state}, update_cache={update_cache}\n")
|
||||||
|
|
||||||
# Handle cache update
|
# Handle cache update
|
||||||
if update_cache:
|
if update_cache:
|
||||||
# In Docker, apt-get update was already run in entrypoint
|
# In Docker, apt-get update was already run in entrypoint
|
||||||
# Just pretend it succeeded
|
# Just pretend it succeeded
|
||||||
result['cache_updated'] = True
|
result["cache_updated"] = True
|
||||||
result['cache_update_time'] = 1754231778 # Fixed timestamp
|
result["cache_update_time"] = 1754231778 # Fixed timestamp
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
|
|
||||||
# Handle package installation/removal
|
# Handle package installation/removal
|
||||||
if name:
|
if name:
|
||||||
|
@ -56,40 +56,41 @@ def main():
|
||||||
installed_packages = []
|
installed_packages = []
|
||||||
for pkg in packages:
|
for pkg in packages:
|
||||||
# Use dpkg to check if package is installed
|
# Use dpkg to check if package is installed
|
||||||
check_cmd = ['dpkg', '-s', pkg]
|
check_cmd = ["dpkg", "-s", pkg]
|
||||||
rc = subprocess.run(check_cmd, capture_output=True)
|
rc = subprocess.run(check_cmd, capture_output=True)
|
||||||
if rc.returncode == 0:
|
if rc.returncode == 0:
|
||||||
installed_packages.append(pkg)
|
installed_packages.append(pkg)
|
||||||
|
|
||||||
if state in ['present', 'latest']:
|
if state in ["present", "latest"]:
|
||||||
# Check if we need to install anything
|
# Check if we need to install anything
|
||||||
missing_packages = [p for p in packages if p not in installed_packages]
|
missing_packages = [p for p in packages if p not in installed_packages]
|
||||||
|
|
||||||
if missing_packages:
|
if missing_packages:
|
||||||
# Log what we would install
|
# Log what we would install
|
||||||
with open('/var/log/mock-apt-module.log', 'a') as f:
|
with open("/var/log/mock-apt-module.log", "a") as f:
|
||||||
f.write(f"Would install packages: {missing_packages}\n")
|
f.write(f"Would install packages: {missing_packages}\n")
|
||||||
|
|
||||||
# For our test purposes, these packages are pre-installed in Docker
|
# For our test purposes, these packages are pre-installed in Docker
|
||||||
# Just report success
|
# Just report success
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
result['stdout'] = f"Mock: Packages {missing_packages} are already available"
|
result["stdout"] = f"Mock: Packages {missing_packages} are already available"
|
||||||
result['stderr'] = ""
|
result["stderr"] = ""
|
||||||
else:
|
else:
|
||||||
result['stdout'] = "All packages are already installed"
|
result["stdout"] = "All packages are already installed"
|
||||||
|
|
||||||
elif state == 'absent':
|
elif state == "absent":
|
||||||
# Check if we need to remove anything
|
# Check if we need to remove anything
|
||||||
present_packages = [p for p in packages if p in installed_packages]
|
present_packages = [p for p in packages if p in installed_packages]
|
||||||
|
|
||||||
if present_packages:
|
if present_packages:
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
result['stdout'] = f"Mock: Would remove packages {present_packages}"
|
result["stdout"] = f"Mock: Would remove packages {present_packages}"
|
||||||
else:
|
else:
|
||||||
result['stdout'] = "No packages to remove"
|
result["stdout"] = "No packages to remove"
|
||||||
|
|
||||||
# Always report success for our testing
|
# Always report success for our testing
|
||||||
module.exit_json(**result)
|
module.exit_json(**result)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -9,79 +9,72 @@ from ansible.module_utils.basic import AnsibleModule
|
||||||
def main():
|
def main():
|
||||||
module = AnsibleModule(
|
module = AnsibleModule(
|
||||||
argument_spec={
|
argument_spec={
|
||||||
'_raw_params': {'type': 'str'},
|
"_raw_params": {"type": "str"},
|
||||||
'cmd': {'type': 'str'},
|
"cmd": {"type": "str"},
|
||||||
'creates': {'type': 'path'},
|
"creates": {"type": "path"},
|
||||||
'removes': {'type': 'path'},
|
"removes": {"type": "path"},
|
||||||
'chdir': {'type': 'path'},
|
"chdir": {"type": "path"},
|
||||||
'executable': {'type': 'path'},
|
"executable": {"type": "path"},
|
||||||
'warn': {'type': 'bool', 'default': False},
|
"warn": {"type": "bool", "default": False},
|
||||||
'stdin': {'type': 'str'},
|
"stdin": {"type": "str"},
|
||||||
'stdin_add_newline': {'type': 'bool', 'default': True},
|
"stdin_add_newline": {"type": "bool", "default": True},
|
||||||
'strip_empty_ends': {'type': 'bool', 'default': True},
|
"strip_empty_ends": {"type": "bool", "default": True},
|
||||||
'_uses_shell': {'type': 'bool', 'default': False},
|
"_uses_shell": {"type": "bool", "default": False},
|
||||||
},
|
},
|
||||||
supports_check_mode=True
|
supports_check_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the command
|
# Get the command
|
||||||
raw_params = module.params.get('_raw_params')
|
raw_params = module.params.get("_raw_params")
|
||||||
cmd = module.params.get('cmd') or raw_params
|
cmd = module.params.get("cmd") or raw_params
|
||||||
|
|
||||||
if not cmd:
|
if not cmd:
|
||||||
module.fail_json(msg="no command given")
|
module.fail_json(msg="no command given")
|
||||||
|
|
||||||
result = {
|
result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
|
||||||
'changed': False,
|
|
||||||
'cmd': cmd,
|
|
||||||
'rc': 0,
|
|
||||||
'stdout': '',
|
|
||||||
'stderr': '',
|
|
||||||
'stdout_lines': [],
|
|
||||||
'stderr_lines': []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log the operation
|
# Log the operation
|
||||||
with open('/var/log/mock-command-module.log', 'a') as f:
|
with open("/var/log/mock-command-module.log", "a") as f:
|
||||||
f.write(f"command module called: cmd={cmd}\n")
|
f.write(f"command module called: cmd={cmd}\n")
|
||||||
|
|
||||||
# Handle specific commands
|
# Handle specific commands
|
||||||
if 'apparmor_status' in cmd:
|
if "apparmor_status" in cmd:
|
||||||
# Pretend apparmor is not installed/active
|
# Pretend apparmor is not installed/active
|
||||||
result['rc'] = 127
|
result["rc"] = 127
|
||||||
result['stderr'] = "apparmor_status: command not found"
|
result["stderr"] = "apparmor_status: command not found"
|
||||||
result['msg'] = "[Errno 2] No such file or directory: b'apparmor_status'"
|
result["msg"] = "[Errno 2] No such file or directory: b'apparmor_status'"
|
||||||
module.fail_json(msg=result['msg'], **result)
|
module.fail_json(msg=result["msg"], **result)
|
||||||
elif 'netplan apply' in cmd:
|
elif "netplan apply" in cmd:
|
||||||
# Pretend netplan succeeded
|
# Pretend netplan succeeded
|
||||||
result['stdout'] = "Mock: netplan configuration applied"
|
result["stdout"] = "Mock: netplan configuration applied"
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
elif 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd:
|
elif "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
|
||||||
# Routing cache flush
|
# Routing cache flush
|
||||||
result['stdout'] = "1"
|
result["stdout"] = "1"
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
else:
|
else:
|
||||||
# For other commands, try to run them
|
# For other commands, try to run them
|
||||||
try:
|
try:
|
||||||
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get('chdir'))
|
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd=module.params.get("chdir"))
|
||||||
result['rc'] = proc.returncode
|
result["rc"] = proc.returncode
|
||||||
result['stdout'] = proc.stdout
|
result["stdout"] = proc.stdout
|
||||||
result['stderr'] = proc.stderr
|
result["stderr"] = proc.stderr
|
||||||
result['stdout_lines'] = proc.stdout.splitlines()
|
result["stdout_lines"] = proc.stdout.splitlines()
|
||||||
result['stderr_lines'] = proc.stderr.splitlines()
|
result["stderr_lines"] = proc.stderr.splitlines()
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result['rc'] = 1
|
result["rc"] = 1
|
||||||
result['stderr'] = str(e)
|
result["stderr"] = str(e)
|
||||||
result['msg'] = str(e)
|
result["msg"] = str(e)
|
||||||
module.fail_json(msg=result['msg'], **result)
|
module.fail_json(msg=result["msg"], **result)
|
||||||
|
|
||||||
if result['rc'] == 0:
|
if result["rc"] == 0:
|
||||||
module.exit_json(**result)
|
module.exit_json(**result)
|
||||||
else:
|
else:
|
||||||
if 'msg' not in result:
|
if "msg" not in result:
|
||||||
result['msg'] = f"Command failed with return code {result['rc']}"
|
result["msg"] = f"Command failed with return code {result['rc']}"
|
||||||
module.fail_json(msg=result['msg'], **result)
|
module.fail_json(msg=result["msg"], **result)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -9,73 +9,71 @@ from ansible.module_utils.basic import AnsibleModule
|
||||||
def main():
|
def main():
|
||||||
module = AnsibleModule(
|
module = AnsibleModule(
|
||||||
argument_spec={
|
argument_spec={
|
||||||
'_raw_params': {'type': 'str'},
|
"_raw_params": {"type": "str"},
|
||||||
'cmd': {'type': 'str'},
|
"cmd": {"type": "str"},
|
||||||
'creates': {'type': 'path'},
|
"creates": {"type": "path"},
|
||||||
'removes': {'type': 'path'},
|
"removes": {"type": "path"},
|
||||||
'chdir': {'type': 'path'},
|
"chdir": {"type": "path"},
|
||||||
'executable': {'type': 'path', 'default': '/bin/sh'},
|
"executable": {"type": "path", "default": "/bin/sh"},
|
||||||
'warn': {'type': 'bool', 'default': False},
|
"warn": {"type": "bool", "default": False},
|
||||||
'stdin': {'type': 'str'},
|
"stdin": {"type": "str"},
|
||||||
'stdin_add_newline': {'type': 'bool', 'default': True},
|
"stdin_add_newline": {"type": "bool", "default": True},
|
||||||
},
|
},
|
||||||
supports_check_mode=True
|
supports_check_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the command
|
# Get the command
|
||||||
raw_params = module.params.get('_raw_params')
|
raw_params = module.params.get("_raw_params")
|
||||||
cmd = module.params.get('cmd') or raw_params
|
cmd = module.params.get("cmd") or raw_params
|
||||||
|
|
||||||
if not cmd:
|
if not cmd:
|
||||||
module.fail_json(msg="no command given")
|
module.fail_json(msg="no command given")
|
||||||
|
|
||||||
result = {
|
result = {"changed": False, "cmd": cmd, "rc": 0, "stdout": "", "stderr": "", "stdout_lines": [], "stderr_lines": []}
|
||||||
'changed': False,
|
|
||||||
'cmd': cmd,
|
|
||||||
'rc': 0,
|
|
||||||
'stdout': '',
|
|
||||||
'stderr': '',
|
|
||||||
'stdout_lines': [],
|
|
||||||
'stderr_lines': []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log the operation
|
# Log the operation
|
||||||
with open('/var/log/mock-shell-module.log', 'a') as f:
|
with open("/var/log/mock-shell-module.log", "a") as f:
|
||||||
f.write(f"shell module called: cmd={cmd}\n")
|
f.write(f"shell module called: cmd={cmd}\n")
|
||||||
|
|
||||||
# Handle specific commands
|
# Handle specific commands
|
||||||
if 'echo 1 > /proc/sys/net/ipv4/route/flush' in cmd:
|
if "echo 1 > /proc/sys/net/ipv4/route/flush" in cmd:
|
||||||
# Routing cache flush - just pretend it worked
|
# Routing cache flush - just pretend it worked
|
||||||
result['stdout'] = ""
|
result["stdout"] = ""
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
elif 'ifconfig lo100' in cmd:
|
elif "ifconfig lo100" in cmd:
|
||||||
# BSD loopback commands - simulate success
|
# BSD loopback commands - simulate success
|
||||||
result['stdout'] = "0"
|
result["stdout"] = "0"
|
||||||
result['changed'] = True
|
result["changed"] = True
|
||||||
else:
|
else:
|
||||||
# For other commands, try to run them
|
# For other commands, try to run them
|
||||||
try:
|
try:
|
||||||
proc = subprocess.run(cmd, shell=True, capture_output=True, text=True,
|
proc = subprocess.run(
|
||||||
executable=module.params.get('executable'),
|
cmd,
|
||||||
cwd=module.params.get('chdir'))
|
shell=True,
|
||||||
result['rc'] = proc.returncode
|
capture_output=True,
|
||||||
result['stdout'] = proc.stdout
|
text=True,
|
||||||
result['stderr'] = proc.stderr
|
executable=module.params.get("executable"),
|
||||||
result['stdout_lines'] = proc.stdout.splitlines()
|
cwd=module.params.get("chdir"),
|
||||||
result['stderr_lines'] = proc.stderr.splitlines()
|
)
|
||||||
result['changed'] = True
|
result["rc"] = proc.returncode
|
||||||
|
result["stdout"] = proc.stdout
|
||||||
|
result["stderr"] = proc.stderr
|
||||||
|
result["stdout_lines"] = proc.stdout.splitlines()
|
||||||
|
result["stderr_lines"] = proc.stderr.splitlines()
|
||||||
|
result["changed"] = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result['rc'] = 1
|
result["rc"] = 1
|
||||||
result['stderr'] = str(e)
|
result["stderr"] = str(e)
|
||||||
result['msg'] = str(e)
|
result["msg"] = str(e)
|
||||||
module.fail_json(msg=result['msg'], **result)
|
module.fail_json(msg=result["msg"], **result)
|
||||||
|
|
||||||
if result['rc'] == 0:
|
if result["rc"] == 0:
|
||||||
module.exit_json(**result)
|
module.exit_json(**result)
|
||||||
else:
|
else:
|
||||||
if 'msg' not in result:
|
if "msg" not in result:
|
||||||
result['msg'] = f"Command failed with return code {result['rc']}"
|
result["msg"] = f"Command failed with return code {result['rc']}"
|
||||||
module.fail_json(msg=result['msg'], **result)
|
module.fail_json(msg=result["msg"], **result)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -24,6 +24,7 @@ import yaml
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
|
|
||||||
def create_expected_cloud_init():
|
def create_expected_cloud_init():
|
||||||
"""
|
"""
|
||||||
Create the expected cloud-init content that should be generated
|
Create the expected cloud-init content that should be generated
|
||||||
|
@ -74,6 +75,7 @@ runcmd:
|
||||||
- systemctl restart sshd.service
|
- systemctl restart sshd.service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TestCloudInitTemplate:
|
class TestCloudInitTemplate:
|
||||||
"""Test class for cloud-init template validation."""
|
"""Test class for cloud-init template validation."""
|
||||||
|
|
||||||
|
@ -98,10 +100,7 @@ class TestCloudInitTemplate:
|
||||||
|
|
||||||
parsed = self.test_yaml_validity()
|
parsed = self.test_yaml_validity()
|
||||||
|
|
||||||
required_sections = [
|
required_sections = ["package_update", "package_upgrade", "packages", "users", "write_files", "runcmd"]
|
||||||
'package_update', 'package_upgrade', 'packages',
|
|
||||||
'users', 'write_files', 'runcmd'
|
|
||||||
]
|
|
||||||
|
|
||||||
missing = [section for section in required_sections if section not in parsed]
|
missing = [section for section in required_sections if section not in parsed]
|
||||||
assert not missing, f"Missing required sections: {missing}"
|
assert not missing, f"Missing required sections: {missing}"
|
||||||
|
@ -114,35 +113,30 @@ class TestCloudInitTemplate:
|
||||||
|
|
||||||
parsed = self.test_yaml_validity()
|
parsed = self.test_yaml_validity()
|
||||||
|
|
||||||
write_files = parsed.get('write_files', [])
|
write_files = parsed.get("write_files", [])
|
||||||
assert write_files, "write_files section should be present"
|
assert write_files, "write_files section should be present"
|
||||||
|
|
||||||
# Find sshd_config file
|
# Find sshd_config file
|
||||||
sshd_config = None
|
sshd_config = None
|
||||||
for file_entry in write_files:
|
for file_entry in write_files:
|
||||||
if file_entry.get('path') == '/etc/ssh/sshd_config':
|
if file_entry.get("path") == "/etc/ssh/sshd_config":
|
||||||
sshd_config = file_entry
|
sshd_config = file_entry
|
||||||
break
|
break
|
||||||
|
|
||||||
assert sshd_config, "sshd_config file should be in write_files"
|
assert sshd_config, "sshd_config file should be in write_files"
|
||||||
|
|
||||||
content = sshd_config.get('content', '')
|
content = sshd_config.get("content", "")
|
||||||
assert content, "sshd_config should have content"
|
assert content, "sshd_config should have content"
|
||||||
|
|
||||||
# Check required SSH configurations
|
# Check required SSH configurations
|
||||||
required_configs = [
|
required_configs = ["Port 4160", "AllowGroups algo", "PermitRootLogin no", "PasswordAuthentication no"]
|
||||||
'Port 4160',
|
|
||||||
'AllowGroups algo',
|
|
||||||
'PermitRootLogin no',
|
|
||||||
'PasswordAuthentication no'
|
|
||||||
]
|
|
||||||
|
|
||||||
missing = [config for config in required_configs if config not in content]
|
missing = [config for config in required_configs if config not in content]
|
||||||
assert not missing, f"Missing SSH configurations: {missing}"
|
assert not missing, f"Missing SSH configurations: {missing}"
|
||||||
|
|
||||||
# Verify proper formatting - first line should be Port directive
|
# Verify proper formatting - first line should be Port directive
|
||||||
lines = content.strip().split('\n')
|
lines = content.strip().split("\n")
|
||||||
assert lines[0].strip() == 'Port 4160', f"First line should be 'Port 4160', got: {repr(lines[0])}"
|
assert lines[0].strip() == "Port 4160", f"First line should be 'Port 4160', got: {repr(lines[0])}"
|
||||||
|
|
||||||
print("✅ SSH configuration correct")
|
print("✅ SSH configuration correct")
|
||||||
|
|
||||||
|
@ -152,26 +146,26 @@ class TestCloudInitTemplate:
|
||||||
|
|
||||||
parsed = self.test_yaml_validity()
|
parsed = self.test_yaml_validity()
|
||||||
|
|
||||||
users = parsed.get('users', [])
|
users = parsed.get("users", [])
|
||||||
assert users, "users section should be present"
|
assert users, "users section should be present"
|
||||||
|
|
||||||
# Find algo user
|
# Find algo user
|
||||||
algo_user = None
|
algo_user = None
|
||||||
for user in users:
|
for user in users:
|
||||||
if isinstance(user, dict) and user.get('name') == 'algo':
|
if isinstance(user, dict) and user.get("name") == "algo":
|
||||||
algo_user = user
|
algo_user = user
|
||||||
break
|
break
|
||||||
|
|
||||||
assert algo_user, "algo user should be defined"
|
assert algo_user, "algo user should be defined"
|
||||||
|
|
||||||
# Check required user properties
|
# Check required user properties
|
||||||
required_props = ['sudo', 'groups', 'shell', 'ssh_authorized_keys']
|
required_props = ["sudo", "groups", "shell", "ssh_authorized_keys"]
|
||||||
missing = [prop for prop in required_props if prop not in algo_user]
|
missing = [prop for prop in required_props if prop not in algo_user]
|
||||||
assert not missing, f"algo user missing properties: {missing}"
|
assert not missing, f"algo user missing properties: {missing}"
|
||||||
|
|
||||||
# Verify sudo configuration
|
# Verify sudo configuration
|
||||||
sudo_config = algo_user.get('sudo', '')
|
sudo_config = algo_user.get("sudo", "")
|
||||||
assert 'NOPASSWD:ALL' in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
|
assert "NOPASSWD:ALL" in sudo_config, f"sudo config should allow passwordless access: {sudo_config}"
|
||||||
|
|
||||||
print("✅ User creation correct")
|
print("✅ User creation correct")
|
||||||
|
|
||||||
|
@ -181,13 +175,13 @@ class TestCloudInitTemplate:
|
||||||
|
|
||||||
parsed = self.test_yaml_validity()
|
parsed = self.test_yaml_validity()
|
||||||
|
|
||||||
runcmd = parsed.get('runcmd', [])
|
runcmd = parsed.get("runcmd", [])
|
||||||
assert runcmd, "runcmd section should be present"
|
assert runcmd, "runcmd section should be present"
|
||||||
|
|
||||||
# Check for SSH restart command
|
# Check for SSH restart command
|
||||||
ssh_restart_found = False
|
ssh_restart_found = False
|
||||||
for cmd in runcmd:
|
for cmd in runcmd:
|
||||||
if 'systemctl restart sshd' in str(cmd):
|
if "systemctl restart sshd" in str(cmd):
|
||||||
ssh_restart_found = True
|
ssh_restart_found = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -202,18 +196,18 @@ class TestCloudInitTemplate:
|
||||||
cloud_init_content = create_expected_cloud_init()
|
cloud_init_content = create_expected_cloud_init()
|
||||||
|
|
||||||
# Extract the sshd_config content lines
|
# Extract the sshd_config content lines
|
||||||
lines = cloud_init_content.split('\n')
|
lines = cloud_init_content.split("\n")
|
||||||
in_sshd_content = False
|
in_sshd_content = False
|
||||||
sshd_lines = []
|
sshd_lines = []
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if 'content: |' in line:
|
if "content: |" in line:
|
||||||
in_sshd_content = True
|
in_sshd_content = True
|
||||||
continue
|
continue
|
||||||
elif in_sshd_content:
|
elif in_sshd_content:
|
||||||
if line.strip() == '' and len(sshd_lines) > 0:
|
if line.strip() == "" and len(sshd_lines) > 0:
|
||||||
break
|
break
|
||||||
if line.startswith('runcmd:'):
|
if line.startswith("runcmd:"):
|
||||||
break
|
break
|
||||||
sshd_lines.append(line)
|
sshd_lines.append(line)
|
||||||
|
|
||||||
|
@ -225,11 +219,13 @@ class TestCloudInitTemplate:
|
||||||
|
|
||||||
for line in non_empty_lines:
|
for line in non_empty_lines:
|
||||||
# Each line should start with exactly 6 spaces
|
# Each line should start with exactly 6 spaces
|
||||||
assert line.startswith(' ') and not line.startswith(' '), \
|
assert line.startswith(" ") and not line.startswith(" "), (
|
||||||
f"Line should have exactly 6 spaces indentation: {repr(line)}"
|
f"Line should have exactly 6 spaces indentation: {repr(line)}"
|
||||||
|
)
|
||||||
|
|
||||||
print("✅ Indentation is consistent")
|
print("✅ Indentation is consistent")
|
||||||
|
|
||||||
|
|
||||||
def run_tests():
|
def run_tests():
|
||||||
"""Run all tests manually (for non-pytest usage)."""
|
"""Run all tests manually (for non-pytest usage)."""
|
||||||
print("🚀 Cloud-init template validation tests")
|
print("🚀 Cloud-init template validation tests")
|
||||||
|
@ -258,6 +254,7 @@ def run_tests():
|
||||||
print(f"❌ Unexpected error: {e}")
|
print(f"❌ Unexpected error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
success = run_tests()
|
success = run_tests()
|
||||||
sys.exit(0 if success else 1)
|
sys.exit(0 if success else 1)
|
||||||
|
|
|
@ -30,49 +30,49 @@ packages:
|
||||||
rendered = self.packages_template.render({})
|
rendered = self.packages_template.render({})
|
||||||
|
|
||||||
# Should only have sudo package
|
# Should only have sudo package
|
||||||
self.assertIn('- sudo', rendered)
|
self.assertIn("- sudo", rendered)
|
||||||
self.assertNotIn('- git', rendered)
|
self.assertNotIn("- git", rendered)
|
||||||
self.assertNotIn('- screen', rendered)
|
self.assertNotIn("- screen", rendered)
|
||||||
self.assertNotIn('- apparmor-utils', rendered)
|
self.assertNotIn("- apparmor-utils", rendered)
|
||||||
|
|
||||||
def test_preinstall_enabled(self):
|
def test_preinstall_enabled(self):
|
||||||
"""Test that package pre-installation works when enabled."""
|
"""Test that package pre-installation works when enabled."""
|
||||||
# Test with pre-installation enabled
|
# Test with pre-installation enabled
|
||||||
rendered = self.packages_template.render({'performance_preinstall_packages': True})
|
rendered = self.packages_template.render({"performance_preinstall_packages": True})
|
||||||
|
|
||||||
# Should have sudo and all universal packages
|
# Should have sudo and all universal packages
|
||||||
self.assertIn('- sudo', rendered)
|
self.assertIn("- sudo", rendered)
|
||||||
self.assertIn('- git', rendered)
|
self.assertIn("- git", rendered)
|
||||||
self.assertIn('- screen', rendered)
|
self.assertIn("- screen", rendered)
|
||||||
self.assertIn('- apparmor-utils', rendered)
|
self.assertIn("- apparmor-utils", rendered)
|
||||||
self.assertIn('- uuid-runtime', rendered)
|
self.assertIn("- uuid-runtime", rendered)
|
||||||
self.assertIn('- coreutils', rendered)
|
self.assertIn("- coreutils", rendered)
|
||||||
self.assertIn('- iptables-persistent', rendered)
|
self.assertIn("- iptables-persistent", rendered)
|
||||||
self.assertIn('- cgroup-tools', rendered)
|
self.assertIn("- cgroup-tools", rendered)
|
||||||
|
|
||||||
def test_preinstall_disabled_explicitly(self):
|
def test_preinstall_disabled_explicitly(self):
|
||||||
"""Test that package pre-installation is disabled when set to false."""
|
"""Test that package pre-installation is disabled when set to false."""
|
||||||
# Test with pre-installation explicitly disabled
|
# Test with pre-installation explicitly disabled
|
||||||
rendered = self.packages_template.render({'performance_preinstall_packages': False})
|
rendered = self.packages_template.render({"performance_preinstall_packages": False})
|
||||||
|
|
||||||
# Should only have sudo package
|
# Should only have sudo package
|
||||||
self.assertIn('- sudo', rendered)
|
self.assertIn("- sudo", rendered)
|
||||||
self.assertNotIn('- git', rendered)
|
self.assertNotIn("- git", rendered)
|
||||||
self.assertNotIn('- screen', rendered)
|
self.assertNotIn("- screen", rendered)
|
||||||
self.assertNotIn('- apparmor-utils', rendered)
|
self.assertNotIn("- apparmor-utils", rendered)
|
||||||
|
|
||||||
def test_package_count(self):
|
def test_package_count(self):
|
||||||
"""Test that the correct number of packages are included."""
|
"""Test that the correct number of packages are included."""
|
||||||
# Default: should only have sudo (1 package)
|
# Default: should only have sudo (1 package)
|
||||||
rendered_default = self.packages_template.render({})
|
rendered_default = self.packages_template.render({})
|
||||||
lines_default = [line.strip() for line in rendered_default.split('\n') if line.strip().startswith('- ')]
|
lines_default = [line.strip() for line in rendered_default.split("\n") if line.strip().startswith("- ")]
|
||||||
self.assertEqual(len(lines_default), 1)
|
self.assertEqual(len(lines_default), 1)
|
||||||
|
|
||||||
# Enabled: should have sudo + 7 universal packages (8 total)
|
# Enabled: should have sudo + 7 universal packages (8 total)
|
||||||
rendered_enabled = self.packages_template.render({'performance_preinstall_packages': True})
|
rendered_enabled = self.packages_template.render({"performance_preinstall_packages": True})
|
||||||
lines_enabled = [line.strip() for line in rendered_enabled.split('\n') if line.strip().startswith('- ')]
|
lines_enabled = [line.strip() for line in rendered_enabled.split("\n") if line.strip().startswith("- ")]
|
||||||
self.assertEqual(len(lines_enabled), 8)
|
self.assertEqual(len(lines_enabled), 8)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
"""
|
"""
|
||||||
Basic sanity tests for Algo VPN that don't require deployment
|
Basic sanity tests for Algo VPN that don't require deployment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
@ -44,11 +45,7 @@ def test_config_file_valid():
|
||||||
|
|
||||||
def test_ansible_syntax():
|
def test_ansible_syntax():
|
||||||
"""Check that main playbook has valid syntax"""
|
"""Check that main playbook has valid syntax"""
|
||||||
result = subprocess.run(
|
result = subprocess.run(["ansible-playbook", "main.yml", "--syntax-check"], capture_output=True, text=True)
|
||||||
["ansible-playbook", "main.yml", "--syntax-check"],
|
|
||||||
capture_output=True,
|
|
||||||
text=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.returncode == 0, f"Ansible syntax check failed:\n{result.stderr}"
|
assert result.returncode == 0, f"Ansible syntax check failed:\n{result.stderr}"
|
||||||
print("✓ Ansible playbook syntax is valid")
|
print("✓ Ansible playbook syntax is valid")
|
||||||
|
@ -60,11 +57,7 @@ def test_shellcheck():
|
||||||
|
|
||||||
for script in shell_scripts:
|
for script in shell_scripts:
|
||||||
if os.path.exists(script):
|
if os.path.exists(script):
|
||||||
result = subprocess.run(
|
result = subprocess.run(["shellcheck", script], capture_output=True, text=True)
|
||||||
["shellcheck", script],
|
|
||||||
capture_output=True,
|
|
||||||
text=True
|
|
||||||
)
|
|
||||||
assert result.returncode == 0, f"Shellcheck failed for {script}:\n{result.stdout}"
|
assert result.returncode == 0, f"Shellcheck failed for {script}:\n{result.stdout}"
|
||||||
print(f"✓ {script} passed shellcheck")
|
print(f"✓ {script} passed shellcheck")
|
||||||
|
|
||||||
|
@ -87,7 +80,7 @@ def test_cloud_init_header_format():
|
||||||
assert os.path.exists(cloud_init_file), f"{cloud_init_file} not found"
|
assert os.path.exists(cloud_init_file), f"{cloud_init_file} not found"
|
||||||
|
|
||||||
with open(cloud_init_file) as f:
|
with open(cloud_init_file) as f:
|
||||||
first_line = f.readline().rstrip('\n\r')
|
first_line = f.readline().rstrip("\n\r")
|
||||||
|
|
||||||
# The first line MUST be exactly "#cloud-config" (no space after #)
|
# The first line MUST be exactly "#cloud-config" (no space after #)
|
||||||
# This regression was introduced in PR #14775 and broke DigitalOcean deployments
|
# This regression was introduced in PR #14775 and broke DigitalOcean deployments
|
||||||
|
|
|
@ -4,31 +4,30 @@ Test cloud provider instance type configurations
|
||||||
Focused on validating that configured instance types are current/valid
|
Focused on validating that configured instance types are current/valid
|
||||||
Based on issues #14730 - Hetzner changed from cx11 to cx22
|
Based on issues #14730 - Hetzner changed from cx11 to cx22
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def test_hetzner_server_types():
|
def test_hetzner_server_types():
|
||||||
"""Test Hetzner server type configurations (issue #14730)"""
|
"""Test Hetzner server type configurations (issue #14730)"""
|
||||||
# Hetzner deprecated cx11 and cpx11 - smallest is now cx22
|
# Hetzner deprecated cx11 and cpx11 - smallest is now cx22
|
||||||
deprecated_types = ['cx11', 'cpx11']
|
deprecated_types = ["cx11", "cpx11"]
|
||||||
current_types = ['cx22', 'cpx22', 'cx32', 'cpx32', 'cx42', 'cpx42']
|
current_types = ["cx22", "cpx22", "cx32", "cpx32", "cx42", "cpx42"]
|
||||||
|
|
||||||
# Test that we're not using deprecated types in any configs
|
# Test that we're not using deprecated types in any configs
|
||||||
test_config = {
|
test_config = {
|
||||||
'cloud_providers': {
|
"cloud_providers": {
|
||||||
'hetzner': {
|
"hetzner": {
|
||||||
'size': 'cx22', # Should be cx22, not cx11
|
"size": "cx22", # Should be cx22, not cx11
|
||||||
'image': 'ubuntu-22.04',
|
"image": "ubuntu-22.04",
|
||||||
'location': 'hel1'
|
"location": "hel1",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hetzner = test_config['cloud_providers']['hetzner']
|
hetzner = test_config["cloud_providers"]["hetzner"]
|
||||||
assert hetzner['size'] not in deprecated_types, \
|
assert hetzner["size"] not in deprecated_types, f"Using deprecated Hetzner type: {hetzner['size']}"
|
||||||
f"Using deprecated Hetzner type: {hetzner['size']}"
|
assert hetzner["size"] in current_types, f"Unknown Hetzner type: {hetzner['size']}"
|
||||||
assert hetzner['size'] in current_types, \
|
|
||||||
f"Unknown Hetzner type: {hetzner['size']}"
|
|
||||||
|
|
||||||
print("✓ Hetzner server types test passed")
|
print("✓ Hetzner server types test passed")
|
||||||
|
|
||||||
|
@ -36,10 +35,10 @@ def test_hetzner_server_types():
|
||||||
def test_digitalocean_instance_types():
|
def test_digitalocean_instance_types():
|
||||||
"""Test DigitalOcean droplet size naming"""
|
"""Test DigitalOcean droplet size naming"""
|
||||||
# DigitalOcean uses format like s-1vcpu-1gb
|
# DigitalOcean uses format like s-1vcpu-1gb
|
||||||
valid_sizes = ['s-1vcpu-1gb', 's-2vcpu-2gb', 's-2vcpu-4gb', 's-4vcpu-8gb']
|
valid_sizes = ["s-1vcpu-1gb", "s-2vcpu-2gb", "s-2vcpu-4gb", "s-4vcpu-8gb"]
|
||||||
deprecated_sizes = ['512mb', '1gb', '2gb'] # Old naming scheme
|
deprecated_sizes = ["512mb", "1gb", "2gb"] # Old naming scheme
|
||||||
|
|
||||||
test_size = 's-2vcpu-2gb'
|
test_size = "s-2vcpu-2gb"
|
||||||
assert test_size in valid_sizes, f"Invalid DO size: {test_size}"
|
assert test_size in valid_sizes, f"Invalid DO size: {test_size}"
|
||||||
assert test_size not in deprecated_sizes, f"Using deprecated DO size: {test_size}"
|
assert test_size not in deprecated_sizes, f"Using deprecated DO size: {test_size}"
|
||||||
|
|
||||||
|
@ -49,10 +48,10 @@ def test_digitalocean_instance_types():
|
||||||
def test_aws_instance_types():
|
def test_aws_instance_types():
|
||||||
"""Test AWS EC2 instance type naming"""
|
"""Test AWS EC2 instance type naming"""
|
||||||
# Common valid instance types
|
# Common valid instance types
|
||||||
valid_types = ['t2.micro', 't3.micro', 't3.small', 't3.medium', 'm5.large']
|
valid_types = ["t2.micro", "t3.micro", "t3.small", "t3.medium", "m5.large"]
|
||||||
deprecated_types = ['t1.micro', 'm1.small'] # Very old types
|
deprecated_types = ["t1.micro", "m1.small"] # Very old types
|
||||||
|
|
||||||
test_type = 't3.micro'
|
test_type = "t3.micro"
|
||||||
assert test_type in valid_types, f"Unknown EC2 type: {test_type}"
|
assert test_type in valid_types, f"Unknown EC2 type: {test_type}"
|
||||||
assert test_type not in deprecated_types, f"Using deprecated EC2 type: {test_type}"
|
assert test_type not in deprecated_types, f"Using deprecated EC2 type: {test_type}"
|
||||||
|
|
||||||
|
@ -62,9 +61,10 @@ def test_aws_instance_types():
|
||||||
def test_vultr_instance_types():
|
def test_vultr_instance_types():
|
||||||
"""Test Vultr instance type naming"""
|
"""Test Vultr instance type naming"""
|
||||||
# Vultr uses format like vc2-1c-1gb
|
# Vultr uses format like vc2-1c-1gb
|
||||||
test_type = 'vc2-1c-1gb'
|
test_type = "vc2-1c-1gb"
|
||||||
assert any(test_type.startswith(prefix) for prefix in ['vc2-', 'vhf-', 'vhp-']), \
|
assert any(test_type.startswith(prefix) for prefix in ["vc2-", "vhf-", "vhp-"]), (
|
||||||
f"Invalid Vultr type format: {test_type}"
|
f"Invalid Vultr type format: {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
print("✓ Vultr instance types test passed")
|
print("✓ Vultr instance types test passed")
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
"""
|
"""
|
||||||
Test configuration file validation without deployment
|
Test configuration file validation without deployment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -28,14 +29,14 @@ Endpoint = 192.168.1.1:51820
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read_string(sample_config)
|
config.read_string(sample_config)
|
||||||
|
|
||||||
assert 'Interface' in config, "Missing [Interface] section"
|
assert "Interface" in config, "Missing [Interface] section"
|
||||||
assert 'Peer' in config, "Missing [Peer] section"
|
assert "Peer" in config, "Missing [Peer] section"
|
||||||
|
|
||||||
# Validate required fields
|
# Validate required fields
|
||||||
assert config['Interface'].get('PrivateKey'), "Missing PrivateKey"
|
assert config["Interface"].get("PrivateKey"), "Missing PrivateKey"
|
||||||
assert config['Interface'].get('Address'), "Missing Address"
|
assert config["Interface"].get("Address"), "Missing Address"
|
||||||
assert config['Peer'].get('PublicKey'), "Missing PublicKey"
|
assert config["Peer"].get("PublicKey"), "Missing PublicKey"
|
||||||
assert config['Peer'].get('AllowedIPs'), "Missing AllowedIPs"
|
assert config["Peer"].get("AllowedIPs"), "Missing AllowedIPs"
|
||||||
|
|
||||||
print("✓ WireGuard config format validation passed")
|
print("✓ WireGuard config format validation passed")
|
||||||
|
|
||||||
|
@ -43,7 +44,7 @@ Endpoint = 192.168.1.1:51820
|
||||||
def test_base64_key_format():
|
def test_base64_key_format():
|
||||||
"""Test that keys are in valid base64 format"""
|
"""Test that keys are in valid base64 format"""
|
||||||
# Base64 keys can have variable length, just check format
|
# Base64 keys can have variable length, just check format
|
||||||
key_pattern = re.compile(r'^[A-Za-z0-9+/]+=*$')
|
key_pattern = re.compile(r"^[A-Za-z0-9+/]+=*$")
|
||||||
|
|
||||||
test_keys = [
|
test_keys = [
|
||||||
"aGVsbG8gd29ybGQgdGhpcyBpcyBub3QgYSByZWFsIGtleQo=",
|
"aGVsbG8gd29ybGQgdGhpcyBpcyBub3QgYSByZWFsIGtleQo=",
|
||||||
|
@ -58,8 +59,8 @@ def test_base64_key_format():
|
||||||
|
|
||||||
def test_ip_address_format():
|
def test_ip_address_format():
|
||||||
"""Test IP address and CIDR notation validation"""
|
"""Test IP address and CIDR notation validation"""
|
||||||
ip_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$')
|
ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$")
|
||||||
endpoint_pattern = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$')
|
endpoint_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5}$")
|
||||||
|
|
||||||
# Test CIDR notation
|
# Test CIDR notation
|
||||||
assert ip_pattern.match("10.19.49.2/32"), "Invalid CIDR notation"
|
assert ip_pattern.match("10.19.49.2/32"), "Invalid CIDR notation"
|
||||||
|
@ -74,11 +75,7 @@ def test_ip_address_format():
|
||||||
def test_mobile_config_xml():
|
def test_mobile_config_xml():
|
||||||
"""Test that mobile config files would be valid XML"""
|
"""Test that mobile config files would be valid XML"""
|
||||||
# First check if xmllint is available
|
# First check if xmllint is available
|
||||||
xmllint_check = subprocess.run(
|
xmllint_check = subprocess.run(["which", "xmllint"], capture_output=True, text=True)
|
||||||
['which', 'xmllint'],
|
|
||||||
capture_output=True,
|
|
||||||
text=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if xmllint_check.returncode != 0:
|
if xmllint_check.returncode != 0:
|
||||||
print("⚠ Skipping XML validation test (xmllint not installed)")
|
print("⚠ Skipping XML validation test (xmllint not installed)")
|
||||||
|
@ -99,17 +96,13 @@ def test_mobile_config_xml():
|
||||||
</dict>
|
</dict>
|
||||||
</plist>"""
|
</plist>"""
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mobileconfig', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".mobileconfig", delete=False) as f:
|
||||||
f.write(sample_mobileconfig)
|
f.write(sample_mobileconfig)
|
||||||
temp_file = f.name
|
temp_file = f.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use xmllint to validate
|
# Use xmllint to validate
|
||||||
result = subprocess.run(
|
result = subprocess.run(["xmllint", "--noout", temp_file], capture_output=True, text=True)
|
||||||
['xmllint', '--noout', temp_file],
|
|
||||||
capture_output=True,
|
|
||||||
text=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.returncode == 0, f"XML validation failed: {result.stderr}"
|
assert result.returncode == 0, f"XML validation failed: {result.stderr}"
|
||||||
print("✓ Mobile config XML validation passed")
|
print("✓ Mobile config XML validation passed")
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Simplified Docker-based localhost deployment tests
|
Simplified Docker-based localhost deployment tests
|
||||||
Verifies services can start and config files exist in expected locations
|
Verifies services can start and config files exist in expected locations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
@ -11,7 +12,7 @@ import sys
|
||||||
def check_docker_available():
|
def check_docker_available():
|
||||||
"""Check if Docker is available"""
|
"""Check if Docker is available"""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(['docker', '--version'], capture_output=True, text=True)
|
result = subprocess.run(["docker", "--version"], capture_output=True, text=True)
|
||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return False
|
return False
|
||||||
|
@ -31,8 +32,8 @@ AllowedIPs = 10.19.49.2/32,fd9d:bc11:4020::2/128
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Just validate the format
|
# Just validate the format
|
||||||
required_sections = ['[Interface]', '[Peer]']
|
required_sections = ["[Interface]", "[Peer]"]
|
||||||
required_fields = ['PrivateKey', 'Address', 'PublicKey', 'AllowedIPs']
|
required_fields = ["PrivateKey", "Address", "PublicKey", "AllowedIPs"]
|
||||||
|
|
||||||
for section in required_sections:
|
for section in required_sections:
|
||||||
if section not in config:
|
if section not in config:
|
||||||
|
@ -68,15 +69,15 @@ conn ikev2-pubkey
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Validate format
|
# Validate format
|
||||||
if 'config setup' not in config:
|
if "config setup" not in config:
|
||||||
print("✗ Missing 'config setup' section")
|
print("✗ Missing 'config setup' section")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if 'conn %default' not in config:
|
if "conn %default" not in config:
|
||||||
print("✗ Missing 'conn %default' section")
|
print("✗ Missing 'conn %default' section")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if 'keyexchange=ikev2' not in config:
|
if "keyexchange=ikev2" not in config:
|
||||||
print("✗ Missing IKEv2 configuration")
|
print("✗ Missing IKEv2 configuration")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -87,19 +88,19 @@ conn ikev2-pubkey
|
||||||
def test_docker_algo_image():
|
def test_docker_algo_image():
|
||||||
"""Test that the Algo Docker image can be built"""
|
"""Test that the Algo Docker image can be built"""
|
||||||
# Check if Dockerfile exists
|
# Check if Dockerfile exists
|
||||||
if not os.path.exists('Dockerfile'):
|
if not os.path.exists("Dockerfile"):
|
||||||
print("✗ Dockerfile not found")
|
print("✗ Dockerfile not found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Read Dockerfile and validate basic structure
|
# Read Dockerfile and validate basic structure
|
||||||
with open('Dockerfile') as f:
|
with open("Dockerfile") as f:
|
||||||
dockerfile_content = f.read()
|
dockerfile_content = f.read()
|
||||||
|
|
||||||
required_elements = [
|
required_elements = [
|
||||||
'FROM', # Base image
|
"FROM", # Base image
|
||||||
'RUN', # Build commands
|
"RUN", # Build commands
|
||||||
'COPY', # Copy Algo files
|
"COPY", # Copy Algo files
|
||||||
'python' # Python dependency
|
"python", # Python dependency
|
||||||
]
|
]
|
||||||
|
|
||||||
missing = []
|
missing = []
|
||||||
|
@ -115,16 +116,14 @@ def test_docker_algo_image():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_localhost_deployment_requirements():
|
def test_localhost_deployment_requirements():
|
||||||
"""Test that localhost deployment requirements are met"""
|
"""Test that localhost deployment requirements are met"""
|
||||||
requirements = {
|
requirements = {
|
||||||
'Python 3.8+': sys.version_info >= (3, 8),
|
"Python 3.8+": sys.version_info >= (3, 8),
|
||||||
'Ansible installed': subprocess.run(['which', 'ansible'], capture_output=True).returncode == 0,
|
"Ansible installed": subprocess.run(["which", "ansible"], capture_output=True).returncode == 0,
|
||||||
'Main playbook exists': os.path.exists('main.yml'),
|
"Main playbook exists": os.path.exists("main.yml"),
|
||||||
'Project config exists': os.path.exists('pyproject.toml'),
|
"Project config exists": os.path.exists("pyproject.toml"),
|
||||||
'Config template exists': os.path.exists('config.cfg.example') or os.path.exists('config.cfg'),
|
"Config template exists": os.path.exists("config.cfg.example") or os.path.exists("config.cfg"),
|
||||||
}
|
}
|
||||||
|
|
||||||
all_met = True
|
all_met = True
|
||||||
|
@ -138,10 +137,6 @@ def test_localhost_deployment_requirements():
|
||||||
return all_met
|
return all_met
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Running Docker localhost deployment tests...")
|
print("Running Docker localhost deployment tests...")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Test that generated configuration files have valid syntax
|
Test that generated configuration files have valid syntax
|
||||||
This validates WireGuard, StrongSwan, SSH, and other configs
|
This validates WireGuard, StrongSwan, SSH, and other configs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
@ -11,7 +12,7 @@ import sys
|
||||||
def check_command_available(cmd):
|
def check_command_available(cmd):
|
||||||
"""Check if a command is available on the system"""
|
"""Check if a command is available on the system"""
|
||||||
try:
|
try:
|
||||||
subprocess.run([cmd, '--version'], capture_output=True, check=False)
|
subprocess.run([cmd, "--version"], capture_output=True, check=False)
|
||||||
return True
|
return True
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return False
|
return False
|
||||||
|
@ -37,51 +38,50 @@ PersistentKeepalive = 25
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# Check for required sections
|
# Check for required sections
|
||||||
if '[Interface]' not in sample_config:
|
if "[Interface]" not in sample_config:
|
||||||
errors.append("Missing [Interface] section")
|
errors.append("Missing [Interface] section")
|
||||||
if '[Peer]' not in sample_config:
|
if "[Peer]" not in sample_config:
|
||||||
errors.append("Missing [Peer] section")
|
errors.append("Missing [Peer] section")
|
||||||
|
|
||||||
# Validate Interface section
|
# Validate Interface section
|
||||||
interface_match = re.search(r'\[Interface\](.*?)\[Peer\]', sample_config, re.DOTALL)
|
interface_match = re.search(r"\[Interface\](.*?)\[Peer\]", sample_config, re.DOTALL)
|
||||||
if interface_match:
|
if interface_match:
|
||||||
interface_section = interface_match.group(1)
|
interface_section = interface_match.group(1)
|
||||||
|
|
||||||
# Check required fields
|
# Check required fields
|
||||||
if not re.search(r'Address\s*=', interface_section):
|
if not re.search(r"Address\s*=", interface_section):
|
||||||
errors.append("Missing Address in Interface section")
|
errors.append("Missing Address in Interface section")
|
||||||
if not re.search(r'PrivateKey\s*=', interface_section):
|
if not re.search(r"PrivateKey\s*=", interface_section):
|
||||||
errors.append("Missing PrivateKey in Interface section")
|
errors.append("Missing PrivateKey in Interface section")
|
||||||
|
|
||||||
# Validate IP addresses
|
# Validate IP addresses
|
||||||
address_match = re.search(r'Address\s*=\s*([^\n]+)', interface_section)
|
address_match = re.search(r"Address\s*=\s*([^\n]+)", interface_section)
|
||||||
if address_match:
|
if address_match:
|
||||||
addresses = address_match.group(1).split(',')
|
addresses = address_match.group(1).split(",")
|
||||||
for addr in addresses:
|
for addr in addresses:
|
||||||
addr = addr.strip()
|
addr = addr.strip()
|
||||||
# Basic IP validation
|
# Basic IP validation
|
||||||
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', addr) and \
|
if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", addr) and not re.match(r"^[0-9a-fA-F:]+/\d+$", addr):
|
||||||
not re.match(r'^[0-9a-fA-F:]+/\d+$', addr):
|
|
||||||
errors.append(f"Invalid IP address format: {addr}")
|
errors.append(f"Invalid IP address format: {addr}")
|
||||||
|
|
||||||
# Validate Peer section
|
# Validate Peer section
|
||||||
peer_match = re.search(r'\[Peer\](.*)', sample_config, re.DOTALL)
|
peer_match = re.search(r"\[Peer\](.*)", sample_config, re.DOTALL)
|
||||||
if peer_match:
|
if peer_match:
|
||||||
peer_section = peer_match.group(1)
|
peer_section = peer_match.group(1)
|
||||||
|
|
||||||
# Check required fields
|
# Check required fields
|
||||||
if not re.search(r'PublicKey\s*=', peer_section):
|
if not re.search(r"PublicKey\s*=", peer_section):
|
||||||
errors.append("Missing PublicKey in Peer section")
|
errors.append("Missing PublicKey in Peer section")
|
||||||
if not re.search(r'AllowedIPs\s*=', peer_section):
|
if not re.search(r"AllowedIPs\s*=", peer_section):
|
||||||
errors.append("Missing AllowedIPs in Peer section")
|
errors.append("Missing AllowedIPs in Peer section")
|
||||||
if not re.search(r'Endpoint\s*=', peer_section):
|
if not re.search(r"Endpoint\s*=", peer_section):
|
||||||
errors.append("Missing Endpoint in Peer section")
|
errors.append("Missing Endpoint in Peer section")
|
||||||
|
|
||||||
# Validate endpoint format
|
# Validate endpoint format
|
||||||
endpoint_match = re.search(r'Endpoint\s*=\s*([^\n]+)', peer_section)
|
endpoint_match = re.search(r"Endpoint\s*=\s*([^\n]+)", peer_section)
|
||||||
if endpoint_match:
|
if endpoint_match:
|
||||||
endpoint = endpoint_match.group(1).strip()
|
endpoint = endpoint_match.group(1).strip()
|
||||||
if not re.match(r'^[\d\.\:]+:\d+$', endpoint):
|
if not re.match(r"^[\d\.\:]+:\d+$", endpoint):
|
||||||
errors.append(f"Invalid Endpoint format: {endpoint}")
|
errors.append(f"Invalid Endpoint format: {endpoint}")
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
|
@ -132,33 +132,32 @@ conn ikev2-pubkey
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# Check for required sections
|
# Check for required sections
|
||||||
if 'config setup' not in sample_config:
|
if "config setup" not in sample_config:
|
||||||
errors.append("Missing 'config setup' section")
|
errors.append("Missing 'config setup' section")
|
||||||
if 'conn %default' not in sample_config:
|
if "conn %default" not in sample_config:
|
||||||
errors.append("Missing 'conn %default' section")
|
errors.append("Missing 'conn %default' section")
|
||||||
|
|
||||||
# Validate connection settings
|
# Validate connection settings
|
||||||
conn_pattern = re.compile(r'conn\s+(\S+)')
|
conn_pattern = re.compile(r"conn\s+(\S+)")
|
||||||
connections = conn_pattern.findall(sample_config)
|
connections = conn_pattern.findall(sample_config)
|
||||||
|
|
||||||
if len(connections) < 2: # Should have at least %default and one other
|
if len(connections) < 2: # Should have at least %default and one other
|
||||||
errors.append("Not enough connection definitions")
|
errors.append("Not enough connection definitions")
|
||||||
|
|
||||||
# Check for required parameters in connections
|
# Check for required parameters in connections
|
||||||
required_params = ['keyexchange', 'left', 'right']
|
required_params = ["keyexchange", "left", "right"]
|
||||||
for param in required_params:
|
for param in required_params:
|
||||||
if f'{param}=' not in sample_config:
|
if f"{param}=" not in sample_config:
|
||||||
errors.append(f"Missing required parameter: {param}")
|
errors.append(f"Missing required parameter: {param}")
|
||||||
|
|
||||||
# Validate IP subnet formats
|
# Validate IP subnet formats
|
||||||
subnet_pattern = re.compile(r'(left|right)subnet\s*=\s*([^\n]+)')
|
subnet_pattern = re.compile(r"(left|right)subnet\s*=\s*([^\n]+)")
|
||||||
for match in subnet_pattern.finditer(sample_config):
|
for match in subnet_pattern.finditer(sample_config):
|
||||||
subnets = match.group(2).split(',')
|
subnets = match.group(2).split(",")
|
||||||
for subnet in subnets:
|
for subnet in subnets:
|
||||||
subnet = subnet.strip()
|
subnet = subnet.strip()
|
||||||
if subnet != '0.0.0.0/0' and subnet != '::/0':
|
if subnet != "0.0.0.0/0" and subnet != "::/0":
|
||||||
if not re.match(r'^\d+\.\d+\.\d+\.\d+/\d+$', subnet) and \
|
if not re.match(r"^\d+\.\d+\.\d+\.\d+/\d+$", subnet) and not re.match(r"^[0-9a-fA-F:]+/\d+$", subnet):
|
||||||
not re.match(r'^[0-9a-fA-F:]+/\d+$', subnet):
|
|
||||||
errors.append(f"Invalid subnet format: {subnet}")
|
errors.append(f"Invalid subnet format: {subnet}")
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
|
@ -188,21 +187,21 @@ def test_ssh_config_syntax():
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# Parse SSH config format
|
# Parse SSH config format
|
||||||
lines = sample_config.strip().split('\n')
|
lines = sample_config.strip().split("\n")
|
||||||
current_host = None
|
current_host = None
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line or line.startswith('#'):
|
if not line or line.startswith("#"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if line.startswith('Host '):
|
if line.startswith("Host "):
|
||||||
current_host = line.split()[1]
|
current_host = line.split()[1]
|
||||||
elif current_host and ' ' in line:
|
elif current_host and " " in line:
|
||||||
key, value = line.split(None, 1)
|
key, value = line.split(None, 1)
|
||||||
|
|
||||||
# Validate common SSH options
|
# Validate common SSH options
|
||||||
if key == 'Port':
|
if key == "Port":
|
||||||
try:
|
try:
|
||||||
port = int(value)
|
port = int(value)
|
||||||
if not 1 <= port <= 65535:
|
if not 1 <= port <= 65535:
|
||||||
|
@ -210,7 +209,7 @@ def test_ssh_config_syntax():
|
||||||
except ValueError:
|
except ValueError:
|
||||||
errors.append(f"Port must be a number: {value}")
|
errors.append(f"Port must be a number: {value}")
|
||||||
|
|
||||||
elif key == 'LocalForward':
|
elif key == "LocalForward":
|
||||||
# Format: LocalForward [bind_address:]port host:hostport
|
# Format: LocalForward [bind_address:]port host:hostport
|
||||||
parts = value.split()
|
parts = value.split()
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
|
@ -256,35 +255,35 @@ COMMIT
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# Check table definitions
|
# Check table definitions
|
||||||
tables = re.findall(r'\*(\w+)', sample_rules)
|
tables = re.findall(r"\*(\w+)", sample_rules)
|
||||||
if 'filter' not in tables:
|
if "filter" not in tables:
|
||||||
errors.append("Missing *filter table")
|
errors.append("Missing *filter table")
|
||||||
if 'nat' not in tables:
|
if "nat" not in tables:
|
||||||
errors.append("Missing *nat table")
|
errors.append("Missing *nat table")
|
||||||
|
|
||||||
# Check for COMMIT statements
|
# Check for COMMIT statements
|
||||||
commit_count = sample_rules.count('COMMIT')
|
commit_count = sample_rules.count("COMMIT")
|
||||||
if commit_count != len(tables):
|
if commit_count != len(tables):
|
||||||
errors.append(f"Number of COMMIT statements ({commit_count}) doesn't match tables ({len(tables)})")
|
errors.append(f"Number of COMMIT statements ({commit_count}) doesn't match tables ({len(tables)})")
|
||||||
|
|
||||||
# Validate chain policies
|
# Validate chain policies
|
||||||
chain_pattern = re.compile(r'^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]', re.MULTILINE)
|
chain_pattern = re.compile(r"^:(\w+)\s+(ACCEPT|DROP|REJECT)\s+\[\d+:\d+\]", re.MULTILINE)
|
||||||
chains = chain_pattern.findall(sample_rules)
|
chains = chain_pattern.findall(sample_rules)
|
||||||
|
|
||||||
required_chains = [('INPUT', 'DROP'), ('FORWARD', 'DROP'), ('OUTPUT', 'ACCEPT')]
|
required_chains = [("INPUT", "DROP"), ("FORWARD", "DROP"), ("OUTPUT", "ACCEPT")]
|
||||||
for chain, _policy in required_chains:
|
for chain, _policy in required_chains:
|
||||||
if not any(c[0] == chain for c in chains):
|
if not any(c[0] == chain for c in chains):
|
||||||
errors.append(f"Missing required chain: {chain}")
|
errors.append(f"Missing required chain: {chain}")
|
||||||
|
|
||||||
# Validate rule syntax
|
# Validate rule syntax
|
||||||
rule_pattern = re.compile(r'^-[AI]\s+(\w+)', re.MULTILINE)
|
rule_pattern = re.compile(r"^-[AI]\s+(\w+)", re.MULTILINE)
|
||||||
rules = rule_pattern.findall(sample_rules)
|
rules = rule_pattern.findall(sample_rules)
|
||||||
|
|
||||||
if len(rules) < 5:
|
if len(rules) < 5:
|
||||||
errors.append("Insufficient firewall rules")
|
errors.append("Insufficient firewall rules")
|
||||||
|
|
||||||
# Check for essential security rules
|
# Check for essential security rules
|
||||||
if '-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT' not in sample_rules:
|
if "-A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT" not in sample_rules:
|
||||||
errors.append("Missing stateful connection tracking rule")
|
errors.append("Missing stateful connection tracking rule")
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
|
@ -320,27 +319,26 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# Parse config
|
# Parse config
|
||||||
for line in sample_config.strip().split('\n'):
|
for line in sample_config.strip().split("\n"):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line or line.startswith('#'):
|
if not line or line.startswith("#"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Most dnsmasq options are key=value or just key
|
# Most dnsmasq options are key=value or just key
|
||||||
if '=' in line:
|
if "=" in line:
|
||||||
key, value = line.split('=', 1)
|
key, value = line.split("=", 1)
|
||||||
|
|
||||||
# Validate specific options
|
# Validate specific options
|
||||||
if key == 'interface':
|
if key == "interface":
|
||||||
if not re.match(r'^[a-zA-Z0-9\-_]+$', value):
|
if not re.match(r"^[a-zA-Z0-9\-_]+$", value):
|
||||||
errors.append(f"Invalid interface name: {value}")
|
errors.append(f"Invalid interface name: {value}")
|
||||||
|
|
||||||
elif key == 'server':
|
elif key == "server":
|
||||||
# Basic IP validation
|
# Basic IP validation
|
||||||
if not re.match(r'^\d+\.\d+\.\d+\.\d+$', value) and \
|
if not re.match(r"^\d+\.\d+\.\d+\.\d+$", value) and not re.match(r"^[0-9a-fA-F:]+$", value):
|
||||||
not re.match(r'^[0-9a-fA-F:]+$', value):
|
|
||||||
errors.append(f"Invalid DNS server IP: {value}")
|
errors.append(f"Invalid DNS server IP: {value}")
|
||||||
|
|
||||||
elif key == 'cache-size':
|
elif key == "cache-size":
|
||||||
try:
|
try:
|
||||||
size = int(value)
|
size = int(value)
|
||||||
if size < 0:
|
if size < 0:
|
||||||
|
@ -349,9 +347,9 @@ addn-hosts=/var/lib/algo/dns/adblock.hosts
|
||||||
errors.append(f"Cache size must be a number: {value}")
|
errors.append(f"Cache size must be a number: {value}")
|
||||||
|
|
||||||
# Check for required options
|
# Check for required options
|
||||||
required = ['interface', 'server']
|
required = ["interface", "server"]
|
||||||
for req in required:
|
for req in required:
|
||||||
if f'{req}=' not in sample_config:
|
if f"{req}=" not in sample_config:
|
||||||
errors.append(f"Missing required option: {req}")
|
errors.append(f"Missing required option: {req}")
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
|
|
|
@ -14,203 +14,203 @@ from jinja2 import Environment, FileSystemLoader
|
||||||
|
|
||||||
def load_template(template_name):
|
def load_template(template_name):
|
||||||
"""Load a Jinja2 template from the roles/common/templates directory."""
|
"""Load a Jinja2 template from the roles/common/templates directory."""
|
||||||
template_dir = Path(__file__).parent.parent.parent / 'roles' / 'common' / 'templates'
|
template_dir = Path(__file__).parent.parent.parent / "roles" / "common" / "templates"
|
||||||
env = Environment(loader=FileSystemLoader(str(template_dir)))
|
env = Environment(loader=FileSystemLoader(str(template_dir)))
|
||||||
return env.get_template(template_name)
|
return env.get_template(template_name)
|
||||||
|
|
||||||
|
|
||||||
def test_wireguard_nat_rules_ipv4():
|
def test_wireguard_nat_rules_ipv4():
|
||||||
"""Test that WireGuard traffic gets proper NAT rules without policy matching."""
|
"""Test that WireGuard traffic gets proper NAT rules without policy matching."""
|
||||||
template = load_template('rules.v4.j2')
|
template = load_template("rules.v4.j2")
|
||||||
|
|
||||||
# Test with WireGuard enabled
|
# Test with WireGuard enabled
|
||||||
result = template.render(
|
result = template.render(
|
||||||
ipsec_enabled=False,
|
ipsec_enabled=False,
|
||||||
wireguard_enabled=True,
|
wireguard_enabled=True,
|
||||||
wireguard_network_ipv4='10.49.0.0/16',
|
wireguard_network_ipv4="10.49.0.0/16",
|
||||||
wireguard_port=51820,
|
wireguard_port=51820,
|
||||||
wireguard_port_avoid=53,
|
wireguard_port_avoid=53,
|
||||||
wireguard_port_actual=51820,
|
wireguard_port_actual=51820,
|
||||||
ansible_default_ipv4={'interface': 'eth0'},
|
ansible_default_ipv4={"interface": "eth0"},
|
||||||
snat_aipv4=None,
|
snat_aipv4=None,
|
||||||
BetweenClients_DROP=True,
|
BetweenClients_DROP=True,
|
||||||
block_smb=True,
|
block_smb=True,
|
||||||
block_netbios=True,
|
block_netbios=True,
|
||||||
local_service_ip='10.49.0.1',
|
local_service_ip="10.49.0.1",
|
||||||
ansible_ssh_port=22,
|
ansible_ssh_port=22,
|
||||||
reduce_mtu=0
|
reduce_mtu=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify NAT rule exists with output interface and without policy matching
|
# Verify NAT rule exists with output interface and without policy matching
|
||||||
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
|
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||||
# Verify no policy matching in WireGuard NAT rules
|
# Verify no policy matching in WireGuard NAT rules
|
||||||
assert '-A POSTROUTING -s 10.49.0.0/16 -m policy' not in result
|
assert "-A POSTROUTING -s 10.49.0.0/16 -m policy" not in result
|
||||||
|
|
||||||
|
|
||||||
def test_ipsec_nat_rules_ipv4():
|
def test_ipsec_nat_rules_ipv4():
|
||||||
"""Test that IPsec traffic gets proper NAT rules without policy matching."""
|
"""Test that IPsec traffic gets proper NAT rules without policy matching."""
|
||||||
template = load_template('rules.v4.j2')
|
template = load_template("rules.v4.j2")
|
||||||
|
|
||||||
# Test with IPsec enabled
|
# Test with IPsec enabled
|
||||||
result = template.render(
|
result = template.render(
|
||||||
ipsec_enabled=True,
|
ipsec_enabled=True,
|
||||||
wireguard_enabled=False,
|
wireguard_enabled=False,
|
||||||
strongswan_network='10.48.0.0/16',
|
strongswan_network="10.48.0.0/16",
|
||||||
strongswan_network_ipv6='2001:db8::/48',
|
strongswan_network_ipv6="2001:db8::/48",
|
||||||
ansible_default_ipv4={'interface': 'eth0'},
|
ansible_default_ipv4={"interface": "eth0"},
|
||||||
snat_aipv4=None,
|
snat_aipv4=None,
|
||||||
BetweenClients_DROP=True,
|
BetweenClients_DROP=True,
|
||||||
block_smb=True,
|
block_smb=True,
|
||||||
block_netbios=True,
|
block_netbios=True,
|
||||||
local_service_ip='10.48.0.1',
|
local_service_ip="10.48.0.1",
|
||||||
ansible_ssh_port=22,
|
ansible_ssh_port=22,
|
||||||
reduce_mtu=0
|
reduce_mtu=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify NAT rule exists with output interface and without policy matching
|
# Verify NAT rule exists with output interface and without policy matching
|
||||||
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
|
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||||
# Verify no policy matching in IPsec NAT rules (this was the bug)
|
# Verify no policy matching in IPsec NAT rules (this was the bug)
|
||||||
assert '-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none' not in result
|
assert "-A POSTROUTING -s 10.48.0.0/16 -m policy --pol none" not in result
|
||||||
|
|
||||||
|
|
||||||
def test_both_vpns_nat_rules_ipv4():
|
def test_both_vpns_nat_rules_ipv4():
|
||||||
"""Test NAT rules when both VPN types are enabled."""
|
"""Test NAT rules when both VPN types are enabled."""
|
||||||
template = load_template('rules.v4.j2')
|
template = load_template("rules.v4.j2")
|
||||||
|
|
||||||
result = template.render(
|
result = template.render(
|
||||||
ipsec_enabled=True,
|
ipsec_enabled=True,
|
||||||
wireguard_enabled=True,
|
wireguard_enabled=True,
|
||||||
strongswan_network='10.48.0.0/16',
|
strongswan_network="10.48.0.0/16",
|
||||||
wireguard_network_ipv4='10.49.0.0/16',
|
wireguard_network_ipv4="10.49.0.0/16",
|
||||||
strongswan_network_ipv6='2001:db8::/48',
|
strongswan_network_ipv6="2001:db8::/48",
|
||||||
wireguard_network_ipv6='2001:db8:a160::/48',
|
wireguard_network_ipv6="2001:db8:a160::/48",
|
||||||
wireguard_port=51820,
|
wireguard_port=51820,
|
||||||
wireguard_port_avoid=53,
|
wireguard_port_avoid=53,
|
||||||
wireguard_port_actual=51820,
|
wireguard_port_actual=51820,
|
||||||
ansible_default_ipv4={'interface': 'eth0'},
|
ansible_default_ipv4={"interface": "eth0"},
|
||||||
snat_aipv4=None,
|
snat_aipv4=None,
|
||||||
BetweenClients_DROP=True,
|
BetweenClients_DROP=True,
|
||||||
block_smb=True,
|
block_smb=True,
|
||||||
block_netbios=True,
|
block_netbios=True,
|
||||||
local_service_ip='10.49.0.1',
|
local_service_ip="10.49.0.1",
|
||||||
ansible_ssh_port=22,
|
ansible_ssh_port=22,
|
||||||
reduce_mtu=0
|
reduce_mtu=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Both should have NAT rules with output interface
|
# Both should have NAT rules with output interface
|
||||||
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
|
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||||
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
|
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||||
|
|
||||||
# Neither should have policy matching
|
# Neither should have policy matching
|
||||||
assert '-m policy --pol none' not in result
|
assert "-m policy --pol none" not in result
|
||||||
|
|
||||||
|
|
||||||
def test_alternative_ingress_snat():
|
def test_alternative_ingress_snat():
|
||||||
"""Test that alternative ingress IP uses SNAT instead of MASQUERADE."""
|
"""Test that alternative ingress IP uses SNAT instead of MASQUERADE."""
|
||||||
template = load_template('rules.v4.j2')
|
template = load_template("rules.v4.j2")
|
||||||
|
|
||||||
result = template.render(
|
result = template.render(
|
||||||
ipsec_enabled=True,
|
ipsec_enabled=True,
|
||||||
wireguard_enabled=True,
|
wireguard_enabled=True,
|
||||||
strongswan_network='10.48.0.0/16',
|
strongswan_network="10.48.0.0/16",
|
||||||
wireguard_network_ipv4='10.49.0.0/16',
|
wireguard_network_ipv4="10.49.0.0/16",
|
||||||
strongswan_network_ipv6='2001:db8::/48',
|
strongswan_network_ipv6="2001:db8::/48",
|
||||||
wireguard_network_ipv6='2001:db8:a160::/48',
|
wireguard_network_ipv6="2001:db8:a160::/48",
|
||||||
wireguard_port=51820,
|
wireguard_port=51820,
|
||||||
wireguard_port_avoid=53,
|
wireguard_port_avoid=53,
|
||||||
wireguard_port_actual=51820,
|
wireguard_port_actual=51820,
|
||||||
ansible_default_ipv4={'interface': 'eth0'},
|
ansible_default_ipv4={"interface": "eth0"},
|
||||||
snat_aipv4='192.168.1.100', # Alternative ingress IP
|
snat_aipv4="192.168.1.100", # Alternative ingress IP
|
||||||
BetweenClients_DROP=True,
|
BetweenClients_DROP=True,
|
||||||
block_smb=True,
|
block_smb=True,
|
||||||
block_netbios=True,
|
block_netbios=True,
|
||||||
local_service_ip='10.49.0.1',
|
local_service_ip="10.49.0.1",
|
||||||
ansible_ssh_port=22,
|
ansible_ssh_port=22,
|
||||||
reduce_mtu=0
|
reduce_mtu=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should use SNAT with specific IP and output interface instead of MASQUERADE
|
# Should use SNAT with specific IP and output interface instead of MASQUERADE
|
||||||
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100' in result
|
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
|
||||||
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100' in result
|
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j SNAT --to 192.168.1.100" in result
|
||||||
assert 'MASQUERADE' not in result
|
assert "MASQUERADE" not in result
|
||||||
|
|
||||||
|
|
||||||
def test_ipsec_forward_rule_has_policy_match():
|
def test_ipsec_forward_rule_has_policy_match():
|
||||||
"""Test that IPsec FORWARD rules still use policy matching (this is correct)."""
|
"""Test that IPsec FORWARD rules still use policy matching (this is correct)."""
|
||||||
template = load_template('rules.v4.j2')
|
template = load_template("rules.v4.j2")
|
||||||
|
|
||||||
result = template.render(
|
result = template.render(
|
||||||
ipsec_enabled=True,
|
ipsec_enabled=True,
|
||||||
wireguard_enabled=False,
|
wireguard_enabled=False,
|
||||||
strongswan_network='10.48.0.0/16',
|
strongswan_network="10.48.0.0/16",
|
||||||
strongswan_network_ipv6='2001:db8::/48',
|
strongswan_network_ipv6="2001:db8::/48",
|
||||||
ansible_default_ipv4={'interface': 'eth0'},
|
ansible_default_ipv4={"interface": "eth0"},
|
||||||
snat_aipv4=None,
|
snat_aipv4=None,
|
||||||
BetweenClients_DROP=True,
|
BetweenClients_DROP=True,
|
||||||
block_smb=True,
|
block_smb=True,
|
||||||
block_netbios=True,
|
block_netbios=True,
|
||||||
local_service_ip='10.48.0.1',
|
local_service_ip="10.48.0.1",
|
||||||
ansible_ssh_port=22,
|
ansible_ssh_port=22,
|
||||||
reduce_mtu=0
|
reduce_mtu=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FORWARD rule should have policy match (this is correct and should stay)
|
# FORWARD rule should have policy match (this is correct and should stay)
|
||||||
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT' in result
|
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.48.0.0/16 -m policy --pol ipsec --dir in -j ACCEPT" in result
|
||||||
|
|
||||||
|
|
||||||
def test_wireguard_forward_rule_no_policy_match():
|
def test_wireguard_forward_rule_no_policy_match():
|
||||||
"""Test that WireGuard FORWARD rules don't use policy matching."""
|
"""Test that WireGuard FORWARD rules don't use policy matching."""
|
||||||
template = load_template('rules.v4.j2')
|
template = load_template("rules.v4.j2")
|
||||||
|
|
||||||
result = template.render(
|
result = template.render(
|
||||||
ipsec_enabled=False,
|
ipsec_enabled=False,
|
||||||
wireguard_enabled=True,
|
wireguard_enabled=True,
|
||||||
wireguard_network_ipv4='10.49.0.0/16',
|
wireguard_network_ipv4="10.49.0.0/16",
|
||||||
wireguard_port=51820,
|
wireguard_port=51820,
|
||||||
wireguard_port_avoid=53,
|
wireguard_port_avoid=53,
|
||||||
wireguard_port_actual=51820,
|
wireguard_port_actual=51820,
|
||||||
ansible_default_ipv4={'interface': 'eth0'},
|
ansible_default_ipv4={"interface": "eth0"},
|
||||||
snat_aipv4=None,
|
snat_aipv4=None,
|
||||||
BetweenClients_DROP=True,
|
BetweenClients_DROP=True,
|
||||||
block_smb=True,
|
block_smb=True,
|
||||||
block_netbios=True,
|
block_netbios=True,
|
||||||
local_service_ip='10.49.0.1',
|
local_service_ip="10.49.0.1",
|
||||||
ansible_ssh_port=22,
|
ansible_ssh_port=22,
|
||||||
reduce_mtu=0
|
reduce_mtu=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# WireGuard FORWARD rule should NOT have any policy match
|
# WireGuard FORWARD rule should NOT have any policy match
|
||||||
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT' in result
|
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -j ACCEPT" in result
|
||||||
assert '-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy' not in result
|
assert "-A FORWARD -m conntrack --ctstate NEW -s 10.49.0.0/16 -m policy" not in result
|
||||||
|
|
||||||
|
|
||||||
def test_output_interface_in_nat_rules():
|
def test_output_interface_in_nat_rules():
|
||||||
"""Test that output interface is specified in NAT rules."""
|
"""Test that output interface is specified in NAT rules."""
|
||||||
template = load_template('rules.v4.j2')
|
template = load_template("rules.v4.j2")
|
||||||
|
|
||||||
result = template.render(
|
result = template.render(
|
||||||
snat_aipv4=False,
|
snat_aipv4=False,
|
||||||
wireguard_enabled=True,
|
wireguard_enabled=True,
|
||||||
ipsec_enabled=True,
|
ipsec_enabled=True,
|
||||||
wireguard_network_ipv4='10.49.0.0/16',
|
wireguard_network_ipv4="10.49.0.0/16",
|
||||||
strongswan_network='10.48.0.0/16',
|
strongswan_network="10.48.0.0/16",
|
||||||
ansible_default_ipv4={'interface': 'eth0', 'address': '10.0.0.1'},
|
ansible_default_ipv4={"interface": "eth0", "address": "10.0.0.1"},
|
||||||
ansible_default_ipv6={'interface': 'eth0', 'address': 'fd9d:bc11:4020::1'},
|
ansible_default_ipv6={"interface": "eth0", "address": "fd9d:bc11:4020::1"},
|
||||||
wireguard_port_actual=51820,
|
wireguard_port_actual=51820,
|
||||||
wireguard_port_avoid=53,
|
wireguard_port_avoid=53,
|
||||||
wireguard_port=51820,
|
wireguard_port=51820,
|
||||||
ansible_ssh_port=22,
|
ansible_ssh_port=22,
|
||||||
reduce_mtu=0
|
reduce_mtu=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that output interface is specified for both VPNs
|
# Check that output interface is specified for both VPNs
|
||||||
assert '-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE' in result
|
assert "-A POSTROUTING -s 10.49.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||||
assert '-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE' in result
|
assert "-A POSTROUTING -s 10.48.0.0/16 -o eth0 -j MASQUERADE" in result
|
||||||
|
|
||||||
# Ensure we don't have rules without output interface
|
# Ensure we don't have rules without output interface
|
||||||
assert '-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE' not in result
|
assert "-A POSTROUTING -s 10.49.0.0/16 -j MASQUERADE" not in result
|
||||||
assert '-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE' not in result
|
assert "-A POSTROUTING -s 10.48.0.0/16 -j MASQUERADE" not in result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, "-v"])
|
||||||
|
|
|
@ -12,7 +12,7 @@ import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
# Add the library directory to the path
|
# Add the library directory to the path
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../library'))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../library"))
|
||||||
|
|
||||||
|
|
||||||
class TestLightsailBoto3Fix(unittest.TestCase):
|
class TestLightsailBoto3Fix(unittest.TestCase):
|
||||||
|
@ -22,15 +22,15 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
||||||
"""Set up test fixtures."""
|
"""Set up test fixtures."""
|
||||||
# Mock the ansible module_utils since we're testing outside of Ansible
|
# Mock the ansible module_utils since we're testing outside of Ansible
|
||||||
self.mock_modules = {
|
self.mock_modules = {
|
||||||
'ansible.module_utils.basic': MagicMock(),
|
"ansible.module_utils.basic": MagicMock(),
|
||||||
'ansible.module_utils.ec2': MagicMock(),
|
"ansible.module_utils.ec2": MagicMock(),
|
||||||
'ansible.module_utils.aws.core': MagicMock(),
|
"ansible.module_utils.aws.core": MagicMock(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Apply mocks
|
# Apply mocks
|
||||||
self.patches = []
|
self.patches = []
|
||||||
for module_name, mock_module in self.mock_modules.items():
|
for module_name, mock_module in self.mock_modules.items():
|
||||||
patcher = patch.dict('sys.modules', {module_name: mock_module})
|
patcher = patch.dict("sys.modules", {module_name: mock_module})
|
||||||
patcher.start()
|
patcher.start()
|
||||||
self.patches.append(patcher)
|
self.patches.append(patcher)
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
||||||
# Import the module
|
# Import the module
|
||||||
spec = importlib.util.spec_from_file_location(
|
spec = importlib.util.spec_from_file_location(
|
||||||
"lightsail_region_facts",
|
"lightsail_region_facts",
|
||||||
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
|
os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
|
||||||
)
|
)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
||||||
|
|
||||||
# Verify the module loaded
|
# Verify the module loaded
|
||||||
self.assertIsNotNone(module)
|
self.assertIsNotNone(module)
|
||||||
self.assertTrue(hasattr(module, 'main'))
|
self.assertTrue(hasattr(module, "main"))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.fail(f"Failed to import lightsail_region_facts: {e}")
|
self.fail(f"Failed to import lightsail_region_facts: {e}")
|
||||||
|
@ -62,15 +62,13 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
||||||
def test_get_aws_connection_info_called_without_boto3(self):
|
def test_get_aws_connection_info_called_without_boto3(self):
|
||||||
"""Test that get_aws_connection_info is called without boto3 parameter."""
|
"""Test that get_aws_connection_info is called without boto3 parameter."""
|
||||||
# Mock get_aws_connection_info to track calls
|
# Mock get_aws_connection_info to track calls
|
||||||
mock_get_aws_connection_info = MagicMock(
|
mock_get_aws_connection_info = MagicMock(return_value=("us-west-2", None, {}))
|
||||||
return_value=('us-west-2', None, {})
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch('ansible.module_utils.ec2.get_aws_connection_info', mock_get_aws_connection_info):
|
with patch("ansible.module_utils.ec2.get_aws_connection_info", mock_get_aws_connection_info):
|
||||||
# Import the module
|
# Import the module
|
||||||
spec = importlib.util.spec_from_file_location(
|
spec = importlib.util.spec_from_file_location(
|
||||||
"lightsail_region_facts",
|
"lightsail_region_facts",
|
||||||
os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
|
os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py"),
|
||||||
)
|
)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
|
||||||
|
@ -79,7 +77,7 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
||||||
mock_ansible_module.params = {}
|
mock_ansible_module.params = {}
|
||||||
mock_ansible_module.check_mode = False
|
mock_ansible_module.check_mode = False
|
||||||
|
|
||||||
with patch('ansible.module_utils.basic.AnsibleModule', return_value=mock_ansible_module):
|
with patch("ansible.module_utils.basic.AnsibleModule", return_value=mock_ansible_module):
|
||||||
# Execute the module
|
# Execute the module
|
||||||
try:
|
try:
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
|
@ -100,28 +98,35 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
||||||
if call_args:
|
if call_args:
|
||||||
# Check positional arguments
|
# Check positional arguments
|
||||||
if call_args[0]: # args
|
if call_args[0]: # args
|
||||||
self.assertTrue(len(call_args[0]) <= 1,
|
self.assertTrue(
|
||||||
"get_aws_connection_info should be called with at most 1 positional arg (module)")
|
len(call_args[0]) <= 1,
|
||||||
|
"get_aws_connection_info should be called with at most 1 positional arg (module)",
|
||||||
|
)
|
||||||
|
|
||||||
# Check keyword arguments
|
# Check keyword arguments
|
||||||
if call_args[1]: # kwargs
|
if call_args[1]: # kwargs
|
||||||
self.assertNotIn('boto3', call_args[1],
|
self.assertNotIn(
|
||||||
"get_aws_connection_info should not be called with boto3 parameter")
|
"boto3", call_args[1], "get_aws_connection_info should not be called with boto3 parameter"
|
||||||
|
)
|
||||||
|
|
||||||
def test_no_boto3_parameter_in_source(self):
|
def test_no_boto3_parameter_in_source(self):
|
||||||
"""Verify that boto3 parameter is not present in the source code."""
|
"""Verify that boto3 parameter is not present in the source code."""
|
||||||
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
|
lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
|
||||||
|
|
||||||
with open(lightsail_path) as f:
|
with open(lightsail_path) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Check that boto3=True is not in the file
|
# Check that boto3=True is not in the file
|
||||||
self.assertNotIn('boto3=True', content,
|
self.assertNotIn(
|
||||||
"boto3=True parameter should not be present in lightsail_region_facts.py")
|
"boto3=True", content, "boto3=True parameter should not be present in lightsail_region_facts.py"
|
||||||
|
)
|
||||||
|
|
||||||
# Check that boto3 parameter is not used with get_aws_connection_info
|
# Check that boto3 parameter is not used with get_aws_connection_info
|
||||||
self.assertNotIn('get_aws_connection_info(module, boto3', content,
|
self.assertNotIn(
|
||||||
"get_aws_connection_info should not be called with boto3 parameter")
|
"get_aws_connection_info(module, boto3",
|
||||||
|
content,
|
||||||
|
"get_aws_connection_info should not be called with boto3 parameter",
|
||||||
|
)
|
||||||
|
|
||||||
def test_regression_issue_14822(self):
|
def test_regression_issue_14822(self):
|
||||||
"""
|
"""
|
||||||
|
@ -132,26 +137,28 @@ class TestLightsailBoto3Fix(unittest.TestCase):
|
||||||
# The boto3 parameter was deprecated and removed in amazon.aws collection
|
# The boto3 parameter was deprecated and removed in amazon.aws collection
|
||||||
# that comes with Ansible 11.x
|
# that comes with Ansible 11.x
|
||||||
|
|
||||||
lightsail_path = os.path.join(os.path.dirname(__file__), '../../library/lightsail_region_facts.py')
|
lightsail_path = os.path.join(os.path.dirname(__file__), "../../library/lightsail_region_facts.py")
|
||||||
|
|
||||||
with open(lightsail_path) as f:
|
with open(lightsail_path) as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
# Find the line that calls get_aws_connection_info
|
# Find the line that calls get_aws_connection_info
|
||||||
for line_num, line in enumerate(lines, 1):
|
for line_num, line in enumerate(lines, 1):
|
||||||
if 'get_aws_connection_info' in line and 'region' in line:
|
if "get_aws_connection_info" in line and "region" in line:
|
||||||
# This should be around line 85
|
# This should be around line 85
|
||||||
# Verify it doesn't have boto3=True
|
# Verify it doesn't have boto3=True
|
||||||
self.assertNotIn('boto3', line,
|
self.assertNotIn("boto3", line, f"Line {line_num} should not contain boto3 parameter")
|
||||||
f"Line {line_num} should not contain boto3 parameter")
|
|
||||||
|
|
||||||
# Verify the correct format
|
# Verify the correct format
|
||||||
self.assertIn('get_aws_connection_info(module)', line,
|
self.assertIn(
|
||||||
f"Line {line_num} should call get_aws_connection_info(module) without boto3")
|
"get_aws_connection_info(module)",
|
||||||
|
line,
|
||||||
|
f"Line {line_num} should call get_aws_connection_info(module) without boto3",
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.fail("Could not find get_aws_connection_info call in lightsail_region_facts.py")
|
self.fail("Could not find get_aws_connection_info call in lightsail_region_facts.py")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -5,6 +5,7 @@ Hybrid approach: validates actual certificates when available, else tests templa
|
||||||
Based on issues #14755, #14718 - Apple device compatibility
|
Based on issues #14755, #14718 - Apple device compatibility
|
||||||
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
|
Issues #75, #153 - Security enhancements (name constraints, EKU restrictions)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -22,7 +23,7 @@ def find_generated_certificates():
|
||||||
config_patterns = [
|
config_patterns = [
|
||||||
"configs/*/ipsec/.pki/cacert.pem",
|
"configs/*/ipsec/.pki/cacert.pem",
|
||||||
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
|
"../configs/*/ipsec/.pki/cacert.pem", # From tests/unit directory
|
||||||
"../../configs/*/ipsec/.pki/cacert.pem" # Alternative path
|
"../../configs/*/ipsec/.pki/cacert.pem", # Alternative path
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern in config_patterns:
|
for pattern in config_patterns:
|
||||||
|
@ -30,26 +31,23 @@ def find_generated_certificates():
|
||||||
if ca_certs:
|
if ca_certs:
|
||||||
base_path = os.path.dirname(ca_certs[0])
|
base_path = os.path.dirname(ca_certs[0])
|
||||||
return {
|
return {
|
||||||
'ca_cert': ca_certs[0],
|
"ca_cert": ca_certs[0],
|
||||||
'base_path': base_path,
|
"base_path": base_path,
|
||||||
'server_certs': glob.glob(f"{base_path}/certs/*.crt"),
|
"server_certs": glob.glob(f"{base_path}/certs/*.crt"),
|
||||||
'p12_files': glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12")
|
"p12_files": glob.glob(f"{base_path.replace('/.pki', '')}/manual/*.p12"),
|
||||||
}
|
}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def test_openssl_version_detection():
|
def test_openssl_version_detection():
|
||||||
"""Test that we can detect OpenSSL version for compatibility checks"""
|
"""Test that we can detect OpenSSL version for compatibility checks"""
|
||||||
result = subprocess.run(
|
result = subprocess.run(["openssl", "version"], capture_output=True, text=True)
|
||||||
['openssl', 'version'],
|
|
||||||
capture_output=True,
|
|
||||||
text=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.returncode == 0, "Failed to get OpenSSL version"
|
assert result.returncode == 0, "Failed to get OpenSSL version"
|
||||||
|
|
||||||
# Parse version - e.g., "OpenSSL 3.0.2 15 Mar 2022"
|
# Parse version - e.g., "OpenSSL 3.0.2 15 Mar 2022"
|
||||||
version_match = re.search(r'OpenSSL\s+(\d+)\.(\d+)\.(\d+)', result.stdout)
|
version_match = re.search(r"OpenSSL\s+(\d+)\.(\d+)\.(\d+)", result.stdout)
|
||||||
assert version_match, f"Can't parse OpenSSL version: {result.stdout}"
|
assert version_match, f"Can't parse OpenSSL version: {result.stdout}"
|
||||||
|
|
||||||
major = int(version_match.group(1))
|
major = int(version_match.group(1))
|
||||||
|
@ -62,7 +60,7 @@ def test_openssl_version_detection():
|
||||||
def validate_ca_certificate_real(cert_files):
|
def validate_ca_certificate_real(cert_files):
|
||||||
"""Validate actual Ansible-generated CA certificate"""
|
"""Validate actual Ansible-generated CA certificate"""
|
||||||
# Read the actual CA certificate generated by Ansible
|
# Read the actual CA certificate generated by Ansible
|
||||||
with open(cert_files['ca_cert'], 'rb') as f:
|
with open(cert_files["ca_cert"], "rb") as f:
|
||||||
cert_data = f.read()
|
cert_data = f.read()
|
||||||
|
|
||||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||||
|
@ -89,30 +87,34 @@ def validate_ca_certificate_real(cert_files):
|
||||||
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
|
assert name_constraints.excluded_subtrees is not None, "CA should have excluded name constraints"
|
||||||
|
|
||||||
# Verify public domains are excluded
|
# Verify public domains are excluded
|
||||||
excluded_dns = [constraint.value for constraint in name_constraints.excluded_subtrees
|
excluded_dns = [
|
||||||
if isinstance(constraint, x509.DNSName)]
|
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.DNSName)
|
||||||
|
]
|
||||||
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||||
for domain in public_domains:
|
for domain in public_domains:
|
||||||
assert domain in excluded_dns, f"CA should exclude public domain {domain}"
|
assert domain in excluded_dns, f"CA should exclude public domain {domain}"
|
||||||
|
|
||||||
# Verify private IP ranges are excluded (Issue #75)
|
# Verify private IP ranges are excluded (Issue #75)
|
||||||
excluded_ips = [constraint.value for constraint in name_constraints.excluded_subtrees
|
excluded_ips = [
|
||||||
if isinstance(constraint, x509.IPAddress)]
|
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.IPAddress)
|
||||||
|
]
|
||||||
assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
|
assert len(excluded_ips) > 0, "CA should exclude private IP ranges"
|
||||||
|
|
||||||
# Verify email domains are also excluded (Issue #153)
|
# Verify email domains are also excluded (Issue #153)
|
||||||
excluded_emails = [constraint.value for constraint in name_constraints.excluded_subtrees
|
excluded_emails = [
|
||||||
if isinstance(constraint, x509.RFC822Name)]
|
constraint.value for constraint in name_constraints.excluded_subtrees if isinstance(constraint, x509.RFC822Name)
|
||||||
|
]
|
||||||
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||||
for domain in email_domains:
|
for domain in email_domains:
|
||||||
assert domain in excluded_emails, f"CA should exclude email domain {domain}"
|
assert domain in excluded_emails, f"CA should exclude email domain {domain}"
|
||||||
|
|
||||||
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
|
print(f"✓ Real CA certificate has proper security constraints: {cert_files['ca_cert']}")
|
||||||
|
|
||||||
|
|
||||||
def validate_ca_certificate_config():
|
def validate_ca_certificate_config():
|
||||||
"""Validate CA certificate configuration in Ansible files (CI mode)"""
|
"""Validate CA certificate configuration in Ansible files (CI mode)"""
|
||||||
# Check that the Ansible task file has proper CA certificate configuration
|
# Check that the Ansible task file has proper CA certificate configuration
|
||||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
@ -122,15 +124,15 @@ def validate_ca_certificate_config():
|
||||||
|
|
||||||
# Verify key security configurations are present
|
# Verify key security configurations are present
|
||||||
security_checks = [
|
security_checks = [
|
||||||
('name_constraints_permitted', 'Name constraints should be configured'),
|
("name_constraints_permitted", "Name constraints should be configured"),
|
||||||
('name_constraints_excluded', 'Excluded name constraints should be configured'),
|
("name_constraints_excluded", "Excluded name constraints should be configured"),
|
||||||
('extended_key_usage', 'Extended Key Usage should be configured'),
|
("extended_key_usage", "Extended Key Usage should be configured"),
|
||||||
('1.3.6.1.5.5.7.3.17', 'IPsec End Entity OID should be present'),
|
("1.3.6.1.5.5.7.3.17", "IPsec End Entity OID should be present"),
|
||||||
('serverAuth', 'Server authentication EKU should be present'),
|
("serverAuth", "Server authentication EKU should be present"),
|
||||||
('clientAuth', 'Client authentication EKU should be present'),
|
("clientAuth", "Client authentication EKU should be present"),
|
||||||
('basic_constraints', 'Basic constraints should be configured'),
|
("basic_constraints", "Basic constraints should be configured"),
|
||||||
('CA:TRUE', 'CA certificate should be marked as CA'),
|
("CA:TRUE", "CA certificate should be marked as CA"),
|
||||||
('pathlen:0', 'Path length constraint should be set')
|
("pathlen:0", "Path length constraint should be set"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in security_checks:
|
for check, message in security_checks:
|
||||||
|
@ -140,7 +142,9 @@ def validate_ca_certificate_config():
|
||||||
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
public_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||||
for domain in public_domains:
|
for domain in public_domains:
|
||||||
# Handle both double quotes and single quotes in YAML
|
# Handle both double quotes and single quotes in YAML
|
||||||
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, f"Public domain {domain} should be excluded"
|
assert f'"DNS:{domain}"' in content or f"'DNS:{domain}'" in content, (
|
||||||
|
f"Public domain {domain} should be excluded"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify private IP ranges are excluded
|
# Verify private IP ranges are excluded
|
||||||
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
|
private_ranges = ["10.0.0.0", "172.16.0.0", "192.168.0.0"]
|
||||||
|
@ -151,13 +155,16 @@ def validate_ca_certificate_config():
|
||||||
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
email_domains = [".com", ".org", ".net", ".gov", ".edu", ".mil", ".int"]
|
||||||
for domain in email_domains:
|
for domain in email_domains:
|
||||||
# Handle both double quotes and single quotes in YAML
|
# Handle both double quotes and single quotes in YAML
|
||||||
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, f"Email domain {domain} should be excluded"
|
assert f'"email:{domain}"' in content or f"'email:{domain}'" in content, (
|
||||||
|
f"Email domain {domain} should be excluded"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify IPv6 constraints are present (Issue #153)
|
# Verify IPv6 constraints are present (Issue #153)
|
||||||
assert "IP:::/0" in content, "IPv6 all addresses should be excluded"
|
assert "IP:::/0" in content, "IPv6 all addresses should be excluded"
|
||||||
|
|
||||||
print("✓ CA certificate configuration has proper security constraints")
|
print("✓ CA certificate configuration has proper security constraints")
|
||||||
|
|
||||||
|
|
||||||
def test_ca_certificate():
|
def test_ca_certificate():
|
||||||
"""Test CA certificate - uses real certs if available, else validates config (Issue #75, #153)"""
|
"""Test CA certificate - uses real certs if available, else validates config (Issue #75, #153)"""
|
||||||
cert_files = find_generated_certificates()
|
cert_files = find_generated_certificates()
|
||||||
|
@ -172,14 +179,18 @@ def validate_server_certificates_real(cert_files):
|
||||||
# Filter to only actual server certificates (not client certs)
|
# Filter to only actual server certificates (not client certs)
|
||||||
# Server certificates contain IP addresses in the filename
|
# Server certificates contain IP addresses in the filename
|
||||||
import re
|
import re
|
||||||
server_certs = [f for f in cert_files['server_certs']
|
|
||||||
if not f.endswith('/cacert.pem') and re.search(r'\d+\.\d+\.\d+\.\d+\.crt$', f)]
|
server_certs = [
|
||||||
|
f
|
||||||
|
for f in cert_files["server_certs"]
|
||||||
|
if not f.endswith("/cacert.pem") and re.search(r"\d+\.\d+\.\d+\.\d+\.crt$", f)
|
||||||
|
]
|
||||||
if not server_certs:
|
if not server_certs:
|
||||||
print("⚠ No server certificates found")
|
print("⚠ No server certificates found")
|
||||||
return
|
return
|
||||||
|
|
||||||
for server_cert_path in server_certs:
|
for server_cert_path in server_certs:
|
||||||
with open(server_cert_path, 'rb') as f:
|
with open(server_cert_path, "rb") as f:
|
||||||
cert_data = f.read()
|
cert_data = f.read()
|
||||||
|
|
||||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||||
|
@ -193,7 +204,9 @@ def validate_server_certificates_real(cert_files):
|
||||||
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
|
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH in eku, "Server cert must have serverAuth EKU"
|
||||||
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
|
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Server cert should have IPsec End Entity EKU"
|
||||||
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
|
# Security check: Server certificates should NOT have clientAuth to prevent role confusion (Issue #153)
|
||||||
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, "Server cert should NOT have clientAuth EKU for role separation"
|
assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH not in eku, (
|
||||||
|
"Server cert should NOT have clientAuth EKU for role separation"
|
||||||
|
)
|
||||||
|
|
||||||
# Check SAN extension exists (required for Apple devices)
|
# Check SAN extension exists (required for Apple devices)
|
||||||
try:
|
try:
|
||||||
|
@ -204,9 +217,10 @@ def validate_server_certificates_real(cert_files):
|
||||||
|
|
||||||
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
|
print(f"✓ Real server certificate valid: {os.path.basename(server_cert_path)}")
|
||||||
|
|
||||||
|
|
||||||
def validate_server_certificates_config():
|
def validate_server_certificates_config():
|
||||||
"""Validate server certificate configuration in Ansible files (CI mode)"""
|
"""Validate server certificate configuration in Ansible files (CI mode)"""
|
||||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
@ -215,7 +229,7 @@ def validate_server_certificates_config():
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Look for server certificate CSR section
|
# Look for server certificate CSR section
|
||||||
server_csr_section = re.search(r'Create CSRs for server certificate.*?register: server_csr', content, re.DOTALL)
|
server_csr_section = re.search(r"Create CSRs for server certificate.*?register: server_csr", content, re.DOTALL)
|
||||||
if not server_csr_section:
|
if not server_csr_section:
|
||||||
print("⚠ Could not find server certificate CSR section")
|
print("⚠ Could not find server certificate CSR section")
|
||||||
return
|
return
|
||||||
|
@ -224,11 +238,11 @@ def validate_server_certificates_config():
|
||||||
|
|
||||||
# Check server certificate CSR configuration
|
# Check server certificate CSR configuration
|
||||||
server_checks = [
|
server_checks = [
|
||||||
('subject_alt_name', 'Server certificates should have SAN extension'),
|
("subject_alt_name", "Server certificates should have SAN extension"),
|
||||||
('serverAuth', 'Server certificates should have serverAuth EKU'),
|
("serverAuth", "Server certificates should have serverAuth EKU"),
|
||||||
('1.3.6.1.5.5.7.3.17', 'Server certificates should have IPsec End Entity EKU'),
|
("1.3.6.1.5.5.7.3.17", "Server certificates should have IPsec End Entity EKU"),
|
||||||
('digitalSignature', 'Server certificates should have digital signature usage'),
|
("digitalSignature", "Server certificates should have digital signature usage"),
|
||||||
('keyEncipherment', 'Server certificates should have key encipherment usage')
|
("keyEncipherment", "Server certificates should have key encipherment usage"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in server_checks:
|
for check, message in server_checks:
|
||||||
|
@ -236,15 +250,20 @@ def validate_server_certificates_config():
|
||||||
|
|
||||||
# Security check: Server certificates should NOT have clientAuth (Issue #153)
|
# Security check: Server certificates should NOT have clientAuth (Issue #153)
|
||||||
# Look for clientAuth in extended_key_usage section, not in comments
|
# Look for clientAuth in extended_key_usage section, not in comments
|
||||||
eku_lines = [line for line in server_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'clientAuth' in line)]
|
eku_lines = [
|
||||||
has_client_auth = any('clientAuth' in line for line in eku_lines if line.strip().startswith('- '))
|
line
|
||||||
|
for line in server_section.split("\n")
|
||||||
|
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "clientAuth" in line)
|
||||||
|
]
|
||||||
|
has_client_auth = any("clientAuth" in line for line in eku_lines if line.strip().startswith("- "))
|
||||||
assert not has_client_auth, "Server certificates should NOT have clientAuth EKU for role separation"
|
assert not has_client_auth, "Server certificates should NOT have clientAuth EKU for role separation"
|
||||||
|
|
||||||
# Verify SAN extension is configured for Apple compatibility
|
# Verify SAN extension is configured for Apple compatibility
|
||||||
assert 'subjectAltName' in server_section, "Server certificates missing SAN configuration for Apple compatibility"
|
assert "subjectAltName" in server_section, "Server certificates missing SAN configuration for Apple compatibility"
|
||||||
|
|
||||||
print("✓ Server certificate configuration has proper EKU and SAN settings")
|
print("✓ Server certificate configuration has proper EKU and SAN settings")
|
||||||
|
|
||||||
|
|
||||||
def test_server_certificates():
|
def test_server_certificates():
|
||||||
"""Test server certificates - uses real certs if available, else validates config"""
|
"""Test server certificates - uses real certs if available, else validates config"""
|
||||||
cert_files = find_generated_certificates()
|
cert_files = find_generated_certificates()
|
||||||
|
@ -258,18 +277,18 @@ def validate_client_certificates_real(cert_files):
|
||||||
"""Validate actual Ansible-generated client certificates"""
|
"""Validate actual Ansible-generated client certificates"""
|
||||||
# Find client certificates (not CA cert, not server cert with IP/DNS name)
|
# Find client certificates (not CA cert, not server cert with IP/DNS name)
|
||||||
client_certs = []
|
client_certs = []
|
||||||
for cert_path in cert_files['server_certs']:
|
for cert_path in cert_files["server_certs"]:
|
||||||
if 'cacert.pem' in cert_path:
|
if "cacert.pem" in cert_path:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with open(cert_path, 'rb') as f:
|
with open(cert_path, "rb") as f:
|
||||||
cert_data = f.read()
|
cert_data = f.read()
|
||||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||||
|
|
||||||
# Check if this looks like a client cert vs server cert
|
# Check if this looks like a client cert vs server cert
|
||||||
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
||||||
# Server certs typically have IP addresses or domain names as CN
|
# Server certs typically have IP addresses or domain names as CN
|
||||||
if not (cn.replace('.', '').isdigit() or '.' in cn and len(cn.split('.')) == 4):
|
if not (cn.replace(".", "").isdigit() or "." in cn and len(cn.split(".")) == 4):
|
||||||
client_certs.append((cert_path, certificate))
|
client_certs.append((cert_path, certificate))
|
||||||
|
|
||||||
if not client_certs:
|
if not client_certs:
|
||||||
|
@ -287,7 +306,9 @@ def validate_client_certificates_real(cert_files):
|
||||||
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
|
assert x509.ObjectIdentifier("1.3.6.1.5.5.7.3.17") in eku, "Client cert should have IPsec End Entity EKU"
|
||||||
|
|
||||||
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
|
# Security check: Client certificates should NOT have serverAuth (prevents impersonation) (Issue #153)
|
||||||
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, "Client cert must NOT have serverAuth EKU to prevent server impersonation"
|
assert x509.oid.ExtendedKeyUsageOID.SERVER_AUTH not in eku, (
|
||||||
|
"Client cert must NOT have serverAuth EKU to prevent server impersonation"
|
||||||
|
)
|
||||||
|
|
||||||
# Check SAN extension for email
|
# Check SAN extension for email
|
||||||
try:
|
try:
|
||||||
|
@ -299,9 +320,10 @@ def validate_client_certificates_real(cert_files):
|
||||||
|
|
||||||
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
|
print(f"✓ Real client certificate valid: {os.path.basename(cert_path)}")
|
||||||
|
|
||||||
|
|
||||||
def validate_client_certificates_config():
|
def validate_client_certificates_config():
|
||||||
"""Validate client certificate configuration in Ansible files (CI mode)"""
|
"""Validate client certificate configuration in Ansible files (CI mode)"""
|
||||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
@ -310,7 +332,9 @@ def validate_client_certificates_config():
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Look for client certificate CSR section
|
# Look for client certificate CSR section
|
||||||
client_csr_section = re.search(r'Create CSRs for client certificates.*?register: client_csr_jobs', content, re.DOTALL)
|
client_csr_section = re.search(
|
||||||
|
r"Create CSRs for client certificates.*?register: client_csr_jobs", content, re.DOTALL
|
||||||
|
)
|
||||||
if not client_csr_section:
|
if not client_csr_section:
|
||||||
print("⚠ Could not find client certificate CSR section")
|
print("⚠ Could not find client certificate CSR section")
|
||||||
return
|
return
|
||||||
|
@ -319,11 +343,11 @@ def validate_client_certificates_config():
|
||||||
|
|
||||||
# Check client certificate configuration
|
# Check client certificate configuration
|
||||||
client_checks = [
|
client_checks = [
|
||||||
('clientAuth', 'Client certificates should have clientAuth EKU'),
|
("clientAuth", "Client certificates should have clientAuth EKU"),
|
||||||
('1.3.6.1.5.5.7.3.17', 'Client certificates should have IPsec End Entity EKU'),
|
("1.3.6.1.5.5.7.3.17", "Client certificates should have IPsec End Entity EKU"),
|
||||||
('digitalSignature', 'Client certificates should have digital signature usage'),
|
("digitalSignature", "Client certificates should have digital signature usage"),
|
||||||
('keyEncipherment', 'Client certificates should have key encipherment usage'),
|
("keyEncipherment", "Client certificates should have key encipherment usage"),
|
||||||
('email:', 'Client certificates should have email SAN')
|
("email:", "Client certificates should have email SAN"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in client_checks:
|
for check, message in client_checks:
|
||||||
|
@ -331,15 +355,22 @@ def validate_client_certificates_config():
|
||||||
|
|
||||||
# Security check: Client certificates should NOT have serverAuth (Issue #153)
|
# Security check: Client certificates should NOT have serverAuth (Issue #153)
|
||||||
# Look for serverAuth in extended_key_usage section, not in comments
|
# Look for serverAuth in extended_key_usage section, not in comments
|
||||||
eku_lines = [line for line in client_section.split('\n') if 'extended_key_usage:' in line or (line.strip().startswith('- ') and 'serverAuth' in line)]
|
eku_lines = [
|
||||||
has_server_auth = any('serverAuth' in line for line in eku_lines if line.strip().startswith('- '))
|
line
|
||||||
|
for line in client_section.split("\n")
|
||||||
|
if "extended_key_usage:" in line or (line.strip().startswith("- ") and "serverAuth" in line)
|
||||||
|
]
|
||||||
|
has_server_auth = any("serverAuth" in line for line in eku_lines if line.strip().startswith("- "))
|
||||||
assert not has_server_auth, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
|
assert not has_server_auth, "Client certificates must NOT have serverAuth EKU to prevent server impersonation"
|
||||||
|
|
||||||
# Verify client certificates use unique email domains (Issue #153)
|
# Verify client certificates use unique email domains (Issue #153)
|
||||||
assert 'openssl_constraint_random_id' in client_section, "Client certificates should use unique email domain per deployment"
|
assert "openssl_constraint_random_id" in client_section, (
|
||||||
|
"Client certificates should use unique email domain per deployment"
|
||||||
|
)
|
||||||
|
|
||||||
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
|
print("✓ Client certificate configuration has proper EKU restrictions (no serverAuth)")
|
||||||
|
|
||||||
|
|
||||||
def test_client_certificates():
|
def test_client_certificates():
|
||||||
"""Test client certificates - uses real certs if available, else validates config (Issue #75, #153)"""
|
"""Test client certificates - uses real certs if available, else validates config (Issue #75, #153)"""
|
||||||
cert_files = find_generated_certificates()
|
cert_files = find_generated_certificates()
|
||||||
|
@ -351,24 +382,33 @@ def test_client_certificates():
|
||||||
|
|
||||||
def validate_pkcs12_files_real(cert_files):
|
def validate_pkcs12_files_real(cert_files):
|
||||||
"""Validate actual Ansible-generated PKCS#12 files"""
|
"""Validate actual Ansible-generated PKCS#12 files"""
|
||||||
if not cert_files.get('p12_files'):
|
if not cert_files.get("p12_files"):
|
||||||
print("⚠ No PKCS#12 files found")
|
print("⚠ No PKCS#12 files found")
|
||||||
return
|
return
|
||||||
|
|
||||||
major, minor = test_openssl_version_detection()
|
major, minor = test_openssl_version_detection()
|
||||||
|
|
||||||
for p12_file in cert_files['p12_files']:
|
for p12_file in cert_files["p12_files"]:
|
||||||
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
|
assert os.path.exists(p12_file), f"PKCS#12 file should exist: {p12_file}"
|
||||||
|
|
||||||
# Test that PKCS#12 file can be read (validates format)
|
# Test that PKCS#12 file can be read (validates format)
|
||||||
legacy_flag = ['-legacy'] if major >= 3 else []
|
legacy_flag = ["-legacy"] if major >= 3 else []
|
||||||
|
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
'openssl', 'pkcs12', '-info',
|
[
|
||||||
'-in', p12_file,
|
"openssl",
|
||||||
'-passin', 'pass:', # Try empty password first
|
"pkcs12",
|
||||||
'-noout'
|
"-info",
|
||||||
] + legacy_flag, capture_output=True, text=True)
|
"-in",
|
||||||
|
p12_file,
|
||||||
|
"-passin",
|
||||||
|
"pass:", # Try empty password first
|
||||||
|
"-noout",
|
||||||
|
]
|
||||||
|
+ legacy_flag,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
# PKCS#12 files should be readable (even if password-protected)
|
# PKCS#12 files should be readable (even if password-protected)
|
||||||
# We're just testing format validity, not trying to extract contents
|
# We're just testing format validity, not trying to extract contents
|
||||||
|
@ -378,9 +418,10 @@ def validate_pkcs12_files_real(cert_files):
|
||||||
|
|
||||||
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
|
print(f"✓ Real PKCS#12 file exists: {os.path.basename(p12_file)}")
|
||||||
|
|
||||||
|
|
||||||
def validate_pkcs12_files_config():
|
def validate_pkcs12_files_config():
|
||||||
"""Validate PKCS#12 file configuration in Ansible files (CI mode)"""
|
"""Validate PKCS#12 file configuration in Ansible files (CI mode)"""
|
||||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
@ -390,13 +431,13 @@ def validate_pkcs12_files_config():
|
||||||
|
|
||||||
# Check PKCS#12 generation configuration
|
# Check PKCS#12 generation configuration
|
||||||
p12_checks = [
|
p12_checks = [
|
||||||
('openssl_pkcs12', 'PKCS#12 generation should be configured'),
|
("openssl_pkcs12", "PKCS#12 generation should be configured"),
|
||||||
('encryption_level', 'PKCS#12 encryption level should be configured'),
|
("encryption_level", "PKCS#12 encryption level should be configured"),
|
||||||
('compatibility2022', 'PKCS#12 should use Apple-compatible encryption'),
|
("compatibility2022", "PKCS#12 should use Apple-compatible encryption"),
|
||||||
('friendly_name', 'PKCS#12 should have friendly names'),
|
("friendly_name", "PKCS#12 should have friendly names"),
|
||||||
('other_certificates', 'PKCS#12 should include CA certificate for full chain'),
|
("other_certificates", "PKCS#12 should include CA certificate for full chain"),
|
||||||
('passphrase', 'PKCS#12 files should be password protected'),
|
("passphrase", "PKCS#12 files should be password protected"),
|
||||||
('mode: "0600"', 'PKCS#12 files should have secure permissions')
|
('mode: "0600"', "PKCS#12 files should have secure permissions"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in p12_checks:
|
for check, message in p12_checks:
|
||||||
|
@ -404,6 +445,7 @@ def validate_pkcs12_files_config():
|
||||||
|
|
||||||
print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
|
print("✓ PKCS#12 configuration has proper Apple device compatibility settings")
|
||||||
|
|
||||||
|
|
||||||
def test_pkcs12_files():
|
def test_pkcs12_files():
|
||||||
"""Test PKCS#12 files - uses real files if available, else validates config (Issue #14755, #14718)"""
|
"""Test PKCS#12 files - uses real files if available, else validates config (Issue #14755, #14718)"""
|
||||||
cert_files = find_generated_certificates()
|
cert_files = find_generated_certificates()
|
||||||
|
@ -416,19 +458,19 @@ def test_pkcs12_files():
|
||||||
def validate_certificate_chain_real(cert_files):
|
def validate_certificate_chain_real(cert_files):
|
||||||
"""Validate actual Ansible-generated certificate chain"""
|
"""Validate actual Ansible-generated certificate chain"""
|
||||||
# Load CA certificate
|
# Load CA certificate
|
||||||
with open(cert_files['ca_cert'], 'rb') as f:
|
with open(cert_files["ca_cert"], "rb") as f:
|
||||||
ca_cert_data = f.read()
|
ca_cert_data = f.read()
|
||||||
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
|
ca_certificate = x509.load_pem_x509_certificate(ca_cert_data)
|
||||||
|
|
||||||
# Test that all other certificates are signed by the CA
|
# Test that all other certificates are signed by the CA
|
||||||
other_certs = [f for f in cert_files['server_certs'] if f != cert_files['ca_cert']]
|
other_certs = [f for f in cert_files["server_certs"] if f != cert_files["ca_cert"]]
|
||||||
|
|
||||||
if not other_certs:
|
if not other_certs:
|
||||||
print("⚠ No client/server certificates found to validate")
|
print("⚠ No client/server certificates found to validate")
|
||||||
return
|
return
|
||||||
|
|
||||||
for cert_path in other_certs:
|
for cert_path in other_certs:
|
||||||
with open(cert_path, 'rb') as f:
|
with open(cert_path, "rb") as f:
|
||||||
cert_data = f.read()
|
cert_data = f.read()
|
||||||
certificate = x509.load_pem_x509_certificate(cert_data)
|
certificate = x509.load_pem_x509_certificate(cert_data)
|
||||||
|
|
||||||
|
@ -437,6 +479,7 @@ def validate_certificate_chain_real(cert_files):
|
||||||
|
|
||||||
# Verify certificate is currently valid (not expired)
|
# Verify certificate is currently valid (not expired)
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
assert certificate.not_valid_before_utc <= now, f"Certificate {cert_path} not yet valid"
|
assert certificate.not_valid_before_utc <= now, f"Certificate {cert_path} not yet valid"
|
||||||
assert certificate.not_valid_after_utc >= now, f"Certificate {cert_path} has expired"
|
assert certificate.not_valid_after_utc >= now, f"Certificate {cert_path} has expired"
|
||||||
|
@ -445,9 +488,10 @@ def validate_certificate_chain_real(cert_files):
|
||||||
|
|
||||||
print("✓ All real certificates properly signed by CA")
|
print("✓ All real certificates properly signed by CA")
|
||||||
|
|
||||||
|
|
||||||
def validate_certificate_chain_config():
|
def validate_certificate_chain_config():
|
||||||
"""Validate certificate chain configuration in Ansible files (CI mode)"""
|
"""Validate certificate chain configuration in Ansible files (CI mode)"""
|
||||||
openssl_task_file = find_ansible_file('roles/strongswan/tasks/openssl.yml')
|
openssl_task_file = find_ansible_file("roles/strongswan/tasks/openssl.yml")
|
||||||
if not openssl_task_file:
|
if not openssl_task_file:
|
||||||
print("⚠ Could not find openssl.yml task file")
|
print("⚠ Could not find openssl.yml task file")
|
||||||
return
|
return
|
||||||
|
@ -457,15 +501,18 @@ def validate_certificate_chain_config():
|
||||||
|
|
||||||
# Check certificate signing configuration
|
# Check certificate signing configuration
|
||||||
chain_checks = [
|
chain_checks = [
|
||||||
('provider: ownca', 'Certificates should be signed by own CA'),
|
("provider: ownca", "Certificates should be signed by own CA"),
|
||||||
('ownca_path', 'CA certificate path should be specified'),
|
("ownca_path", "CA certificate path should be specified"),
|
||||||
('ownca_privatekey_path', 'CA private key path should be specified'),
|
("ownca_privatekey_path", "CA private key path should be specified"),
|
||||||
('ownca_privatekey_passphrase', 'CA private key should be password protected'),
|
("ownca_privatekey_passphrase", "CA private key should be password protected"),
|
||||||
('certificate_validity_days: 3650', 'Certificate validity should be configurable (default 10 years)'),
|
("certificate_validity_days: 3650", "Certificate validity should be configurable (default 10 years)"),
|
||||||
('ownca_not_after: "+{{ certificate_validity_days }}d"', 'Certificates should use configurable validity period'),
|
(
|
||||||
('ownca_not_before: "-1d"', 'Certificates should have backdated start time'),
|
'ownca_not_after: "+{{ certificate_validity_days }}d"',
|
||||||
('curve: secp384r1', 'Should use strong elliptic curve cryptography'),
|
"Certificates should use configurable validity period",
|
||||||
('type: ECC', 'Should use elliptic curve keys for better security')
|
),
|
||||||
|
('ownca_not_before: "-1d"', "Certificates should have backdated start time"),
|
||||||
|
("curve: secp384r1", "Should use strong elliptic curve cryptography"),
|
||||||
|
("type: ECC", "Should use elliptic curve keys for better security"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for check, message in chain_checks:
|
for check, message in chain_checks:
|
||||||
|
@ -473,6 +520,7 @@ def validate_certificate_chain_config():
|
||||||
|
|
||||||
print("✓ Certificate chain configuration properly set up for CA signing")
|
print("✓ Certificate chain configuration properly set up for CA signing")
|
||||||
|
|
||||||
|
|
||||||
def test_certificate_chain():
|
def test_certificate_chain():
|
||||||
"""Test certificate chain - uses real certs if available, else validates config"""
|
"""Test certificate chain - uses real certs if available, else validates config"""
|
||||||
cert_files = find_generated_certificates()
|
cert_files = find_generated_certificates()
|
||||||
|
@ -499,6 +547,7 @@ def find_ansible_file(relative_path):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tests = [
|
tests = [
|
||||||
test_openssl_version_detection,
|
test_openssl_version_detection,
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Enhanced tests for StrongSwan templates.
|
Enhanced tests for StrongSwan templates.
|
||||||
Tests all strongswan role templates with various configurations.
|
Tests all strongswan role templates with various configurations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -21,7 +22,7 @@ def mock_to_uuid(value):
|
||||||
|
|
||||||
def mock_bool(value):
|
def mock_bool(value):
|
||||||
"""Mock the bool filter"""
|
"""Mock the bool filter"""
|
||||||
return str(value).lower() in ('true', '1', 'yes', 'on')
|
return str(value).lower() in ("true", "1", "yes", "on")
|
||||||
|
|
||||||
|
|
||||||
def mock_version(version_string, comparison):
|
def mock_version(version_string, comparison):
|
||||||
|
@ -33,67 +34,67 @@ def mock_version(version_string, comparison):
|
||||||
def mock_b64encode(value):
|
def mock_b64encode(value):
|
||||||
"""Mock base64 encoding"""
|
"""Mock base64 encoding"""
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = value.encode('utf-8')
|
value = value.encode("utf-8")
|
||||||
return base64.b64encode(value).decode('ascii')
|
return base64.b64encode(value).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
def mock_b64decode(value):
|
def mock_b64decode(value):
|
||||||
"""Mock base64 decoding"""
|
"""Mock base64 decoding"""
|
||||||
import base64
|
import base64
|
||||||
return base64.b64decode(value).decode('utf-8')
|
|
||||||
|
return base64.b64decode(value).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def get_strongswan_test_variables(scenario='default'):
|
def get_strongswan_test_variables(scenario="default"):
|
||||||
"""Get test variables for StrongSwan templates with different scenarios."""
|
"""Get test variables for StrongSwan templates with different scenarios."""
|
||||||
base_vars = load_test_variables()
|
base_vars = load_test_variables()
|
||||||
|
|
||||||
# Add StrongSwan specific variables
|
# Add StrongSwan specific variables
|
||||||
strongswan_vars = {
|
strongswan_vars = {
|
||||||
'ipsec_config_path': '/etc/ipsec.d',
|
"ipsec_config_path": "/etc/ipsec.d",
|
||||||
'ipsec_pki_path': '/etc/ipsec.d',
|
"ipsec_pki_path": "/etc/ipsec.d",
|
||||||
'strongswan_enabled': True,
|
"strongswan_enabled": True,
|
||||||
'strongswan_network': '10.19.48.0/24',
|
"strongswan_network": "10.19.48.0/24",
|
||||||
'strongswan_network_ipv6': 'fd9d:bc11:4021::/64',
|
"strongswan_network_ipv6": "fd9d:bc11:4021::/64",
|
||||||
'strongswan_log_level': '2',
|
"strongswan_log_level": "2",
|
||||||
'openssl_constraint_random_id': 'test-' + str(uuid.uuid4()),
|
"openssl_constraint_random_id": "test-" + str(uuid.uuid4()),
|
||||||
'subjectAltName': 'IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a',
|
"subjectAltName": "IP:10.0.0.1,IP:2600:3c01::f03c:91ff:fedf:3b2a",
|
||||||
'subjectAltName_type': 'IP',
|
"subjectAltName_type": "IP",
|
||||||
'subjectAltName_client': 'IP:10.0.0.1',
|
"subjectAltName_client": "IP:10.0.0.1",
|
||||||
'ansible_default_ipv6': {
|
"ansible_default_ipv6": {"address": "2600:3c01::f03c:91ff:fedf:3b2a"},
|
||||||
'address': '2600:3c01::f03c:91ff:fedf:3b2a'
|
"openssl_version": "3.0.0",
|
||||||
},
|
"p12_export_password": "test-password",
|
||||||
'openssl_version': '3.0.0',
|
"ike_lifetime": "24h",
|
||||||
'p12_export_password': 'test-password',
|
"ipsec_lifetime": "8h",
|
||||||
'ike_lifetime': '24h',
|
"ike_dpd": "30s",
|
||||||
'ipsec_lifetime': '8h',
|
"ipsec_dead_peer_detection": True,
|
||||||
'ike_dpd': '30s',
|
"rekey_margin": "3m",
|
||||||
'ipsec_dead_peer_detection': True,
|
"rekeymargin": "3m",
|
||||||
'rekey_margin': '3m',
|
"dpddelay": "35s",
|
||||||
'rekeymargin': '3m',
|
"keyexchange": "ikev2",
|
||||||
'dpddelay': '35s',
|
"ike_cipher": "aes128gcm16-prfsha512-ecp256",
|
||||||
'keyexchange': 'ikev2',
|
"esp_cipher": "aes128gcm16-ecp256",
|
||||||
'ike_cipher': 'aes128gcm16-prfsha512-ecp256',
|
"leftsourceip": "10.19.48.1",
|
||||||
'esp_cipher': 'aes128gcm16-ecp256',
|
"leftsubnet": "0.0.0.0/0,::/0",
|
||||||
'leftsourceip': '10.19.48.1',
|
"rightsourceip": "10.19.48.2/24,fd9d:bc11:4021::2/64",
|
||||||
'leftsubnet': '0.0.0.0/0,::/0',
|
|
||||||
'rightsourceip': '10.19.48.2/24,fd9d:bc11:4021::2/64',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Merge with base variables
|
# Merge with base variables
|
||||||
test_vars = {**base_vars, **strongswan_vars}
|
test_vars = {**base_vars, **strongswan_vars}
|
||||||
|
|
||||||
# Apply scenario-specific overrides
|
# Apply scenario-specific overrides
|
||||||
if scenario == 'ipv4_only':
|
if scenario == "ipv4_only":
|
||||||
test_vars['ipv6_support'] = False
|
test_vars["ipv6_support"] = False
|
||||||
test_vars['subjectAltName'] = 'IP:10.0.0.1'
|
test_vars["subjectAltName"] = "IP:10.0.0.1"
|
||||||
test_vars['ansible_default_ipv6'] = None
|
test_vars["ansible_default_ipv6"] = None
|
||||||
elif scenario == 'dns_hostname':
|
elif scenario == "dns_hostname":
|
||||||
test_vars['IP_subject_alt_name'] = 'vpn.example.com'
|
test_vars["IP_subject_alt_name"] = "vpn.example.com"
|
||||||
test_vars['subjectAltName'] = 'DNS:vpn.example.com'
|
test_vars["subjectAltName"] = "DNS:vpn.example.com"
|
||||||
test_vars['subjectAltName_type'] = 'DNS'
|
test_vars["subjectAltName_type"] = "DNS"
|
||||||
elif scenario == 'openssl_legacy':
|
elif scenario == "openssl_legacy":
|
||||||
test_vars['openssl_version'] = '1.1.1'
|
test_vars["openssl_version"] = "1.1.1"
|
||||||
|
|
||||||
return test_vars
|
return test_vars
|
||||||
|
|
||||||
|
@ -101,16 +102,16 @@ def get_strongswan_test_variables(scenario='default'):
|
||||||
def test_strongswan_templates():
|
def test_strongswan_templates():
|
||||||
"""Test all StrongSwan templates with various configurations."""
|
"""Test all StrongSwan templates with various configurations."""
|
||||||
templates = [
|
templates = [
|
||||||
'roles/strongswan/templates/ipsec.conf.j2',
|
"roles/strongswan/templates/ipsec.conf.j2",
|
||||||
'roles/strongswan/templates/ipsec.secrets.j2',
|
"roles/strongswan/templates/ipsec.secrets.j2",
|
||||||
'roles/strongswan/templates/strongswan.conf.j2',
|
"roles/strongswan/templates/strongswan.conf.j2",
|
||||||
'roles/strongswan/templates/charon.conf.j2',
|
"roles/strongswan/templates/charon.conf.j2",
|
||||||
'roles/strongswan/templates/client_ipsec.conf.j2',
|
"roles/strongswan/templates/client_ipsec.conf.j2",
|
||||||
'roles/strongswan/templates/client_ipsec.secrets.j2',
|
"roles/strongswan/templates/client_ipsec.secrets.j2",
|
||||||
'roles/strongswan/templates/100-CustomLimitations.conf.j2',
|
"roles/strongswan/templates/100-CustomLimitations.conf.j2",
|
||||||
]
|
]
|
||||||
|
|
||||||
scenarios = ['default', 'ipv4_only', 'dns_hostname', 'openssl_legacy']
|
scenarios = ["default", "ipv4_only", "dns_hostname", "openssl_legacy"]
|
||||||
errors = []
|
errors = []
|
||||||
tested = 0
|
tested = 0
|
||||||
|
|
||||||
|
@ -127,21 +128,18 @@ def test_strongswan_templates():
|
||||||
test_vars = get_strongswan_test_variables(scenario)
|
test_vars = get_strongswan_test_variables(scenario)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
env = Environment(
|
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
|
||||||
loader=FileSystemLoader(template_dir),
|
|
||||||
undefined=StrictUndefined
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add mock filters
|
# Add mock filters
|
||||||
env.filters['to_uuid'] = mock_to_uuid
|
env.filters["to_uuid"] = mock_to_uuid
|
||||||
env.filters['bool'] = mock_bool
|
env.filters["bool"] = mock_bool
|
||||||
env.filters['b64encode'] = mock_b64encode
|
env.filters["b64encode"] = mock_b64encode
|
||||||
env.filters['b64decode'] = mock_b64decode
|
env.filters["b64decode"] = mock_b64decode
|
||||||
env.tests['version'] = mock_version
|
env.tests["version"] = mock_version
|
||||||
|
|
||||||
# For client templates, add item context
|
# For client templates, add item context
|
||||||
if 'client' in template_name:
|
if "client" in template_name:
|
||||||
test_vars['item'] = 'testuser'
|
test_vars["item"] = "testuser"
|
||||||
|
|
||||||
template = env.get_template(template_name)
|
template = env.get_template(template_name)
|
||||||
output = template.render(**test_vars)
|
output = template.render(**test_vars)
|
||||||
|
@ -150,16 +148,16 @@ def test_strongswan_templates():
|
||||||
assert len(output) > 0, f"Empty output from {template_path} ({scenario})"
|
assert len(output) > 0, f"Empty output from {template_path} ({scenario})"
|
||||||
|
|
||||||
# Specific validations based on template
|
# Specific validations based on template
|
||||||
if 'ipsec.conf' in template_name and 'client' not in template_name:
|
if "ipsec.conf" in template_name and "client" not in template_name:
|
||||||
assert 'conn' in output, "Missing connection definition"
|
assert "conn" in output, "Missing connection definition"
|
||||||
if scenario != 'ipv4_only' and test_vars.get('ipv6_support'):
|
if scenario != "ipv4_only" and test_vars.get("ipv6_support"):
|
||||||
assert '::/0' in output or 'fd9d:bc11' in output, "Missing IPv6 configuration"
|
assert "::/0" in output or "fd9d:bc11" in output, "Missing IPv6 configuration"
|
||||||
|
|
||||||
if 'ipsec.secrets' in template_name:
|
if "ipsec.secrets" in template_name:
|
||||||
assert 'PSK' in output or 'ECDSA' in output, "Missing authentication method"
|
assert "PSK" in output or "ECDSA" in output, "Missing authentication method"
|
||||||
|
|
||||||
if 'strongswan.conf' in template_name:
|
if "strongswan.conf" in template_name:
|
||||||
assert 'charon' in output, "Missing charon configuration"
|
assert "charon" in output, "Missing charon configuration"
|
||||||
|
|
||||||
print(f" ✅ {template_name} ({scenario})")
|
print(f" ✅ {template_name} ({scenario})")
|
||||||
|
|
||||||
|
@ -182,7 +180,7 @@ def test_openssl_template_constraints():
|
||||||
# This tests the actual openssl.yml task file to ensure our fix works
|
# This tests the actual openssl.yml task file to ensure our fix works
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
openssl_path = 'roles/strongswan/tasks/openssl.yml'
|
openssl_path = "roles/strongswan/tasks/openssl.yml"
|
||||||
if not os.path.exists(openssl_path):
|
if not os.path.exists(openssl_path):
|
||||||
print("⚠️ OpenSSL tasks file not found")
|
print("⚠️ OpenSSL tasks file not found")
|
||||||
return True
|
return True
|
||||||
|
@ -194,22 +192,23 @@ def test_openssl_template_constraints():
|
||||||
# Find the CA CSR task
|
# Find the CA CSR task
|
||||||
ca_csr_task = None
|
ca_csr_task = None
|
||||||
for task in content:
|
for task in content:
|
||||||
if isinstance(task, dict) and task.get('name', '').startswith('Create certificate signing request'):
|
if isinstance(task, dict) and task.get("name", "").startswith("Create certificate signing request"):
|
||||||
ca_csr_task = task
|
ca_csr_task = task
|
||||||
break
|
break
|
||||||
|
|
||||||
if ca_csr_task:
|
if ca_csr_task:
|
||||||
# Check that name_constraints_permitted is properly formatted
|
# Check that name_constraints_permitted is properly formatted
|
||||||
csr_module = ca_csr_task.get('community.crypto.openssl_csr_pipe', {})
|
csr_module = ca_csr_task.get("community.crypto.openssl_csr_pipe", {})
|
||||||
constraints = csr_module.get('name_constraints_permitted', '')
|
constraints = csr_module.get("name_constraints_permitted", "")
|
||||||
|
|
||||||
# The constraints should be a Jinja2 template without inline comments
|
# The constraints should be a Jinja2 template without inline comments
|
||||||
if '#' in str(constraints):
|
if "#" in str(constraints):
|
||||||
# Check if the # is within {{ }}
|
# Check if the # is within {{ }}
|
||||||
import re
|
import re
|
||||||
jinja_blocks = re.findall(r'\{\{.*?\}\}', str(constraints), re.DOTALL)
|
|
||||||
|
jinja_blocks = re.findall(r"\{\{.*?\}\}", str(constraints), re.DOTALL)
|
||||||
for block in jinja_blocks:
|
for block in jinja_blocks:
|
||||||
if '#' in block:
|
if "#" in block:
|
||||||
print("❌ Found inline comment in Jinja2 expression")
|
print("❌ Found inline comment in Jinja2 expression")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -223,7 +222,7 @@ def test_openssl_template_constraints():
|
||||||
|
|
||||||
def test_mobileconfig_template():
|
def test_mobileconfig_template():
|
||||||
"""Test the mobileconfig template with various scenarios."""
|
"""Test the mobileconfig template with various scenarios."""
|
||||||
template_path = 'roles/strongswan/templates/mobileconfig.j2'
|
template_path = "roles/strongswan/templates/mobileconfig.j2"
|
||||||
|
|
||||||
if not os.path.exists(template_path):
|
if not os.path.exists(template_path):
|
||||||
print("⚠️ Mobileconfig template not found")
|
print("⚠️ Mobileconfig template not found")
|
||||||
|
@ -237,20 +236,20 @@ def test_mobileconfig_template():
|
||||||
|
|
||||||
test_cases = [
|
test_cases = [
|
||||||
{
|
{
|
||||||
'name': 'iPhone with cellular on-demand',
|
"name": "iPhone with cellular on-demand",
|
||||||
'algo_ondemand_cellular': 'true',
|
"algo_ondemand_cellular": "true",
|
||||||
'algo_ondemand_wifi': 'false',
|
"algo_ondemand_wifi": "false",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'iPad with WiFi on-demand',
|
"name": "iPad with WiFi on-demand",
|
||||||
'algo_ondemand_cellular': 'false',
|
"algo_ondemand_cellular": "false",
|
||||||
'algo_ondemand_wifi': 'true',
|
"algo_ondemand_wifi": "true",
|
||||||
'algo_ondemand_wifi_exclude': 'MyHomeNetwork,OfficeWiFi',
|
"algo_ondemand_wifi_exclude": "MyHomeNetwork,OfficeWiFi",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'Mac without on-demand',
|
"name": "Mac without on-demand",
|
||||||
'algo_ondemand_cellular': 'false',
|
"algo_ondemand_cellular": "false",
|
||||||
'algo_ondemand_wifi': 'false',
|
"algo_ondemand_wifi": "false",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -258,43 +257,41 @@ def test_mobileconfig_template():
|
||||||
for test_case in test_cases:
|
for test_case in test_cases:
|
||||||
test_vars = get_strongswan_test_variables()
|
test_vars = get_strongswan_test_variables()
|
||||||
test_vars.update(test_case)
|
test_vars.update(test_case)
|
||||||
|
|
||||||
# Mock Ansible task result format for item
|
# Mock Ansible task result format for item
|
||||||
class MockTaskResult:
|
class MockTaskResult:
|
||||||
def __init__(self, content):
|
def __init__(self, content):
|
||||||
self.stdout = content
|
self.stdout = content
|
||||||
|
|
||||||
test_vars['item'] = ('testuser', MockTaskResult('TU9DS19QS0NTMTJfQ09OVEVOVA==')) # Tuple with mock result
|
test_vars["item"] = ("testuser", MockTaskResult("TU9DS19QS0NTMTJfQ09OVEVOVA==")) # Tuple with mock result
|
||||||
test_vars['PayloadContentCA_base64'] = 'TU9DS19DQV9DRVJUX0JBU0U2NA==' # Valid base64
|
test_vars["PayloadContentCA_base64"] = "TU9DS19DQV9DRVJUX0JBU0U2NA==" # Valid base64
|
||||||
test_vars['PayloadContentUser_base64'] = 'TU9DS19VU0VSX0NFUlRfQkFTRTY0' # Valid base64
|
test_vars["PayloadContentUser_base64"] = "TU9DS19VU0VSX0NFUlRfQkFTRTY0" # Valid base64
|
||||||
test_vars['pkcs12_PayloadCertificateUUID'] = str(uuid.uuid4())
|
test_vars["pkcs12_PayloadCertificateUUID"] = str(uuid.uuid4())
|
||||||
test_vars['PayloadContent'] = 'TU9DS19QS0NTMTJfQ09OVEVOVA==' # Valid base64 for PKCS12
|
test_vars["PayloadContent"] = "TU9DS19QS0NTMTJfQ09OVEVOVA==" # Valid base64 for PKCS12
|
||||||
test_vars['algo_server_name'] = 'test-algo-vpn'
|
test_vars["algo_server_name"] = "test-algo-vpn"
|
||||||
test_vars['VPN_PayloadIdentifier'] = str(uuid.uuid4())
|
test_vars["VPN_PayloadIdentifier"] = str(uuid.uuid4())
|
||||||
test_vars['CA_PayloadIdentifier'] = str(uuid.uuid4())
|
test_vars["CA_PayloadIdentifier"] = str(uuid.uuid4())
|
||||||
test_vars['PayloadContentCA'] = 'TU9DS19DQV9DRVJUX0NPTlRFTlQ=' # Valid base64
|
test_vars["PayloadContentCA"] = "TU9DS19DQV9DRVJUX0NPTlRFTlQ=" # Valid base64
|
||||||
|
|
||||||
try:
|
try:
|
||||||
env = Environment(
|
env = Environment(loader=FileSystemLoader("roles/strongswan/templates"), undefined=StrictUndefined)
|
||||||
loader=FileSystemLoader('roles/strongswan/templates'),
|
|
||||||
undefined=StrictUndefined
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add mock filters
|
# Add mock filters
|
||||||
env.filters['to_uuid'] = mock_to_uuid
|
env.filters["to_uuid"] = mock_to_uuid
|
||||||
env.filters['b64encode'] = mock_b64encode
|
env.filters["b64encode"] = mock_b64encode
|
||||||
env.filters['b64decode'] = mock_b64decode
|
env.filters["b64decode"] = mock_b64decode
|
||||||
|
|
||||||
template = env.get_template('mobileconfig.j2')
|
template = env.get_template("mobileconfig.j2")
|
||||||
output = template.render(**test_vars)
|
output = template.render(**test_vars)
|
||||||
|
|
||||||
# Validate output
|
# Validate output
|
||||||
assert '<?xml' in output, "Missing XML declaration"
|
assert "<?xml" in output, "Missing XML declaration"
|
||||||
assert '<plist' in output, "Missing plist element"
|
assert "<plist" in output, "Missing plist element"
|
||||||
assert 'PayloadType' in output, "Missing PayloadType"
|
assert "PayloadType" in output, "Missing PayloadType"
|
||||||
|
|
||||||
# Check on-demand configuration
|
# Check on-demand configuration
|
||||||
if test_case.get('algo_ondemand_cellular') == 'true' or test_case.get('algo_ondemand_wifi') == 'true':
|
if test_case.get("algo_ondemand_cellular") == "true" or test_case.get("algo_ondemand_wifi") == "true":
|
||||||
assert 'OnDemandEnabled' in output, f"Missing OnDemand config for {test_case['name']}"
|
assert "OnDemandEnabled" in output, f"Missing OnDemand config for {test_case['name']}"
|
||||||
|
|
||||||
print(f" ✅ Mobileconfig: {test_case['name']}")
|
print(f" ✅ Mobileconfig: {test_case['name']}")
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Test that Ansible templates render correctly
|
Test that Ansible templates render correctly
|
||||||
This catches undefined variables, syntax errors, and logic bugs
|
This catches undefined variables, syntax errors, and logic bugs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -22,20 +23,20 @@ def mock_to_uuid(value):
|
||||||
|
|
||||||
def mock_bool(value):
|
def mock_bool(value):
|
||||||
"""Mock the bool filter"""
|
"""Mock the bool filter"""
|
||||||
return str(value).lower() in ('true', '1', 'yes', 'on')
|
return str(value).lower() in ("true", "1", "yes", "on")
|
||||||
|
|
||||||
|
|
||||||
def mock_lookup(type, path):
|
def mock_lookup(type, path):
|
||||||
"""Mock the lookup function"""
|
"""Mock the lookup function"""
|
||||||
# Return fake data for file lookups
|
# Return fake data for file lookups
|
||||||
if type == 'file':
|
if type == "file":
|
||||||
if 'private' in path:
|
if "private" in path:
|
||||||
return 'MOCK_PRIVATE_KEY_BASE64=='
|
return "MOCK_PRIVATE_KEY_BASE64=="
|
||||||
elif 'public' in path:
|
elif "public" in path:
|
||||||
return 'MOCK_PUBLIC_KEY_BASE64=='
|
return "MOCK_PUBLIC_KEY_BASE64=="
|
||||||
elif 'preshared' in path:
|
elif "preshared" in path:
|
||||||
return 'MOCK_PRESHARED_KEY_BASE64=='
|
return "MOCK_PRESHARED_KEY_BASE64=="
|
||||||
return 'MOCK_LOOKUP_DATA'
|
return "MOCK_LOOKUP_DATA"
|
||||||
|
|
||||||
|
|
||||||
def get_test_variables():
|
def get_test_variables():
|
||||||
|
@ -47,8 +48,8 @@ def get_test_variables():
|
||||||
def find_templates():
|
def find_templates():
|
||||||
"""Find all Jinja2 template files in the repo"""
|
"""Find all Jinja2 template files in the repo"""
|
||||||
templates = []
|
templates = []
|
||||||
for pattern in ['**/*.j2', '**/*.jinja2', '**/*.yml.j2']:
|
for pattern in ["**/*.j2", "**/*.jinja2", "**/*.yml.j2"]:
|
||||||
templates.extend(Path('.').glob(pattern))
|
templates.extend(Path(".").glob(pattern))
|
||||||
return templates
|
return templates
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,10 +58,10 @@ def test_template_syntax():
|
||||||
templates = find_templates()
|
templates = find_templates()
|
||||||
|
|
||||||
# Skip some paths that aren't real templates
|
# Skip some paths that aren't real templates
|
||||||
skip_paths = ['.git/', 'venv/', '.venv/', '.env/', 'configs/']
|
skip_paths = [".git/", "venv/", ".venv/", ".env/", "configs/"]
|
||||||
|
|
||||||
# Skip templates that use Ansible-specific filters
|
# Skip templates that use Ansible-specific filters
|
||||||
skip_templates = ['vpn-dict.j2', 'mobileconfig.j2', 'dnscrypt-proxy.toml.j2']
|
skip_templates = ["vpn-dict.j2", "mobileconfig.j2", "dnscrypt-proxy.toml.j2"]
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
skipped = 0
|
skipped = 0
|
||||||
|
@ -76,10 +77,7 @@ def test_template_syntax():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
template_dir = template_path.parent
|
template_dir = template_path.parent
|
||||||
env = Environment(
|
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
|
||||||
loader=FileSystemLoader(template_dir),
|
|
||||||
undefined=StrictUndefined
|
|
||||||
)
|
|
||||||
|
|
||||||
# Just try to load the template - this checks syntax
|
# Just try to load the template - this checks syntax
|
||||||
env.get_template(template_path.name)
|
env.get_template(template_path.name)
|
||||||
|
@ -103,13 +101,13 @@ def test_template_syntax():
|
||||||
def test_critical_templates():
|
def test_critical_templates():
|
||||||
"""Test that critical templates render with test data"""
|
"""Test that critical templates render with test data"""
|
||||||
critical_templates = [
|
critical_templates = [
|
||||||
'roles/wireguard/templates/client.conf.j2',
|
"roles/wireguard/templates/client.conf.j2",
|
||||||
'roles/strongswan/templates/ipsec.conf.j2',
|
"roles/strongswan/templates/ipsec.conf.j2",
|
||||||
'roles/strongswan/templates/ipsec.secrets.j2',
|
"roles/strongswan/templates/ipsec.secrets.j2",
|
||||||
'roles/dns/templates/adblock.sh.j2',
|
"roles/dns/templates/adblock.sh.j2",
|
||||||
'roles/dns/templates/dnsmasq.conf.j2',
|
"roles/dns/templates/dnsmasq.conf.j2",
|
||||||
'roles/common/templates/rules.v4.j2',
|
"roles/common/templates/rules.v4.j2",
|
||||||
'roles/common/templates/rules.v6.j2',
|
"roles/common/templates/rules.v6.j2",
|
||||||
]
|
]
|
||||||
|
|
||||||
test_vars = get_test_variables()
|
test_vars = get_test_variables()
|
||||||
|
@ -123,21 +121,18 @@ def test_critical_templates():
|
||||||
template_dir = os.path.dirname(template_path)
|
template_dir = os.path.dirname(template_path)
|
||||||
template_name = os.path.basename(template_path)
|
template_name = os.path.basename(template_path)
|
||||||
|
|
||||||
env = Environment(
|
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
|
||||||
loader=FileSystemLoader(template_dir),
|
|
||||||
undefined=StrictUndefined
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add mock functions
|
# Add mock functions
|
||||||
env.globals['lookup'] = mock_lookup
|
env.globals["lookup"] = mock_lookup
|
||||||
env.filters['to_uuid'] = mock_to_uuid
|
env.filters["to_uuid"] = mock_to_uuid
|
||||||
env.filters['bool'] = mock_bool
|
env.filters["bool"] = mock_bool
|
||||||
|
|
||||||
template = env.get_template(template_name)
|
template = env.get_template(template_name)
|
||||||
|
|
||||||
# Add item context for templates that use loops
|
# Add item context for templates that use loops
|
||||||
if 'client' in template_name:
|
if "client" in template_name:
|
||||||
test_vars['item'] = ('test-user', 'test-user')
|
test_vars["item"] = ("test-user", "test-user")
|
||||||
|
|
||||||
# Try to render
|
# Try to render
|
||||||
output = template.render(**test_vars)
|
output = template.render(**test_vars)
|
||||||
|
@ -163,17 +158,17 @@ def test_variable_consistency():
|
||||||
"""Check that commonly used variables are defined consistently"""
|
"""Check that commonly used variables are defined consistently"""
|
||||||
# Variables that should be used consistently across templates
|
# Variables that should be used consistently across templates
|
||||||
common_vars = [
|
common_vars = [
|
||||||
'server_name',
|
"server_name",
|
||||||
'IP_subject_alt_name',
|
"IP_subject_alt_name",
|
||||||
'wireguard_port',
|
"wireguard_port",
|
||||||
'wireguard_network',
|
"wireguard_network",
|
||||||
'dns_servers',
|
"dns_servers",
|
||||||
'users',
|
"users",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check if main.yml defines these
|
# Check if main.yml defines these
|
||||||
if os.path.exists('main.yml'):
|
if os.path.exists("main.yml"):
|
||||||
with open('main.yml') as f:
|
with open("main.yml") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
missing = []
|
missing = []
|
||||||
|
@ -192,28 +187,19 @@ def test_wireguard_ipv6_endpoints():
|
||||||
"""Test that WireGuard client configs properly format IPv6 endpoints"""
|
"""Test that WireGuard client configs properly format IPv6 endpoints"""
|
||||||
test_cases = [
|
test_cases = [
|
||||||
# IPv4 address - should not be bracketed
|
# IPv4 address - should not be bracketed
|
||||||
{
|
{"IP_subject_alt_name": "192.168.1.100", "expected_endpoint": "Endpoint = 192.168.1.100:51820"},
|
||||||
'IP_subject_alt_name': '192.168.1.100',
|
|
||||||
'expected_endpoint': 'Endpoint = 192.168.1.100:51820'
|
|
||||||
},
|
|
||||||
# IPv6 address - should be bracketed
|
# IPv6 address - should be bracketed
|
||||||
{
|
{
|
||||||
'IP_subject_alt_name': '2600:3c01::f03c:91ff:fedf:3b2a',
|
"IP_subject_alt_name": "2600:3c01::f03c:91ff:fedf:3b2a",
|
||||||
'expected_endpoint': 'Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820'
|
"expected_endpoint": "Endpoint = [2600:3c01::f03c:91ff:fedf:3b2a]:51820",
|
||||||
},
|
},
|
||||||
# Hostname - should not be bracketed
|
# Hostname - should not be bracketed
|
||||||
{
|
{"IP_subject_alt_name": "vpn.example.com", "expected_endpoint": "Endpoint = vpn.example.com:51820"},
|
||||||
'IP_subject_alt_name': 'vpn.example.com',
|
|
||||||
'expected_endpoint': 'Endpoint = vpn.example.com:51820'
|
|
||||||
},
|
|
||||||
# IPv6 with zone ID - should be bracketed
|
# IPv6 with zone ID - should be bracketed
|
||||||
{
|
{"IP_subject_alt_name": "fe80::1%eth0", "expected_endpoint": "Endpoint = [fe80::1%eth0]:51820"},
|
||||||
'IP_subject_alt_name': 'fe80::1%eth0',
|
|
||||||
'expected_endpoint': 'Endpoint = [fe80::1%eth0]:51820'
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
template_path = 'roles/wireguard/templates/client.conf.j2'
|
template_path = "roles/wireguard/templates/client.conf.j2"
|
||||||
if not os.path.exists(template_path):
|
if not os.path.exists(template_path):
|
||||||
print(f"⚠ Skipping IPv6 endpoint test - {template_path} not found")
|
print(f"⚠ Skipping IPv6 endpoint test - {template_path} not found")
|
||||||
return
|
return
|
||||||
|
@ -225,24 +211,23 @@ def test_wireguard_ipv6_endpoints():
|
||||||
try:
|
try:
|
||||||
# Set up test variables
|
# Set up test variables
|
||||||
test_vars = {**base_vars, **test_case}
|
test_vars = {**base_vars, **test_case}
|
||||||
test_vars['item'] = ('test-user', 'test-user')
|
test_vars["item"] = ("test-user", "test-user")
|
||||||
|
|
||||||
# Render template
|
# Render template
|
||||||
env = Environment(
|
env = Environment(loader=FileSystemLoader("roles/wireguard/templates"), undefined=StrictUndefined)
|
||||||
loader=FileSystemLoader('roles/wireguard/templates'),
|
env.globals["lookup"] = mock_lookup
|
||||||
undefined=StrictUndefined
|
|
||||||
)
|
|
||||||
env.globals['lookup'] = mock_lookup
|
|
||||||
|
|
||||||
template = env.get_template('client.conf.j2')
|
template = env.get_template("client.conf.j2")
|
||||||
output = template.render(**test_vars)
|
output = template.render(**test_vars)
|
||||||
|
|
||||||
# Check if the expected endpoint format is in the output
|
# Check if the expected endpoint format is in the output
|
||||||
if test_case['expected_endpoint'] not in output:
|
if test_case["expected_endpoint"] not in output:
|
||||||
errors.append(f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output")
|
errors.append(
|
||||||
|
f"Expected '{test_case['expected_endpoint']}' for IP '{test_case['IP_subject_alt_name']}' but not found in output"
|
||||||
|
)
|
||||||
# Print relevant part of output for debugging
|
# Print relevant part of output for debugging
|
||||||
for line in output.split('\n'):
|
for line in output.split("\n"):
|
||||||
if 'Endpoint' in line:
|
if "Endpoint" in line:
|
||||||
errors.append(f" Found: {line.strip()}")
|
errors.append(f" Found: {line.strip()}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -262,27 +247,27 @@ def test_template_conditionals():
|
||||||
test_cases = [
|
test_cases = [
|
||||||
# WireGuard enabled, IPsec disabled
|
# WireGuard enabled, IPsec disabled
|
||||||
{
|
{
|
||||||
'wireguard_enabled': True,
|
"wireguard_enabled": True,
|
||||||
'ipsec_enabled': False,
|
"ipsec_enabled": False,
|
||||||
'dns_encryption': True,
|
"dns_encryption": True,
|
||||||
'dns_adblocking': True,
|
"dns_adblocking": True,
|
||||||
'algo_ssh_tunneling': False,
|
"algo_ssh_tunneling": False,
|
||||||
},
|
},
|
||||||
# IPsec enabled, WireGuard disabled
|
# IPsec enabled, WireGuard disabled
|
||||||
{
|
{
|
||||||
'wireguard_enabled': False,
|
"wireguard_enabled": False,
|
||||||
'ipsec_enabled': True,
|
"ipsec_enabled": True,
|
||||||
'dns_encryption': False,
|
"dns_encryption": False,
|
||||||
'dns_adblocking': False,
|
"dns_adblocking": False,
|
||||||
'algo_ssh_tunneling': True,
|
"algo_ssh_tunneling": True,
|
||||||
},
|
},
|
||||||
# Both enabled
|
# Both enabled
|
||||||
{
|
{
|
||||||
'wireguard_enabled': True,
|
"wireguard_enabled": True,
|
||||||
'ipsec_enabled': True,
|
"ipsec_enabled": True,
|
||||||
'dns_encryption': True,
|
"dns_encryption": True,
|
||||||
'dns_adblocking': True,
|
"dns_adblocking": True,
|
||||||
'algo_ssh_tunneling': True,
|
"algo_ssh_tunneling": True,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -294,7 +279,7 @@ def test_template_conditionals():
|
||||||
|
|
||||||
# Test a few templates that have conditionals
|
# Test a few templates that have conditionals
|
||||||
conditional_templates = [
|
conditional_templates = [
|
||||||
'roles/common/templates/rules.v4.j2',
|
"roles/common/templates/rules.v4.j2",
|
||||||
]
|
]
|
||||||
|
|
||||||
for template_path in conditional_templates:
|
for template_path in conditional_templates:
|
||||||
|
@ -305,23 +290,19 @@ def test_template_conditionals():
|
||||||
template_dir = os.path.dirname(template_path)
|
template_dir = os.path.dirname(template_path)
|
||||||
template_name = os.path.basename(template_path)
|
template_name = os.path.basename(template_path)
|
||||||
|
|
||||||
env = Environment(
|
env = Environment(loader=FileSystemLoader(template_dir), undefined=StrictUndefined)
|
||||||
loader=FileSystemLoader(template_dir),
|
|
||||||
undefined=StrictUndefined
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add mock functions
|
# Add mock functions
|
||||||
env.globals['lookup'] = mock_lookup
|
env.globals["lookup"] = mock_lookup
|
||||||
env.filters['to_uuid'] = mock_to_uuid
|
env.filters["to_uuid"] = mock_to_uuid
|
||||||
env.filters['bool'] = mock_bool
|
env.filters["bool"] = mock_bool
|
||||||
|
|
||||||
template = env.get_template(template_name)
|
template = env.get_template(template_name)
|
||||||
output = template.render(**test_vars)
|
output = template.render(**test_vars)
|
||||||
|
|
||||||
# Verify conditionals work
|
# Verify conditionals work
|
||||||
if test_case.get('wireguard_enabled'):
|
if test_case.get("wireguard_enabled"):
|
||||||
assert str(test_vars['wireguard_port']) in output, \
|
assert str(test_vars["wireguard_port"]) in output, f"WireGuard port missing when enabled (case {i})"
|
||||||
f"WireGuard port missing when enabled (case {i})"
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Conditional test failed for {template_path} case {i}: {e}")
|
print(f"✗ Conditional test failed for {template_path} case {i}: {e}")
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Test user management functionality without deployment
|
Test user management functionality without deployment
|
||||||
Based on issues #14745, #14746, #14738, #14726
|
Based on issues #14745, #14746, #14738, #14726
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
@ -23,15 +24,15 @@ users:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = yaml.safe_load(test_config)
|
config = yaml.safe_load(test_config)
|
||||||
users = config.get('users', [])
|
users = config.get("users", [])
|
||||||
|
|
||||||
assert len(users) == 5, f"Expected 5 users, got {len(users)}"
|
assert len(users) == 5, f"Expected 5 users, got {len(users)}"
|
||||||
assert 'alice' in users, "Missing user 'alice'"
|
assert "alice" in users, "Missing user 'alice'"
|
||||||
assert 'user-with-dash' in users, "Dash in username not handled"
|
assert "user-with-dash" in users, "Dash in username not handled"
|
||||||
assert 'user_with_underscore' in users, "Underscore in username not handled"
|
assert "user_with_underscore" in users, "Underscore in username not handled"
|
||||||
|
|
||||||
# Test that usernames are valid
|
# Test that usernames are valid
|
||||||
username_pattern = re.compile(r'^[a-zA-Z0-9_-]+$')
|
username_pattern = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||||
for user in users:
|
for user in users:
|
||||||
assert username_pattern.match(user), f"Invalid username format: {user}"
|
assert username_pattern.match(user), f"Invalid username format: {user}"
|
||||||
|
|
||||||
|
@ -42,35 +43,27 @@ def test_server_selection_format():
|
||||||
"""Test server selection string parsing (issue #14727)"""
|
"""Test server selection string parsing (issue #14727)"""
|
||||||
# Test various server display formats
|
# Test various server display formats
|
||||||
test_cases = [
|
test_cases = [
|
||||||
|
{"display": "1. 192.168.1.100 (algo-server)", "expected_ip": "192.168.1.100", "expected_name": "algo-server"},
|
||||||
|
{"display": "2. 10.0.0.1 (production-vpn)", "expected_ip": "10.0.0.1", "expected_name": "production-vpn"},
|
||||||
{
|
{
|
||||||
'display': '1. 192.168.1.100 (algo-server)',
|
"display": "3. vpn.example.com (example-server)",
|
||||||
'expected_ip': '192.168.1.100',
|
"expected_ip": "vpn.example.com",
|
||||||
'expected_name': 'algo-server'
|
"expected_name": "example-server",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
'display': '2. 10.0.0.1 (production-vpn)',
|
|
||||||
'expected_ip': '10.0.0.1',
|
|
||||||
'expected_name': 'production-vpn'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'display': '3. vpn.example.com (example-server)',
|
|
||||||
'expected_ip': 'vpn.example.com',
|
|
||||||
'expected_name': 'example-server'
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Pattern to extract IP and name from display string
|
# Pattern to extract IP and name from display string
|
||||||
pattern = re.compile(r'^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$')
|
pattern = re.compile(r"^\d+\.\s+([^\s]+)\s+\(([^)]+)\)$")
|
||||||
|
|
||||||
for case in test_cases:
|
for case in test_cases:
|
||||||
match = pattern.match(case['display'])
|
match = pattern.match(case["display"])
|
||||||
assert match, f"Failed to parse: {case['display']}"
|
assert match, f"Failed to parse: {case['display']}"
|
||||||
|
|
||||||
ip_or_host = match.group(1)
|
ip_or_host = match.group(1)
|
||||||
name = match.group(2)
|
name = match.group(2)
|
||||||
|
|
||||||
assert ip_or_host == case['expected_ip'], f"Wrong IP extracted: {ip_or_host}"
|
assert ip_or_host == case["expected_ip"], f"Wrong IP extracted: {ip_or_host}"
|
||||||
assert name == case['expected_name'], f"Wrong name extracted: {name}"
|
assert name == case["expected_name"], f"Wrong name extracted: {name}"
|
||||||
|
|
||||||
print("✓ Server selection format test passed")
|
print("✓ Server selection format test passed")
|
||||||
|
|
||||||
|
@ -78,12 +71,12 @@ def test_server_selection_format():
|
||||||
def test_ssh_key_preservation():
|
def test_ssh_key_preservation():
|
||||||
"""Test that SSH keys aren't regenerated unnecessarily"""
|
"""Test that SSH keys aren't regenerated unnecessarily"""
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
ssh_key_path = os.path.join(tmpdir, 'test_key')
|
ssh_key_path = os.path.join(tmpdir, "test_key")
|
||||||
|
|
||||||
# Simulate existing SSH key
|
# Simulate existing SSH key
|
||||||
with open(ssh_key_path, 'w') as f:
|
with open(ssh_key_path, "w") as f:
|
||||||
f.write("EXISTING_SSH_KEY_CONTENT")
|
f.write("EXISTING_SSH_KEY_CONTENT")
|
||||||
with open(f"{ssh_key_path}.pub", 'w') as f:
|
with open(f"{ssh_key_path}.pub", "w") as f:
|
||||||
f.write("ssh-rsa EXISTING_PUBLIC_KEY")
|
f.write("ssh-rsa EXISTING_PUBLIC_KEY")
|
||||||
|
|
||||||
# Record original content
|
# Record original content
|
||||||
|
@ -105,11 +98,7 @@ def test_ssh_key_preservation():
|
||||||
def test_ca_password_handling():
|
def test_ca_password_handling():
|
||||||
"""Test CA password validation and handling"""
|
"""Test CA password validation and handling"""
|
||||||
# Test password requirements
|
# Test password requirements
|
||||||
valid_passwords = [
|
valid_passwords = ["SecurePassword123!", "Algo-VPN-2024", "Complex#Pass@Word999"]
|
||||||
"SecurePassword123!",
|
|
||||||
"Algo-VPN-2024",
|
|
||||||
"Complex#Pass@Word999"
|
|
||||||
]
|
|
||||||
|
|
||||||
invalid_passwords = [
|
invalid_passwords = [
|
||||||
"", # Empty
|
"", # Empty
|
||||||
|
@ -120,13 +109,13 @@ def test_ca_password_handling():
|
||||||
# Basic password validation
|
# Basic password validation
|
||||||
for pwd in valid_passwords:
|
for pwd in valid_passwords:
|
||||||
assert len(pwd) >= 12, f"Password too short: {pwd}"
|
assert len(pwd) >= 12, f"Password too short: {pwd}"
|
||||||
assert ' ' not in pwd, f"Password contains spaces: {pwd}"
|
assert " " not in pwd, f"Password contains spaces: {pwd}"
|
||||||
|
|
||||||
for pwd in invalid_passwords:
|
for pwd in invalid_passwords:
|
||||||
issues = []
|
issues = []
|
||||||
if len(pwd) < 12:
|
if len(pwd) < 12:
|
||||||
issues.append("too short")
|
issues.append("too short")
|
||||||
if ' ' in pwd:
|
if " " in pwd:
|
||||||
issues.append("contains spaces")
|
issues.append("contains spaces")
|
||||||
if not pwd:
|
if not pwd:
|
||||||
issues.append("empty")
|
issues.append("empty")
|
||||||
|
@ -137,8 +126,8 @@ def test_ca_password_handling():
|
||||||
|
|
||||||
def test_user_config_generation():
|
def test_user_config_generation():
|
||||||
"""Test that user configs would be generated correctly"""
|
"""Test that user configs would be generated correctly"""
|
||||||
users = ['alice', 'bob', 'charlie']
|
users = ["alice", "bob", "charlie"]
|
||||||
server_name = 'test-server'
|
server_name = "test-server"
|
||||||
|
|
||||||
# Simulate config file structure
|
# Simulate config file structure
|
||||||
for user in users:
|
for user in users:
|
||||||
|
@ -168,7 +157,7 @@ users:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = yaml.safe_load(test_config)
|
config = yaml.safe_load(test_config)
|
||||||
users = config.get('users', [])
|
users = config.get("users", [])
|
||||||
|
|
||||||
# Check for duplicates
|
# Check for duplicates
|
||||||
unique_users = list(set(users))
|
unique_users = list(set(users))
|
||||||
|
@ -182,7 +171,7 @@ users:
|
||||||
duplicates.append(user)
|
duplicates.append(user)
|
||||||
seen.add(user)
|
seen.add(user)
|
||||||
|
|
||||||
assert 'alice' in duplicates, "Duplicate 'alice' not detected"
|
assert "alice" in duplicates, "Duplicate 'alice' not detected"
|
||||||
|
|
||||||
print("✓ Duplicate user handling test passed")
|
print("✓ Duplicate user handling test passed")
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Test WireGuard key generation - focused on x25519_pubkey module integration
|
Test WireGuard key generation - focused on x25519_pubkey module integration
|
||||||
Addresses test gap identified in tests/README.md line 63-67: WireGuard private/public key generation
|
Addresses test gap identified in tests/README.md line 63-67: WireGuard private/public key generation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -10,13 +11,13 @@ import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
# Add library directory to path to import our custom module
|
# Add library directory to path to import our custom module
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'library'))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "library"))
|
||||||
|
|
||||||
|
|
||||||
def test_wireguard_tools_available():
|
def test_wireguard_tools_available():
|
||||||
"""Test that WireGuard tools are available for validation"""
|
"""Test that WireGuard tools are available for validation"""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(['wg', '--version'], capture_output=True, text=True)
|
result = subprocess.run(["wg", "--version"], capture_output=True, text=True)
|
||||||
assert result.returncode == 0, "WireGuard tools not available"
|
assert result.returncode == 0, "WireGuard tools not available"
|
||||||
print(f"✓ WireGuard tools available: {result.stdout.strip()}")
|
print(f"✓ WireGuard tools available: {result.stdout.strip()}")
|
||||||
return True
|
return True
|
||||||
|
@ -29,6 +30,7 @@ def test_x25519_module_import():
|
||||||
"""Test that our custom x25519_pubkey module can be imported and used"""
|
"""Test that our custom x25519_pubkey module can be imported and used"""
|
||||||
try:
|
try:
|
||||||
import x25519_pubkey # noqa: F401
|
import x25519_pubkey # noqa: F401
|
||||||
|
|
||||||
print("✓ x25519_pubkey module imports successfully")
|
print("✓ x25519_pubkey module imports successfully")
|
||||||
return True
|
return True
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -37,16 +39,17 @@ def test_x25519_module_import():
|
||||||
|
|
||||||
def generate_test_private_key():
|
def generate_test_private_key():
|
||||||
"""Generate a test private key using the same method as Algo"""
|
"""Generate a test private key using the same method as Algo"""
|
||||||
with tempfile.NamedTemporaryFile(suffix='.raw', delete=False) as temp_file:
|
with tempfile.NamedTemporaryFile(suffix=".raw", delete=False) as temp_file:
|
||||||
raw_key_path = temp_file.name
|
raw_key_path = temp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate 32 random bytes for X25519 private key (same as community.crypto does)
|
# Generate 32 random bytes for X25519 private key (same as community.crypto does)
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
raw_data = secrets.token_bytes(32)
|
raw_data = secrets.token_bytes(32)
|
||||||
|
|
||||||
# Write raw key to file (like community.crypto openssl_privatekey with format: raw)
|
# Write raw key to file (like community.crypto openssl_privatekey with format: raw)
|
||||||
with open(raw_key_path, 'wb') as f:
|
with open(raw_key_path, "wb") as f:
|
||||||
f.write(raw_data)
|
f.write(raw_data)
|
||||||
|
|
||||||
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
|
assert len(raw_data) == 32, f"Private key should be 32 bytes, got {len(raw_data)}"
|
||||||
|
@ -83,7 +86,7 @@ def test_x25519_pubkey_from_raw_file():
|
||||||
def exit_json(self, **kwargs):
|
def exit_json(self, **kwargs):
|
||||||
self.result = kwargs
|
self.result = kwargs
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix='.pub', delete=False) as temp_pub:
|
with tempfile.NamedTemporaryFile(suffix=".pub", delete=False) as temp_pub:
|
||||||
public_key_path = temp_pub.name
|
public_key_path = temp_pub.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -95,11 +98,9 @@ def test_x25519_pubkey_from_raw_file():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Mock the module call
|
# Mock the module call
|
||||||
mock_module = MockModule({
|
mock_module = MockModule(
|
||||||
'private_key_path': raw_key_path,
|
{"private_key_path": raw_key_path, "public_key_path": public_key_path, "private_key_b64": None}
|
||||||
'public_key_path': public_key_path,
|
)
|
||||||
'private_key_b64': None
|
|
||||||
})
|
|
||||||
|
|
||||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||||
|
|
||||||
|
@ -107,8 +108,8 @@ def test_x25519_pubkey_from_raw_file():
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
# Check the result
|
# Check the result
|
||||||
assert 'public_key' in mock_module.result
|
assert "public_key" in mock_module.result
|
||||||
assert mock_module.result['changed']
|
assert mock_module.result["changed"]
|
||||||
assert os.path.exists(public_key_path)
|
assert os.path.exists(public_key_path)
|
||||||
|
|
||||||
with open(public_key_path) as f:
|
with open(public_key_path) as f:
|
||||||
|
@ -160,11 +161,7 @@ def test_x25519_pubkey_from_b64_string():
|
||||||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mock_module = MockModule({
|
mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
|
||||||
'private_key_b64': b64_key,
|
|
||||||
'private_key_path': None,
|
|
||||||
'public_key_path': None
|
|
||||||
})
|
|
||||||
|
|
||||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||||
|
|
||||||
|
@ -172,8 +169,8 @@ def test_x25519_pubkey_from_b64_string():
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
# Check the result
|
# Check the result
|
||||||
assert 'public_key' in mock_module.result
|
assert "public_key" in mock_module.result
|
||||||
derived_pubkey = mock_module.result['public_key']
|
derived_pubkey = mock_module.result["public_key"]
|
||||||
|
|
||||||
# Validate base64 format
|
# Validate base64 format
|
||||||
try:
|
try:
|
||||||
|
@ -222,21 +219,17 @@ def test_wireguard_validation():
|
||||||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mock_module = MockModule({
|
mock_module = MockModule({"private_key_b64": b64_key, "private_key_path": None, "public_key_path": None})
|
||||||
'private_key_b64': b64_key,
|
|
||||||
'private_key_path': None,
|
|
||||||
'public_key_path': None
|
|
||||||
})
|
|
||||||
|
|
||||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
derived_pubkey = mock_module.result['public_key']
|
derived_pubkey = mock_module.result["public_key"]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as temp_config:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False) as temp_config:
|
||||||
# Create a WireGuard config using our keys
|
# Create a WireGuard config using our keys
|
||||||
wg_config = f"""[Interface]
|
wg_config = f"""[Interface]
|
||||||
PrivateKey = {b64_key}
|
PrivateKey = {b64_key}
|
||||||
|
@ -251,16 +244,12 @@ AllowedIPs = 10.19.49.2/32
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test that WireGuard can parse our config
|
# Test that WireGuard can parse our config
|
||||||
result = subprocess.run([
|
result = subprocess.run(["wg-quick", "strip", config_path], capture_output=True, text=True)
|
||||||
'wg-quick', 'strip', config_path
|
|
||||||
], capture_output=True, text=True)
|
|
||||||
|
|
||||||
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
|
assert result.returncode == 0, f"WireGuard rejected our config: {result.stderr}"
|
||||||
|
|
||||||
# Test key derivation with wg pubkey command
|
# Test key derivation with wg pubkey command
|
||||||
wg_result = subprocess.run([
|
wg_result = subprocess.run(["wg", "pubkey"], input=b64_key, capture_output=True, text=True)
|
||||||
'wg', 'pubkey'
|
|
||||||
], input=b64_key, capture_output=True, text=True)
|
|
||||||
|
|
||||||
if wg_result.returncode == 0:
|
if wg_result.returncode == 0:
|
||||||
wg_derived = wg_result.stdout.strip()
|
wg_derived = wg_result.stdout.strip()
|
||||||
|
@ -286,8 +275,8 @@ def test_key_consistency():
|
||||||
raw_key_path, b64_key = generate_test_private_key()
|
raw_key_path, b64_key = generate_test_private_key()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def derive_pubkey_from_same_key():
|
|
||||||
|
|
||||||
|
def derive_pubkey_from_same_key():
|
||||||
class MockModule:
|
class MockModule:
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
self.params = params
|
self.params = params
|
||||||
|
@ -305,16 +294,18 @@ def test_key_consistency():
|
||||||
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
original_AnsibleModule = x25519_pubkey.AnsibleModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mock_module = MockModule({
|
mock_module = MockModule(
|
||||||
'private_key_b64': b64_key, # SAME key each time
|
{
|
||||||
'private_key_path': None,
|
"private_key_b64": b64_key, # SAME key each time
|
||||||
'public_key_path': None
|
"private_key_path": None,
|
||||||
})
|
"public_key_path": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
x25519_pubkey.AnsibleModule = lambda **kwargs: mock_module
|
||||||
run_module()
|
run_module()
|
||||||
|
|
||||||
return mock_module.result['public_key']
|
return mock_module.result["public_key"]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
x25519_pubkey.AnsibleModule = original_AnsibleModule
|
||||||
|
|
|
@ -14,13 +14,13 @@ from pathlib import Path
|
||||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateSyntaxError, meta
|
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateSyntaxError, meta
|
||||||
|
|
||||||
|
|
||||||
def find_jinja2_templates(root_dir: str = '.') -> list[Path]:
|
def find_jinja2_templates(root_dir: str = ".") -> list[Path]:
|
||||||
"""Find all Jinja2 template files in the project."""
|
"""Find all Jinja2 template files in the project."""
|
||||||
templates = []
|
templates = []
|
||||||
patterns = ['**/*.j2', '**/*.jinja2', '**/*.yml.j2', '**/*.conf.j2']
|
patterns = ["**/*.j2", "**/*.jinja2", "**/*.yml.j2", "**/*.conf.j2"]
|
||||||
|
|
||||||
# Skip these directories
|
# Skip these directories
|
||||||
skip_dirs = {'.git', '.venv', 'venv', '.env', 'configs', '__pycache__', '.cache'}
|
skip_dirs = {".git", ".venv", "venv", ".env", "configs", "__pycache__", ".cache"}
|
||||||
|
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
for path in Path(root_dir).glob(pattern):
|
for path in Path(root_dir).glob(pattern):
|
||||||
|
@ -39,25 +39,25 @@ def check_inline_comments_in_expressions(template_content: str, template_path: P
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# Pattern to find Jinja2 expressions
|
# Pattern to find Jinja2 expressions
|
||||||
jinja_pattern = re.compile(r'\{\{.*?\}\}|\{%.*?%\}', re.DOTALL)
|
jinja_pattern = re.compile(r"\{\{.*?\}\}|\{%.*?%\}", re.DOTALL)
|
||||||
|
|
||||||
for match in jinja_pattern.finditer(template_content):
|
for match in jinja_pattern.finditer(template_content):
|
||||||
expression = match.group()
|
expression = match.group()
|
||||||
lines = expression.split('\n')
|
lines = expression.split("\n")
|
||||||
|
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
# Check for # that's not in a string
|
# Check for # that's not in a string
|
||||||
# Simple heuristic: if # appears after non-whitespace and not in quotes
|
# Simple heuristic: if # appears after non-whitespace and not in quotes
|
||||||
if '#' in line:
|
if "#" in line:
|
||||||
# Remove quoted strings to avoid false positives
|
# Remove quoted strings to avoid false positives
|
||||||
cleaned = re.sub(r'"[^"]*"', '', line)
|
cleaned = re.sub(r'"[^"]*"', "", line)
|
||||||
cleaned = re.sub(r"'[^']*'", '', cleaned)
|
cleaned = re.sub(r"'[^']*'", "", cleaned)
|
||||||
|
|
||||||
if '#' in cleaned:
|
if "#" in cleaned:
|
||||||
# Check if it's likely a comment (has text after it)
|
# Check if it's likely a comment (has text after it)
|
||||||
hash_pos = cleaned.index('#')
|
hash_pos = cleaned.index("#")
|
||||||
if hash_pos > 0 and cleaned[hash_pos-1:hash_pos] != '\\':
|
if hash_pos > 0 and cleaned[hash_pos - 1 : hash_pos] != "\\":
|
||||||
line_num = template_content[:match.start()].count('\n') + i + 1
|
line_num = template_content[: match.start()].count("\n") + i + 1
|
||||||
errors.append(
|
errors.append(
|
||||||
f"{template_path}:{line_num}: Inline comment (#) found in Jinja2 expression. "
|
f"{template_path}:{line_num}: Inline comment (#) found in Jinja2 expression. "
|
||||||
f"Move comments outside the expression."
|
f"Move comments outside the expression."
|
||||||
|
@ -83,11 +83,24 @@ def check_undefined_variables(template_path: Path) -> list[str]:
|
||||||
|
|
||||||
# Common Ansible variables that are always available
|
# Common Ansible variables that are always available
|
||||||
ansible_builtins = {
|
ansible_builtins = {
|
||||||
'ansible_default_ipv4', 'ansible_default_ipv6', 'ansible_hostname',
|
"ansible_default_ipv4",
|
||||||
'ansible_distribution', 'ansible_distribution_version', 'ansible_facts',
|
"ansible_default_ipv6",
|
||||||
'inventory_hostname', 'hostvars', 'groups', 'group_names',
|
"ansible_hostname",
|
||||||
'play_hosts', 'ansible_version', 'ansible_user', 'ansible_host',
|
"ansible_distribution",
|
||||||
'item', 'ansible_loop', 'ansible_index', 'lookup'
|
"ansible_distribution_version",
|
||||||
|
"ansible_facts",
|
||||||
|
"inventory_hostname",
|
||||||
|
"hostvars",
|
||||||
|
"groups",
|
||||||
|
"group_names",
|
||||||
|
"play_hosts",
|
||||||
|
"ansible_version",
|
||||||
|
"ansible_user",
|
||||||
|
"ansible_host",
|
||||||
|
"item",
|
||||||
|
"ansible_loop",
|
||||||
|
"ansible_index",
|
||||||
|
"lookup",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Filter out known Ansible variables
|
# Filter out known Ansible variables
|
||||||
|
@ -95,9 +108,7 @@ def check_undefined_variables(template_path: Path) -> list[str]:
|
||||||
|
|
||||||
# Only report if there are truly unknown variables
|
# Only report if there are truly unknown variables
|
||||||
if unknown_vars and len(unknown_vars) < 20: # Avoid noise from templates with many vars
|
if unknown_vars and len(unknown_vars) < 20: # Avoid noise from templates with many vars
|
||||||
errors.append(
|
errors.append(f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}")
|
||||||
f"{template_path}: Uses undefined variables: {', '.join(sorted(unknown_vars))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Don't report parse errors here, they're handled elsewhere
|
# Don't report parse errors here, they're handled elsewhere
|
||||||
|
@ -116,9 +127,9 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
|
||||||
# Skip full parsing for templates that use Ansible-specific features heavily
|
# Skip full parsing for templates that use Ansible-specific features heavily
|
||||||
# We still check for inline comments but skip full template parsing
|
# We still check for inline comments but skip full template parsing
|
||||||
ansible_specific_templates = {
|
ansible_specific_templates = {
|
||||||
'dnscrypt-proxy.toml.j2', # Uses |bool filter
|
"dnscrypt-proxy.toml.j2", # Uses |bool filter
|
||||||
'mobileconfig.j2', # Uses |to_uuid filter and complex item structures
|
"mobileconfig.j2", # Uses |to_uuid filter and complex item structures
|
||||||
'vpn-dict.j2', # Uses |to_uuid filter
|
"vpn-dict.j2", # Uses |to_uuid filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if template_path.name in ansible_specific_templates:
|
if template_path.name in ansible_specific_templates:
|
||||||
|
@ -139,18 +150,15 @@ def validate_template_syntax(template_path: Path) -> tuple[bool, list[str]]:
|
||||||
errors.extend(check_inline_comments_in_expressions(template_content, template_path))
|
errors.extend(check_inline_comments_in_expressions(template_content, template_path))
|
||||||
|
|
||||||
# Try to parse the template
|
# Try to parse the template
|
||||||
env = Environment(
|
env = Environment(loader=FileSystemLoader(template_path.parent), undefined=StrictUndefined)
|
||||||
loader=FileSystemLoader(template_path.parent),
|
|
||||||
undefined=StrictUndefined
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add mock Ansible filters to avoid syntax errors
|
# Add mock Ansible filters to avoid syntax errors
|
||||||
env.filters['bool'] = lambda x: x
|
env.filters["bool"] = lambda x: x
|
||||||
env.filters['to_uuid'] = lambda x: x
|
env.filters["to_uuid"] = lambda x: x
|
||||||
env.filters['b64encode'] = lambda x: x
|
env.filters["b64encode"] = lambda x: x
|
||||||
env.filters['b64decode'] = lambda x: x
|
env.filters["b64decode"] = lambda x: x
|
||||||
env.filters['regex_replace'] = lambda x, y, z: x
|
env.filters["regex_replace"] = lambda x, y, z: x
|
||||||
env.filters['default'] = lambda x, d: x if x else d
|
env.filters["default"] = lambda x, d: x if x else d
|
||||||
|
|
||||||
# This will raise TemplateSyntaxError if there's a syntax problem
|
# This will raise TemplateSyntaxError if there's a syntax problem
|
||||||
env.get_template(template_path.name)
|
env.get_template(template_path.name)
|
||||||
|
@ -178,18 +186,20 @@ def check_common_antipatterns(template_path: Path) -> list[str]:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Check for missing spaces around filters
|
# Check for missing spaces around filters
|
||||||
if re.search(r'\{\{[^}]+\|[^ ]', content):
|
if re.search(r"\{\{[^}]+\|[^ ]", content):
|
||||||
warnings.append(f"{template_path}: Missing space after filter pipe (|)")
|
warnings.append(f"{template_path}: Missing space after filter pipe (|)")
|
||||||
|
|
||||||
# Check for deprecated 'when' in Jinja2 (should use if)
|
# Check for deprecated 'when' in Jinja2 (should use if)
|
||||||
if re.search(r'\{%\s*when\s+', content):
|
if re.search(r"\{%\s*when\s+", content):
|
||||||
warnings.append(f"{template_path}: Use 'if' instead of 'when' in Jinja2 templates")
|
warnings.append(f"{template_path}: Use 'if' instead of 'when' in Jinja2 templates")
|
||||||
|
|
||||||
# Check for extremely long expressions (harder to debug)
|
# Check for extremely long expressions (harder to debug)
|
||||||
for match in re.finditer(r'\{\{(.+?)\}\}', content, re.DOTALL):
|
for match in re.finditer(r"\{\{(.+?)\}\}", content, re.DOTALL):
|
||||||
if len(match.group(1)) > 200:
|
if len(match.group(1)) > 200:
|
||||||
line_num = content[:match.start()].count('\n') + 1
|
line_num = content[: match.start()].count("\n") + 1
|
||||||
warnings.append(f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up")
|
warnings.append(
|
||||||
|
f"{template_path}:{line_num}: Very long expression (>200 chars), consider breaking it up"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Ignore errors in anti-pattern checking
|
pass # Ignore errors in anti-pattern checking
|
||||||
|
|
Loading…
Add table
Reference in a new issue