feat(sdk): limit selected high-volume AWS resource analysis (#11228)

This commit is contained in:
Hugo Pereira Brito
2026-06-30 15:49:12 +01:00
committed by GitHub
parent 34e8e3ca61
commit c46cbaaa4a
25 changed files with 2392 additions and 247 deletions
+47
View File
@@ -3,6 +3,7 @@ constraint surface (CIDRs, account IDs, port ranges, enums, thresholds)."""
import pytest
from prowler.config.scan_config_schema import SCAN_CONFIG_SCHEMA
from prowler.config.schema.aws import AWSProviderConfig
from prowler.config.schema.validator import validate_provider_config
@@ -11,6 +12,52 @@ def _validate(raw):
return validate_provider_config("aws", raw, AWSProviderConfig)
RESOURCE_LIMIT_KEYS = [
"max_scanned_resources_per_service",
"max_ebs_snapshots",
"max_backup_recovery_points",
"max_cloudwatch_log_groups",
"max_lambda_functions",
"max_ecs_task_definitions",
"max_codeartifact_packages",
]
class Test_AWS_Resource_Limits:
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_positive_values_round_trip(self, key):
assert _validate({key: 100}) == {key: 100}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_null_values_round_trip(self, key):
assert _validate({key: None}) == {key: None}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_zero_disable_sentinel_round_trips(self, key):
assert _validate({key: 0}) == {key: 0}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_numeric_strings_are_coerced_to_int(self, key):
assert _validate({key: "100"}) == {key: 100}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_disable_sentinel_minus_one_round_trips(self, key):
assert _validate({key: -1}) == {key: -1}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
@pytest.mark.parametrize("value", [True, False])
def test_booleans_are_dropped_not_coerced_to_int(self, key, value):
assert _validate({key: value}) == {}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_invalid_strings_are_dropped(self, key):
assert _validate({key: "not-an-int"}) == {}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_keys_are_exposed_in_scan_config_schema(self, key):
assert key in SCAN_CONFIG_SCHEMA["properties"]["aws"]["properties"]
class Test_AWS_Threat_Detection_Thresholds:
"""All threat detection thresholds are documented as fractions in 0..1.
The biggest risk of mistyping them is silently disabling the check."""
+156
View File
@@ -0,0 +1,156 @@
from prowler.lib.resource_limit import (
get_resource_scan_limit,
iter_limited_paginator_items,
limit_resources,
)
class FakePaginator:
def __init__(self, pages):
self.pages = pages
self.paginate_calls = []
self.pages_requested = 0
def paginate(self, **kwargs):
self.paginate_calls.append(kwargs)
for page in self.pages:
self.pages_requested += 1
yield page
class Test_limit_resources:
def test_no_limit_returns_all_in_order(self):
resources = ["PASS", "FAIL", "PASS"]
result = list(limit_resources(iter(resources), None))
assert result == ["PASS", "FAIL", "PASS"]
class Test_iter_limited_paginator_items:
def test_positive_limit_stops_without_page_size(self):
paginator = FakePaginator(
[
{"Items": [1, 2]},
{"Items": [3, 4]},
{"Items": [5]},
]
)
result = list(iter_limited_paginator_items(paginator, "Items", 3))
assert result == [1, 2, 3]
assert paginator.paginate_calls == [{}]
assert paginator.pages_requested == 2
def test_absurd_limit_is_not_sent_as_page_size(self):
paginator = FakePaginator([{"Items": [1, 2]}])
result = list(iter_limited_paginator_items(paginator, "Items", 200000))
assert result == [1, 2]
assert paginator.paginate_calls == [{}]
def test_operation_parameters_are_forwarded_unchanged(self):
paginator = FakePaginator([{"Snapshots": ["snapshot"]}])
result = list(
iter_limited_paginator_items(
paginator,
"Snapshots",
1,
OwnerIds=["self"],
)
)
assert result == ["snapshot"]
assert paginator.paginate_calls == [{"OwnerIds": ["self"]}]
def test_item_filter_limits_selected_items_only(self):
paginator = FakePaginator(
[
{"Items": [{"arn": "skip"}, {"arn": "first"}]},
{"Items": [{"arn": "second"}, {"arn": "third"}]},
]
)
result = list(
iter_limited_paginator_items(
paginator,
"Items",
2,
item_filter=lambda item: item["arn"] != "skip",
)
)
assert result == [{"arn": "first"}, {"arn": "second"}]
assert paginator.pages_requested == 2
def test_limit_zero_or_negative_is_unlimited(self):
resources = list(range(5))
assert list(limit_resources(iter(resources), 0)) == resources
assert list(limit_resources(iter(resources), -3)) == resources
def test_positive_limit_stops_after_selected_resources(self):
pulled = []
def gen():
for i in range(1000):
pulled.append(i)
yield i
result = list(limit_resources(gen(), 100))
assert result == list(range(100))
assert len(pulled) == 100
def test_does_not_reorder_or_inspect_resource_status(self):
resources = ["PASS", "FAIL", "PASS", "FAIL"]
result = list(limit_resources(iter(resources), 3))
assert result == ["PASS", "FAIL", "PASS"]
class Test_get_resource_scan_limit:
def test_per_service_override_wins(self):
config = {
"max_scanned_resources_per_service": 100,
"max_ecs_task_definitions": 25,
}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 25
def test_falls_back_to_global_default(self):
config = {"max_scanned_resources_per_service": 50}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 50
def test_null_per_service_override_falls_back_to_global_default(self):
config = {
"max_scanned_resources_per_service": 50,
"max_ecs_task_definitions": None,
}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 50
def test_default_is_unlimited_when_unset(self):
assert get_resource_scan_limit({}, "max_ecs_task_definitions") is None
def test_null_per_service_override_falls_back_to_unlimited_global_default(self):
config = {"max_ecs_task_definitions": None}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") is None
def test_non_positive_means_unlimited(self):
assert (
get_resource_scan_limit(
{"max_scanned_resources_per_service": 0}, "max_lambda_functions"
)
is None
)
assert (
get_resource_scan_limit(
{"max_lambda_functions": -1}, "max_lambda_functions"
)
is None
)
@@ -6,10 +6,16 @@ from re import search
from unittest.mock import patch
import mock
import pytest
from boto3 import client, resource
from botocore.client import ClientError
from moto import mock_aws
from prowler.providers.aws.services.awslambda.awslambda_service import AuthType, Lambda
from prowler.providers.aws.services.awslambda.awslambda_service import (
AuthType,
Function,
Lambda,
)
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -85,6 +91,367 @@ class Test_Lambda_Service:
awslambda = Lambda(set_mocked_aws_provider([AWS_REGION_US_EAST_1]))
assert awslambda.service == "lambda"
def test_function_limit_selects_latest_functions_for_analysis(self):
awslambda = Lambda.__new__(Lambda)
awslambda.functions = {
"old": Function(
name="old",
arn="old",
security_groups=[],
last_modified="2024-01-01T00:00:00.000+0000",
region=AWS_REGION_EU_WEST_1,
),
"new": Function(
name="new",
arn="new",
security_groups=[],
last_modified="2024-01-02T00:00:00.000+0000",
region=AWS_REGION_EU_WEST_1,
),
}
awslambda.function_limit = 1
awslambda._select_functions_for_analysis()
assert list(awslambda.functions) == ["new"]
def test_function_limit_selects_global_latest_across_regions(self):
class FakePaginator:
def __init__(self, functions):
self.functions = functions
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [{"Functions": self.functions}]
class FakeLambdaClient:
def __init__(self, region, functions):
self.region = region
self.functions = functions
def get_paginator(self, name):
assert name == "list_functions"
return FakePaginator(self.functions)
awslambda = Lambda.__new__(Lambda)
awslambda.functions = {}
awslambda.security_groups_in_use = set()
awslambda.regions_with_functions = set()
awslambda.function_limit = 1
awslambda.audit_resources = []
old_client = FakeLambdaClient(
AWS_REGION_EU_WEST_1,
[
{
"FunctionName": "old",
"FunctionArn": "arn:aws:lambda:eu-west-1:123456789012:function:old",
"LastModified": "2024-01-01T00:00:00.000+0000",
}
],
)
new_client = FakeLambdaClient(
AWS_REGION_US_EAST_1,
[
{
"FunctionName": "new",
"FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:new",
"LastModified": "2024-01-02T00:00:00.000+0000",
}
],
)
awslambda._list_functions(old_client)
awslambda._list_functions(new_client)
awslambda._select_functions_for_analysis()
assert [function.name for function in awslambda.functions.values()] == ["new"]
def test_function_limit_keeps_complete_auxiliary_indexes(self):
class FakePaginator:
def __init__(self, functions):
self.functions = functions
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [{"Functions": self.functions}]
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "list_functions"
return FakePaginator(
[
{
"FunctionName": "old",
"FunctionArn": (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:old"
),
"LastModified": "2024-01-01T00:00:00.000+0000",
"VpcConfig": {"SecurityGroupIds": ["sg-old"]},
},
{
"FunctionName": "new",
"FunctionArn": (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:new"
),
"LastModified": "2024-01-02T00:00:00.000+0000",
"VpcConfig": {"SecurityGroupIds": ["sg-new"]},
},
]
)
awslambda = Lambda.__new__(Lambda)
awslambda.functions = {}
awslambda.security_groups_in_use = set()
awslambda.regions_with_functions = set()
awslambda.function_limit = 1
awslambda.audit_resources = []
awslambda._list_functions(FakeLambdaClient())
awslambda._select_functions_for_analysis()
assert [function.name for function in awslambda.functions.values()] == ["new"]
assert awslambda.security_groups_in_use == {"sg-old", "sg-new"}
assert awslambda.regions_with_functions == {AWS_REGION_US_EAST_1}
def test_list_event_source_mappings_uses_selected_functions_as_api_scope(self):
class FakePaginator:
def __init__(self):
self.paginate_calls = []
def paginate(self, **kwargs):
self.paginate_calls.append(kwargs)
function_name = kwargs["FunctionName"]
return [
{
"EventSourceMappings": [
{
"UUID": f"{function_name}-mapping",
"FunctionArn": (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:{function_name}:1"
),
"EventSourceArn": "arn:aws:sqs:queue",
"State": "Enabled",
"BatchSize": 10,
}
]
}
]
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def __init__(self):
self.paginator = FakePaginator()
def get_paginator(self, name):
assert name == "list_event_source_mappings"
return self.paginator
selected_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:selected"
)
other_region_arn = (
f"arn:aws:lambda:{AWS_REGION_EU_WEST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:other-region"
)
awslambda = Lambda.__new__(Lambda)
awslambda.function_limit = 1
awslambda.functions = {
selected_arn: Function(
name="selected",
arn=selected_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
),
other_region_arn: Function(
name="other-region",
arn=other_region_arn,
security_groups=[],
region=AWS_REGION_EU_WEST_1,
),
}
regional_client = FakeLambdaClient()
awslambda._list_event_source_mappings(regional_client)
assert regional_client.paginator.paginate_calls == [
{"FunctionName": "selected"}
]
assert len(awslambda.functions[selected_arn].event_source_mappings) == 1
assert (
awslambda.functions[selected_arn].event_source_mappings[0].uuid
== "selected-mapping"
)
assert not awslambda.functions[other_region_arn].event_source_mappings
def test_list_event_source_mappings_keeps_unlimited_regional_api_scope(self):
class FakePaginator:
def __init__(self):
self.paginate_calls = []
def paginate(self, **kwargs):
self.paginate_calls.append(kwargs)
return [
{
"EventSourceMappings": [
{
"UUID": "selected-mapping",
"FunctionArn": selected_arn,
"EventSourceArn": "arn:aws:sqs:queue",
"State": "Enabled",
}
]
}
]
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def __init__(self):
self.paginator = FakePaginator()
def get_paginator(self, name):
assert name == "list_event_source_mappings"
return self.paginator
selected_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:selected"
)
awslambda = Lambda.__new__(Lambda)
awslambda.function_limit = None
awslambda.functions = {
selected_arn: Function(
name="selected",
arn=selected_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
)
}
regional_client = FakeLambdaClient()
awslambda._list_event_source_mappings(regional_client)
assert regional_client.paginator.paginate_calls == [{}]
assert len(awslambda.functions[selected_arn].event_source_mappings) == 1
def test_list_event_source_mappings_continues_after_invalid_parameter_value(self):
class FakePaginator:
def paginate(self, **kwargs):
function_name = kwargs["FunctionName"]
if function_name == "deleted":
raise ClientError(
{
"Error": {
"Code": "InvalidParameterValueException",
"Message": "Function no longer exists",
}
},
"ListEventSourceMappings",
)
return [
{
"EventSourceMappings": [
{
"UUID": f"{function_name}-mapping",
"FunctionArn": (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:{function_name}"
),
"EventSourceArn": "arn:aws:sqs:queue",
"State": "Enabled",
}
]
}
]
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "list_event_source_mappings"
return FakePaginator()
deleted_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:deleted"
)
remaining_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:remaining"
)
awslambda = Lambda.__new__(Lambda)
awslambda.function_limit = 2
awslambda.functions = {
deleted_arn: Function(
name="deleted",
arn=deleted_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
),
remaining_arn: Function(
name="remaining",
arn=remaining_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
),
}
awslambda._list_event_source_mappings(FakeLambdaClient())
assert not awslambda.functions[deleted_arn].event_source_mappings
assert len(awslambda.functions[remaining_arn].event_source_mappings) == 1
assert (
awslambda.functions[remaining_arn].event_source_mappings[0].uuid
== "remaining-mapping"
)
def test_list_event_source_mappings_raises_non_transient_client_error(self):
class FakePaginator:
def paginate(self, **kwargs):
raise ClientError(
{
"Error": {
"Code": "AccessDeniedException",
"Message": "Access denied",
}
},
"ListEventSourceMappings",
)
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "list_event_source_mappings"
return FakePaginator()
function_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:selected"
)
awslambda = Lambda.__new__(Lambda)
awslambda.function_limit = 1
awslambda.functions = {
function_arn: Function(
name="selected",
arn=function_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
)
}
with pytest.raises(ClientError) as error:
awslambda._list_event_source_mappings(FakeLambdaClient())
assert error.value.response["Error"]["Code"] == "AccessDeniedException"
@mock_aws
def test_list_functions(self):
# Create IAM Lambda Role
@@ -253,3 +620,63 @@ class Test_Lambda_Service:
f"{tmp_dir_name}/{files_in_zip[0]}", "r"
) as lambda_code_file:
assert lambda_code_file.read() == LAMBDA_FUNCTION_CODE
@mock_aws
def test_function_limit_exposes_only_selected_functions(self):
lambda_client = client("lambda", region_name=AWS_REGION_US_EAST_1)
iam_client = client("iam", region_name=AWS_REGION_US_EAST_1)
iam_role = iam_client.create_role(
RoleName="test-role",
AssumeRolePolicyDocument="{}",
)["Role"]["Arn"]
for name in ("function-1", "function-2"):
lambda_client.create_function(
FunctionName=name,
Runtime="python3.7",
Role=iam_role,
Handler="lambda_function.lambda_handler",
Code={"ZipFile": create_zip_file().read()},
PackageType="ZIP",
)
awslambda = Lambda(
set_mocked_aws_provider(
audited_regions=[AWS_REGION_US_EAST_1],
audit_config={"max_lambda_functions": 1},
)
)
assert len(awslambda.functions) == 1
@mock_aws
def test_get_function_code_fetches_only_selected_functions(self):
lambda_client = client("lambda", region_name=AWS_REGION_US_EAST_1)
iam_client = client("iam", region_name=AWS_REGION_US_EAST_1)
iam_role = iam_client.create_role(
RoleName="test-role",
AssumeRolePolicyDocument="{}",
)["Role"]["Arn"]
for name in ("function-1", "function-2"):
lambda_client.create_function(
FunctionName=name,
Runtime="python3.7",
Role=iam_role,
Handler="lambda_function.lambda_handler",
Code={"ZipFile": create_zip_file().read()},
PackageType="ZIP",
)
awslambda = Lambda(
set_mocked_aws_provider(
audited_regions=[AWS_REGION_US_EAST_1],
audit_config={"max_lambda_functions": 1},
)
)
fetched = []
def fetch_function_code(function_name, _function_region):
fetched.append(function_name)
return mock.MagicMock()
awslambda._fetch_function_code = fetch_function_code
assert len(list(awslambda._get_function_code())) == 1
assert len(fetched) == 1
@@ -1,11 +1,16 @@
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch
import botocore
from boto3 import client
from moto import mock_aws
from prowler.providers.aws.services.backup.backup_service import Backup
from prowler.providers.aws.services.backup.backup_service import (
Backup,
BackupVault,
RecoveryPoint,
)
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -292,3 +297,248 @@ class TestBackupService:
assert backup.recovery_points[0].backup_vault_region == "eu-west-1"
assert backup.recovery_points[0].tags == []
assert backup.recovery_points[0].encrypted is True
def test_recovery_point_limit_bounds_tag_calls_to_selected_points(self):
class FakePaginator:
def paginate(self, **kwargs):
return [
{
"RecoveryPoints": [
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:new",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 2),
},
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:old",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 1),
},
]
}
]
class FakeBackupClient:
def __init__(self):
self.tag_calls = []
def get_paginator(self, name):
assert name == "list_recovery_points_by_backup_vault"
return FakePaginator()
def list_tags(self, **kwargs):
self.tag_calls.append(kwargs["ResourceArn"])
return {"Tags": {}}
regional_client = FakeBackupClient()
backup = Backup.__new__(Backup)
backup.backup_vaults = [
BackupVault(
arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:vault",
name="vault",
region=AWS_REGION_EU_WEST_1,
encryption="",
recovery_points=2,
locked=False,
)
]
backup.recovery_points = []
backup.recovery_point_limit = 1
backup.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
backup._list_recovery_points()
backup._select_recovery_points_for_analysis()
for recovery_point in backup.recovery_points:
backup._list_tags(recovery_point)
assert [rp.id for rp in backup.recovery_points] == ["new"]
assert regional_client.tag_calls == [
"arn:aws:backup:eu-west-1:123456789012:recovery-point:new"
]
def test_recovery_point_limit_selects_global_newest_across_vaults(self):
class FakePaginator:
def __init__(self, recovery_points_by_vault):
self.recovery_points_by_vault = recovery_points_by_vault
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [
{
"RecoveryPoints": self.recovery_points_by_vault[
kwargs["BackupVaultName"]
]
}
]
class FakeBackupClient:
def __init__(self, recovery_points_by_vault):
self.recovery_points_by_vault = recovery_points_by_vault
def get_paginator(self, name):
assert name == "list_recovery_points_by_backup_vault"
return FakePaginator(self.recovery_points_by_vault)
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 1
backup.recovery_points = []
backup.backup_vaults = [
BackupVault(
arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:old-vault",
name="old-vault",
region=AWS_REGION_EU_WEST_1,
encryption="",
recovery_points=1,
locked=False,
),
BackupVault(
arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:new-vault",
name="new-vault",
region=AWS_REGION_EU_WEST_1,
encryption="",
recovery_points=1,
locked=False,
),
]
backup.regional_clients = {
AWS_REGION_EU_WEST_1: FakeBackupClient(
{
"old-vault": [
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:old",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 1),
}
],
"new-vault": [
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:new",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 2),
}
],
}
)
}
backup._list_recovery_points()
backup._select_recovery_points_for_analysis()
assert [rp.id for rp in backup.recovery_points] == ["new"]
def test_recovery_point_limit_exposes_only_selected_resources(self):
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 2
backup.recovery_points = []
backup.backup_vaults = [
BackupVault(
arn="arn",
name="vault",
region="eu-west-1",
encryption="",
recovery_points=3,
locked=False,
)
]
class Paginator:
def paginate(self, **_kwargs):
return [
{
"RecoveryPoints": [
{
"RecoveryPointArn": f"arn:aws:backup:eu-west-1:123456789012:recovery-point:{i}",
"IsEncrypted": True,
}
for i in range(3)
]
}
]
backup.regional_clients = {
"eu-west-1": SimpleNamespace(get_paginator=lambda _: Paginator())
}
tagged = []
def list_tags(recovery_point):
tagged.append(recovery_point.arn)
backup._list_tags = list_tags
backup._list_recovery_points()
backup._select_recovery_points_for_analysis()
for recovery_point in backup.recovery_points:
backup._list_tags(recovery_point)
assert len(backup.recovery_points) == 2
assert len(tagged) == 2
def test_recovery_point_limit_uses_deterministic_tie_breaker(self):
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 2
backup.recovery_points = [
RecoveryPoint(
arn="arn:aws:backup:us-east-1:123456789012:recovery-point:z",
id="z",
region="us-east-1",
backup_vault_name="vault-b",
encrypted=True,
backup_vault_region="us-east-1",
),
RecoveryPoint(
arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:b",
id="b",
region="eu-west-1",
backup_vault_name="vault-b",
encrypted=True,
backup_vault_region="eu-west-1",
),
RecoveryPoint(
arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:a",
id="a",
region="eu-west-1",
backup_vault_name="vault-a",
encrypted=True,
backup_vault_region="eu-west-1",
),
]
backup._select_recovery_points_for_analysis()
assert [rp.id for rp in backup.recovery_points] == ["a", "b"]
def test_recovery_point_limit_keeps_newest_before_tie_breaker(self):
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 2
backup.recovery_points = [
RecoveryPoint(
arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:older-a",
id="older-a",
region="eu-west-1",
backup_vault_name="vault-a",
encrypted=True,
backup_vault_region="eu-west-1",
creation_date=datetime(2024, 1, 1),
),
RecoveryPoint(
arn="arn:aws:backup:us-east-1:123456789012:recovery-point:newer-z",
id="newer-z",
region="us-east-1",
backup_vault_name="vault-z",
encrypted=True,
backup_vault_region="us-east-1",
creation_date=datetime(2024, 1, 2),
),
RecoveryPoint(
arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:missing-date",
id="missing-date",
region="eu-west-1",
backup_vault_name="vault-a",
encrypted=True,
backup_vault_region="eu-west-1",
),
]
backup._select_recovery_points_for_analysis()
assert [rp.id for rp in backup.recovery_points] == ["newer-z", "older-a"]
@@ -227,6 +227,72 @@ class Test_bedrock_model_invocation_logs_encryption_enabled:
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_tags == []
def test_cloudwatch_logging_uses_complete_log_group_index(self):
from prowler.providers.aws.services.bedrock.bedrock_service import (
LoggingConfiguration,
)
from prowler.providers.aws.services.cloudwatch.cloudwatch_service import (
LogGroup,
)
bedrock_client = mock.MagicMock()
bedrock_client.logging_configurations = {
AWS_REGION_US_EAST_1: LoggingConfiguration(
enabled=True,
cloudwatch_log_group="Test",
)
}
bedrock_client._get_model_invocation_logging_arn_template.return_value = (
"arn:aws:bedrock:us-east-1:123456789012:model-invocation-logging"
)
logs_client = mock.MagicMock()
logs_client.audited_partition = "aws"
logs_client.audited_account = "123456789012"
logs_client.log_groups = {}
logs_client.all_log_groups = {
"arn:aws:logs:us-east-1:123456789012:log-group:Test:*": LogGroup(
arn="arn:aws:logs:us-east-1:123456789012:log-group:Test:*",
name="Test",
retention_days=30,
never_expire=False,
kms_id=None,
region=AWS_REGION_US_EAST_1,
)
}
s3_client = mock.MagicMock()
s3_client.audited_partition = "aws"
s3_client.buckets = {}
with (
mock.patch(
"prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.bedrock_client",
new=bedrock_client,
),
mock.patch(
"prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.logs_client",
new=logs_client,
),
mock.patch(
"prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.s3_client",
new=s3_client,
),
):
from prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled import (
bedrock_model_invocation_logs_encryption_enabled,
)
check = bedrock_model_invocation_logs_encryption_enabled()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "Bedrock Model Invocation logs are not encrypted in CloudWatch Log Group: Test."
)
@mock_aws
def test_s3_and_cloudwatch_logging_encrypted(self):
logs_client = client("logs", region_name=AWS_REGION_US_EAST_1)
@@ -4,6 +4,7 @@ from moto import mock_aws
from prowler.providers.aws.services.cloudwatch.cloudwatch_service import (
CloudWatch,
LogGroup,
Logs,
)
from prowler.providers.aws.services.cloudwatch.lib.metric_filters import (
@@ -188,7 +189,9 @@ class Test_CloudWatch_Service:
arn = f"arn:aws:logs:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:log-group:/log-group/test:*"
logs = Logs(aws_provider)
assert len(logs.log_groups) == 1
assert len(logs.all_log_groups) == 1
assert arn in logs.log_groups
assert arn in logs.all_log_groups
assert logs.log_groups[arn].name == "/log-group/test"
assert logs.log_groups[arn].retention_days == 400
assert logs.log_groups[arn].kms_id == "test_kms_id"
@@ -212,7 +215,9 @@ class Test_CloudWatch_Service:
arn = f"arn:aws:logs:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:log-group:/log-group/test:*"
logs = Logs(aws_provider)
assert len(logs.log_groups) == 1
assert len(logs.all_log_groups) == 1
assert arn in logs.log_groups
assert arn in logs.all_log_groups
assert logs.log_groups[arn].name == "/log-group/test"
assert logs.log_groups[arn].never_expire
# Since it never expires we don't use the retention_days
@@ -221,6 +226,190 @@ class Test_CloudWatch_Service:
assert logs.log_groups[arn].region == AWS_REGION_US_EAST_1
assert logs.log_groups[arn].tags == [{}]
def test_log_group_limit_exposes_only_selected_resources(self):
class FakeLogsClient:
def __init__(self):
self.filter_calls = []
def filter_log_events(self, **kwargs):
self.filter_calls.append(kwargs["logGroupName"])
return {"events": []}
regional_client = FakeLogsClient()
logs = Logs.__new__(Logs)
logs.log_group_limit = 1
logs._log_groups_hydrated = set()
logs.regional_clients = {AWS_REGION_US_EAST_1: regional_client}
logs.events_per_log_group_threshold = 1000
logs.log_groups = {
f"arn:{i}": LogGroup(
arn=f"arn:{i}",
name=f"log-{i}",
retention_days=30,
never_expire=False,
kms_id=None,
creation_time=i,
region=AWS_REGION_US_EAST_1,
)
for i in range(3)
}
tagged = []
def list_tags(log_group):
tagged.append(log_group.arn)
logs._list_tags_for_resource = list_tags
logs._select_log_groups_for_analysis()
for log_group in logs.log_groups.values():
logs._list_tags_for_resource(log_group)
logs._get_log_events(log_group)
assert list(logs.log_groups) == ["arn:2"]
assert tagged == ["arn:2"]
assert regional_client.filter_calls == ["log-2"]
def test_log_group_limit_selects_global_newest_across_regions(self):
class FakePaginator:
def __init__(self, log_groups):
self.log_groups = log_groups
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [{"logGroups": self.log_groups}]
class FakeLogsClient:
def __init__(self, region, log_groups):
self.region = region
self.log_groups = log_groups
def get_paginator(self, name):
assert name == "describe_log_groups"
return FakePaginator(self.log_groups)
logs = Logs.__new__(Logs)
logs.all_log_groups = {}
logs.log_groups = {}
logs.log_group_limit = 1
logs.audit_resources = []
logs._describe_log_groups(
FakeLogsClient(
"eu-west-1",
[
{
"arn": "arn:aws:logs:eu-west-1:123456789012:log-group:old:*",
"logGroupName": "old",
"creationTime": 1,
}
],
)
)
logs._describe_log_groups(
FakeLogsClient(
AWS_REGION_US_EAST_1,
[
{
"arn": "arn:aws:logs:us-east-1:123456789012:log-group:new:*",
"logGroupName": "new",
"creationTime": 2,
}
],
)
)
logs._select_log_groups_for_analysis()
assert [log_group.name for log_group in logs.log_groups.values()] == ["new"]
assert [log_group.name for log_group in logs.all_log_groups.values()] == [
"old",
"new",
]
def test_metric_filters_use_complete_log_group_index(self):
class FakePaginator:
def paginate(self):
return [
{
"metricFilters": [
{
"filterName": "test-filter",
"filterPattern": "test-pattern",
"logGroupName": "old",
"metricTransformations": [
{"metricName": "test-metric"}
],
}
]
}
]
class FakeLogsClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "describe_metric_filters"
return FakePaginator()
logs = Logs.__new__(Logs)
old_log_group = LogGroup(
arn="arn:old",
name="old",
retention_days=30,
never_expire=False,
kms_id=None,
creation_time=1,
region=AWS_REGION_US_EAST_1,
)
logs.audited_partition = "aws"
logs.audited_account = AWS_ACCOUNT_NUMBER
logs.audit_resources = []
logs.metric_filters = []
logs.log_groups = {}
logs.all_log_groups = {old_log_group.arn: old_log_group}
logs._log_groups_hydrated = set()
logs._list_tags_for_resource = lambda log_group: None
logs._describe_metric_filters(FakeLogsClient())
assert len(logs.metric_filters) == 1
assert logs.metric_filters[0].log_group == old_log_group
def test_log_group_collection_recovers_all_log_groups_after_access_denied(self):
class FakePaginator:
def paginate(self):
return [
{
"logGroups": [
{
"arn": "arn:aws:logs:us-east-1:123456789012:log-group:success:*",
"logGroupName": "success",
"creationTime": 1,
}
]
}
]
class FakeLogsClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "describe_log_groups"
return FakePaginator()
logs = Logs.__new__(Logs)
logs.all_log_groups = None
logs.log_groups = None
logs.audit_resources = []
logs._describe_log_groups(FakeLogsClient())
assert list(logs.all_log_groups) == [
"arn:aws:logs:us-east-1:123456789012:log-group:success:*"
]
assert list(logs.log_groups) == [
"arn:aws:logs:us-east-1:123456789012:log-group:success:*"
]
class Test_build_metric_filter_pattern:
@pytest.mark.parametrize("bad_operator", ["==", "~=", "<", "<>", ">=", ""])
@@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest.mock import patch
import botocore
@@ -6,6 +7,7 @@ from prowler.providers.aws.services.codeartifact.codeartifact_service import (
CodeArtifact,
LatestPackageVersionStatus,
OriginInformationValues,
Repository,
RestrictionValues,
)
from tests.providers.aws.utils import (
@@ -208,6 +210,104 @@ class Test_CodeArtifact_Service:
== OriginInformationValues.INTERNAL
)
def test_package_limit_bounds_package_version_lookups_to_selected_packages(self):
class FakePaginator:
def paginate(self, **kwargs):
return [
{
"packages": [
{
"format": "pypi",
"package": "first-package",
"originConfiguration": {
"restrictions": {
"publish": "ALLOW",
"upstream": "ALLOW",
}
},
},
{
"format": "pypi",
"package": "second-package",
"originConfiguration": {
"restrictions": {
"publish": "ALLOW",
"upstream": "ALLOW",
}
},
},
]
}
]
class FakeCodeArtifactClient:
def __init__(self):
self.version_calls = []
def get_paginator(self, name):
assert name == "list_packages"
return FakePaginator()
def list_package_versions(self, **kwargs):
self.version_calls.append(kwargs["package"])
return {
"versions": [
{
"version": "1.0.0",
"status": "Published",
"origin": {"originType": "INTERNAL"},
}
]
}
regional_client = FakeCodeArtifactClient()
codeartifact = CodeArtifact.__new__(CodeArtifact)
codeartifact.repositories = {
TEST_REPOSITORY_ARN: Repository(
name="test-repository",
arn=TEST_REPOSITORY_ARN,
domain_name="test-domain",
domain_owner=AWS_ACCOUNT_NUMBER,
region=AWS_REGION_EU_WEST_1,
)
}
codeartifact._packages_listed = set()
codeartifact.package_limit = 1
codeartifact.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
pairs = list(codeartifact._load_packages_for_analysis())
assert [package.name for _, package in pairs] == ["first-package"]
assert regional_client.version_calls == ["first-package"]
def test_package_limit_exposes_only_selected_packages(self):
codeartifact = CodeArtifact.__new__(CodeArtifact)
codeartifact.package_limit = 2
codeartifact._packages_listed = set()
repository = Repository(
name="repository",
arn="repo",
domain_name="domain",
domain_owner=AWS_ACCOUNT_NUMBER,
region=AWS_REGION_EU_WEST_1,
)
codeartifact.repositories = {repository.arn: repository}
enriched = []
def iter_repository_packages(repository, limit=None):
for index in range(3):
if limit is not None and index >= limit:
return
enriched.append(index)
yield SimpleNamespace(name=f"package-{index}")
codeartifact._iter_repository_packages = iter_repository_packages
packages = list(codeartifact._load_packages_for_analysis())
assert [package.name for _, package in packages] == ["package-0", "package-1"]
assert enriched == [0, 1]
def mock_make_api_call_no_namespace(self, operation_name, kwarg):
"""Mock for packages without a namespace to exercise the else branch"""
@@ -14,6 +14,57 @@ EXAMPLE_AMI_ID = "ami-12c6146b"
class Test_ec2_securitygroup_not_used:
def test_ec2_sg_used_by_lambda_outside_selected_analysis_limit(self):
from prowler.providers.aws.services.ec2.ec2_service import SecurityGroup
sg_id = "sg-limited-out"
sg_name = "lambda-sg"
security_group = SecurityGroup(
name=sg_name,
region=AWS_REGION_US_EAST_1,
arn=f"arn:aws:ec2:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:security-group/{sg_id}",
id=sg_id,
vpc_id="vpc-test",
associated_sgs=[],
network_interfaces=[],
ingress_rules=[],
egress_rules=[],
tags=[],
)
ec2_client = mock.MagicMock()
ec2_client.security_groups = {security_group.arn: security_group}
awslambda_client = mock.MagicMock()
awslambda_client.functions = {}
awslambda_client.security_groups_in_use = {sg_id}
aws_provider = set_mocked_aws_provider()
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used.ec2_client",
new=ec2_client,
),
mock.patch(
"prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used.awslambda_client",
new=awslambda_client,
),
):
from prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used import (
ec2_securitygroup_not_used,
)
result = ec2_securitygroup_not_used().execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Security group {sg_name} ({sg_id}) it is being used."
)
@mock_aws
def test_ec2_default_sgs(self):
# Create EC2 Mocked Resources
@@ -11,7 +11,7 @@ from freezegun import freeze_time
from moto import mock_aws
from prowler.config.config import encoding_format_utf_8
from prowler.providers.aws.services.ec2.ec2_service import EC2
from prowler.providers.aws.services.ec2.ec2_service import EC2, Snapshot
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -103,6 +103,99 @@ class Test_EC2_Service:
ec2 = EC2(aws_provider)
assert ec2.audited_account == AWS_ACCOUNT_NUMBER
def test_snapshot_limit_bounds_public_attribute_calls_to_latest_selected(self):
class FakeEC2Client:
def __init__(self):
self.calls = []
def describe_snapshot_attribute(self, **kwargs):
self.calls.append(kwargs["SnapshotId"])
return {"CreateVolumePermissions": []}
regional_client = FakeEC2Client()
ec2 = EC2.__new__(EC2)
ec2.snapshots = [
Snapshot(
id="snap-old",
arn="arn:aws:ec2:eu-west-1:123456789012:snapshot/snap-old",
region=AWS_REGION_EU_WEST_1,
encrypted=True,
start_time=datetime(2024, 1, 1),
volume="vol-old",
),
Snapshot(
id="snap-new",
arn="arn:aws:ec2:eu-west-1:123456789012:snapshot/snap-new",
region=AWS_REGION_EU_WEST_1,
encrypted=True,
start_time=datetime(2024, 1, 2),
volume="vol-new",
),
]
ec2.snapshot_limit = 1
ec2.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
ec2._select_snapshots_for_analysis()
for snapshot in ec2.snapshots:
ec2._determine_public_snapshots(snapshot)
assert [snapshot.id for snapshot in ec2.snapshots] == ["snap-new"]
assert regional_client.calls == ["snap-new"]
def test_snapshot_limit_preserves_volume_index_and_selects_global_latest(self):
class FakePaginator:
def __init__(self, snapshots):
self.snapshots = snapshots
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [{"Snapshots": self.snapshots}]
class FakeEC2Client:
def __init__(self, region, snapshots):
self.region = region
self.snapshots = snapshots
def get_paginator(self, name):
assert name == "describe_snapshots"
return FakePaginator(self.snapshots)
ec2 = EC2.__new__(EC2)
ec2.snapshots = []
ec2.volumes_with_snapshots = {}
ec2.regions_with_snapshots = {}
ec2.snapshot_limit = 1
ec2.audit_resources = []
ec2.audited_partition = "aws"
ec2.audited_account = AWS_ACCOUNT_NUMBER
old_client = FakeEC2Client(
AWS_REGION_EU_WEST_1,
[
{
"SnapshotId": "snap-old",
"VolumeId": "vol-old",
"StartTime": datetime(2024, 1, 1),
}
],
)
new_client = FakeEC2Client(
AWS_REGION_US_EAST_1,
[
{
"SnapshotId": "snap-new",
"VolumeId": "vol-new",
"StartTime": datetime(2024, 1, 2),
}
],
)
ec2._describe_snapshots(old_client)
ec2._describe_snapshots(new_client)
ec2._select_snapshots_for_analysis()
assert ec2.volumes_with_snapshots == {"vol-old": True, "vol-new": True}
assert [snapshot.id for snapshot in ec2.snapshots] == ["snap-new"]
# Test EC2 Describe Instances
@mock_aws
@freeze_time(MOCK_DATETIME)
@@ -346,6 +439,24 @@ class Test_EC2_Service:
assert not snapshot.encrypted
assert snapshot.public
@mock_aws
def test_snapshot_limit_exposes_only_selected_snapshots(self):
ec2_client = client("ec2", region_name=AWS_REGION_US_EAST_1)
ec2_resource = resource("ec2", region_name=AWS_REGION_US_EAST_1)
volume_id = ec2_resource.create_volume(
AvailabilityZone="us-east-1a",
Size=80,
VolumeType="gp2",
).id
for _ in range(3):
ec2_client.create_snapshot(VolumeId=volume_id)
aws_provider = set_mocked_aws_provider(
[AWS_REGION_US_EAST_1], audit_config={"max_ebs_snapshots": 1}
)
ec2 = EC2(aws_provider)
assert len(ec2.snapshots) == 1
# Test EC2 Instance User Data
@mock_aws
def test_get_instance_user_data(self):
@@ -3,7 +3,11 @@ from unittest.mock import patch
import botocore
from prowler.providers.aws.services.ecs.ecs_service import ECS
from tests.providers.aws.utils import AWS_REGION_EU_WEST_1, set_mocked_aws_provider
from tests.providers.aws.utils import (
AWS_REGION_EU_WEST_1,
AWS_REGION_US_EAST_1,
set_mocked_aws_provider,
)
make_api_call = botocore.client.BaseClient._make_api_call
@@ -115,6 +119,23 @@ def mock_generate_regional_clients(provider, service):
return {AWS_REGION_EU_WEST_1: regional_client}
def mock_generate_multi_region_clients(provider, service):
eu_west_1_client = provider._session.current_session.client(
service, region_name=AWS_REGION_EU_WEST_1
)
eu_west_1_client.region = AWS_REGION_EU_WEST_1
us_east_1_client = provider._session.current_session.client(
service, region_name=AWS_REGION_US_EAST_1
)
us_east_1_client.region = AWS_REGION_US_EAST_1
return {
AWS_REGION_EU_WEST_1: eu_west_1_client,
AWS_REGION_US_EAST_1: us_east_1_client,
}
@patch(
"prowler.providers.aws.aws_provider.AwsProvider.generate_regional_clients",
new=mock_generate_regional_clients,
@@ -139,7 +160,6 @@ class Test_ECS_Service:
ecs = ECS(aws_provider)
assert ecs.session.__class__.__name__ == "Session"
# Test list ECS task definitions
@patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
def test_list_task_definitions(self):
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
@@ -201,6 +221,169 @@ class Test_ECS_Service:
.readonly_rootfilesystem
)
def test_task_definitions_are_loaded_once_for_analysis(self):
describe_calls = []
list_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
list_calls.append(kwarg)
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == [
"3",
"2",
"1",
]
assert list_calls == [{"sort": "DESC"}]
assert len(describe_calls) == 3
def test_task_definition_limit_exposes_only_selected_resources(self):
describe_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1], audit_config={"max_ecs_task_definitions": 2}
)
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == ["3", "2"]
assert len(describe_calls) == 2
def test_task_definition_limit_bounds_describe_calls(self):
describe_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return mock_make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1], audit_config={"max_ecs_task_definitions": 1}
)
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == ["3"]
assert describe_calls == [
"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:3"
]
def test_task_definition_limit_does_not_starve_later_regions(self):
describe_calls = []
def counting_make_api_call(self, operation_name, kwarg):
region = self.meta.region_name
if operation_name == "ListTaskDefinitions":
task_definition_revisions = {
AWS_REGION_EU_WEST_1: (3, 2, 1),
AWS_REGION_US_EAST_1: (9,),
}[region]
return {
"taskDefinitionArns": [
f"arn:aws:ecs:{region}:123456789012:task-definition/fam:{revision}"
for revision in task_definition_revisions
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
if operation_name == "ListClusters":
return {"clusterArns": []}
return mock_make_api_call(self, operation_name, kwarg)
with (
patch(
"prowler.providers.aws.aws_provider.AwsProvider.generate_regional_clients",
new=mock_generate_multi_region_clients,
),
patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
),
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1, AWS_REGION_US_EAST_1],
audit_config={"max_ecs_task_definitions": 2},
)
ecs = ECS(aws_provider)
assert [td.region for td in ecs.task_definitions.values()] == [
AWS_REGION_EU_WEST_1,
AWS_REGION_US_EAST_1,
]
assert set(describe_calls) == {
"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:3",
"arn:aws:ecs:us-east-1:123456789012:task-definition/fam:9",
}
# Test list ECS clusters
@patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
def test_list_clusters(self):
@@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest import mock
from prowler.providers.aws.services.inspector2.inspector2_service import Inspector
@@ -13,6 +14,65 @@ FINDING_ARN = (
class Test_inspector2_is_enabled:
def test_lambda_disabled_with_region_hidden_by_function_analysis_limit(self):
inspector2_client = mock.MagicMock()
inspector2_client.provider = SimpleNamespace(scan_unused_services=False)
inspector2_client.inspectors = [
Inspector(
id=AWS_ACCOUNT_NUMBER,
arn=f"arn:aws:inspector2:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:inspector2",
status="ENABLED",
ec2_status="ENABLED",
ecr_status="ENABLED",
lambda_status="DISABLED",
lambda_code_status="ENABLED",
region=AWS_REGION_EU_WEST_1,
)
]
awslambda_client = mock.MagicMock()
awslambda_client.functions = {}
awslambda_client.regions_with_functions = {AWS_REGION_EU_WEST_1}
ec2_client = mock.MagicMock()
ec2_client.instances = []
ecr_client = mock.MagicMock()
ecr_client.registries = {AWS_REGION_EU_WEST_1: SimpleNamespace(repositories=[])}
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.inspector2_client",
new=inspector2_client,
),
mock.patch(
"prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.awslambda_client",
new=awslambda_client,
),
mock.patch(
"prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.ec2_client",
new=ec2_client,
),
mock.patch(
"prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.ecr_client",
new=ecr_client,
),
):
from prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled import (
inspector2_is_enabled,
)
result = inspector2_is_enabled().execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "Inspector2 is not enabled for the following services: Lambda."
)
def test_inspector2_disabled(self):
# Mock the inspector2 client
inspector2_client = mock.MagicMock