Compare commits

...

10 Commits

Author SHA1 Message Date
Hugo P.Brito 1250417ebe refactor(sdk): move AWS resource limits into services 2026-06-10 16:34:40 +02:00
Hugo P.Brito 0f102a6b8d feat(sdk): apply resource limits to more AWS checks 2026-06-10 14:00:32 +02:00
Hugo P.Brito 2499e8b3ab fix(sdk): disable resource analysis limits by default
- Keep AWS resource analysis limits opt-in

- Document unlimited default behavior

- Update limit resolution tests
2026-06-10 13:09:58 +02:00
Hugo P.Brito 2376115ff1 feat(sdk): limit AWS resource analysis
- Scope resource scan limits to selected high-volume AWS paths

- Remove fail-driven finding prioritization semantics

- Add tests for bounded expensive AWS calls
2026-06-10 12:51:24 +02:00
Hugo P.Brito d4e19dca37 docs(sdk): document per-service resource scan limit config 2026-05-27 11:47:23 +02:00
Hugo P.Brito 67eb40494c feat(sdk): apply resource scan limit to CloudWatch log groups
- Lazy iter_log_groups generator deferring heavy log-event fetch
- Migrate the 4 cloudwatch_log_group_* checks to the shared limiter
- Keep log group tags eager (consumed by metric-filter checks)
2026-05-27 11:45:54 +02:00
Hugo P.Brito ec2c78180b feat(sdk): apply resource scan limit to Lambda functions
- Lazy iter_functions generator with on-demand per-function hydration
  of policy, URL config, tags and event source mappings
- Migrate the 12 awslambda_function_* checks to the shared limiter
2026-05-27 11:45:54 +02:00
Hugo P.Brito e8e6f2b5b4 feat(sdk): apply resource scan limit to EBS snapshots and Backup
- Lazy iter_recovery_points generator for Backup recovery points
- Lazy iter_snapshots with deferred public-status hydration for EBS
- Migrate EBS snapshot and Backup recovery point checks
2026-05-27 11:45:54 +02:00
Hugo P.Brito 1aa652d780 docs(sdk): changelog for per-service resource scan limit 2026-05-27 11:45:53 +02:00
Hugo P.Brito 2a24008d46 feat(sdk): fail-driven per-service AWS resource scan limit
- Add limited_findings helper and configurable per-service limit
- Lazy memoized ECS task definition and CodeArtifact package fetch
- Prioritize FAIL findings within the configurable limit
2026-05-27 11:44:08 +02:00
17 changed files with 1017 additions and 203 deletions
@@ -77,6 +77,25 @@ The following list includes all the AWS checks with configurable variables that
| `opensearch_service_domains_not_publicly_accessible` | `trusted_ips` | List of Strings |
### Resource Scan Limit
Some AWS services accumulate large numbers of resources (EBS snapshots, backup recovery points, CloudWatch log groups, Lambda functions, ECS task definitions, and CodeArtifact packages). Scanning every resource increases scan time, cost, API throttling, and finding volume. By default, Prowler scans every resource. Configure a positive resource scan limit to cap how many resources Prowler analyzes for these high-volume AWS resource paths, using the latest resources where AWS API ordering or resource timestamps support it. Otherwise, Prowler uses best-effort API order.
The global default applies to the supported resources below and is overridable per resource. The default value is `0`, which disables the limit and scans every resource. `0`, negative values, and `null` are unlimited; positive values enable limits. The limit applies to resources selected for analysis, not to findings; a selected resource may produce zero, one, or many findings.
Exact list API call reduction depends on each AWS API's ordering and pagination capabilities. When Prowler must enumerate candidates locally to select the latest resources, list calls may still read candidates, but expensive per-resource enrichment calls are bounded to the selected resources for the supported paths below.
| Value | Scope | Type |
|------------------------------------|---------------------------------------------------------|---------|
| `max_scanned_resources_per_service`| Global default for supported high-volume AWS resources (default `0`, disabled/unlimited) | Integer |
| `max_ebs_snapshots` | EBS snapshots (`ec2_ebs_*` checks) | Integer |
| `max_backup_recovery_points` | Backup recovery points (`backup_recovery_point_*`) | Integer |
| `max_cloudwatch_log_groups` | CloudWatch log groups (`cloudwatch_log_group_*`) | Integer |
| `max_lambda_functions` | Lambda functions (`awslambda_function_*`) | Integer |
| `max_ecs_task_definitions` | ECS task definitions (`ecs_task_definitions_*`) | Integer |
| `max_codeartifact_packages` | CodeArtifact packages (`codeartifact_packages_*`) | Integer |
## Azure
### Configurable Checks
@@ -181,6 +200,19 @@ aws:
# AWS Global Configuration
# aws.mute_non_default_regions --> Set to True to muted failed findings in non-default regions for AccessAnalyzer, GuardDuty, SecurityHub, DRS and Config
mute_non_default_regions: False
# AWS Resource Scan Limit Configuration
# Disabled by default: scan every resource unless a positive limit is configured.
# Findings are not capped. Set to 0 (or a negative value) to disable the limit.
# aws.max_scanned_resources_per_service --> global default for all services below
max_scanned_resources_per_service: 0
# Per-service overrides. Leave as null to fall back to the global default.
max_ebs_snapshots: null
max_backup_recovery_points: null
max_cloudwatch_log_groups: null
max_lambda_functions: null
max_ecs_task_definitions: null
max_codeartifact_packages: null
# If you want to mute failed findings only in specific regions, create a file with the following syntax and run it with `prowler aws -w mutelist.yaml`:
# Mutelist:
# Accounts:
+4
View File
@@ -10,6 +10,10 @@ All notable changes to the **Prowler SDK** are documented in this file.
- AWS AI Security Framework compliance for AWS provider [(#11353)](https://github.com/prowler-cloud/prowler/pull/11353)
- `storage_account_public_network_access_disabled` check for Azure provider and remapped the Azure CIS "Public Network Access is Disabled" requirements to it [(#11334)](https://github.com/prowler-cloud/prowler/pull/11334)
### 🔄 Changed
- AWS scans for EBS snapshots, Backup recovery points, CloudWatch log groups, Lambda functions, ECS task definitions, and CodeArtifact packages now support configurable resource analysis limits via `aws.max_scanned_resources_per_service`; limits are disabled by default and only positive values cap analyzed resources [(#11228)](https://github.com/prowler-cloud/prowler/pull/11228)
### 🐞 Fixed
- Compliance CSV row count now matches the UI per requirement by sourcing rows from the framework JSON's `requirement.Checks` instead of the stale `finding.compliance` snapshot [(#11370)](https://github.com/prowler-cloud/prowler/pull/11370)
+26
View File
@@ -3,6 +3,32 @@ aws:
# AWS Global Configuration
# aws.mute_non_default_regions --> Set to True to muted failed findings in non-default regions for AccessAnalyzer, GuardDuty, SecurityHub, DRS and Config
mute_non_default_regions: False
# AWS Resource Scan Limit Configuration
# Limits the number of resources scanned per service for services that can
# accumulate huge numbers of resources (EBS snapshots, backup recovery
# points, CloudWatch log groups, Lambda functions, ECS task definitions,
# CodeArtifact packages). Limits apply to resources analyzed, not findings:
# a selected resource can produce zero, one, or many findings. Where the AWS
# API supports server-side ordering the latest resources are scanned first;
# otherwise it is best-effort API order.
# Disabled by default: scan every resource unless a positive limit is configured.
# Set to 0 (or a negative value) to disable the limit (scan every resource).
# aws.max_scanned_resources_per_service --> global default for all services below
max_scanned_resources_per_service: 0
# Per-service overrides. Leave as null to fall back to the global default.
# aws.max_ebs_snapshots --> ec2_ebs_* checks (EBS snapshots)
max_ebs_snapshots: null
# aws.max_backup_recovery_points --> backup_recovery_point_* checks
max_backup_recovery_points: null
# aws.max_cloudwatch_log_groups --> cloudwatch_log_group_* checks
max_cloudwatch_log_groups: null
# aws.max_lambda_functions --> awslambda_function_* checks
max_lambda_functions: null
# aws.max_ecs_task_definitions --> ecs_task_definitions_* checks
max_ecs_task_definitions: null
# aws.max_codeartifact_packages --> codeartifact_packages_* checks
max_codeartifact_packages: null
# aws.disallowed_regions --> List of AWS regions to exclude from the scan.
# Also settable via the PROWLER_AWS_DISALLOWED_REGIONS environment variable or
# the --excluded-region CLI flag. Precedence: CLI > env var > config file.
+50
View File
@@ -0,0 +1,50 @@
"""Scoped resource scan limits for high-volume AWS resources.
Some AWS services accumulate huge numbers of resources (EBS snapshots, backup
recovery points, log groups, Lambda functions, ECS task definitions,
CodeArtifact packages). Scanning all of them causes API throttling, slow
scans, cost and noisy findings.
``get_resource_scan_limit`` resolves the configured number of resources to
analyze for a supported resource path. A limited resource can produce zero,
one, or many findings; findings are not capped or re-ordered here.
"""
from collections.abc import Iterable, Iterator
from itertools import islice
from typing import Optional, TypeVar
GLOBAL_LIMIT_KEY = "max_scanned_resources_per_service"
T = TypeVar("T")
def get_resource_scan_limit(audit_config: dict, service_key: str) -> Optional[int]:
"""Resolve the resource scan limit for a service.
Precedence: per-service key (``service_key``) > global
``max_scanned_resources_per_service`` > unlimited.
A non-positive resolved value means **unlimited** (``None``), preserving
the legacy behavior as an explicit opt-out.
Args:
audit_config: The provider ``audit_config`` dictionary.
service_key: The per-service config key, e.g. ``max_lambda_functions``.
Returns:
The limit as a positive ``int``, or ``None`` for unlimited.
"""
value = audit_config.get(service_key)
if value is None:
value = audit_config.get(GLOBAL_LIMIT_KEY)
if value is None or value <= 0:
return None
return int(value)
def limit_resources(resources: Iterable[T], limit: Optional[int]) -> Iterator[T]:
"""Yield up to ``limit`` resources without changing resource order."""
if not limit or limit <= 0:
yield from resources
return
yield from islice(resources, limit)
@@ -9,6 +9,7 @@ import requests
from botocore.client import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.check.resource_limit import get_resource_scan_limit, limit_resources
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -18,8 +19,14 @@ class Lambda(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# Functions are listed first, then trimmed to the subset selected for
# analysis before expensive per-function detail is hydrated.
self.functions = {}
self.function_limit = get_resource_scan_limit(
self.audit_config, "max_lambda_functions"
)
self.__threading_call__(self._list_functions)
self._select_functions_for_analysis()
self._list_tags_for_resource()
self.__threading_call__(self._get_policy)
self.__threading_call__(self._get_function_url_config)
@@ -48,6 +55,10 @@ class Lambda(AWSService):
subnet_ids=set(vpc_config.get("SubnetIds", [])),
region=regional_client.region,
)
if "LastModified" in function:
self.functions[lambda_arn].last_modified = function[
"LastModified"
]
if "Runtime" in function:
self.functions[lambda_arn].runtime = function["Runtime"]
if "Environment" in function:
@@ -76,6 +87,19 @@ class Lambda(AWSService):
f" {error}"
)
def _select_functions_for_analysis(self):
self.functions = {
function.arn: function
for function in limit_resources(
sorted(
self.functions.values(),
key=lambda f: f.last_modified or "",
reverse=True,
),
self.function_limit,
)
}
def _list_event_source_mappings(self, regional_client):
logger.info("Lambda - Listing Event Source Mappings...")
try:
@@ -158,10 +182,9 @@ class Lambda(AWSService):
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
self.functions[function.arn].policy = {}
except Exception as error:
logger.error(
f"{regional_client.region} --"
f"{function.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
@@ -187,10 +210,9 @@ class Lambda(AWSService):
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
self.functions[function.arn].url_config = None
except Exception as error:
logger.error(
f"{regional_client.region} --"
f"{function.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
@@ -206,10 +228,9 @@ class Lambda(AWSService):
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
function.tags = []
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{function.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
@@ -259,6 +280,7 @@ class Function(BaseModel):
name: str
arn: str
security_groups: list
last_modified: Optional[str] = None
runtime: Optional[str] = None
environment: Optional[dict] = None
region: str
@@ -4,6 +4,7 @@ from typing import Optional
from botocore.client import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.check.resource_limit import get_resource_scan_limit, limit_resources
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -27,8 +28,13 @@ class Backup(AWSService):
self.__threading_call__(self._list_backup_report_plans)
self.protected_resources = []
self.__threading_call__(self._list_backup_selections)
# Recovery points are listed first, then only the selected subset is
# tagged and exposed for checks.
self.recovery_points = []
self.__threading_call__(self._list_recovery_points)
self.recovery_point_limit = get_resource_scan_limit(
self.audit_config, "max_backup_recovery_points"
)
self._list_recovery_points()
self.__threading_call__(self._list_tags, self.recovery_points)
def _list_backup_vaults(self, regional_client):
@@ -183,38 +189,52 @@ class Backup(AWSService):
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _list_recovery_points(self, regional_client):
def _list_recovery_points(self):
logger.info("Backup - Listing Recovery Points...")
try:
if self.backup_vaults:
for backup_vault in self.backup_vaults:
paginator = regional_client.get_paginator(
"list_recovery_points_by_backup_vault"
)
for page in paginator.paginate(BackupVaultName=backup_vault.name):
for recovery_point in page.get("RecoveryPoints", []):
arn = recovery_point.get("RecoveryPointArn")
if arn:
self.recovery_points.append(
RecoveryPoint(
arn=arn,
id=arn.split(":")[-1],
backup_vault_name=backup_vault.name,
encrypted=recovery_point.get(
"IsEncrypted", False
),
backup_vault_region=backup_vault.region,
region=regional_client.region,
tags=[],
)
)
candidates = []
for backup_vault in self.backup_vaults or []:
regional_client = self.regional_clients[backup_vault.region]
paginator = regional_client.get_paginator(
"list_recovery_points_by_backup_vault"
)
for page in paginator.paginate(BackupVaultName=backup_vault.name):
for recovery_point in page.get("RecoveryPoints", []):
arn = recovery_point.get("RecoveryPointArn")
if not arn:
continue
candidates.append((backup_vault, recovery_point))
for backup_vault, recovery_point in limit_resources(
sorted(
candidates,
key=lambda candidate: (
candidate[1]["CreationDate"].timestamp()
if candidate[1].get("CreationDate")
else 0.0
),
reverse=True,
),
self.recovery_point_limit,
):
arn = recovery_point.get("RecoveryPointArn")
rp = RecoveryPoint(
arn=arn,
id=arn.split(":")[-1],
backup_vault_name=backup_vault.name,
encrypted=recovery_point.get("IsEncrypted", False),
creation_date=recovery_point.get("CreationDate"),
backup_vault_region=backup_vault.region,
region=backup_vault.region,
tags=[],
)
self.recovery_points.append(rp)
except ClientError as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
@@ -256,4 +276,5 @@ class RecoveryPoint(BaseModel):
backup_vault_name: str
encrypted: bool
backup_vault_region: str
creation_date: Optional[datetime] = None
tags: Optional[list] = None
@@ -5,6 +5,7 @@ from typing import Optional
from botocore.exceptions import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.check.resource_limit import get_resource_scan_limit, limit_resources
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -83,8 +84,17 @@ class Logs(AWSService):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
self.log_group_arn_template = f"arn:{self.audited_partition}:logs:{self.region}:{self.audited_account}:log-group"
# Log groups are listed first, then only the selected subset is enriched
# and exposed for checks.
self.log_groups = {}
self._log_groups_hydrated = set()
self.log_group_limit = get_resource_scan_limit(
self.audit_config, "max_cloudwatch_log_groups"
)
# The threshold for number of events to return per log group.
self.events_per_log_group_threshold = 1000
self.__threading_call__(self._describe_log_groups)
self._select_log_groups_for_analysis()
self.resource_policies = {}
self.__threading_call__(self._describe_resource_policies)
self.metric_filters = []
@@ -94,14 +104,26 @@ class Logs(AWSService):
"cloudwatch_log_group_no_secrets_in_logs"
in provider.audit_metadata.expected_checks
):
self.events_per_log_group_threshold = (
1000 # The threshold for number of events to return per log group.
)
self.__threading_call__(self._get_log_events)
self.__threading_call__(self._get_log_events, self.log_groups.values())
self.__threading_call__(
self._list_tags_for_resource, self.log_groups.values()
)
def _select_log_groups_for_analysis(self):
if not self.log_groups:
return
self.log_groups = {
log_group.arn: log_group
for log_group in limit_resources(
sorted(
self.log_groups.values(),
key=lambda lg: lg.creation_time or 0,
reverse=True,
),
self.log_group_limit,
)
}
def _describe_metric_filters(self, regional_client):
logger.info("CloudWatch Logs - Describing metric filters...")
try:
@@ -123,6 +145,9 @@ class Logs(AWSService):
log_group = lg
break
if log_group and log_group.arn not in self._log_groups_hydrated:
self._list_tags_for_resource(log_group)
self.metric_filters.append(
MetricFilter(
arn=arn,
@@ -174,6 +199,7 @@ class Logs(AWSService):
retention_days=retention_days,
never_expire=never_expire,
kms_id=kms,
creation_time=log_group.get("creationTime"),
region=regional_client.region,
)
except ClientError as error:
@@ -192,37 +218,24 @@ class Logs(AWSService):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _get_log_events(self, regional_client):
regional_log_groups = [
log_group
for log_group in self.log_groups.values()
if log_group.region == regional_client.region
]
total_log_groups = len(regional_log_groups)
def _get_log_events(self, log_group):
logger.info(
f"CloudWatch Logs - Retrieving log events for {total_log_groups} log groups in {regional_client.region}..."
f"CloudWatch Logs - Retrieving log events for log group {log_group.name}..."
)
try:
for count, log_group in enumerate(regional_log_groups, start=1):
events = regional_client.filter_log_events(
logGroupName=log_group.name,
limit=self.events_per_log_group_threshold,
)["events"]
for event in events:
if event["logStreamName"] not in log_group.log_streams:
log_group.log_streams[event["logStreamName"]] = []
log_group.log_streams[event["logStreamName"]].append(event)
if count % 10 == 0:
logger.info(
f"CloudWatch Logs - Retrieved log events for {count}/{total_log_groups} log groups in {regional_client.region}..."
)
regional_client = self.regional_clients[log_group.region]
events = regional_client.filter_log_events(
logGroupName=log_group.name,
limit=self.events_per_log_group_threshold,
)["events"]
for event in events:
if event["logStreamName"] not in log_group.log_streams:
log_group.log_streams[event["logStreamName"]] = []
log_group.log_streams[event["logStreamName"]].append(event)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{log_group.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
logger.info(
f"CloudWatch Logs - Finished retrieving log events in {regional_client.region}..."
)
def _describe_resource_policies(self, regional_client):
logger.info("CloudWatch Logs - Describing resource policies...")
@@ -257,6 +270,8 @@ class Logs(AWSService):
)
def _list_tags_for_resource(self, log_group):
if log_group.arn in self._log_groups_hydrated:
return
logger.info(f"CloudWatch Logs - List Tags for Log Group {log_group.name}...")
try:
regional_client = self.regional_clients[log_group.region]
@@ -264,6 +279,7 @@ class Logs(AWSService):
resourceArn=log_group.arn
)["tags"]
log_group.tags = [response]
self._log_groups_hydrated.add(log_group.arn)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
@@ -292,6 +308,7 @@ class LogGroup(BaseModel):
retention_days: int
never_expire: bool
kms_id: Optional[str]
creation_time: Optional[int] = None
region: str
log_streams: dict[str, list[str]] = (
{}
@@ -1,9 +1,10 @@
from enum import Enum
from typing import Optional
from typing import Iterator, Optional, Tuple
from botocore.exceptions import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.check.resource_limit import get_resource_scan_limit
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -15,8 +16,15 @@ class CodeArtifact(AWSService):
super().__init__(__class__.__name__, provider)
# repositories is a dictionary containing all the codeartifact service information
self.repositories = {}
# repository ARNs whose selected packages have been listed and memoized
# into repository.packages.
self._packages_listed = set()
self.package_limit = get_resource_scan_limit(
self.audit_config, "max_codeartifact_packages"
)
self.__threading_call__(self._list_repositories)
self.__threading_call__(self._list_packages)
for _ in self._load_packages_for_analysis():
pass
self._list_tags_for_resource()
def _list_repositories(self, regional_client):
@@ -51,124 +59,128 @@ class CodeArtifact(AWSService):
f" {error}"
)
def _list_packages(self, regional_client):
logger.info("CodeArtifact - Listing Packages and retrieving information...")
for repository in self.repositories:
try:
if self.repositories[repository].region == regional_client.region:
list_packages_paginator = regional_client.get_paginator(
"list_packages"
)
list_packages_parameters = {
"domain": self.repositories[repository].domain_name,
"domainOwner": self.repositories[repository].domain_owner,
"repository": self.repositories[repository].name,
def _iter_repository_packages(self, repository) -> Iterator["Package"]:
"""Yield packages for a single repository, hydrating each one lazily.
Each package requires an extra ``list_package_versions`` call to
resolve its latest version, so producing them lazily lets the resource
limit stop before extra package version calls.
"""
regional_client = self.regional_clients[repository.region]
try:
list_packages_paginator = regional_client.get_paginator("list_packages")
list_packages_parameters = {
"domain": repository.domain_name,
"domainOwner": repository.domain_owner,
"repository": repository.name,
}
for page in list_packages_paginator.paginate(**list_packages_parameters):
for package in page["packages"]:
# Package information
package_format = package["format"]
package_namespace = package.get("namespace")
package_name = package["package"]
package_origin_configuration_restrictions_publish = package[
"originConfiguration"
]["restrictions"]["publish"]
package_origin_configuration_restrictions_upstream = package[
"originConfiguration"
]["restrictions"]["upstream"]
# Get Latest Package Version
list_package_versions_parameters = {
"domain": repository.domain_name,
"domainOwner": repository.domain_owner,
"repository": repository.name,
"format": package_format,
"package": package_name,
"sortBy": "PUBLISHED_TIME",
"maxResults": 1,
}
packages = []
for page in list_packages_paginator.paginate(
**list_packages_parameters
):
for package in page["packages"]:
# Package information
package_format = package["format"]
package_namespace = package.get("namespace")
package_name = package["package"]
package_origin_configuration_restrictions_publish = package[
"originConfiguration"
]["restrictions"]["publish"]
package_origin_configuration_restrictions_upstream = (
package["originConfiguration"]["restrictions"][
"upstream"
]
)
# Get Latest Package Version
if package_namespace:
latest_version_information = (
regional_client.list_package_versions(
domain=self.repositories[
repository
].domain_name,
domainOwner=self.repositories[
repository
].domain_owner,
repository=self.repositories[repository].name,
format=package_format,
namespace=package_namespace,
package=package_name,
sortBy="PUBLISHED_TIME",
maxResults=1,
)
)
else:
latest_version_information = (
regional_client.list_package_versions(
domain=self.repositories[
repository
].domain_name,
domainOwner=self.repositories[
repository
].domain_owner,
repository=self.repositories[repository].name,
format=package_format,
package=package_name,
sortBy="PUBLISHED_TIME",
maxResults=1,
)
)
latest_version = ""
latest_origin_type = "UNKNOWN"
latest_status = "Published"
if latest_version_information.get("versions"):
latest_version = latest_version_information["versions"][
0
].get("version")
latest_origin_type = (
latest_version_information["versions"][0]
.get("origin", {})
.get("originType", "UNKNOWN")
)
latest_status = latest_version_information["versions"][
0
].get("status", "Published")
packages.append(
Package(
name=package_name,
namespace=package_namespace,
format=package_format,
origin_configuration=OriginConfiguration(
restrictions=Restrictions(
publish=package_origin_configuration_restrictions_publish,
upstream=package_origin_configuration_restrictions_upstream,
)
),
latest_version=LatestPackageVersion(
version=latest_version,
status=latest_status,
origin=OriginInformation(
origin_type=latest_origin_type
),
),
)
)
# Save all the packages information
self.repositories[repository].packages = packages
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
f"{regional_client.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
if package_namespace:
list_package_versions_parameters["namespace"] = (
package_namespace
)
latest_version_information = regional_client.list_package_versions(
**list_package_versions_parameters
)
continue
latest_version = ""
latest_origin_type = "UNKNOWN"
latest_status = "Published"
if latest_version_information.get("versions"):
latest_version = latest_version_information["versions"][0].get(
"version"
)
latest_origin_type = (
latest_version_information["versions"][0]
.get("origin", {})
.get("originType", "UNKNOWN")
)
latest_status = latest_version_information["versions"][0].get(
"status", "Published"
)
except Exception as error:
logger.error(
f"{regional_client.region} --"
yield Package(
name=package_name,
namespace=package_namespace,
format=package_format,
origin_configuration=OriginConfiguration(
restrictions=Restrictions(
publish=package_origin_configuration_restrictions_publish,
upstream=package_origin_configuration_restrictions_upstream,
)
),
latest_version=LatestPackageVersion(
version=latest_version,
status=latest_status,
origin=OriginInformation(origin_type=latest_origin_type),
),
)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
f"{repository.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
else:
logger.error(
f"{repository.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
except Exception as error:
logger.error(
f"{repository.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
def _load_packages_for_analysis(self) -> Iterator[Tuple["Repository", "Package"]]:
"""Yield the ``(repository, package)`` pairs selected for analysis.
Package listing stays in the service layer so checks receive only the
selected packages and remain unaware of resource-analysis limits.
"""
yielded = 0
for repository in list(self.repositories.values()):
if repository.arn in self._packages_listed:
for package in repository.packages:
yield repository, package
yielded += 1
if self.package_limit and yielded >= self.package_limit:
return
continue
collected = []
for package in self._iter_repository_packages(repository):
collected.append(package)
repository.packages = collected
yield repository, package
yielded += 1
if self.package_limit and yielded >= self.package_limit:
self._packages_listed.add(repository.arn)
return
self._packages_listed.add(repository.arn)
def _list_tags_for_resource(self):
logger.info("CodeArtifact - List Tags...")
@@ -5,6 +5,7 @@ from typing import Optional, Union
from botocore.client import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.check.resource_limit import get_resource_scan_limit, limit_resources
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -26,7 +27,12 @@ class EC2(AWSService):
self.snapshots = []
self.volumes_with_snapshots = {}
self.regions_with_snapshots = {}
# Snapshots are listed first, then limited before public status is hydrated.
self.snapshot_limit = get_resource_scan_limit(
self.audit_config, "max_ebs_snapshots"
)
self.__threading_call__(self._describe_snapshots)
self._select_snapshots_for_analysis()
self.__threading_call__(self._determine_public_snapshots, self.snapshots)
self.network_interfaces = {}
self.__threading_call__(self._describe_network_interfaces)
@@ -207,6 +213,7 @@ class EC2(AWSService):
arn=arn,
region=regional_client.region,
encrypted=snapshot.get("Encrypted", False),
start_time=snapshot.get("StartTime"),
tags=snapshot.get("Tags"),
volume=snapshot["VolumeId"],
)
@@ -243,6 +250,18 @@ class EC2(AWSService):
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _select_snapshots_for_analysis(self):
self.snapshots = list(
limit_resources(
sorted(
self.snapshots,
key=lambda s: (s.start_time.timestamp() if s.start_time else 0.0),
reverse=True,
),
self.snapshot_limit,
)
)
def _describe_network_interfaces(self, regional_client):
try:
# Get Network Interfaces with Public IPs
@@ -686,6 +705,7 @@ class Snapshot(BaseModel):
region: str
encrypted: bool
public: bool = False
start_time: Optional[datetime] = None
tags: Optional[list] = []
volume: Optional[str]
@@ -1,8 +1,10 @@
from datetime import datetime
from re import sub
from typing import Optional
from pydantic.v1 import BaseModel
from prowler.lib.check.resource_limit import get_resource_scan_limit, limit_resources
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -12,39 +14,73 @@ class ECS(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# Task definition ARNs are listed first, then only the selected subset
# is described and exposed for checks.
self.task_definitions = {}
self._task_definition_arns = None
self.task_definition_limit = get_resource_scan_limit(
self.audit_config, "max_ecs_task_definitions"
)
self.services = {}
self.clusters = {}
self.task_sets = {}
self.__threading_call__(self._list_task_definitions)
self.__threading_call__(
self._describe_task_definition, self.task_definitions.values()
)
for _ in self._load_task_definitions_for_analysis():
pass
self.__threading_call__(self._list_clusters)
self.__threading_call__(self._describe_clusters, self.clusters.values())
self.__threading_call__(self._describe_services, self.clusters.values())
def _list_task_definitions(self, regional_client):
def _list_task_definition_arns(self) -> list:
"""List task definition ARNs newest-first, memoized.
Uses the ``list_task_definitions`` server-side ``sort=DESC`` so the
latest revisions are scanned first across all regions.
"""
if self._task_definition_arns is not None:
return self._task_definition_arns
logger.info("ECS - Listing Task Definitions...")
try:
list_ecs_paginator = regional_client.get_paginator("list_task_definitions")
for page in list_ecs_paginator.paginate():
for task_definition in page["taskDefinitionArns"]:
if not self.audit_resources or (
is_resource_filtered(task_definition, self.audit_resources)
):
self.task_definitions[task_definition] = TaskDefinition(
# we want the family name without the revision
name=sub(":.*", "", task_definition.split("/")[-1]),
arn=task_definition,
revision=task_definition.split(":")[-1],
region=regional_client.region,
environment_variables=[],
)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
arns = []
for region, regional_client in self.regional_clients.items():
try:
list_ecs_paginator = regional_client.get_paginator(
"list_task_definitions"
)
for page in list_ecs_paginator.paginate(sort="DESC"):
for task_definition in page["taskDefinitionArns"]:
if not self.audit_resources or (
is_resource_filtered(task_definition, self.audit_resources)
):
arns.append((task_definition, region))
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
self._task_definition_arns = arns
return arns
def _load_task_definitions_for_analysis(self):
"""Yield task definitions lazily, describing each one on demand.
Resources already fetched are memoized in ``self.task_definitions`` and
reused across checks (checks run sequentially, so no locking is needed).
The configured resource limit bounds ``describe_task_definition`` calls.
"""
for arn, region in limit_resources(
self._list_task_definition_arns(), self.task_definition_limit
):
task_definition = self.task_definitions.get(arn)
if task_definition is None:
task_definition = TaskDefinition(
# we want the family name without the revision
name=sub(":.*", "", arn.split("/")[-1]),
arn=arn,
revision=arn.split(":")[-1],
region=region,
environment_variables=[],
)
self._describe_task_definition(task_definition)
self.task_definitions[arn] = task_definition
yield task_definition
def _describe_task_definition(self, task_definition):
logger.info("ECS - Describing Task Definition...")
@@ -84,6 +120,9 @@ class ECS(AWSService):
)
)
task_definition.pid_mode = response["taskDefinition"].get("pidMode", "")
task_definition.registered_at = response["taskDefinition"].get(
"registeredAt"
)
task_definition.tags = response.get("tags")
task_definition.network_mode = response["taskDefinition"].get(
"networkMode", "bridge"
@@ -208,6 +247,7 @@ class TaskDefinition(BaseModel):
region: str
container_definitions: list[ContainerDefinition] = []
pid_mode: Optional[str]
registered_at: Optional[datetime] = None
tags: Optional[list] = []
network_mode: Optional[str]
+71
View File
@@ -0,0 +1,71 @@
from prowler.lib.check.resource_limit import get_resource_scan_limit, limit_resources
class Test_limit_resources:
def test_no_limit_returns_all_in_order(self):
resources = ["PASS", "FAIL", "PASS"]
result = list(limit_resources(iter(resources), None))
assert result == ["PASS", "FAIL", "PASS"]
def test_limit_zero_or_negative_is_unlimited(self):
resources = list(range(5))
assert list(limit_resources(iter(resources), 0)) == resources
assert list(limit_resources(iter(resources), -3)) == resources
def test_positive_limit_stops_after_selected_resources(self):
pulled = []
def gen():
for i in range(1000):
pulled.append(i)
yield i
result = list(limit_resources(gen(), 100))
assert result == list(range(100))
assert len(pulled) == 100
def test_does_not_reorder_or_inspect_resource_status(self):
resources = ["PASS", "FAIL", "PASS", "FAIL"]
result = list(limit_resources(iter(resources), 3))
assert result == ["PASS", "FAIL", "PASS"]
class Test_get_resource_scan_limit:
def test_per_service_override_wins(self):
config = {
"max_scanned_resources_per_service": 100,
"max_ecs_task_definitions": 25,
}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 25
def test_falls_back_to_global_default(self):
config = {"max_scanned_resources_per_service": 50}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 50
def test_default_is_unlimited_when_unset(self):
assert get_resource_scan_limit({}, "max_ecs_task_definitions") is None
def test_null_per_service_override_falls_back_to_unlimited_global_default(self):
config = {"max_ecs_task_definitions": None}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") is None
def test_non_positive_means_unlimited(self):
assert (
get_resource_scan_limit(
{"max_scanned_resources_per_service": 0}, "max_lambda_functions"
)
is None
)
assert (
get_resource_scan_limit(
{"max_lambda_functions": -1}, "max_lambda_functions"
)
is None
)
@@ -9,7 +9,11 @@ import mock
from boto3 import client, resource
from moto import mock_aws
from prowler.providers.aws.services.awslambda.awslambda_service import AuthType, Lambda
from prowler.providers.aws.services.awslambda.awslambda_service import (
AuthType,
Function,
Lambda,
)
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -85,6 +89,30 @@ class Test_Lambda_Service:
awslambda = Lambda(set_mocked_aws_provider([AWS_REGION_US_EAST_1]))
assert awslambda.service == "lambda"
def test_function_limit_selects_latest_functions_for_analysis(self):
awslambda = Lambda.__new__(Lambda)
awslambda.functions = {
"old": Function(
name="old",
arn="old",
security_groups=[],
last_modified="2024-01-01T00:00:00.000+0000",
region=AWS_REGION_EU_WEST_1,
),
"new": Function(
name="new",
arn="new",
security_groups=[],
last_modified="2024-01-02T00:00:00.000+0000",
region=AWS_REGION_EU_WEST_1,
),
}
awslambda.function_limit = 1
awslambda._select_functions_for_analysis()
assert list(awslambda.functions) == ["new"]
@mock_aws
def test_list_functions(self):
# Create IAM Lambda Role
@@ -253,3 +281,63 @@ class Test_Lambda_Service:
f"{tmp_dir_name}/{files_in_zip[0]}", "r"
) as lambda_code_file:
assert lambda_code_file.read() == LAMBDA_FUNCTION_CODE
@mock_aws
def test_function_limit_exposes_only_selected_functions(self):
lambda_client = client("lambda", region_name=AWS_REGION_US_EAST_1)
iam_client = client("iam", region_name=AWS_REGION_US_EAST_1)
iam_role = iam_client.create_role(
RoleName="test-role",
AssumeRolePolicyDocument="{}",
)["Role"]["Arn"]
for name in ("function-1", "function-2"):
lambda_client.create_function(
FunctionName=name,
Runtime="python3.7",
Role=iam_role,
Handler="lambda_function.lambda_handler",
Code={"ZipFile": create_zip_file().read()},
PackageType="ZIP",
)
awslambda = Lambda(
set_mocked_aws_provider(
audited_regions=[AWS_REGION_US_EAST_1],
audit_config={"max_lambda_functions": 1},
)
)
assert len(awslambda.functions) == 1
@mock_aws
def test_get_function_code_fetches_only_selected_functions(self):
lambda_client = client("lambda", region_name=AWS_REGION_US_EAST_1)
iam_client = client("iam", region_name=AWS_REGION_US_EAST_1)
iam_role = iam_client.create_role(
RoleName="test-role",
AssumeRolePolicyDocument="{}",
)["Role"]["Arn"]
for name in ("function-1", "function-2"):
lambda_client.create_function(
FunctionName=name,
Runtime="python3.7",
Role=iam_role,
Handler="lambda_function.lambda_handler",
Code={"ZipFile": create_zip_file().read()},
PackageType="ZIP",
)
awslambda = Lambda(
set_mocked_aws_provider(
audited_regions=[AWS_REGION_US_EAST_1],
audit_config={"max_lambda_functions": 1},
)
)
fetched = []
def fetch_function_code(function_name, _function_region):
fetched.append(function_name)
return mock.MagicMock()
awslambda._fetch_function_code = fetch_function_code
assert len(list(awslambda._get_function_code())) == 1
assert len(fetched) == 1
@@ -1,11 +1,12 @@
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch
import botocore
from boto3 import client
from moto import mock_aws
from prowler.providers.aws.services.backup.backup_service import Backup
from prowler.providers.aws.services.backup.backup_service import Backup, BackupVault
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -292,3 +293,106 @@ class TestBackupService:
assert backup.recovery_points[0].backup_vault_region == "eu-west-1"
assert backup.recovery_points[0].tags == []
assert backup.recovery_points[0].encrypted is True
def test_recovery_point_limit_bounds_tag_calls_to_selected_points(self):
class FakePaginator:
def paginate(self, **kwargs):
return [
{
"RecoveryPoints": [
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:new",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 2),
},
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:old",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 1),
},
]
}
]
class FakeBackupClient:
def __init__(self):
self.tag_calls = []
def get_paginator(self, name):
assert name == "list_recovery_points_by_backup_vault"
return FakePaginator()
def list_tags(self, **kwargs):
self.tag_calls.append(kwargs["ResourceArn"])
return {"Tags": {}}
regional_client = FakeBackupClient()
backup = Backup.__new__(Backup)
backup.backup_vaults = [
BackupVault(
arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:vault",
name="vault",
region=AWS_REGION_EU_WEST_1,
encryption="",
recovery_points=2,
locked=False,
)
]
backup.recovery_points = []
backup.recovery_point_limit = 1
backup.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
backup._list_recovery_points()
for recovery_point in backup.recovery_points:
backup._list_tags(recovery_point)
assert [rp.id for rp in backup.recovery_points] == ["new"]
assert regional_client.tag_calls == [
"arn:aws:backup:eu-west-1:123456789012:recovery-point:new"
]
def test_recovery_point_limit_exposes_only_selected_resources(self):
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 2
backup.recovery_points = []
backup.backup_vaults = [
BackupVault(
arn="arn",
name="vault",
region="eu-west-1",
encryption="",
recovery_points=3,
locked=False,
)
]
class Paginator:
def paginate(self, **_kwargs):
return [
{
"RecoveryPoints": [
{
"RecoveryPointArn": f"arn:aws:backup:eu-west-1:123456789012:recovery-point:{i}",
"IsEncrypted": True,
}
for i in range(3)
]
}
]
backup.regional_clients = {
"eu-west-1": SimpleNamespace(get_paginator=lambda _: Paginator())
}
tagged = []
def list_tags(recovery_point):
tagged.append(recovery_point.arn)
backup._list_tags = list_tags
backup._list_recovery_points()
for recovery_point in backup.recovery_points:
backup._list_tags(recovery_point)
assert len(backup.recovery_points) == 2
assert len(tagged) == 2
@@ -3,6 +3,7 @@ from moto import mock_aws
from prowler.providers.aws.services.cloudwatch.cloudwatch_service import (
CloudWatch,
LogGroup,
Logs,
)
from tests.providers.aws.utils import (
@@ -216,3 +217,46 @@ class Test_CloudWatch_Service:
assert logs.log_groups[arn].kms_id == "test_kms_id"
assert logs.log_groups[arn].region == AWS_REGION_US_EAST_1
assert logs.log_groups[arn].tags == [{}]
def test_log_group_limit_exposes_only_selected_resources(self):
class FakeLogsClient:
def __init__(self):
self.filter_calls = []
def filter_log_events(self, **kwargs):
self.filter_calls.append(kwargs["logGroupName"])
return {"events": []}
regional_client = FakeLogsClient()
logs = Logs.__new__(Logs)
logs.log_group_limit = 1
logs._log_groups_hydrated = set()
logs.regional_clients = {AWS_REGION_US_EAST_1: regional_client}
logs.events_per_log_group_threshold = 1000
logs.log_groups = {
f"arn:{i}": LogGroup(
arn=f"arn:{i}",
name=f"log-{i}",
retention_days=30,
never_expire=False,
kms_id=None,
creation_time=i,
region=AWS_REGION_US_EAST_1,
)
for i in range(3)
}
tagged = []
def list_tags(log_group):
tagged.append(log_group.arn)
logs._list_tags_for_resource = list_tags
logs._select_log_groups_for_analysis()
for log_group in logs.log_groups.values():
logs._list_tags_for_resource(log_group)
logs._get_log_events(log_group)
assert list(logs.log_groups) == ["arn:2"]
assert tagged == ["arn:2"]
assert regional_client.filter_calls == ["log-2"]
@@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest.mock import patch
import botocore
@@ -6,6 +7,7 @@ from prowler.providers.aws.services.codeartifact.codeartifact_service import (
CodeArtifact,
LatestPackageVersionStatus,
OriginInformationValues,
Repository,
RestrictionValues,
)
from tests.providers.aws.utils import (
@@ -208,6 +210,102 @@ class Test_CodeArtifact_Service:
== OriginInformationValues.INTERNAL
)
def test_package_limit_bounds_package_version_lookups_to_selected_packages(self):
class FakePaginator:
def paginate(self, **kwargs):
return [
{
"packages": [
{
"format": "pypi",
"package": "first-package",
"originConfiguration": {
"restrictions": {
"publish": "ALLOW",
"upstream": "ALLOW",
}
},
},
{
"format": "pypi",
"package": "second-package",
"originConfiguration": {
"restrictions": {
"publish": "ALLOW",
"upstream": "ALLOW",
}
},
},
]
}
]
class FakeCodeArtifactClient:
def __init__(self):
self.version_calls = []
def get_paginator(self, name):
assert name == "list_packages"
return FakePaginator()
def list_package_versions(self, **kwargs):
self.version_calls.append(kwargs["package"])
return {
"versions": [
{
"version": "1.0.0",
"status": "Published",
"origin": {"originType": "INTERNAL"},
}
]
}
regional_client = FakeCodeArtifactClient()
codeartifact = CodeArtifact.__new__(CodeArtifact)
codeartifact.repositories = {
TEST_REPOSITORY_ARN: Repository(
name="test-repository",
arn=TEST_REPOSITORY_ARN,
domain_name="test-domain",
domain_owner=AWS_ACCOUNT_NUMBER,
region=AWS_REGION_EU_WEST_1,
)
}
codeartifact._packages_listed = set()
codeartifact.package_limit = 1
codeartifact.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
pairs = list(codeartifact._load_packages_for_analysis())
assert [package.name for _, package in pairs] == ["first-package"]
assert regional_client.version_calls == ["first-package"]
def test_package_limit_exposes_only_selected_packages(self):
codeartifact = CodeArtifact.__new__(CodeArtifact)
codeartifact.package_limit = 2
codeartifact._packages_listed = set()
repository = Repository(
name="repository",
arn="repo",
domain_name="domain",
domain_owner=AWS_ACCOUNT_NUMBER,
region=AWS_REGION_EU_WEST_1,
)
codeartifact.repositories = {repository.arn: repository}
enriched = []
def iter_repository_packages(repository):
for index in range(3):
enriched.append(index)
yield SimpleNamespace(name=f"package-{index}")
codeartifact._iter_repository_packages = iter_repository_packages
packages = list(codeartifact._load_packages_for_analysis())
assert [package.name for _, package in packages] == ["package-0", "package-1"]
assert enriched == [0, 1]
def mock_make_api_call_no_namespace(self, operation_name, kwarg):
"""Mock for packages without a namespace to exercise the else branch"""
@@ -11,7 +11,7 @@ from freezegun import freeze_time
from moto import mock_aws
from prowler.config.config import encoding_format_utf_8
from prowler.providers.aws.services.ec2.ec2_service import EC2
from prowler.providers.aws.services.ec2.ec2_service import EC2, Snapshot
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -103,6 +103,45 @@ class Test_EC2_Service:
ec2 = EC2(aws_provider)
assert ec2.audited_account == AWS_ACCOUNT_NUMBER
def test_snapshot_limit_bounds_public_attribute_calls_to_latest_selected(self):
class FakeEC2Client:
def __init__(self):
self.calls = []
def describe_snapshot_attribute(self, **kwargs):
self.calls.append(kwargs["SnapshotId"])
return {"CreateVolumePermissions": []}
regional_client = FakeEC2Client()
ec2 = EC2.__new__(EC2)
ec2.snapshots = [
Snapshot(
id="snap-old",
arn="arn:aws:ec2:eu-west-1:123456789012:snapshot/snap-old",
region=AWS_REGION_EU_WEST_1,
encrypted=True,
start_time=datetime(2024, 1, 1),
volume="vol-old",
),
Snapshot(
id="snap-new",
arn="arn:aws:ec2:eu-west-1:123456789012:snapshot/snap-new",
region=AWS_REGION_EU_WEST_1,
encrypted=True,
start_time=datetime(2024, 1, 2),
volume="vol-new",
),
]
ec2.snapshot_limit = 1
ec2.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
ec2._select_snapshots_for_analysis()
for snapshot in ec2.snapshots:
ec2._determine_public_snapshots(snapshot)
assert [snapshot.id for snapshot in ec2.snapshots] == ["snap-new"]
assert regional_client.calls == ["snap-new"]
# Test EC2 Describe Instances
@mock_aws
@freeze_time(MOCK_DATETIME)
@@ -346,6 +385,24 @@ class Test_EC2_Service:
assert not snapshot.encrypted
assert snapshot.public
@mock_aws
def test_snapshot_limit_exposes_only_selected_snapshots(self):
ec2_client = client("ec2", region_name=AWS_REGION_US_EAST_1)
ec2_resource = resource("ec2", region_name=AWS_REGION_US_EAST_1)
volume_id = ec2_resource.create_volume(
AvailabilityZone="us-east-1a",
Size=80,
VolumeType="gp2",
).id
for _ in range(3):
ec2_client.create_snapshot(VolumeId=volume_id)
aws_provider = set_mocked_aws_provider(
[AWS_REGION_US_EAST_1], audit_config={"max_ebs_snapshots": 1}
)
ec2 = EC2(aws_provider)
assert len(ec2.snapshots) == 1
# Test EC2 Instance User Data
@mock_aws
def test_get_instance_user_data(self):
@@ -139,7 +139,6 @@ class Test_ECS_Service:
ecs = ECS(aws_provider)
assert ecs.session.__class__.__name__ == "Session"
# Test list ECS task definitions
@patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
def test_list_task_definitions(self):
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
@@ -201,6 +200,115 @@ class Test_ECS_Service:
.readonly_rootfilesystem
)
def test_task_definitions_are_loaded_once_for_analysis(self):
describe_calls = []
list_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
list_calls.append(kwarg)
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == [
"3",
"2",
"1",
]
assert list_calls == [{"sort": "DESC"}]
assert len(describe_calls) == 3
def test_task_definition_limit_exposes_only_selected_resources(self):
describe_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1], audit_config={"max_ecs_task_definitions": 2}
)
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == ["3", "2"]
assert len(describe_calls) == 2
def test_task_definition_limit_bounds_describe_calls(self):
describe_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return mock_make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1], audit_config={"max_ecs_task_definitions": 1}
)
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == ["3"]
assert describe_calls == [
"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:3"
]
# Test list ECS clusters
@patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
def test_list_clusters(self):