feat(sdk): limit selected high-volume AWS resource analysis (#11228)

This commit is contained in:
Hugo Pereira Brito
2026-06-30 15:49:12 +01:00
committed by GitHub
parent 34e8e3ca61
commit c46cbaaa4a
25 changed files with 2392 additions and 247 deletions
@@ -88,6 +88,65 @@ The following list includes all the AWS checks with configurable variables that
| `vpc_endpoint_services_allowed_principals_trust_boundaries` | `trusted_account_ids` | List of Strings |
| `opensearch_service_domains_not_publicly_accessible` | `trusted_ips` | List of Strings |
### Resource Scan Limit
<VersionBadge version="5.32.0" />
Some AWS services accumulate large numbers of resources (EBS snapshots, backup recovery points, CloudWatch log groups, Lambda functions, ECS task definitions, and CodeArtifact packages). Scanning every resource increases scan time, cost, API throttling, and finding volume. By default, Prowler scans every resource. Configure a positive resource scan limit to cap how many resources Prowler analyzes for these high-volume AWS resource paths.
The global default applies to the supported resources below and is overridable per resource. The default global value is `0`, which disables the limit and scans every resource. A global `null` value is also unlimited. For per-resource values, `null` means inherit the global default; set `0` or a negative value to disable that resource limit explicitly. Positive values enable limits.
<Warning>
When positive resource scan limits are configured, compliance results are based only on the selected resources, not on the full set of matching resources in the account. Treat compliance summaries and percentages as partial evidence, because unselected resources are not analyzed and can change the real compliance posture.
</Warning>
#### Global Behavior
Resource scan limits select resources for analysis. They do not cap, prioritize, or reorder findings.
* **`0`, negative, or global `null` values:** Disable the limit and keep the legacy behavior for that resource path. Prowler analyzes every discovered matching resource.
* **Positive values:** Select at most that many resources for the affected resource path. A selected resource can produce zero, one, or many findings.
* **No PASS/FAIL prioritization:** Prowler does not inspect the compliance result before selecting resources. Limits do not prefer failed resources, passed resources, or resources with more findings.
* **Latest-first where possible:** When AWS exposes timestamps or useful ordering, Prowler selects the newest resources first. When AWS only exposes API order, Prowler preserves that API order and documents the behavior as best effort.
* **Findings are downstream:** Checks only evaluate the resources exposed by the service client after selection. Findings from unselected resources are not produced because those resources are not analyzed.
Exact list API call reduction depends on each AWS API's ordering and pagination capabilities. When Prowler must enumerate candidates locally to select the latest resources, list calls may still read candidates, but expensive per-resource enrichment calls are bounded to the selected resources for the supported paths below.
#### Full Collections Versus Limited Analysis Sets
Some checks need lightweight evidence from a complete resource collection to avoid incorrect cross-service conclusions, while other checks perform primary analysis on a limited resource set.
Prowler keeps full lightweight collections where they are needed for cross-service evidence. For example:
* **Lambda security groups and regions:** Prowler records security groups used by all discovered Lambda functions and the regions where functions exist before it limits Lambda functions for primary Lambda checks. This helps Amazon EC2 and Amazon Inspector checks avoid false positives such as treating Lambda security groups as unused or assuming a region has no Lambda functions.
* **CloudWatch `all_log_groups`:** Prowler records all discovered CloudWatch log groups in `all_log_groups` before limiting the primary `log_groups` analysis set. Other services can still resolve log group evidence, while CloudWatch log group checks only analyze the selected log groups.
This split is intentional. It reduces expensive per-resource analysis calls without discarding lightweight context that other services need for accurate results.
#### Supported AWS Resource Limits
| Value | Scope | Type |
|-------|-------|------|
| `max_scanned_resources_per_service` | Global default for all supported high-volume AWS resources (default `0`, disabled/unlimited) | Integer |
| `max_ebs_snapshots` | EBS snapshots (`ec2_ebs_*` checks) | Integer |
| `max_backup_recovery_points` | Backup recovery points (`backup_recovery_point_*`) | Integer |
| `max_cloudwatch_log_groups` | CloudWatch log groups (`cloudwatch_log_group_*`) | Integer |
| `max_lambda_functions` | Lambda functions (`awslambda_function_*`) | Integer |
| `max_ecs_task_definitions` | ECS task definitions (`ecs_task_definitions_*`) | Integer |
| `max_codeartifact_packages` | CodeArtifact packages (`codeartifact_packages_*`) | Integer |
#### Resource Limit Behavior By Resource Path
| Resource Path | What Prowler Discovers | What A Positive Limit Selects For Analysis | Ordering And Latest Behavior | AWS Calls Reduced | Drawbacks And Consequences |
|---------------|------------------------|--------------------------------------------|------------------------------|-------------------|----------------------------|
| EBS snapshots (`max_ebs_snapshots`) | Prowler lists self-owned snapshots and keeps lightweight evidence that volumes and regions have snapshots. | The selected EBS snapshots exposed to `ec2_ebs_*` checks. | Prowler sorts discovered snapshots by `StartTime` newest first, then applies the limit. Snapshots without a timestamp sort last. | Bounds expensive per-snapshot public attribute checks to selected snapshots. Snapshot listing still runs so Prowler can choose the newest snapshots and keep volume/region evidence. | Older unselected snapshots are not analyzed by snapshot checks. A public, unencrypted, or otherwise noncompliant older snapshot can be missed when the limit is lower than the number of snapshots. |
| Backup recovery points (`max_backup_recovery_points`) | Prowler lists backup vaults, plans, selections, and recovery point candidates in discovered vaults. | The selected recovery points exposed to `backup_recovery_point_*` checks and tag hydration. | Prowler sorts discovered recovery points by `CreationDate` newest first across vaults, then applies the limit. Recovery points without a timestamp sort last. | Bounds recovery point tag calls to selected recovery points. Vault and recovery point list calls still run so Prowler can choose the newest points. | Older unselected recovery points are not analyzed. A nonencrypted or otherwise noncompliant older recovery point can be missed. |
| CloudWatch log groups (`max_cloudwatch_log_groups`) | Prowler lists log groups into both `all_log_groups` and the primary `log_groups` collection. `all_log_groups` remains available as lightweight cross-service evidence. | The selected log groups exposed to `cloudwatch_log_group_*` checks, tag hydration, and log event retrieval for checks that need log contents. | Prowler sorts discovered log groups by `creationTime` newest first, then applies the limit. Log groups without a creation time sort last. | Bounds tag calls and log event retrieval to selected log groups. Log group listing still runs to build `all_log_groups` and choose newest log groups. | Older unselected log groups are not analyzed by CloudWatch log group checks. Retention, encryption, or secrets-in-logs issues in older log groups can be missed, although cross-service evidence can still use `all_log_groups`. |
| Lambda functions (`max_lambda_functions`) | Prowler lists Lambda functions and records lightweight security group and region evidence for all discovered functions. | The selected Lambda functions exposed to `awslambda_function_*` checks and per-function enrichment such as tags, policies, function URLs, and event source mappings. | Prowler sorts discovered functions by `LastModified` newest first, then applies the limit. Functions without `LastModified` sort last. | Bounds per-function enrichment calls to selected functions. Function listing still runs to choose newest functions and keep security group/region evidence. | Older unselected functions are not analyzed by Lambda checks. Runtime, policy, URL, environment secret, or dead-letter queue issues in older functions can be missed. Cross-service checks can still use full Lambda security group and region evidence to avoid false positives. |
| ECS task definitions (`max_ecs_task_definitions`) | Prowler lists ECS task definition ARN candidates in each region. Candidate ARNs can remain visible and discoverable through AWS list operations, even when not all are described. | The selected task definitions that Prowler describes and exposes to `ecs_task_definitions_*` checks. | Selection is not random. Prowler calls `ListTaskDefinitions` with `sort=DESC`, which asks AWS to return task definition ARNs in descending family and revision order. Prowler then interleaves regional candidate lists to avoid starving later regions before applying the limit. This selects the latest task definition revisions according to the ARN order AWS provides, while preserving regional fairness. | Bounds `DescribeTaskDefinition` calls to selected task definitions. Prowler may still list candidates so it can select the bounded set and keep discovery deterministic. | Unselected task definitions are not described or analyzed. Issues in older task definition revisions, or in lower-priority families outside the selected AWS `sort=DESC` order, can be missed. Because ECS ordering is family/revision based rather than a registration timestamp sort across every family, this is latest-first according to AWS task definition ARN ordering, not a global newest-by-time guarantee. |
| CodeArtifact packages (`max_codeartifact_packages`) | Prowler lists CodeArtifact repositories and lazily lists packages inside them. | The selected packages exposed to `codeartifact_packages_*` checks, including latest-version metadata for those packages. | AWS `ListPackages` does not provide a newest-package timestamp ordering in this path. Prowler preserves repository order and package API order, then applies the limit. Latest package version metadata is retrieved for selected packages with `sortBy=PUBLISHED_TIME` and `maxResults=1`. | Bounds `ListPackageVersions` calls to selected packages and can stop package listing once the limit is reached. Repository listing still runs. | Package selection is best effort by API order, not newest package order. Packages outside the selected repository/API order are not analyzed, so origin restriction or latest-version issues can be missed. |
Use limits when scan duration, API throttling, or cost are more important than exhaustive coverage for these high-volume resources. Keep limits disabled when you need complete evidence for every resource in the affected checks.
### Validating Discovered Secrets
@@ -219,6 +278,19 @@ aws:
# AWS Global Configuration
# aws.mute_non_default_regions --> Set to True to muted failed findings in non-default regions for AccessAnalyzer, GuardDuty, SecurityHub, DRS and Config
mute_non_default_regions: False
# AWS Resource Scan Limit Configuration
# Disabled by default: scan every resource unless a positive limit is configured.
# Findings are not capped. Set to 0 (or a negative value) to disable the limit.
# aws.max_scanned_resources_per_service --> global default for all services below
max_scanned_resources_per_service: 0
# Per-service overrides. Leave as null to fall back to the global default.
max_ebs_snapshots: null
max_backup_recovery_points: null
max_cloudwatch_log_groups: null
max_lambda_functions: null
max_ecs_task_definitions: null
max_codeartifact_packages: null
# If you want to mute failed findings only in specific regions, create a file with the following syntax and run it with `prowler aws -w mutelist.yaml`:
# Mutelist:
# Accounts:
+4
View File
@@ -38,6 +38,10 @@ All notable changes to the **Prowler SDK** are documented in this file.
- GitHub default branch protection checks now evaluate repository rulesets in addition to classic branch protection, avoiding false positives for repositories that enforce protection through rulesets [(#11723)](https://github.com/prowler-cloud/prowler/pull/11723)
- Okta, Alibaba Cloud and OpenStack scan-config sections are now validated against a registered schema instead of being silently accepted, so their configurable thresholds (session/idle timeouts, retention days, image-sharing and secret-scanning settings) log a warning and fall back to the built-in default whenever a value is out of range [(#11725)](https://github.com/prowler-cloud/prowler/pull/11725)
### 🔄 Changed
- AWS scans for EBS snapshots, Backup recovery points, CloudWatch log groups, Lambda functions, ECS task definitions, and CodeArtifact packages now support configurable resource analysis limits via `aws.max_scanned_resources_per_service`; limits are disabled by default and only positive values cap analyzed resources [(#11228)](https://github.com/prowler-cloud/prowler/pull/11228)
---
## [5.31.1] (Prowler v5.31.1)
+26
View File
@@ -3,6 +3,32 @@ aws:
# AWS Global Configuration
# aws.mute_non_default_regions --> Set to True to muted failed findings in non-default regions for AccessAnalyzer, GuardDuty, SecurityHub, DRS and Config
mute_non_default_regions: False
# AWS Resource Scan Limit Configuration
# Limits the number of resources scanned per service for services that can
# accumulate huge numbers of resources (EBS snapshots, backup recovery
# points, CloudWatch log groups, Lambda functions, ECS task definitions,
# CodeArtifact packages). Limits apply to resources analyzed, not findings:
# a selected resource can produce zero, one, or many findings. Where the AWS
# API supports server-side ordering the latest resources are scanned first;
# otherwise it is best-effort API order.
# Disabled by default: scan every resource unless a positive limit is configured.
# Set to 0 (or a negative value) to disable the limit (scan every resource).
# aws.max_scanned_resources_per_service --> global default for all services below
max_scanned_resources_per_service: 0
# Per-service overrides. Leave as null to fall back to the global default.
# aws.max_ebs_snapshots --> ec2_ebs_* checks (EBS snapshots)
max_ebs_snapshots: null
# aws.max_backup_recovery_points --> backup_recovery_point_* checks
max_backup_recovery_points: null
# aws.max_cloudwatch_log_groups --> cloudwatch_log_group_* checks
max_cloudwatch_log_groups: null
# aws.max_lambda_functions --> awslambda_function_* checks
max_lambda_functions: null
# aws.max_ecs_task_definitions --> ecs_task_definitions_* checks
max_ecs_task_definitions: null
# aws.max_codeartifact_packages --> codeartifact_packages_* checks
max_codeartifact_packages: null
# aws.disallowed_regions --> List of AWS regions to exclude from the scan.
# Also settable via the PROWLER_AWS_DISALLOWED_REGIONS environment variable or
# the --excluded-region CLI flag. Precedence: CLI > env var > config file.
+56 -1
View File
@@ -14,7 +14,7 @@ thresholds) and avoids ints that obviously break downstream maths
from typing import Annotated, Literal, Optional
from pydantic import AfterValidator, Field
from pydantic import AfterValidator, BeforeValidator, Field
from prowler.config.schema.base import ProviderConfigBase
from prowler.config.schema.validators import (
@@ -101,10 +101,65 @@ def _validate_account_ids(v: Optional[list[str]]) -> Optional[list[str]]:
return v
def _reject_bool_resource_limit(v):
if isinstance(v, bool):
raise ValueError("resource scan limits must be integers, not booleans")
return v
ResourceScanLimit = Annotated[
Optional[int], BeforeValidator(_reject_bool_resource_limit)
]
# ---- Main schema ------------------------------------------------------------
class AWSProviderConfig(ProviderConfigBase):
# --- Resource scan limits ---------------------------------------------
max_scanned_resources_per_service: ResourceScanLimit = Field(
default=None,
ge=-1,
le=1_000_000,
description="Global resource scan limit for high-volume AWS services. Use 0 or -1 to disable.",
)
max_ebs_snapshots: ResourceScanLimit = Field(
default=None,
ge=-1,
le=1_000_000,
description="Resource scan limit for EBS snapshots. Use 0 or -1 to disable.",
)
max_backup_recovery_points: ResourceScanLimit = Field(
default=None,
ge=-1,
le=1_000_000,
description="Resource scan limit for AWS Backup recovery points. Use 0 or -1 to disable.",
)
max_cloudwatch_log_groups: ResourceScanLimit = Field(
default=None,
ge=-1,
le=1_000_000,
description="Resource scan limit for CloudWatch log groups. Use 0 or -1 to disable.",
)
max_lambda_functions: ResourceScanLimit = Field(
default=None,
ge=-1,
le=1_000_000,
description="Resource scan limit for Lambda functions. Use 0 or -1 to disable.",
)
max_ecs_task_definitions: ResourceScanLimit = Field(
default=None,
ge=-1,
le=1_000_000,
description="Resource scan limit for ECS task definitions. Use 0 or -1 to disable.",
)
max_codeartifact_packages: ResourceScanLimit = Field(
default=None,
ge=-1,
le=1_000_000,
description="Resource scan limit for CodeArtifact packages. Use 0 or -1 to disable.",
)
# --- IAM ---------------------------------------------------------------
mute_non_default_regions: Optional[bool] = None
disallowed_regions: Optional[list[str]] = None
+88
View File
@@ -0,0 +1,88 @@
"""Scoped resource scan limits for high-volume resources.
Some services accumulate huge numbers of resources (EBS snapshots, backup
recovery points, log groups, Lambda functions, ECS task definitions,
CodeArtifact packages). Scanning all of them causes API throttling, slow
scans, cost and noisy findings.
``get_resource_scan_limit`` resolves the configured number of resources to
analyze for a supported resource path. A limited resource can produce zero,
one, or many findings; findings are not capped or re-ordered here.
Tradeoff: for newest-based resources, services may need to list lightweight or
base metadata broadly to select the truly newest resources, then apply limits
only to expensive hydration or analysis. The helper must not send
user-configured limits as unsafe paginator ``PageSize`` values because AWS
services validate page sizes differently.
"""
from collections.abc import Callable, Iterable, Iterator, Mapping
from itertools import islice
from typing import Any, Optional, Protocol, TypeVar
GLOBAL_LIMIT_KEY = "max_scanned_resources_per_service"
T = TypeVar("T")
class PaginatorProtocol(Protocol):
"""Minimal boto3-compatible paginator interface used by this module."""
def paginate(self, **operation_parameters: Any) -> Iterable[Mapping[str, Any]]:
"""Return paginator pages for the provided operation parameters."""
def get_resource_scan_limit(audit_config: dict, service_key: str) -> Optional[int]:
"""Resolve the resource scan limit for a service.
Precedence: per-service key (``service_key``) > global
``max_scanned_resources_per_service`` > unlimited.
A non-positive resolved value means **unlimited** (``None``), preserving
the legacy behavior as an explicit opt-out.
Args:
audit_config: The provider ``audit_config`` dictionary.
service_key: The per-service config key, e.g. ``max_lambda_functions``.
Returns:
The limit as a positive ``int``, or ``None`` for unlimited.
"""
value = audit_config.get(service_key)
if value is None:
value = audit_config.get(GLOBAL_LIMIT_KEY)
if value is None or value <= 0:
return None
return int(value)
def limit_resources(resources: Iterable[T], limit: Optional[int]) -> Iterator[T]:
"""Yield up to ``limit`` resources without changing resource order."""
if not limit or limit <= 0:
yield from resources
return
yield from islice(resources, limit)
def iter_limited_paginator_items(
paginator: PaginatorProtocol,
result_key: str,
limit: Optional[int],
item_filter: Optional[Callable[[T], bool]] = None,
**operation_parameters: Any,
) -> Iterator[T]:
"""Yield paginator result items, stopping after ``limit`` selected items.
The configured resource-analysis limit is intentionally not sent as
``PageSize`` because AWS services validate page sizes differently. The
paginator receives only the operation parameters needed by the AWS API,
while this iterator applies the analysis limit defensively client-side.
"""
selected = 0
for page in paginator.paginate(**operation_parameters):
for item in page.get(result_key, []):
if item_filter and not item_filter(item):
continue
yield item
selected += 1
if limit and selected >= limit:
return
@@ -10,6 +10,10 @@ from botocore.client import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.logger import logger
from prowler.lib.resource_limit import (
get_resource_scan_limit,
limit_resources,
)
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -18,8 +22,16 @@ class Lambda(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# Functions are listed first, then trimmed to the subset selected for
# analysis before expensive per-function detail is hydrated.
self.functions = {}
self.security_groups_in_use = set()
self.regions_with_functions = set()
self.function_limit = get_resource_scan_limit(
self.audit_config, "max_lambda_functions"
)
self.__threading_call__(self._list_functions)
self._select_functions_for_analysis()
self._list_tags_for_resource()
self.__threading_call__(self._get_policy)
self.__threading_call__(self._get_function_url_config)
@@ -30,24 +42,29 @@ class Lambda(AWSService):
try:
list_functions_paginator = regional_client.get_paginator("list_functions")
for page in list_functions_paginator.paginate():
for function in page["Functions"]:
if not self.audit_resources or (
is_resource_filtered(
function["FunctionArn"], self.audit_resources
)
for function in page.get("Functions", []):
if not self.audit_resources or is_resource_filtered(
function["FunctionArn"], self.audit_resources
):
lambda_name = function["FunctionName"]
lambda_arn = function["FunctionArn"]
vpc_config = function.get("VpcConfig", {})
security_groups = vpc_config.get("SecurityGroupIds", [])
self.security_groups_in_use.update(security_groups)
self.regions_with_functions.add(regional_client.region)
# We must use the Lambda ARN as the dict key since we could have Lambdas in different regions with the same name
self.functions[lambda_arn] = Function(
name=lambda_name,
arn=lambda_arn,
security_groups=vpc_config.get("SecurityGroupIds", []),
security_groups=security_groups,
vpc_id=vpc_config.get("VpcId"),
subnet_ids=set(vpc_config.get("SubnetIds", [])),
region=regional_client.region,
)
if "LastModified" in function:
self.functions[lambda_arn].last_modified = function[
"LastModified"
]
if "Runtime" in function:
self.functions[lambda_arn].runtime = function["Runtime"]
if "Environment" in function:
@@ -76,26 +93,61 @@ class Lambda(AWSService):
f" {error}"
)
def _select_functions_for_analysis(self):
self.functions = {
function.arn: function
for function in limit_resources(
sorted(
self.functions.values(),
key=lambda f: f.last_modified or "",
reverse=True,
),
self.function_limit,
)
}
def _list_event_source_mappings(self, regional_client):
logger.info("Lambda - Listing Event Source Mappings...")
try:
paginator = regional_client.get_paginator("list_event_source_mappings")
for page in paginator.paginate():
for mapping in page.get("EventSourceMappings", []):
function_arn = mapping.get("FunctionArn", "")
# Normalise to unqualified ARN (strip :qualifier suffix if present)
base_arn = ":".join(function_arn.split(":")[:7])
if base_arn not in self.functions:
continue
self.functions[base_arn].event_source_mappings.append(
EventSourceMapping(
uuid=mapping["UUID"],
event_source_arn=mapping.get("EventSourceArn", ""),
state=mapping.get("State", ""),
batch_size=mapping.get("BatchSize"),
starting_position=mapping.get("StartingPosition"),
if not self.function_limit:
for page in paginator.paginate():
self._add_event_source_mappings(page.get("EventSourceMappings", []))
return
for function in self.functions.values():
if function.region != regional_client.region:
continue
try:
for page in paginator.paginate(FunctionName=function.name):
self._add_event_source_mappings(
page.get("EventSourceMappings", [])
)
)
except ClientError as error:
if (
error.response.get("Error", {}).get("Code")
== "InvalidParameterValueException"
):
logger.warning(
f"{function.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
else:
logger.error(
f"{function.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
raise
except ClientError as error:
if self.function_limit:
raise
logger.error(
f"{regional_client.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
except Exception as error:
logger.error(
f"{regional_client.region} --"
@@ -103,6 +155,23 @@ class Lambda(AWSService):
f" {error}"
)
def _add_event_source_mappings(self, event_source_mappings):
for mapping in event_source_mappings:
function_arn = mapping.get("FunctionArn", "")
# Normalise to unqualified ARN (strip :qualifier suffix if present)
base_arn = ":".join(function_arn.split(":")[:7])
if base_arn not in self.functions:
continue
self.functions[base_arn].event_source_mappings.append(
EventSourceMapping(
uuid=mapping["UUID"],
event_source_arn=mapping.get("EventSourceArn", ""),
state=mapping.get("State", ""),
batch_size=mapping.get("BatchSize"),
starting_position=mapping.get("StartingPosition"),
)
)
def _get_function_code(self):
logger.info("Lambda - Getting Function Code...")
# Use a thread pool handle the queueing and execution of the _fetch_function_code tasks, up to max_workers tasks concurrently.
@@ -158,7 +227,6 @@ class Lambda(AWSService):
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
self.functions[function.arn].policy = {}
except Exception as error:
logger.error(
f"{regional_client.region} --"
@@ -187,7 +255,6 @@ class Lambda(AWSService):
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
self.functions[function.arn].url_config = None
except Exception as error:
logger.error(
f"{regional_client.region} --"
@@ -206,10 +273,9 @@ class Lambda(AWSService):
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
function.tags = []
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{function.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
@@ -259,6 +325,7 @@ class Function(BaseModel):
name: str
arn: str
security_groups: list
last_modified: Optional[str] = None
runtime: Optional[str] = None
environment: Optional[dict] = None
region: str
@@ -5,6 +5,10 @@ from botocore.client import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.logger import logger
from prowler.lib.resource_limit import (
get_resource_scan_limit,
limit_resources,
)
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -27,8 +31,14 @@ class Backup(AWSService):
self.__threading_call__(self._list_backup_report_plans)
self.protected_resources = []
self.__threading_call__(self._list_backup_selections)
# Recovery points are listed first, then only the selected subset is
# tagged and exposed for checks.
self.recovery_points = []
self.__threading_call__(self._list_recovery_points)
self.recovery_point_limit = get_resource_scan_limit(
self.audit_config, "max_backup_recovery_points"
)
self.__threading_call__(self._list_recovery_points, self.backup_vaults or [])
self._select_recovery_points_for_analysis()
self.__threading_call__(self._list_tags, self.recovery_points)
def _list_backup_vaults(self, regional_client):
@@ -183,40 +193,63 @@ class Backup(AWSService):
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _list_recovery_points(self, regional_client):
def _list_recovery_points(self, backup_vault=None):
logger.info("Backup - Listing Recovery Points...")
if backup_vault is None:
for vault in self.backup_vaults or []:
self._list_recovery_points(vault)
return
try:
if self.backup_vaults:
for backup_vault in self.backup_vaults:
paginator = regional_client.get_paginator(
"list_recovery_points_by_backup_vault"
)
for page in paginator.paginate(BackupVaultName=backup_vault.name):
for recovery_point in page.get("RecoveryPoints", []):
arn = recovery_point.get("RecoveryPointArn")
if arn:
self.recovery_points.append(
RecoveryPoint(
arn=arn,
id=arn.split(":")[-1],
backup_vault_name=backup_vault.name,
encrypted=recovery_point.get(
"IsEncrypted", False
),
backup_vault_region=backup_vault.region,
region=regional_client.region,
tags=[],
)
)
regional_client = self.regional_clients[backup_vault.region]
paginator = regional_client.get_paginator(
"list_recovery_points_by_backup_vault"
)
for page in paginator.paginate(BackupVaultName=backup_vault.name):
for recovery_point in page.get("RecoveryPoints", []):
arn = recovery_point.get("RecoveryPointArn")
if arn:
rp = RecoveryPoint(
arn=arn,
id=arn.split(":")[-1],
backup_vault_name=backup_vault.name,
encrypted=recovery_point.get("IsEncrypted", False),
creation_date=recovery_point.get("CreationDate"),
backup_vault_region=backup_vault.region,
region=backup_vault.region,
tags=[],
)
self.recovery_points.append(rp)
except ClientError as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{backup_vault.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{backup_vault.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _select_recovery_points_for_analysis(self):
self.recovery_points = list(
limit_resources(
sorted(
self.recovery_points,
key=lambda rp: (
(
-rp.creation_date.timestamp()
if isinstance(rp.creation_date, datetime)
else 0.0
),
rp.region or "",
rp.backup_vault_name or "",
rp.arn or "",
rp.id or "",
),
),
self.recovery_point_limit,
)
)
class BackupVault(BaseModel):
arn: str
@@ -256,4 +289,5 @@ class RecoveryPoint(BaseModel):
backup_vault_name: str
encrypted: bool
backup_vault_region: str
creation_date: Optional[datetime] = None
tags: Optional[list] = None
@@ -30,12 +30,13 @@ class bedrock_model_invocation_logs_encryption_enabled(Check):
s3_encryption = False
if logging.cloudwatch_log_group:
log_group_arn = f"arn:{logs_client.audited_partition}:logs:{region}:{logs_client.audited_account}:log-group:{logging.cloudwatch_log_group}"
all_log_groups = getattr(logs_client, "all_log_groups", None) or {}
if (
log_group_arn in logs_client.log_groups
and not logs_client.log_groups[log_group_arn].kms_id
log_group_arn in all_log_groups
and not all_log_groups[log_group_arn].kms_id
) or (
log_group_arn + ":*" in logs_client.log_groups
and not logs_client.log_groups[log_group_arn + ":*"].kms_id
log_group_arn + ":*" in all_log_groups
and not all_log_groups[log_group_arn + ":*"].kms_id
):
cloudwatch_encryption = False
if not s3_encryption and not cloudwatch_encryption:
@@ -6,6 +6,10 @@ from botocore.exceptions import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.logger import logger
from prowler.lib.resource_limit import (
get_resource_scan_limit,
limit_resources,
)
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -83,8 +87,19 @@ class Logs(AWSService):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
self.log_group_arn_template = f"arn:{self.audited_partition}:logs:{self.region}:{self.audited_account}:log-group"
# Log groups are listed first, then only the selected subset is enriched
# and exposed for primary log group checks. Keep a complete lightweight
# index for cross-service evidence lookups.
self.all_log_groups = {}
self.log_groups = {}
self._log_groups_hydrated = set()
self.log_group_limit = get_resource_scan_limit(
self.audit_config, "max_cloudwatch_log_groups"
)
# The threshold for number of events to return per log group.
self.events_per_log_group_threshold = 1000
self.__threading_call__(self._describe_log_groups)
self._select_log_groups_for_analysis()
self.resource_policies = {}
self.__threading_call__(self._describe_resource_policies)
self.metric_filters = []
@@ -94,14 +109,27 @@ class Logs(AWSService):
"cloudwatch_log_group_no_secrets_in_logs"
in provider.audit_metadata.expected_checks
):
self.events_per_log_group_threshold = (
1000 # The threshold for number of events to return per log group.
)
self.__threading_call__(self._get_log_events)
self.__threading_call__(self._get_log_events, self.log_groups.values())
self.__threading_call__(
self._list_tags_for_resource, self.log_groups.values()
)
def _select_log_groups_for_analysis(self):
"""Select the newest log groups for bounded analysis."""
if not self.log_groups:
return
self.log_groups = {
log_group.arn: log_group
for log_group in limit_resources(
sorted(
self.log_groups.values(),
key=lambda lg: lg.creation_time or 0,
reverse=True,
),
self.log_group_limit,
)
}
def _describe_metric_filters(self, regional_client):
logger.info("CloudWatch Logs - Describing metric filters...")
try:
@@ -118,11 +146,21 @@ class Logs(AWSService):
self.metric_filters = []
log_group = None
for lg in self.log_groups.values():
if lg.name == filter["logGroupName"]:
for lg in (self.all_log_groups or {}).values():
if (
lg.name == filter["logGroupName"]
and lg.region == regional_client.region
):
log_group = lg
break
if (
log_group
and log_group.arn in (self.log_groups or {})
and log_group.arn not in self._log_groups_hydrated
):
self._list_tags_for_resource(log_group)
self.metric_filters.append(
MetricFilter(
arn=arn,
@@ -156,9 +194,9 @@ class Logs(AWSService):
"describe_log_groups"
)
for page in describe_log_groups_paginator.paginate():
for log_group in page["logGroups"]:
if not self.audit_resources or (
is_resource_filtered(log_group["arn"], self.audit_resources)
for log_group in page.get("logGroups", []):
if not self.audit_resources or is_resource_filtered(
log_group["arn"], self.audit_resources
):
never_expire = False
kms = log_group.get("kmsKeyId")
@@ -168,20 +206,26 @@ class Logs(AWSService):
retention_days = 9999
if self.log_groups is None:
self.log_groups = {}
self.log_groups[log_group["arn"]] = LogGroup(
if self.all_log_groups is None:
self.all_log_groups = {}
log_group_object = LogGroup(
arn=log_group["arn"],
name=log_group["logGroupName"],
retention_days=retention_days,
never_expire=never_expire,
kms_id=kms,
creation_time=log_group.get("creationTime"),
region=regional_client.region,
)
self.all_log_groups[log_group_object.arn] = log_group_object
self.log_groups[log_group_object.arn] = log_group_object
except ClientError as error:
if error.response["Error"]["Code"] == "AccessDeniedException":
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
if not self.log_groups:
self.all_log_groups = None
self.log_groups = None
else:
logger.error(
@@ -192,37 +236,29 @@ class Logs(AWSService):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _get_log_events(self, regional_client):
regional_log_groups = [
log_group
for log_group in self.log_groups.values()
if log_group.region == regional_client.region
]
total_log_groups = len(regional_log_groups)
def _get_log_events(self, log_group):
"""Retrieve recent log events for a selected log group.
Args:
log_group: Log group selected for bounded analysis.
"""
logger.info(
f"CloudWatch Logs - Retrieving log events for {total_log_groups} log groups in {regional_client.region}..."
f"CloudWatch Logs - Retrieving log events for log group {log_group.name}..."
)
try:
for count, log_group in enumerate(regional_log_groups, start=1):
events = regional_client.filter_log_events(
logGroupName=log_group.name,
limit=self.events_per_log_group_threshold,
)["events"]
for event in events:
if event["logStreamName"] not in log_group.log_streams:
log_group.log_streams[event["logStreamName"]] = []
log_group.log_streams[event["logStreamName"]].append(event)
if count % 10 == 0:
logger.info(
f"CloudWatch Logs - Retrieved log events for {count}/{total_log_groups} log groups in {regional_client.region}..."
)
regional_client = self.regional_clients[log_group.region]
events = regional_client.filter_log_events(
logGroupName=log_group.name,
limit=self.events_per_log_group_threshold,
)["events"]
for event in events:
if event["logStreamName"] not in log_group.log_streams:
log_group.log_streams[event["logStreamName"]] = []
log_group.log_streams[event["logStreamName"]].append(event)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{log_group.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
logger.info(
f"CloudWatch Logs - Finished retrieving log events in {regional_client.region}..."
)
def _describe_resource_policies(self, regional_client):
logger.info("CloudWatch Logs - Describing resource policies...")
@@ -257,6 +293,13 @@ class Logs(AWSService):
)
def _list_tags_for_resource(self, log_group):
"""Hydrate tags for a selected log group once.
Args:
log_group: Log group selected for tag hydration.
"""
if log_group.arn in self._log_groups_hydrated:
return
logger.info(f"CloudWatch Logs - List Tags for Log Group {log_group.name}...")
try:
regional_client = self.regional_clients[log_group.region]
@@ -264,6 +307,7 @@ class Logs(AWSService):
resourceArn=log_group.arn
)["tags"]
log_group.tags = [response]
self._log_groups_hydrated.add(log_group.arn)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
@@ -292,6 +336,7 @@ class LogGroup(BaseModel):
retention_days: int
never_expire: bool
kms_id: Optional[str]
creation_time: Optional[int] = None
region: str
log_streams: dict[str, list[str]] = (
{}
@@ -1,10 +1,14 @@
from enum import Enum
from typing import Optional
from typing import Iterator, Optional, Tuple
from botocore.exceptions import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.logger import logger
from prowler.lib.resource_limit import (
get_resource_scan_limit,
iter_limited_paginator_items,
)
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -15,9 +19,18 @@ class CodeArtifact(AWSService):
super().__init__(__class__.__name__, provider)
# repositories is a dictionary containing all the codeartifact service information
self.repositories = {}
# repository ARNs whose selected packages have been listed and memoized
# into repository.packages.
self._packages_listed = set()
self.package_limit = get_resource_scan_limit(
self.audit_config, "max_codeartifact_packages"
)
self.__threading_call__(self._list_repositories)
self.__threading_call__(self._list_packages)
self._list_tags_for_resource()
for _ in self._load_packages_for_analysis():
pass
self.__threading_call__(
self._list_tags_for_resource, self.repositories.values()
)
def _list_repositories(self, regional_client):
logger.info("CodeArtifact - Listing Repositories...")
@@ -51,134 +64,146 @@ class CodeArtifact(AWSService):
f" {error}"
)
def _list_packages(self, regional_client):
logger.info("CodeArtifact - Listing Packages and retrieving information...")
for repository in self.repositories:
try:
if self.repositories[repository].region == regional_client.region:
list_packages_paginator = regional_client.get_paginator(
"list_packages"
def _iter_repository_packages(
self, repository, limit: Optional[int] = None
) -> Iterator["Package"]:
"""Yield packages for a single repository, hydrating each one lazily.
Each package requires an extra ``list_package_versions`` call to
resolve its latest version, so producing them lazily lets the resource
limit stop before extra package version calls.
"""
regional_client = self.regional_clients[repository.region]
try:
list_packages_paginator = regional_client.get_paginator("list_packages")
list_packages_parameters = {
"domain": repository.domain_name,
"domainOwner": repository.domain_owner,
"repository": repository.name,
}
for package in iter_limited_paginator_items(
list_packages_paginator,
"packages",
limit,
**list_packages_parameters,
):
# Package information
package_format = package["format"]
package_namespace = package.get("namespace")
package_name = package["package"]
package_origin_configuration_restrictions_publish = package[
"originConfiguration"
]["restrictions"]["publish"]
package_origin_configuration_restrictions_upstream = package[
"originConfiguration"
]["restrictions"]["upstream"]
# Get Latest Package Version
list_package_versions_parameters = {
"domain": repository.domain_name,
"domainOwner": repository.domain_owner,
"repository": repository.name,
"format": package_format,
"package": package_name,
"sortBy": "PUBLISHED_TIME",
"maxResults": 1,
}
if package_namespace:
list_package_versions_parameters["namespace"] = package_namespace
latest_version_information = regional_client.list_package_versions(
**list_package_versions_parameters
)
latest_version = ""
latest_origin_type = "UNKNOWN"
latest_status = "Published"
if latest_version_information.get("versions"):
latest_version = latest_version_information["versions"][0].get(
"version"
)
list_packages_parameters = {
"domain": self.repositories[repository].domain_name,
"domainOwner": self.repositories[repository].domain_owner,
"repository": self.repositories[repository].name,
}
packages = []
for page in list_packages_paginator.paginate(
**list_packages_parameters
):
for package in page["packages"]:
# Package information
package_format = package["format"]
package_namespace = package.get("namespace")
package_name = package["package"]
package_origin_configuration_restrictions_publish = package[
"originConfiguration"
]["restrictions"]["publish"]
package_origin_configuration_restrictions_upstream = (
package["originConfiguration"]["restrictions"][
"upstream"
]
)
# Get Latest Package Version
if package_namespace:
latest_version_information = (
regional_client.list_package_versions(
domain=self.repositories[
repository
].domain_name,
domainOwner=self.repositories[
repository
].domain_owner,
repository=self.repositories[repository].name,
format=package_format,
namespace=package_namespace,
package=package_name,
sortBy="PUBLISHED_TIME",
maxResults=1,
)
)
else:
latest_version_information = (
regional_client.list_package_versions(
domain=self.repositories[
repository
].domain_name,
domainOwner=self.repositories[
repository
].domain_owner,
repository=self.repositories[repository].name,
format=package_format,
package=package_name,
sortBy="PUBLISHED_TIME",
maxResults=1,
)
)
latest_version = ""
latest_origin_type = "UNKNOWN"
latest_status = "Published"
if latest_version_information.get("versions"):
latest_version = latest_version_information["versions"][
0
].get("version")
latest_origin_type = (
latest_version_information["versions"][0]
.get("origin", {})
.get("originType", "UNKNOWN")
)
latest_status = latest_version_information["versions"][
0
].get("status", "Published")
packages.append(
Package(
name=package_name,
namespace=package_namespace,
format=package_format,
origin_configuration=OriginConfiguration(
restrictions=Restrictions(
publish=package_origin_configuration_restrictions_publish,
upstream=package_origin_configuration_restrictions_upstream,
)
),
latest_version=LatestPackageVersion(
version=latest_version,
status=latest_status,
origin=OriginInformation(
origin_type=latest_origin_type
),
),
)
)
# Save all the packages information
self.repositories[repository].packages = packages
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
f"{regional_client.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
latest_origin_type = (
latest_version_information["versions"][0]
.get("origin", {})
.get("originType", "UNKNOWN")
)
latest_status = latest_version_information["versions"][0].get(
"status", "Published"
)
continue
except Exception as error:
logger.error(
f"{regional_client.region} --"
yield Package(
name=package_name,
namespace=package_namespace,
format=package_format,
origin_configuration=OriginConfiguration(
restrictions=Restrictions(
publish=package_origin_configuration_restrictions_publish,
upstream=package_origin_configuration_restrictions_upstream,
)
),
latest_version=LatestPackageVersion(
version=latest_version,
status=latest_status,
origin=OriginInformation(origin_type=latest_origin_type),
),
)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
f"{repository.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
else:
logger.error(
f"{repository.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
except Exception as error:
logger.error(
f"{repository.region} --"
f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:"
f" {error}"
)
def _list_tags_for_resource(self):
def _load_packages_for_analysis(self) -> Iterator[Tuple["Repository", "Package"]]:
"""Yield the ``(repository, package)`` pairs selected for analysis.
Package listing stays in the service layer so checks receive only the
selected packages and remain unaware of resource-analysis limits.
"""
yielded = 0
for repository in list(self.repositories.values()):
if repository.arn in self._packages_listed:
for package in repository.packages:
yield repository, package
yielded += 1
if self.package_limit and yielded >= self.package_limit:
return
continue
collected = []
remaining_limit = None
if self.package_limit:
remaining_limit = self.package_limit - yielded
if remaining_limit <= 0:
return
for package in self._iter_repository_packages(repository, remaining_limit):
collected.append(package)
repository.packages = collected
yield repository, package
yielded += 1
if self.package_limit and yielded >= self.package_limit:
self._packages_listed.add(repository.arn)
return
self._packages_listed.add(repository.arn)
def _list_tags_for_resource(self, repository):
logger.info("CodeArtifact - List Tags...")
try:
for repository in self.repositories.values():
regional_client = self.regional_clients[repository.region]
response = regional_client.list_tags_for_resource(
resourceArn=repository.arn
)["tags"]
repository.tags = response
regional_client = self.regional_clients[repository.region]
response = regional_client.list_tags_for_resource(
resourceArn=repository.arn
)["tags"]
repository.tags = response
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -15,11 +15,10 @@ class ec2_securitygroup_not_used(Check):
report.resource_details = security_group.name
report.status = "PASS"
report.status_extended = f"Security group {security_group.name} ({security_group.id}) it is being used."
sg_in_lambda = False
sg_in_lambda = (
security_group.id in awslambda_client.security_groups_in_use
)
sg_associated = False
for function in awslambda_client.functions.values():
if security_group.id in function.security_groups:
sg_in_lambda = True
for sg in ec2_client.security_groups.values():
if security_group.id in sg.associated_sgs:
sg_associated = True
@@ -6,6 +6,10 @@ from botocore.client import ClientError
from pydantic.v1 import BaseModel
from prowler.lib.logger import logger
from prowler.lib.resource_limit import (
get_resource_scan_limit,
limit_resources,
)
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -26,8 +30,12 @@ class EC2(AWSService):
self.snapshots = []
self.volumes_with_snapshots = {}
self.regions_with_snapshots = {}
# Snapshots are listed first, then limited after per-region snapshot
# presence is derived and before public status is hydrated.
self.snapshot_limit = get_resource_scan_limit(
self.audit_config, "max_ebs_snapshots"
)
self.__threading_call__(self._describe_snapshots)
self.__threading_call__(self._determine_public_snapshots, self.snapshots)
self.network_interfaces = {}
self.__threading_call__(self._describe_network_interfaces)
self.images = []
@@ -36,6 +44,8 @@ class EC2(AWSService):
self.__threading_call__(self._describe_volumes)
self.attributes_for_regions = {}
self.__threading_call__(self._get_resources_for_regions)
self._select_snapshots_for_analysis()
self.__threading_call__(self._determine_public_snapshots, self.snapshots)
self.ebs_encryption_by_default = []
self.__threading_call__(self._get_ebs_encryption_settings)
self.elastic_ips = []
@@ -207,6 +217,7 @@ class EC2(AWSService):
arn=arn,
region=regional_client.region,
encrypted=snapshot.get("Encrypted", False),
start_time=snapshot.get("StartTime"),
tags=snapshot.get("Tags"),
volume=snapshot["VolumeId"],
)
@@ -243,6 +254,18 @@ class EC2(AWSService):
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _select_snapshots_for_analysis(self):
self.snapshots = list(
limit_resources(
sorted(
self.snapshots,
key=lambda s: (s.start_time.timestamp() if s.start_time else 0.0),
reverse=True,
),
self.snapshot_limit,
)
)
def _describe_network_interfaces(self, regional_client):
try:
# Get Network Interfaces with Public IPs
@@ -686,6 +709,7 @@ class Snapshot(BaseModel):
region: str
encrypted: bool
public: bool = False
start_time: Optional[datetime] = None
tags: Optional[list] = []
volume: Optional[str]
@@ -1,9 +1,16 @@
from datetime import datetime
from itertools import zip_longest
from re import sub
from typing import Optional
from pydantic.v1 import BaseModel
from prowler.lib.logger import logger
from prowler.lib.resource_limit import (
get_resource_scan_limit,
iter_limited_paginator_items,
limit_resources,
)
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
@@ -12,40 +19,95 @@ class ECS(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# Task definition ARNs are listed first, then only the selected subset
# is described and exposed for checks.
self.task_definitions = {}
self._task_definition_arns = None
self._task_definition_arns_by_region = {}
self.task_definition_limit = get_resource_scan_limit(
self.audit_config, "max_ecs_task_definitions"
)
self.services = {}
self.clusters = {}
self.task_sets = {}
self.__threading_call__(self._list_task_definitions)
self.__threading_call__(
self._describe_task_definition, self.task_definitions.values()
)
for _ in self._load_task_definitions_for_analysis():
pass
self.__threading_call__(self._list_clusters)
self.__threading_call__(self._describe_clusters, self.clusters.values())
self.__threading_call__(self._describe_services, self.clusters.values())
def _list_task_definitions(self, regional_client):
def _list_task_definition_arns(self) -> list:
"""List task definition ARNs newest-first, memoized.
AWS returns ``list_task_definitions(sort=DESC)`` results per region.
Prowler limits the task definitions it describes and exposes to checks.
"""
if self._task_definition_arns is not None:
return self._task_definition_arns
logger.info("ECS - Listing Task Definitions...")
self.__threading_call__(self._list_task_definition_arns_by_region)
arns_by_region = []
for region in self.regional_clients:
arns_by_region.append(self._task_definition_arns_by_region.get(region, []))
arns = []
for task_definition_batch in zip_longest(*arns_by_region):
for task_definition in task_definition_batch:
if task_definition:
arns.append(task_definition)
self._task_definition_arns = arns
return arns
def _list_task_definition_arns_by_region(self, regional_client):
try:
list_ecs_paginator = regional_client.get_paginator("list_task_definitions")
for page in list_ecs_paginator.paginate():
for task_definition in page["taskDefinitionArns"]:
if not self.audit_resources or (
is_resource_filtered(task_definition, self.audit_resources)
):
self.task_definitions[task_definition] = TaskDefinition(
# we want the family name without the revision
name=sub(":.*", "", task_definition.split("/")[-1]),
arn=task_definition,
revision=task_definition.split(":")[-1],
region=regional_client.region,
environment_variables=[],
)
regional_arns = []
for task_definition in iter_limited_paginator_items(
list_ecs_paginator,
"taskDefinitionArns",
None,
item_filter=lambda task_definition: not self.audit_resources
or is_resource_filtered(task_definition, self.audit_resources),
sort="DESC",
):
regional_arns.append((task_definition, regional_client.region))
self._task_definition_arns_by_region[regional_client.region] = regional_arns
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _load_task_definitions_for_analysis(self):
"""Yield task definitions lazily, describing each one on demand.
Resources already fetched are memoized in ``self.task_definitions`` and
reused across checks (checks run sequentially, so no locking is needed).
The configured resource limit bounds ``describe_task_definition`` calls.
"""
task_definitions = []
for arn, region in limit_resources(
self._list_task_definition_arns(), self.task_definition_limit
):
task_definition = self.task_definitions.get(arn)
if task_definition is None:
task_definition = TaskDefinition(
# we want the family name without the revision
name=sub(":.*", "", arn.split("/")[-1]),
arn=arn,
revision=arn.split(":")[-1],
region=region,
environment_variables=[],
)
self.task_definitions[arn] = task_definition
task_definitions.append(task_definition)
self.__threading_call__(self._describe_task_definition, task_definitions)
for arn, _ in limit_resources(
self._list_task_definition_arns(), self.task_definition_limit
):
task_definition = self.task_definitions[arn]
yield task_definition
def _describe_task_definition(self, task_definition):
logger.info("ECS - Describing Task Definition...")
try:
@@ -84,6 +146,9 @@ class ECS(AWSService):
)
)
task_definition.pid_mode = response["taskDefinition"].get("pidMode", "")
task_definition.registered_at = response["taskDefinition"].get(
"registeredAt"
)
task_definition.tags = response.get("tags")
task_definition.network_mode = response["taskDefinition"].get(
"networkMode", "bridge"
@@ -208,6 +273,7 @@ class TaskDefinition(BaseModel):
region: str
container_definitions: list[ContainerDefinition] = []
pid_mode: Optional[str]
registered_at: Optional[datetime] = None
tags: Optional[list] = []
network_mode: Optional[str]
@@ -15,11 +15,10 @@ class inspector2_is_enabled(Check):
if inspector.status == "ENABLED":
report.status = "PASS"
report.status_extended = "Inspector2 is enabled for EC2 instances, ECR container images, Lambda functions and code."
funtions_in_region = False
functions_in_region = (
inspector.region in awslambda_client.regions_with_functions
)
ec2_in_region = False
for function in awslambda_client.functions.values():
if function.region == inspector.region:
funtions_in_region = True
for instance in ec2_client.instances:
if instance == inspector.region:
ec2_in_region = True
@@ -36,12 +35,12 @@ class inspector2_is_enabled(Check):
failed_services.append("ECR")
if inspector.lambda_status != "ENABLED" and (
inspector2_client.provider.scan_unused_services
or funtions_in_region
or functions_in_region
):
failed_services.append("Lambda")
if inspector.lambda_code_status != "ENABLED" and (
inspector2_client.provider.scan_unused_services
or funtions_in_region
or functions_in_region
):
failed_services.append("Lambda Code")
+47
View File
@@ -3,6 +3,7 @@ constraint surface (CIDRs, account IDs, port ranges, enums, thresholds)."""
import pytest
from prowler.config.scan_config_schema import SCAN_CONFIG_SCHEMA
from prowler.config.schema.aws import AWSProviderConfig
from prowler.config.schema.validator import validate_provider_config
@@ -11,6 +12,52 @@ def _validate(raw):
return validate_provider_config("aws", raw, AWSProviderConfig)
RESOURCE_LIMIT_KEYS = [
"max_scanned_resources_per_service",
"max_ebs_snapshots",
"max_backup_recovery_points",
"max_cloudwatch_log_groups",
"max_lambda_functions",
"max_ecs_task_definitions",
"max_codeartifact_packages",
]
class Test_AWS_Resource_Limits:
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_positive_values_round_trip(self, key):
assert _validate({key: 100}) == {key: 100}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_null_values_round_trip(self, key):
assert _validate({key: None}) == {key: None}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_zero_disable_sentinel_round_trips(self, key):
assert _validate({key: 0}) == {key: 0}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_numeric_strings_are_coerced_to_int(self, key):
assert _validate({key: "100"}) == {key: 100}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_disable_sentinel_minus_one_round_trips(self, key):
assert _validate({key: -1}) == {key: -1}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
@pytest.mark.parametrize("value", [True, False])
def test_booleans_are_dropped_not_coerced_to_int(self, key, value):
assert _validate({key: value}) == {}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_invalid_strings_are_dropped(self, key):
assert _validate({key: "not-an-int"}) == {}
@pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS)
def test_keys_are_exposed_in_scan_config_schema(self, key):
assert key in SCAN_CONFIG_SCHEMA["properties"]["aws"]["properties"]
class Test_AWS_Threat_Detection_Thresholds:
"""All threat detection thresholds are documented as fractions in 0..1.
The biggest risk of mistyping them is silently disabling the check."""
+156
View File
@@ -0,0 +1,156 @@
from prowler.lib.resource_limit import (
get_resource_scan_limit,
iter_limited_paginator_items,
limit_resources,
)
class FakePaginator:
def __init__(self, pages):
self.pages = pages
self.paginate_calls = []
self.pages_requested = 0
def paginate(self, **kwargs):
self.paginate_calls.append(kwargs)
for page in self.pages:
self.pages_requested += 1
yield page
class Test_limit_resources:
def test_no_limit_returns_all_in_order(self):
resources = ["PASS", "FAIL", "PASS"]
result = list(limit_resources(iter(resources), None))
assert result == ["PASS", "FAIL", "PASS"]
class Test_iter_limited_paginator_items:
def test_positive_limit_stops_without_page_size(self):
paginator = FakePaginator(
[
{"Items": [1, 2]},
{"Items": [3, 4]},
{"Items": [5]},
]
)
result = list(iter_limited_paginator_items(paginator, "Items", 3))
assert result == [1, 2, 3]
assert paginator.paginate_calls == [{}]
assert paginator.pages_requested == 2
def test_absurd_limit_is_not_sent_as_page_size(self):
paginator = FakePaginator([{"Items": [1, 2]}])
result = list(iter_limited_paginator_items(paginator, "Items", 200000))
assert result == [1, 2]
assert paginator.paginate_calls == [{}]
def test_operation_parameters_are_forwarded_unchanged(self):
paginator = FakePaginator([{"Snapshots": ["snapshot"]}])
result = list(
iter_limited_paginator_items(
paginator,
"Snapshots",
1,
OwnerIds=["self"],
)
)
assert result == ["snapshot"]
assert paginator.paginate_calls == [{"OwnerIds": ["self"]}]
def test_item_filter_limits_selected_items_only(self):
paginator = FakePaginator(
[
{"Items": [{"arn": "skip"}, {"arn": "first"}]},
{"Items": [{"arn": "second"}, {"arn": "third"}]},
]
)
result = list(
iter_limited_paginator_items(
paginator,
"Items",
2,
item_filter=lambda item: item["arn"] != "skip",
)
)
assert result == [{"arn": "first"}, {"arn": "second"}]
assert paginator.pages_requested == 2
def test_limit_zero_or_negative_is_unlimited(self):
resources = list(range(5))
assert list(limit_resources(iter(resources), 0)) == resources
assert list(limit_resources(iter(resources), -3)) == resources
def test_positive_limit_stops_after_selected_resources(self):
pulled = []
def gen():
for i in range(1000):
pulled.append(i)
yield i
result = list(limit_resources(gen(), 100))
assert result == list(range(100))
assert len(pulled) == 100
def test_does_not_reorder_or_inspect_resource_status(self):
resources = ["PASS", "FAIL", "PASS", "FAIL"]
result = list(limit_resources(iter(resources), 3))
assert result == ["PASS", "FAIL", "PASS"]
class Test_get_resource_scan_limit:
def test_per_service_override_wins(self):
config = {
"max_scanned_resources_per_service": 100,
"max_ecs_task_definitions": 25,
}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 25
def test_falls_back_to_global_default(self):
config = {"max_scanned_resources_per_service": 50}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 50
def test_null_per_service_override_falls_back_to_global_default(self):
config = {
"max_scanned_resources_per_service": 50,
"max_ecs_task_definitions": None,
}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 50
def test_default_is_unlimited_when_unset(self):
assert get_resource_scan_limit({}, "max_ecs_task_definitions") is None
def test_null_per_service_override_falls_back_to_unlimited_global_default(self):
config = {"max_ecs_task_definitions": None}
assert get_resource_scan_limit(config, "max_ecs_task_definitions") is None
def test_non_positive_means_unlimited(self):
assert (
get_resource_scan_limit(
{"max_scanned_resources_per_service": 0}, "max_lambda_functions"
)
is None
)
assert (
get_resource_scan_limit(
{"max_lambda_functions": -1}, "max_lambda_functions"
)
is None
)
@@ -6,10 +6,16 @@ from re import search
from unittest.mock import patch
import mock
import pytest
from boto3 import client, resource
from botocore.client import ClientError
from moto import mock_aws
from prowler.providers.aws.services.awslambda.awslambda_service import AuthType, Lambda
from prowler.providers.aws.services.awslambda.awslambda_service import (
AuthType,
Function,
Lambda,
)
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -85,6 +91,367 @@ class Test_Lambda_Service:
awslambda = Lambda(set_mocked_aws_provider([AWS_REGION_US_EAST_1]))
assert awslambda.service == "lambda"
def test_function_limit_selects_latest_functions_for_analysis(self):
awslambda = Lambda.__new__(Lambda)
awslambda.functions = {
"old": Function(
name="old",
arn="old",
security_groups=[],
last_modified="2024-01-01T00:00:00.000+0000",
region=AWS_REGION_EU_WEST_1,
),
"new": Function(
name="new",
arn="new",
security_groups=[],
last_modified="2024-01-02T00:00:00.000+0000",
region=AWS_REGION_EU_WEST_1,
),
}
awslambda.function_limit = 1
awslambda._select_functions_for_analysis()
assert list(awslambda.functions) == ["new"]
def test_function_limit_selects_global_latest_across_regions(self):
class FakePaginator:
def __init__(self, functions):
self.functions = functions
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [{"Functions": self.functions}]
class FakeLambdaClient:
def __init__(self, region, functions):
self.region = region
self.functions = functions
def get_paginator(self, name):
assert name == "list_functions"
return FakePaginator(self.functions)
awslambda = Lambda.__new__(Lambda)
awslambda.functions = {}
awslambda.security_groups_in_use = set()
awslambda.regions_with_functions = set()
awslambda.function_limit = 1
awslambda.audit_resources = []
old_client = FakeLambdaClient(
AWS_REGION_EU_WEST_1,
[
{
"FunctionName": "old",
"FunctionArn": "arn:aws:lambda:eu-west-1:123456789012:function:old",
"LastModified": "2024-01-01T00:00:00.000+0000",
}
],
)
new_client = FakeLambdaClient(
AWS_REGION_US_EAST_1,
[
{
"FunctionName": "new",
"FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:new",
"LastModified": "2024-01-02T00:00:00.000+0000",
}
],
)
awslambda._list_functions(old_client)
awslambda._list_functions(new_client)
awslambda._select_functions_for_analysis()
assert [function.name for function in awslambda.functions.values()] == ["new"]
def test_function_limit_keeps_complete_auxiliary_indexes(self):
class FakePaginator:
def __init__(self, functions):
self.functions = functions
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [{"Functions": self.functions}]
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "list_functions"
return FakePaginator(
[
{
"FunctionName": "old",
"FunctionArn": (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:old"
),
"LastModified": "2024-01-01T00:00:00.000+0000",
"VpcConfig": {"SecurityGroupIds": ["sg-old"]},
},
{
"FunctionName": "new",
"FunctionArn": (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:new"
),
"LastModified": "2024-01-02T00:00:00.000+0000",
"VpcConfig": {"SecurityGroupIds": ["sg-new"]},
},
]
)
awslambda = Lambda.__new__(Lambda)
awslambda.functions = {}
awslambda.security_groups_in_use = set()
awslambda.regions_with_functions = set()
awslambda.function_limit = 1
awslambda.audit_resources = []
awslambda._list_functions(FakeLambdaClient())
awslambda._select_functions_for_analysis()
assert [function.name for function in awslambda.functions.values()] == ["new"]
assert awslambda.security_groups_in_use == {"sg-old", "sg-new"}
assert awslambda.regions_with_functions == {AWS_REGION_US_EAST_1}
def test_list_event_source_mappings_uses_selected_functions_as_api_scope(self):
class FakePaginator:
def __init__(self):
self.paginate_calls = []
def paginate(self, **kwargs):
self.paginate_calls.append(kwargs)
function_name = kwargs["FunctionName"]
return [
{
"EventSourceMappings": [
{
"UUID": f"{function_name}-mapping",
"FunctionArn": (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:{function_name}:1"
),
"EventSourceArn": "arn:aws:sqs:queue",
"State": "Enabled",
"BatchSize": 10,
}
]
}
]
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def __init__(self):
self.paginator = FakePaginator()
def get_paginator(self, name):
assert name == "list_event_source_mappings"
return self.paginator
selected_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:selected"
)
other_region_arn = (
f"arn:aws:lambda:{AWS_REGION_EU_WEST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:other-region"
)
awslambda = Lambda.__new__(Lambda)
awslambda.function_limit = 1
awslambda.functions = {
selected_arn: Function(
name="selected",
arn=selected_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
),
other_region_arn: Function(
name="other-region",
arn=other_region_arn,
security_groups=[],
region=AWS_REGION_EU_WEST_1,
),
}
regional_client = FakeLambdaClient()
awslambda._list_event_source_mappings(regional_client)
assert regional_client.paginator.paginate_calls == [
{"FunctionName": "selected"}
]
assert len(awslambda.functions[selected_arn].event_source_mappings) == 1
assert (
awslambda.functions[selected_arn].event_source_mappings[0].uuid
== "selected-mapping"
)
assert not awslambda.functions[other_region_arn].event_source_mappings
def test_list_event_source_mappings_keeps_unlimited_regional_api_scope(self):
class FakePaginator:
def __init__(self):
self.paginate_calls = []
def paginate(self, **kwargs):
self.paginate_calls.append(kwargs)
return [
{
"EventSourceMappings": [
{
"UUID": "selected-mapping",
"FunctionArn": selected_arn,
"EventSourceArn": "arn:aws:sqs:queue",
"State": "Enabled",
}
]
}
]
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def __init__(self):
self.paginator = FakePaginator()
def get_paginator(self, name):
assert name == "list_event_source_mappings"
return self.paginator
selected_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:selected"
)
awslambda = Lambda.__new__(Lambda)
awslambda.function_limit = None
awslambda.functions = {
selected_arn: Function(
name="selected",
arn=selected_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
)
}
regional_client = FakeLambdaClient()
awslambda._list_event_source_mappings(regional_client)
assert regional_client.paginator.paginate_calls == [{}]
assert len(awslambda.functions[selected_arn].event_source_mappings) == 1
def test_list_event_source_mappings_continues_after_invalid_parameter_value(self):
class FakePaginator:
def paginate(self, **kwargs):
function_name = kwargs["FunctionName"]
if function_name == "deleted":
raise ClientError(
{
"Error": {
"Code": "InvalidParameterValueException",
"Message": "Function no longer exists",
}
},
"ListEventSourceMappings",
)
return [
{
"EventSourceMappings": [
{
"UUID": f"{function_name}-mapping",
"FunctionArn": (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:{function_name}"
),
"EventSourceArn": "arn:aws:sqs:queue",
"State": "Enabled",
}
]
}
]
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "list_event_source_mappings"
return FakePaginator()
deleted_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:deleted"
)
remaining_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:remaining"
)
awslambda = Lambda.__new__(Lambda)
awslambda.function_limit = 2
awslambda.functions = {
deleted_arn: Function(
name="deleted",
arn=deleted_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
),
remaining_arn: Function(
name="remaining",
arn=remaining_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
),
}
awslambda._list_event_source_mappings(FakeLambdaClient())
assert not awslambda.functions[deleted_arn].event_source_mappings
assert len(awslambda.functions[remaining_arn].event_source_mappings) == 1
assert (
awslambda.functions[remaining_arn].event_source_mappings[0].uuid
== "remaining-mapping"
)
def test_list_event_source_mappings_raises_non_transient_client_error(self):
class FakePaginator:
def paginate(self, **kwargs):
raise ClientError(
{
"Error": {
"Code": "AccessDeniedException",
"Message": "Access denied",
}
},
"ListEventSourceMappings",
)
class FakeLambdaClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "list_event_source_mappings"
return FakePaginator()
function_arn = (
f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:"
f"{AWS_ACCOUNT_NUMBER}:function:selected"
)
awslambda = Lambda.__new__(Lambda)
awslambda.function_limit = 1
awslambda.functions = {
function_arn: Function(
name="selected",
arn=function_arn,
security_groups=[],
region=AWS_REGION_US_EAST_1,
)
}
with pytest.raises(ClientError) as error:
awslambda._list_event_source_mappings(FakeLambdaClient())
assert error.value.response["Error"]["Code"] == "AccessDeniedException"
@mock_aws
def test_list_functions(self):
# Create IAM Lambda Role
@@ -253,3 +620,63 @@ class Test_Lambda_Service:
f"{tmp_dir_name}/{files_in_zip[0]}", "r"
) as lambda_code_file:
assert lambda_code_file.read() == LAMBDA_FUNCTION_CODE
@mock_aws
def test_function_limit_exposes_only_selected_functions(self):
lambda_client = client("lambda", region_name=AWS_REGION_US_EAST_1)
iam_client = client("iam", region_name=AWS_REGION_US_EAST_1)
iam_role = iam_client.create_role(
RoleName="test-role",
AssumeRolePolicyDocument="{}",
)["Role"]["Arn"]
for name in ("function-1", "function-2"):
lambda_client.create_function(
FunctionName=name,
Runtime="python3.7",
Role=iam_role,
Handler="lambda_function.lambda_handler",
Code={"ZipFile": create_zip_file().read()},
PackageType="ZIP",
)
awslambda = Lambda(
set_mocked_aws_provider(
audited_regions=[AWS_REGION_US_EAST_1],
audit_config={"max_lambda_functions": 1},
)
)
assert len(awslambda.functions) == 1
@mock_aws
def test_get_function_code_fetches_only_selected_functions(self):
lambda_client = client("lambda", region_name=AWS_REGION_US_EAST_1)
iam_client = client("iam", region_name=AWS_REGION_US_EAST_1)
iam_role = iam_client.create_role(
RoleName="test-role",
AssumeRolePolicyDocument="{}",
)["Role"]["Arn"]
for name in ("function-1", "function-2"):
lambda_client.create_function(
FunctionName=name,
Runtime="python3.7",
Role=iam_role,
Handler="lambda_function.lambda_handler",
Code={"ZipFile": create_zip_file().read()},
PackageType="ZIP",
)
awslambda = Lambda(
set_mocked_aws_provider(
audited_regions=[AWS_REGION_US_EAST_1],
audit_config={"max_lambda_functions": 1},
)
)
fetched = []
def fetch_function_code(function_name, _function_region):
fetched.append(function_name)
return mock.MagicMock()
awslambda._fetch_function_code = fetch_function_code
assert len(list(awslambda._get_function_code())) == 1
assert len(fetched) == 1
@@ -1,11 +1,16 @@
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch
import botocore
from boto3 import client
from moto import mock_aws
from prowler.providers.aws.services.backup.backup_service import Backup
from prowler.providers.aws.services.backup.backup_service import (
Backup,
BackupVault,
RecoveryPoint,
)
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -292,3 +297,248 @@ class TestBackupService:
assert backup.recovery_points[0].backup_vault_region == "eu-west-1"
assert backup.recovery_points[0].tags == []
assert backup.recovery_points[0].encrypted is True
def test_recovery_point_limit_bounds_tag_calls_to_selected_points(self):
class FakePaginator:
def paginate(self, **kwargs):
return [
{
"RecoveryPoints": [
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:new",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 2),
},
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:old",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 1),
},
]
}
]
class FakeBackupClient:
def __init__(self):
self.tag_calls = []
def get_paginator(self, name):
assert name == "list_recovery_points_by_backup_vault"
return FakePaginator()
def list_tags(self, **kwargs):
self.tag_calls.append(kwargs["ResourceArn"])
return {"Tags": {}}
regional_client = FakeBackupClient()
backup = Backup.__new__(Backup)
backup.backup_vaults = [
BackupVault(
arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:vault",
name="vault",
region=AWS_REGION_EU_WEST_1,
encryption="",
recovery_points=2,
locked=False,
)
]
backup.recovery_points = []
backup.recovery_point_limit = 1
backup.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
backup._list_recovery_points()
backup._select_recovery_points_for_analysis()
for recovery_point in backup.recovery_points:
backup._list_tags(recovery_point)
assert [rp.id for rp in backup.recovery_points] == ["new"]
assert regional_client.tag_calls == [
"arn:aws:backup:eu-west-1:123456789012:recovery-point:new"
]
def test_recovery_point_limit_selects_global_newest_across_vaults(self):
class FakePaginator:
def __init__(self, recovery_points_by_vault):
self.recovery_points_by_vault = recovery_points_by_vault
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [
{
"RecoveryPoints": self.recovery_points_by_vault[
kwargs["BackupVaultName"]
]
}
]
class FakeBackupClient:
def __init__(self, recovery_points_by_vault):
self.recovery_points_by_vault = recovery_points_by_vault
def get_paginator(self, name):
assert name == "list_recovery_points_by_backup_vault"
return FakePaginator(self.recovery_points_by_vault)
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 1
backup.recovery_points = []
backup.backup_vaults = [
BackupVault(
arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:old-vault",
name="old-vault",
region=AWS_REGION_EU_WEST_1,
encryption="",
recovery_points=1,
locked=False,
),
BackupVault(
arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:new-vault",
name="new-vault",
region=AWS_REGION_EU_WEST_1,
encryption="",
recovery_points=1,
locked=False,
),
]
backup.regional_clients = {
AWS_REGION_EU_WEST_1: FakeBackupClient(
{
"old-vault": [
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:old",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 1),
}
],
"new-vault": [
{
"RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:new",
"IsEncrypted": True,
"CreationDate": datetime(2024, 1, 2),
}
],
}
)
}
backup._list_recovery_points()
backup._select_recovery_points_for_analysis()
assert [rp.id for rp in backup.recovery_points] == ["new"]
def test_recovery_point_limit_exposes_only_selected_resources(self):
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 2
backup.recovery_points = []
backup.backup_vaults = [
BackupVault(
arn="arn",
name="vault",
region="eu-west-1",
encryption="",
recovery_points=3,
locked=False,
)
]
class Paginator:
def paginate(self, **_kwargs):
return [
{
"RecoveryPoints": [
{
"RecoveryPointArn": f"arn:aws:backup:eu-west-1:123456789012:recovery-point:{i}",
"IsEncrypted": True,
}
for i in range(3)
]
}
]
backup.regional_clients = {
"eu-west-1": SimpleNamespace(get_paginator=lambda _: Paginator())
}
tagged = []
def list_tags(recovery_point):
tagged.append(recovery_point.arn)
backup._list_tags = list_tags
backup._list_recovery_points()
backup._select_recovery_points_for_analysis()
for recovery_point in backup.recovery_points:
backup._list_tags(recovery_point)
assert len(backup.recovery_points) == 2
assert len(tagged) == 2
def test_recovery_point_limit_uses_deterministic_tie_breaker(self):
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 2
backup.recovery_points = [
RecoveryPoint(
arn="arn:aws:backup:us-east-1:123456789012:recovery-point:z",
id="z",
region="us-east-1",
backup_vault_name="vault-b",
encrypted=True,
backup_vault_region="us-east-1",
),
RecoveryPoint(
arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:b",
id="b",
region="eu-west-1",
backup_vault_name="vault-b",
encrypted=True,
backup_vault_region="eu-west-1",
),
RecoveryPoint(
arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:a",
id="a",
region="eu-west-1",
backup_vault_name="vault-a",
encrypted=True,
backup_vault_region="eu-west-1",
),
]
backup._select_recovery_points_for_analysis()
assert [rp.id for rp in backup.recovery_points] == ["a", "b"]
def test_recovery_point_limit_keeps_newest_before_tie_breaker(self):
backup = Backup.__new__(Backup)
backup.recovery_point_limit = 2
backup.recovery_points = [
RecoveryPoint(
arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:older-a",
id="older-a",
region="eu-west-1",
backup_vault_name="vault-a",
encrypted=True,
backup_vault_region="eu-west-1",
creation_date=datetime(2024, 1, 1),
),
RecoveryPoint(
arn="arn:aws:backup:us-east-1:123456789012:recovery-point:newer-z",
id="newer-z",
region="us-east-1",
backup_vault_name="vault-z",
encrypted=True,
backup_vault_region="us-east-1",
creation_date=datetime(2024, 1, 2),
),
RecoveryPoint(
arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:missing-date",
id="missing-date",
region="eu-west-1",
backup_vault_name="vault-a",
encrypted=True,
backup_vault_region="eu-west-1",
),
]
backup._select_recovery_points_for_analysis()
assert [rp.id for rp in backup.recovery_points] == ["newer-z", "older-a"]
@@ -227,6 +227,72 @@ class Test_bedrock_model_invocation_logs_encryption_enabled:
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_tags == []
def test_cloudwatch_logging_uses_complete_log_group_index(self):
from prowler.providers.aws.services.bedrock.bedrock_service import (
LoggingConfiguration,
)
from prowler.providers.aws.services.cloudwatch.cloudwatch_service import (
LogGroup,
)
bedrock_client = mock.MagicMock()
bedrock_client.logging_configurations = {
AWS_REGION_US_EAST_1: LoggingConfiguration(
enabled=True,
cloudwatch_log_group="Test",
)
}
bedrock_client._get_model_invocation_logging_arn_template.return_value = (
"arn:aws:bedrock:us-east-1:123456789012:model-invocation-logging"
)
logs_client = mock.MagicMock()
logs_client.audited_partition = "aws"
logs_client.audited_account = "123456789012"
logs_client.log_groups = {}
logs_client.all_log_groups = {
"arn:aws:logs:us-east-1:123456789012:log-group:Test:*": LogGroup(
arn="arn:aws:logs:us-east-1:123456789012:log-group:Test:*",
name="Test",
retention_days=30,
never_expire=False,
kms_id=None,
region=AWS_REGION_US_EAST_1,
)
}
s3_client = mock.MagicMock()
s3_client.audited_partition = "aws"
s3_client.buckets = {}
with (
mock.patch(
"prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.bedrock_client",
new=bedrock_client,
),
mock.patch(
"prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.logs_client",
new=logs_client,
),
mock.patch(
"prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.s3_client",
new=s3_client,
),
):
from prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled import (
bedrock_model_invocation_logs_encryption_enabled,
)
check = bedrock_model_invocation_logs_encryption_enabled()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "Bedrock Model Invocation logs are not encrypted in CloudWatch Log Group: Test."
)
@mock_aws
def test_s3_and_cloudwatch_logging_encrypted(self):
logs_client = client("logs", region_name=AWS_REGION_US_EAST_1)
@@ -4,6 +4,7 @@ from moto import mock_aws
from prowler.providers.aws.services.cloudwatch.cloudwatch_service import (
CloudWatch,
LogGroup,
Logs,
)
from prowler.providers.aws.services.cloudwatch.lib.metric_filters import (
@@ -188,7 +189,9 @@ class Test_CloudWatch_Service:
arn = f"arn:aws:logs:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:log-group:/log-group/test:*"
logs = Logs(aws_provider)
assert len(logs.log_groups) == 1
assert len(logs.all_log_groups) == 1
assert arn in logs.log_groups
assert arn in logs.all_log_groups
assert logs.log_groups[arn].name == "/log-group/test"
assert logs.log_groups[arn].retention_days == 400
assert logs.log_groups[arn].kms_id == "test_kms_id"
@@ -212,7 +215,9 @@ class Test_CloudWatch_Service:
arn = f"arn:aws:logs:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:log-group:/log-group/test:*"
logs = Logs(aws_provider)
assert len(logs.log_groups) == 1
assert len(logs.all_log_groups) == 1
assert arn in logs.log_groups
assert arn in logs.all_log_groups
assert logs.log_groups[arn].name == "/log-group/test"
assert logs.log_groups[arn].never_expire
# Since it never expires we don't use the retention_days
@@ -221,6 +226,190 @@ class Test_CloudWatch_Service:
assert logs.log_groups[arn].region == AWS_REGION_US_EAST_1
assert logs.log_groups[arn].tags == [{}]
def test_log_group_limit_exposes_only_selected_resources(self):
class FakeLogsClient:
def __init__(self):
self.filter_calls = []
def filter_log_events(self, **kwargs):
self.filter_calls.append(kwargs["logGroupName"])
return {"events": []}
regional_client = FakeLogsClient()
logs = Logs.__new__(Logs)
logs.log_group_limit = 1
logs._log_groups_hydrated = set()
logs.regional_clients = {AWS_REGION_US_EAST_1: regional_client}
logs.events_per_log_group_threshold = 1000
logs.log_groups = {
f"arn:{i}": LogGroup(
arn=f"arn:{i}",
name=f"log-{i}",
retention_days=30,
never_expire=False,
kms_id=None,
creation_time=i,
region=AWS_REGION_US_EAST_1,
)
for i in range(3)
}
tagged = []
def list_tags(log_group):
tagged.append(log_group.arn)
logs._list_tags_for_resource = list_tags
logs._select_log_groups_for_analysis()
for log_group in logs.log_groups.values():
logs._list_tags_for_resource(log_group)
logs._get_log_events(log_group)
assert list(logs.log_groups) == ["arn:2"]
assert tagged == ["arn:2"]
assert regional_client.filter_calls == ["log-2"]
def test_log_group_limit_selects_global_newest_across_regions(self):
class FakePaginator:
def __init__(self, log_groups):
self.log_groups = log_groups
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [{"logGroups": self.log_groups}]
class FakeLogsClient:
def __init__(self, region, log_groups):
self.region = region
self.log_groups = log_groups
def get_paginator(self, name):
assert name == "describe_log_groups"
return FakePaginator(self.log_groups)
logs = Logs.__new__(Logs)
logs.all_log_groups = {}
logs.log_groups = {}
logs.log_group_limit = 1
logs.audit_resources = []
logs._describe_log_groups(
FakeLogsClient(
"eu-west-1",
[
{
"arn": "arn:aws:logs:eu-west-1:123456789012:log-group:old:*",
"logGroupName": "old",
"creationTime": 1,
}
],
)
)
logs._describe_log_groups(
FakeLogsClient(
AWS_REGION_US_EAST_1,
[
{
"arn": "arn:aws:logs:us-east-1:123456789012:log-group:new:*",
"logGroupName": "new",
"creationTime": 2,
}
],
)
)
logs._select_log_groups_for_analysis()
assert [log_group.name for log_group in logs.log_groups.values()] == ["new"]
assert [log_group.name for log_group in logs.all_log_groups.values()] == [
"old",
"new",
]
def test_metric_filters_use_complete_log_group_index(self):
class FakePaginator:
def paginate(self):
return [
{
"metricFilters": [
{
"filterName": "test-filter",
"filterPattern": "test-pattern",
"logGroupName": "old",
"metricTransformations": [
{"metricName": "test-metric"}
],
}
]
}
]
class FakeLogsClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "describe_metric_filters"
return FakePaginator()
logs = Logs.__new__(Logs)
old_log_group = LogGroup(
arn="arn:old",
name="old",
retention_days=30,
never_expire=False,
kms_id=None,
creation_time=1,
region=AWS_REGION_US_EAST_1,
)
logs.audited_partition = "aws"
logs.audited_account = AWS_ACCOUNT_NUMBER
logs.audit_resources = []
logs.metric_filters = []
logs.log_groups = {}
logs.all_log_groups = {old_log_group.arn: old_log_group}
logs._log_groups_hydrated = set()
logs._list_tags_for_resource = lambda log_group: None
logs._describe_metric_filters(FakeLogsClient())
assert len(logs.metric_filters) == 1
assert logs.metric_filters[0].log_group == old_log_group
def test_log_group_collection_recovers_all_log_groups_after_access_denied(self):
class FakePaginator:
def paginate(self):
return [
{
"logGroups": [
{
"arn": "arn:aws:logs:us-east-1:123456789012:log-group:success:*",
"logGroupName": "success",
"creationTime": 1,
}
]
}
]
class FakeLogsClient:
region = AWS_REGION_US_EAST_1
def get_paginator(self, name):
assert name == "describe_log_groups"
return FakePaginator()
logs = Logs.__new__(Logs)
logs.all_log_groups = None
logs.log_groups = None
logs.audit_resources = []
logs._describe_log_groups(FakeLogsClient())
assert list(logs.all_log_groups) == [
"arn:aws:logs:us-east-1:123456789012:log-group:success:*"
]
assert list(logs.log_groups) == [
"arn:aws:logs:us-east-1:123456789012:log-group:success:*"
]
class Test_build_metric_filter_pattern:
@pytest.mark.parametrize("bad_operator", ["==", "~=", "<", "<>", ">=", ""])
@@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest.mock import patch
import botocore
@@ -6,6 +7,7 @@ from prowler.providers.aws.services.codeartifact.codeartifact_service import (
CodeArtifact,
LatestPackageVersionStatus,
OriginInformationValues,
Repository,
RestrictionValues,
)
from tests.providers.aws.utils import (
@@ -208,6 +210,104 @@ class Test_CodeArtifact_Service:
== OriginInformationValues.INTERNAL
)
def test_package_limit_bounds_package_version_lookups_to_selected_packages(self):
class FakePaginator:
def paginate(self, **kwargs):
return [
{
"packages": [
{
"format": "pypi",
"package": "first-package",
"originConfiguration": {
"restrictions": {
"publish": "ALLOW",
"upstream": "ALLOW",
}
},
},
{
"format": "pypi",
"package": "second-package",
"originConfiguration": {
"restrictions": {
"publish": "ALLOW",
"upstream": "ALLOW",
}
},
},
]
}
]
class FakeCodeArtifactClient:
def __init__(self):
self.version_calls = []
def get_paginator(self, name):
assert name == "list_packages"
return FakePaginator()
def list_package_versions(self, **kwargs):
self.version_calls.append(kwargs["package"])
return {
"versions": [
{
"version": "1.0.0",
"status": "Published",
"origin": {"originType": "INTERNAL"},
}
]
}
regional_client = FakeCodeArtifactClient()
codeartifact = CodeArtifact.__new__(CodeArtifact)
codeartifact.repositories = {
TEST_REPOSITORY_ARN: Repository(
name="test-repository",
arn=TEST_REPOSITORY_ARN,
domain_name="test-domain",
domain_owner=AWS_ACCOUNT_NUMBER,
region=AWS_REGION_EU_WEST_1,
)
}
codeartifact._packages_listed = set()
codeartifact.package_limit = 1
codeartifact.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
pairs = list(codeartifact._load_packages_for_analysis())
assert [package.name for _, package in pairs] == ["first-package"]
assert regional_client.version_calls == ["first-package"]
def test_package_limit_exposes_only_selected_packages(self):
codeartifact = CodeArtifact.__new__(CodeArtifact)
codeartifact.package_limit = 2
codeartifact._packages_listed = set()
repository = Repository(
name="repository",
arn="repo",
domain_name="domain",
domain_owner=AWS_ACCOUNT_NUMBER,
region=AWS_REGION_EU_WEST_1,
)
codeartifact.repositories = {repository.arn: repository}
enriched = []
def iter_repository_packages(repository, limit=None):
for index in range(3):
if limit is not None and index >= limit:
return
enriched.append(index)
yield SimpleNamespace(name=f"package-{index}")
codeartifact._iter_repository_packages = iter_repository_packages
packages = list(codeartifact._load_packages_for_analysis())
assert [package.name for _, package in packages] == ["package-0", "package-1"]
assert enriched == [0, 1]
def mock_make_api_call_no_namespace(self, operation_name, kwarg):
"""Mock for packages without a namespace to exercise the else branch"""
@@ -14,6 +14,57 @@ EXAMPLE_AMI_ID = "ami-12c6146b"
class Test_ec2_securitygroup_not_used:
def test_ec2_sg_used_by_lambda_outside_selected_analysis_limit(self):
from prowler.providers.aws.services.ec2.ec2_service import SecurityGroup
sg_id = "sg-limited-out"
sg_name = "lambda-sg"
security_group = SecurityGroup(
name=sg_name,
region=AWS_REGION_US_EAST_1,
arn=f"arn:aws:ec2:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:security-group/{sg_id}",
id=sg_id,
vpc_id="vpc-test",
associated_sgs=[],
network_interfaces=[],
ingress_rules=[],
egress_rules=[],
tags=[],
)
ec2_client = mock.MagicMock()
ec2_client.security_groups = {security_group.arn: security_group}
awslambda_client = mock.MagicMock()
awslambda_client.functions = {}
awslambda_client.security_groups_in_use = {sg_id}
aws_provider = set_mocked_aws_provider()
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used.ec2_client",
new=ec2_client,
),
mock.patch(
"prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used.awslambda_client",
new=awslambda_client,
),
):
from prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used import (
ec2_securitygroup_not_used,
)
result = ec2_securitygroup_not_used().execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Security group {sg_name} ({sg_id}) it is being used."
)
@mock_aws
def test_ec2_default_sgs(self):
# Create EC2 Mocked Resources
@@ -11,7 +11,7 @@ from freezegun import freeze_time
from moto import mock_aws
from prowler.config.config import encoding_format_utf_8
from prowler.providers.aws.services.ec2.ec2_service import EC2
from prowler.providers.aws.services.ec2.ec2_service import EC2, Snapshot
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
@@ -103,6 +103,99 @@ class Test_EC2_Service:
ec2 = EC2(aws_provider)
assert ec2.audited_account == AWS_ACCOUNT_NUMBER
def test_snapshot_limit_bounds_public_attribute_calls_to_latest_selected(self):
class FakeEC2Client:
def __init__(self):
self.calls = []
def describe_snapshot_attribute(self, **kwargs):
self.calls.append(kwargs["SnapshotId"])
return {"CreateVolumePermissions": []}
regional_client = FakeEC2Client()
ec2 = EC2.__new__(EC2)
ec2.snapshots = [
Snapshot(
id="snap-old",
arn="arn:aws:ec2:eu-west-1:123456789012:snapshot/snap-old",
region=AWS_REGION_EU_WEST_1,
encrypted=True,
start_time=datetime(2024, 1, 1),
volume="vol-old",
),
Snapshot(
id="snap-new",
arn="arn:aws:ec2:eu-west-1:123456789012:snapshot/snap-new",
region=AWS_REGION_EU_WEST_1,
encrypted=True,
start_time=datetime(2024, 1, 2),
volume="vol-new",
),
]
ec2.snapshot_limit = 1
ec2.regional_clients = {AWS_REGION_EU_WEST_1: regional_client}
ec2._select_snapshots_for_analysis()
for snapshot in ec2.snapshots:
ec2._determine_public_snapshots(snapshot)
assert [snapshot.id for snapshot in ec2.snapshots] == ["snap-new"]
assert regional_client.calls == ["snap-new"]
def test_snapshot_limit_preserves_volume_index_and_selects_global_latest(self):
class FakePaginator:
def __init__(self, snapshots):
self.snapshots = snapshots
def paginate(self, **kwargs):
assert "PageSize" not in kwargs
return [{"Snapshots": self.snapshots}]
class FakeEC2Client:
def __init__(self, region, snapshots):
self.region = region
self.snapshots = snapshots
def get_paginator(self, name):
assert name == "describe_snapshots"
return FakePaginator(self.snapshots)
ec2 = EC2.__new__(EC2)
ec2.snapshots = []
ec2.volumes_with_snapshots = {}
ec2.regions_with_snapshots = {}
ec2.snapshot_limit = 1
ec2.audit_resources = []
ec2.audited_partition = "aws"
ec2.audited_account = AWS_ACCOUNT_NUMBER
old_client = FakeEC2Client(
AWS_REGION_EU_WEST_1,
[
{
"SnapshotId": "snap-old",
"VolumeId": "vol-old",
"StartTime": datetime(2024, 1, 1),
}
],
)
new_client = FakeEC2Client(
AWS_REGION_US_EAST_1,
[
{
"SnapshotId": "snap-new",
"VolumeId": "vol-new",
"StartTime": datetime(2024, 1, 2),
}
],
)
ec2._describe_snapshots(old_client)
ec2._describe_snapshots(new_client)
ec2._select_snapshots_for_analysis()
assert ec2.volumes_with_snapshots == {"vol-old": True, "vol-new": True}
assert [snapshot.id for snapshot in ec2.snapshots] == ["snap-new"]
# Test EC2 Describe Instances
@mock_aws
@freeze_time(MOCK_DATETIME)
@@ -346,6 +439,24 @@ class Test_EC2_Service:
assert not snapshot.encrypted
assert snapshot.public
@mock_aws
def test_snapshot_limit_exposes_only_selected_snapshots(self):
ec2_client = client("ec2", region_name=AWS_REGION_US_EAST_1)
ec2_resource = resource("ec2", region_name=AWS_REGION_US_EAST_1)
volume_id = ec2_resource.create_volume(
AvailabilityZone="us-east-1a",
Size=80,
VolumeType="gp2",
).id
for _ in range(3):
ec2_client.create_snapshot(VolumeId=volume_id)
aws_provider = set_mocked_aws_provider(
[AWS_REGION_US_EAST_1], audit_config={"max_ebs_snapshots": 1}
)
ec2 = EC2(aws_provider)
assert len(ec2.snapshots) == 1
# Test EC2 Instance User Data
@mock_aws
def test_get_instance_user_data(self):
@@ -3,7 +3,11 @@ from unittest.mock import patch
import botocore
from prowler.providers.aws.services.ecs.ecs_service import ECS
from tests.providers.aws.utils import AWS_REGION_EU_WEST_1, set_mocked_aws_provider
from tests.providers.aws.utils import (
AWS_REGION_EU_WEST_1,
AWS_REGION_US_EAST_1,
set_mocked_aws_provider,
)
make_api_call = botocore.client.BaseClient._make_api_call
@@ -115,6 +119,23 @@ def mock_generate_regional_clients(provider, service):
return {AWS_REGION_EU_WEST_1: regional_client}
def mock_generate_multi_region_clients(provider, service):
eu_west_1_client = provider._session.current_session.client(
service, region_name=AWS_REGION_EU_WEST_1
)
eu_west_1_client.region = AWS_REGION_EU_WEST_1
us_east_1_client = provider._session.current_session.client(
service, region_name=AWS_REGION_US_EAST_1
)
us_east_1_client.region = AWS_REGION_US_EAST_1
return {
AWS_REGION_EU_WEST_1: eu_west_1_client,
AWS_REGION_US_EAST_1: us_east_1_client,
}
@patch(
"prowler.providers.aws.aws_provider.AwsProvider.generate_regional_clients",
new=mock_generate_regional_clients,
@@ -139,7 +160,6 @@ class Test_ECS_Service:
ecs = ECS(aws_provider)
assert ecs.session.__class__.__name__ == "Session"
# Test list ECS task definitions
@patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
def test_list_task_definitions(self):
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
@@ -201,6 +221,169 @@ class Test_ECS_Service:
.readonly_rootfilesystem
)
def test_task_definitions_are_loaded_once_for_analysis(self):
describe_calls = []
list_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
list_calls.append(kwarg)
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == [
"3",
"2",
"1",
]
assert list_calls == [{"sort": "DESC"}]
assert len(describe_calls) == 3
def test_task_definition_limit_exposes_only_selected_resources(self):
describe_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1], audit_config={"max_ecs_task_definitions": 2}
)
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == ["3", "2"]
assert len(describe_calls) == 2
def test_task_definition_limit_bounds_describe_calls(self):
describe_calls = []
def counting_make_api_call(self, operation_name, kwarg):
if operation_name == "ListTaskDefinitions":
return {
"taskDefinitionArns": [
f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}"
for i in (3, 2, 1)
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
return mock_make_api_call(self, operation_name, kwarg)
with patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1], audit_config={"max_ecs_task_definitions": 1}
)
ecs = ECS(aws_provider)
assert [td.revision for td in ecs.task_definitions.values()] == ["3"]
assert describe_calls == [
"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:3"
]
def test_task_definition_limit_does_not_starve_later_regions(self):
describe_calls = []
def counting_make_api_call(self, operation_name, kwarg):
region = self.meta.region_name
if operation_name == "ListTaskDefinitions":
task_definition_revisions = {
AWS_REGION_EU_WEST_1: (3, 2, 1),
AWS_REGION_US_EAST_1: (9,),
}[region]
return {
"taskDefinitionArns": [
f"arn:aws:ecs:{region}:123456789012:task-definition/fam:{revision}"
for revision in task_definition_revisions
]
}
if operation_name == "DescribeTaskDefinition":
describe_calls.append(kwarg["taskDefinition"])
return {
"taskDefinition": {
"containerDefinitions": [],
"networkMode": "bridge",
"pidMode": "",
"tags": [],
}
}
if operation_name == "ListClusters":
return {"clusterArns": []}
return mock_make_api_call(self, operation_name, kwarg)
with (
patch(
"prowler.providers.aws.aws_provider.AwsProvider.generate_regional_clients",
new=mock_generate_multi_region_clients,
),
patch(
"botocore.client.BaseClient._make_api_call", new=counting_make_api_call
),
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1, AWS_REGION_US_EAST_1],
audit_config={"max_ecs_task_definitions": 2},
)
ecs = ECS(aws_provider)
assert [td.region for td in ecs.task_definitions.values()] == [
AWS_REGION_EU_WEST_1,
AWS_REGION_US_EAST_1,
]
assert set(describe_calls) == {
"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:3",
"arn:aws:ecs:us-east-1:123456789012:task-definition/fam:9",
}
# Test list ECS clusters
@patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
def test_list_clusters(self):
@@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest import mock
from prowler.providers.aws.services.inspector2.inspector2_service import Inspector
@@ -13,6 +14,65 @@ FINDING_ARN = (
class Test_inspector2_is_enabled:
def test_lambda_disabled_with_region_hidden_by_function_analysis_limit(self):
inspector2_client = mock.MagicMock()
inspector2_client.provider = SimpleNamespace(scan_unused_services=False)
inspector2_client.inspectors = [
Inspector(
id=AWS_ACCOUNT_NUMBER,
arn=f"arn:aws:inspector2:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:inspector2",
status="ENABLED",
ec2_status="ENABLED",
ecr_status="ENABLED",
lambda_status="DISABLED",
lambda_code_status="ENABLED",
region=AWS_REGION_EU_WEST_1,
)
]
awslambda_client = mock.MagicMock()
awslambda_client.functions = {}
awslambda_client.regions_with_functions = {AWS_REGION_EU_WEST_1}
ec2_client = mock.MagicMock()
ec2_client.instances = []
ecr_client = mock.MagicMock()
ecr_client.registries = {AWS_REGION_EU_WEST_1: SimpleNamespace(repositories=[])}
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.inspector2_client",
new=inspector2_client,
),
mock.patch(
"prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.awslambda_client",
new=awslambda_client,
),
mock.patch(
"prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.ec2_client",
new=ec2_client,
),
mock.patch(
"prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.ecr_client",
new=ecr_client,
),
):
from prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled import (
inspector2_is_enabled,
)
result = inspector2_is_enabled().execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "Inspector2 is not enabled for the following services: Lambda."
)
def test_inspector2_disabled(self):
# Mock the inspector2 client
inspector2_client = mock.MagicMock