diff --git a/docs/user-guide/cli/tutorials/configuration_file.mdx b/docs/user-guide/cli/tutorials/configuration_file.mdx index 8a7550de06..a0c920a59b 100644 --- a/docs/user-guide/cli/tutorials/configuration_file.mdx +++ b/docs/user-guide/cli/tutorials/configuration_file.mdx @@ -88,6 +88,65 @@ The following list includes all the AWS checks with configurable variables that | `vpc_endpoint_services_allowed_principals_trust_boundaries` | `trusted_account_ids` | List of Strings | | `opensearch_service_domains_not_publicly_accessible` | `trusted_ips` | List of Strings | +### Resource Scan Limit + + + +Some AWS services accumulate large numbers of resources (EBS snapshots, backup recovery points, CloudWatch log groups, Lambda functions, ECS task definitions, and CodeArtifact packages). Scanning every resource increases scan time, cost, API throttling, and finding volume. By default, Prowler scans every resource. Configure a positive resource scan limit to cap how many resources Prowler analyzes for these high-volume AWS resource paths. + +The global default applies to the supported resources below and is overridable per resource. The default global value is `0`, which disables the limit and scans every resource. A global `null` value is also unlimited. For per-resource values, `null` means inherit the global default; set `0` or a negative value to disable that resource limit explicitly. Positive values enable limits. + + +When positive resource scan limits are configured, compliance results are based only on the selected resources, not on the full set of matching resources in the account. Treat compliance summaries and percentages as partial evidence, because unselected resources are not analyzed and can change the real compliance posture. + + +#### Global Behavior + +Resource scan limits select resources for analysis. They do not cap, prioritize, or reorder findings. + +* **`0`, negative, or global `null` values:** Disable the limit and keep the legacy behavior for that resource path. Prowler analyzes every discovered matching resource. +* **Positive values:** Select at most that many resources for the affected resource path. A selected resource can produce zero, one, or many findings. +* **No PASS/FAIL prioritization:** Prowler does not inspect the compliance result before selecting resources. Limits do not prefer failed resources, passed resources, or resources with more findings. +* **Latest-first where possible:** When AWS exposes timestamps or useful ordering, Prowler selects the newest resources first. When AWS only exposes API order, Prowler preserves that API order and documents the behavior as best effort. +* **Findings are downstream:** Checks only evaluate the resources exposed by the service client after selection. Findings from unselected resources are not produced because those resources are not analyzed. + +Exact list API call reduction depends on each AWS API's ordering and pagination capabilities. When Prowler must enumerate candidates locally to select the latest resources, list calls may still read candidates, but expensive per-resource enrichment calls are bounded to the selected resources for the supported paths below. + +#### Full Collections Versus Limited Analysis Sets + +Some checks need lightweight evidence from a complete resource collection to avoid incorrect cross-service conclusions, while other checks perform primary analysis on a limited resource set. + +Prowler keeps full lightweight collections where they are needed for cross-service evidence. For example: + +* **Lambda security groups and regions:** Prowler records security groups used by all discovered Lambda functions and the regions where functions exist before it limits Lambda functions for primary Lambda checks. This helps Amazon EC2 and Amazon Inspector checks avoid false positives such as treating Lambda security groups as unused or assuming a region has no Lambda functions. +* **CloudWatch `all_log_groups`:** Prowler records all discovered CloudWatch log groups in `all_log_groups` before limiting the primary `log_groups` analysis set. Other services can still resolve log group evidence, while CloudWatch log group checks only analyze the selected log groups. + +This split is intentional. It reduces expensive per-resource analysis calls without discarding lightweight context that other services need for accurate results. + +#### Supported AWS Resource Limits + +| Value | Scope | Type | +|-------|-------|------| +| `max_scanned_resources_per_service` | Global default for all supported high-volume AWS resources (default `0`, disabled/unlimited) | Integer | +| `max_ebs_snapshots` | EBS snapshots (`ec2_ebs_*` checks) | Integer | +| `max_backup_recovery_points` | Backup recovery points (`backup_recovery_point_*`) | Integer | +| `max_cloudwatch_log_groups` | CloudWatch log groups (`cloudwatch_log_group_*`) | Integer | +| `max_lambda_functions` | Lambda functions (`awslambda_function_*`) | Integer | +| `max_ecs_task_definitions` | ECS task definitions (`ecs_task_definitions_*`) | Integer | +| `max_codeartifact_packages` | CodeArtifact packages (`codeartifact_packages_*`) | Integer | + +#### Resource Limit Behavior By Resource Path + +| Resource Path | What Prowler Discovers | What A Positive Limit Selects For Analysis | Ordering And Latest Behavior | AWS Calls Reduced | Drawbacks And Consequences | +|---------------|------------------------|--------------------------------------------|------------------------------|-------------------|----------------------------| +| EBS snapshots (`max_ebs_snapshots`) | Prowler lists self-owned snapshots and keeps lightweight evidence that volumes and regions have snapshots. | The selected EBS snapshots exposed to `ec2_ebs_*` checks. | Prowler sorts discovered snapshots by `StartTime` newest first, then applies the limit. Snapshots without a timestamp sort last. | Bounds expensive per-snapshot public attribute checks to selected snapshots. Snapshot listing still runs so Prowler can choose the newest snapshots and keep volume/region evidence. | Older unselected snapshots are not analyzed by snapshot checks. A public, unencrypted, or otherwise noncompliant older snapshot can be missed when the limit is lower than the number of snapshots. | +| Backup recovery points (`max_backup_recovery_points`) | Prowler lists backup vaults, plans, selections, and recovery point candidates in discovered vaults. | The selected recovery points exposed to `backup_recovery_point_*` checks and tag hydration. | Prowler sorts discovered recovery points by `CreationDate` newest first across vaults, then applies the limit. Recovery points without a timestamp sort last. | Bounds recovery point tag calls to selected recovery points. Vault and recovery point list calls still run so Prowler can choose the newest points. | Older unselected recovery points are not analyzed. A nonencrypted or otherwise noncompliant older recovery point can be missed. | +| CloudWatch log groups (`max_cloudwatch_log_groups`) | Prowler lists log groups into both `all_log_groups` and the primary `log_groups` collection. `all_log_groups` remains available as lightweight cross-service evidence. | The selected log groups exposed to `cloudwatch_log_group_*` checks, tag hydration, and log event retrieval for checks that need log contents. | Prowler sorts discovered log groups by `creationTime` newest first, then applies the limit. Log groups without a creation time sort last. | Bounds tag calls and log event retrieval to selected log groups. Log group listing still runs to build `all_log_groups` and choose newest log groups. | Older unselected log groups are not analyzed by CloudWatch log group checks. Retention, encryption, or secrets-in-logs issues in older log groups can be missed, although cross-service evidence can still use `all_log_groups`. | +| Lambda functions (`max_lambda_functions`) | Prowler lists Lambda functions and records lightweight security group and region evidence for all discovered functions. | The selected Lambda functions exposed to `awslambda_function_*` checks and per-function enrichment such as tags, policies, function URLs, and event source mappings. | Prowler sorts discovered functions by `LastModified` newest first, then applies the limit. Functions without `LastModified` sort last. | Bounds per-function enrichment calls to selected functions. Function listing still runs to choose newest functions and keep security group/region evidence. | Older unselected functions are not analyzed by Lambda checks. Runtime, policy, URL, environment secret, or dead-letter queue issues in older functions can be missed. Cross-service checks can still use full Lambda security group and region evidence to avoid false positives. | +| ECS task definitions (`max_ecs_task_definitions`) | Prowler lists ECS task definition ARN candidates in each region. Candidate ARNs can remain visible and discoverable through AWS list operations, even when not all are described. | The selected task definitions that Prowler describes and exposes to `ecs_task_definitions_*` checks. | Selection is not random. Prowler calls `ListTaskDefinitions` with `sort=DESC`, which asks AWS to return task definition ARNs in descending family and revision order. Prowler then interleaves regional candidate lists to avoid starving later regions before applying the limit. This selects the latest task definition revisions according to the ARN order AWS provides, while preserving regional fairness. | Bounds `DescribeTaskDefinition` calls to selected task definitions. Prowler may still list candidates so it can select the bounded set and keep discovery deterministic. | Unselected task definitions are not described or analyzed. Issues in older task definition revisions, or in lower-priority families outside the selected AWS `sort=DESC` order, can be missed. Because ECS ordering is family/revision based rather than a registration timestamp sort across every family, this is latest-first according to AWS task definition ARN ordering, not a global newest-by-time guarantee. | +| CodeArtifact packages (`max_codeartifact_packages`) | Prowler lists CodeArtifact repositories and lazily lists packages inside them. | The selected packages exposed to `codeartifact_packages_*` checks, including latest-version metadata for those packages. | AWS `ListPackages` does not provide a newest-package timestamp ordering in this path. Prowler preserves repository order and package API order, then applies the limit. Latest package version metadata is retrieved for selected packages with `sortBy=PUBLISHED_TIME` and `maxResults=1`. | Bounds `ListPackageVersions` calls to selected packages and can stop package listing once the limit is reached. Repository listing still runs. | Package selection is best effort by API order, not newest package order. Packages outside the selected repository/API order are not analyzed, so origin restriction or latest-version issues can be missed. | + +Use limits when scan duration, API throttling, or cost are more important than exhaustive coverage for these high-volume resources. Keep limits disabled when you need complete evidence for every resource in the affected checks. ### Validating Discovered Secrets @@ -219,6 +278,19 @@ aws: # AWS Global Configuration # aws.mute_non_default_regions --> Set to True to muted failed findings in non-default regions for AccessAnalyzer, GuardDuty, SecurityHub, DRS and Config mute_non_default_regions: False + + # AWS Resource Scan Limit Configuration + # Disabled by default: scan every resource unless a positive limit is configured. + # Findings are not capped. Set to 0 (or a negative value) to disable the limit. + # aws.max_scanned_resources_per_service --> global default for all services below + max_scanned_resources_per_service: 0 + # Per-service overrides. Leave as null to fall back to the global default. + max_ebs_snapshots: null + max_backup_recovery_points: null + max_cloudwatch_log_groups: null + max_lambda_functions: null + max_ecs_task_definitions: null + max_codeartifact_packages: null # If you want to mute failed findings only in specific regions, create a file with the following syntax and run it with `prowler aws -w mutelist.yaml`: # Mutelist: # Accounts: diff --git a/prowler/CHANGELOG.md b/prowler/CHANGELOG.md index 9e21f3bfa0..11e694d9ba 100644 --- a/prowler/CHANGELOG.md +++ b/prowler/CHANGELOG.md @@ -38,6 +38,10 @@ All notable changes to the **Prowler SDK** are documented in this file. - GitHub default branch protection checks now evaluate repository rulesets in addition to classic branch protection, avoiding false positives for repositories that enforce protection through rulesets [(#11723)](https://github.com/prowler-cloud/prowler/pull/11723) - Okta, Alibaba Cloud and OpenStack scan-config sections are now validated against a registered schema instead of being silently accepted, so their configurable thresholds (session/idle timeouts, retention days, image-sharing and secret-scanning settings) log a warning and fall back to the built-in default whenever a value is out of range [(#11725)](https://github.com/prowler-cloud/prowler/pull/11725) +### 🔄 Changed + +- AWS scans for EBS snapshots, Backup recovery points, CloudWatch log groups, Lambda functions, ECS task definitions, and CodeArtifact packages now support configurable resource analysis limits via `aws.max_scanned_resources_per_service`; limits are disabled by default and only positive values cap analyzed resources [(#11228)](https://github.com/prowler-cloud/prowler/pull/11228) + --- ## [5.31.1] (Prowler v5.31.1) diff --git a/prowler/config/config.yaml b/prowler/config/config.yaml index dae81a92d3..5a796d698c 100644 --- a/prowler/config/config.yaml +++ b/prowler/config/config.yaml @@ -3,6 +3,32 @@ aws: # AWS Global Configuration # aws.mute_non_default_regions --> Set to True to muted failed findings in non-default regions for AccessAnalyzer, GuardDuty, SecurityHub, DRS and Config mute_non_default_regions: False + + # AWS Resource Scan Limit Configuration + # Limits the number of resources scanned per service for services that can + # accumulate huge numbers of resources (EBS snapshots, backup recovery + # points, CloudWatch log groups, Lambda functions, ECS task definitions, + # CodeArtifact packages). Limits apply to resources analyzed, not findings: + # a selected resource can produce zero, one, or many findings. Where the AWS + # API supports server-side ordering the latest resources are scanned first; + # otherwise it is best-effort API order. + # Disabled by default: scan every resource unless a positive limit is configured. + # Set to 0 (or a negative value) to disable the limit (scan every resource). + # aws.max_scanned_resources_per_service --> global default for all services below + max_scanned_resources_per_service: 0 + # Per-service overrides. Leave as null to fall back to the global default. + # aws.max_ebs_snapshots --> ec2_ebs_* checks (EBS snapshots) + max_ebs_snapshots: null + # aws.max_backup_recovery_points --> backup_recovery_point_* checks + max_backup_recovery_points: null + # aws.max_cloudwatch_log_groups --> cloudwatch_log_group_* checks + max_cloudwatch_log_groups: null + # aws.max_lambda_functions --> awslambda_function_* checks + max_lambda_functions: null + # aws.max_ecs_task_definitions --> ecs_task_definitions_* checks + max_ecs_task_definitions: null + # aws.max_codeartifact_packages --> codeartifact_packages_* checks + max_codeartifact_packages: null # aws.disallowed_regions --> List of AWS regions to exclude from the scan. # Also settable via the PROWLER_AWS_DISALLOWED_REGIONS environment variable or # the --excluded-region CLI flag. Precedence: CLI > env var > config file. diff --git a/prowler/config/schema/aws.py b/prowler/config/schema/aws.py index d15fc276c5..a0dbe91932 100644 --- a/prowler/config/schema/aws.py +++ b/prowler/config/schema/aws.py @@ -14,7 +14,7 @@ thresholds) and avoids ints that obviously break downstream maths from typing import Annotated, Literal, Optional -from pydantic import AfterValidator, Field +from pydantic import AfterValidator, BeforeValidator, Field from prowler.config.schema.base import ProviderConfigBase from prowler.config.schema.validators import ( @@ -101,10 +101,65 @@ def _validate_account_ids(v: Optional[list[str]]) -> Optional[list[str]]: return v +def _reject_bool_resource_limit(v): + if isinstance(v, bool): + raise ValueError("resource scan limits must be integers, not booleans") + return v + + +ResourceScanLimit = Annotated[ + Optional[int], BeforeValidator(_reject_bool_resource_limit) +] + + # ---- Main schema ------------------------------------------------------------ class AWSProviderConfig(ProviderConfigBase): + # --- Resource scan limits --------------------------------------------- + max_scanned_resources_per_service: ResourceScanLimit = Field( + default=None, + ge=-1, + le=1_000_000, + description="Global resource scan limit for high-volume AWS services. Use 0 or -1 to disable.", + ) + max_ebs_snapshots: ResourceScanLimit = Field( + default=None, + ge=-1, + le=1_000_000, + description="Resource scan limit for EBS snapshots. Use 0 or -1 to disable.", + ) + max_backup_recovery_points: ResourceScanLimit = Field( + default=None, + ge=-1, + le=1_000_000, + description="Resource scan limit for AWS Backup recovery points. Use 0 or -1 to disable.", + ) + max_cloudwatch_log_groups: ResourceScanLimit = Field( + default=None, + ge=-1, + le=1_000_000, + description="Resource scan limit for CloudWatch log groups. Use 0 or -1 to disable.", + ) + max_lambda_functions: ResourceScanLimit = Field( + default=None, + ge=-1, + le=1_000_000, + description="Resource scan limit for Lambda functions. Use 0 or -1 to disable.", + ) + max_ecs_task_definitions: ResourceScanLimit = Field( + default=None, + ge=-1, + le=1_000_000, + description="Resource scan limit for ECS task definitions. Use 0 or -1 to disable.", + ) + max_codeartifact_packages: ResourceScanLimit = Field( + default=None, + ge=-1, + le=1_000_000, + description="Resource scan limit for CodeArtifact packages. Use 0 or -1 to disable.", + ) + # --- IAM --------------------------------------------------------------- mute_non_default_regions: Optional[bool] = None disallowed_regions: Optional[list[str]] = None diff --git a/prowler/lib/resource_limit.py b/prowler/lib/resource_limit.py new file mode 100644 index 0000000000..da144e9dd0 --- /dev/null +++ b/prowler/lib/resource_limit.py @@ -0,0 +1,88 @@ +"""Scoped resource scan limits for high-volume resources. + +Some services accumulate huge numbers of resources (EBS snapshots, backup +recovery points, log groups, Lambda functions, ECS task definitions, +CodeArtifact packages). Scanning all of them causes API throttling, slow +scans, cost and noisy findings. + +``get_resource_scan_limit`` resolves the configured number of resources to +analyze for a supported resource path. A limited resource can produce zero, +one, or many findings; findings are not capped or re-ordered here. + +Tradeoff: for newest-based resources, services may need to list lightweight or +base metadata broadly to select the truly newest resources, then apply limits +only to expensive hydration or analysis. The helper must not send +user-configured limits as unsafe paginator ``PageSize`` values because AWS +services validate page sizes differently. +""" + +from collections.abc import Callable, Iterable, Iterator, Mapping +from itertools import islice +from typing import Any, Optional, Protocol, TypeVar + +GLOBAL_LIMIT_KEY = "max_scanned_resources_per_service" +T = TypeVar("T") + + +class PaginatorProtocol(Protocol): + """Minimal boto3-compatible paginator interface used by this module.""" + + def paginate(self, **operation_parameters: Any) -> Iterable[Mapping[str, Any]]: + """Return paginator pages for the provided operation parameters.""" + + +def get_resource_scan_limit(audit_config: dict, service_key: str) -> Optional[int]: + """Resolve the resource scan limit for a service. + + Precedence: per-service key (``service_key``) > global + ``max_scanned_resources_per_service`` > unlimited. + + A non-positive resolved value means **unlimited** (``None``), preserving + the legacy behavior as an explicit opt-out. + + Args: + audit_config: The provider ``audit_config`` dictionary. + service_key: The per-service config key, e.g. ``max_lambda_functions``. + + Returns: + The limit as a positive ``int``, or ``None`` for unlimited. + """ + value = audit_config.get(service_key) + if value is None: + value = audit_config.get(GLOBAL_LIMIT_KEY) + if value is None or value <= 0: + return None + return int(value) + + +def limit_resources(resources: Iterable[T], limit: Optional[int]) -> Iterator[T]: + """Yield up to ``limit`` resources without changing resource order.""" + if not limit or limit <= 0: + yield from resources + return + yield from islice(resources, limit) + + +def iter_limited_paginator_items( + paginator: PaginatorProtocol, + result_key: str, + limit: Optional[int], + item_filter: Optional[Callable[[T], bool]] = None, + **operation_parameters: Any, +) -> Iterator[T]: + """Yield paginator result items, stopping after ``limit`` selected items. + + The configured resource-analysis limit is intentionally not sent as + ``PageSize`` because AWS services validate page sizes differently. The + paginator receives only the operation parameters needed by the AWS API, + while this iterator applies the analysis limit defensively client-side. + """ + selected = 0 + for page in paginator.paginate(**operation_parameters): + for item in page.get(result_key, []): + if item_filter and not item_filter(item): + continue + yield item + selected += 1 + if limit and selected >= limit: + return diff --git a/prowler/providers/aws/services/awslambda/awslambda_service.py b/prowler/providers/aws/services/awslambda/awslambda_service.py index 433a5ab588..ac90e9fb11 100644 --- a/prowler/providers/aws/services/awslambda/awslambda_service.py +++ b/prowler/providers/aws/services/awslambda/awslambda_service.py @@ -10,6 +10,10 @@ from botocore.client import ClientError from pydantic.v1 import BaseModel from prowler.lib.logger import logger +from prowler.lib.resource_limit import ( + get_resource_scan_limit, + limit_resources, +) from prowler.lib.scan_filters.scan_filters import is_resource_filtered from prowler.providers.aws.lib.service.service import AWSService @@ -18,8 +22,16 @@ class Lambda(AWSService): def __init__(self, provider): # Call AWSService's __init__ super().__init__(__class__.__name__, provider) + # Functions are listed first, then trimmed to the subset selected for + # analysis before expensive per-function detail is hydrated. self.functions = {} + self.security_groups_in_use = set() + self.regions_with_functions = set() + self.function_limit = get_resource_scan_limit( + self.audit_config, "max_lambda_functions" + ) self.__threading_call__(self._list_functions) + self._select_functions_for_analysis() self._list_tags_for_resource() self.__threading_call__(self._get_policy) self.__threading_call__(self._get_function_url_config) @@ -30,24 +42,29 @@ class Lambda(AWSService): try: list_functions_paginator = regional_client.get_paginator("list_functions") for page in list_functions_paginator.paginate(): - for function in page["Functions"]: - if not self.audit_resources or ( - is_resource_filtered( - function["FunctionArn"], self.audit_resources - ) + for function in page.get("Functions", []): + if not self.audit_resources or is_resource_filtered( + function["FunctionArn"], self.audit_resources ): lambda_name = function["FunctionName"] lambda_arn = function["FunctionArn"] vpc_config = function.get("VpcConfig", {}) + security_groups = vpc_config.get("SecurityGroupIds", []) + self.security_groups_in_use.update(security_groups) + self.regions_with_functions.add(regional_client.region) # We must use the Lambda ARN as the dict key since we could have Lambdas in different regions with the same name self.functions[lambda_arn] = Function( name=lambda_name, arn=lambda_arn, - security_groups=vpc_config.get("SecurityGroupIds", []), + security_groups=security_groups, vpc_id=vpc_config.get("VpcId"), subnet_ids=set(vpc_config.get("SubnetIds", [])), region=regional_client.region, ) + if "LastModified" in function: + self.functions[lambda_arn].last_modified = function[ + "LastModified" + ] if "Runtime" in function: self.functions[lambda_arn].runtime = function["Runtime"] if "Environment" in function: @@ -76,26 +93,61 @@ class Lambda(AWSService): f" {error}" ) + def _select_functions_for_analysis(self): + self.functions = { + function.arn: function + for function in limit_resources( + sorted( + self.functions.values(), + key=lambda f: f.last_modified or "", + reverse=True, + ), + self.function_limit, + ) + } + def _list_event_source_mappings(self, regional_client): logger.info("Lambda - Listing Event Source Mappings...") try: paginator = regional_client.get_paginator("list_event_source_mappings") - for page in paginator.paginate(): - for mapping in page.get("EventSourceMappings", []): - function_arn = mapping.get("FunctionArn", "") - # Normalise to unqualified ARN (strip :qualifier suffix if present) - base_arn = ":".join(function_arn.split(":")[:7]) - if base_arn not in self.functions: - continue - self.functions[base_arn].event_source_mappings.append( - EventSourceMapping( - uuid=mapping["UUID"], - event_source_arn=mapping.get("EventSourceArn", ""), - state=mapping.get("State", ""), - batch_size=mapping.get("BatchSize"), - starting_position=mapping.get("StartingPosition"), + if not self.function_limit: + for page in paginator.paginate(): + self._add_event_source_mappings(page.get("EventSourceMappings", [])) + return + + for function in self.functions.values(): + if function.region != regional_client.region: + continue + try: + for page in paginator.paginate(FunctionName=function.name): + self._add_event_source_mappings( + page.get("EventSourceMappings", []) ) - ) + except ClientError as error: + if ( + error.response.get("Error", {}).get("Code") + == "InvalidParameterValueException" + ): + logger.warning( + f"{function.region} --" + f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" + f" {error}" + ) + else: + logger.error( + f"{function.region} --" + f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" + f" {error}" + ) + raise + except ClientError as error: + if self.function_limit: + raise + logger.error( + f"{regional_client.region} --" + f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" + f" {error}" + ) except Exception as error: logger.error( f"{regional_client.region} --" @@ -103,6 +155,23 @@ class Lambda(AWSService): f" {error}" ) + def _add_event_source_mappings(self, event_source_mappings): + for mapping in event_source_mappings: + function_arn = mapping.get("FunctionArn", "") + # Normalise to unqualified ARN (strip :qualifier suffix if present) + base_arn = ":".join(function_arn.split(":")[:7]) + if base_arn not in self.functions: + continue + self.functions[base_arn].event_source_mappings.append( + EventSourceMapping( + uuid=mapping["UUID"], + event_source_arn=mapping.get("EventSourceArn", ""), + state=mapping.get("State", ""), + batch_size=mapping.get("BatchSize"), + starting_position=mapping.get("StartingPosition"), + ) + ) + def _get_function_code(self): logger.info("Lambda - Getting Function Code...") # Use a thread pool handle the queueing and execution of the _fetch_function_code tasks, up to max_workers tasks concurrently. @@ -158,7 +227,6 @@ class Lambda(AWSService): except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": self.functions[function.arn].policy = {} - except Exception as error: logger.error( f"{regional_client.region} --" @@ -187,7 +255,6 @@ class Lambda(AWSService): except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": self.functions[function.arn].url_config = None - except Exception as error: logger.error( f"{regional_client.region} --" @@ -206,10 +273,9 @@ class Lambda(AWSService): except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": function.tags = [] - except Exception as error: logger.error( - f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + f"{function.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) @@ -259,6 +325,7 @@ class Function(BaseModel): name: str arn: str security_groups: list + last_modified: Optional[str] = None runtime: Optional[str] = None environment: Optional[dict] = None region: str diff --git a/prowler/providers/aws/services/backup/backup_service.py b/prowler/providers/aws/services/backup/backup_service.py index 4320d42c0b..ddacc9e6d2 100644 --- a/prowler/providers/aws/services/backup/backup_service.py +++ b/prowler/providers/aws/services/backup/backup_service.py @@ -5,6 +5,10 @@ from botocore.client import ClientError from pydantic.v1 import BaseModel from prowler.lib.logger import logger +from prowler.lib.resource_limit import ( + get_resource_scan_limit, + limit_resources, +) from prowler.lib.scan_filters.scan_filters import is_resource_filtered from prowler.providers.aws.lib.service.service import AWSService @@ -27,8 +31,14 @@ class Backup(AWSService): self.__threading_call__(self._list_backup_report_plans) self.protected_resources = [] self.__threading_call__(self._list_backup_selections) + # Recovery points are listed first, then only the selected subset is + # tagged and exposed for checks. self.recovery_points = [] - self.__threading_call__(self._list_recovery_points) + self.recovery_point_limit = get_resource_scan_limit( + self.audit_config, "max_backup_recovery_points" + ) + self.__threading_call__(self._list_recovery_points, self.backup_vaults or []) + self._select_recovery_points_for_analysis() self.__threading_call__(self._list_tags, self.recovery_points) def _list_backup_vaults(self, regional_client): @@ -183,40 +193,63 @@ class Backup(AWSService): f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - def _list_recovery_points(self, regional_client): + def _list_recovery_points(self, backup_vault=None): logger.info("Backup - Listing Recovery Points...") + if backup_vault is None: + for vault in self.backup_vaults or []: + self._list_recovery_points(vault) + return + try: - if self.backup_vaults: - for backup_vault in self.backup_vaults: - paginator = regional_client.get_paginator( - "list_recovery_points_by_backup_vault" - ) - for page in paginator.paginate(BackupVaultName=backup_vault.name): - for recovery_point in page.get("RecoveryPoints", []): - arn = recovery_point.get("RecoveryPointArn") - if arn: - self.recovery_points.append( - RecoveryPoint( - arn=arn, - id=arn.split(":")[-1], - backup_vault_name=backup_vault.name, - encrypted=recovery_point.get( - "IsEncrypted", False - ), - backup_vault_region=backup_vault.region, - region=regional_client.region, - tags=[], - ) - ) + regional_client = self.regional_clients[backup_vault.region] + paginator = regional_client.get_paginator( + "list_recovery_points_by_backup_vault" + ) + for page in paginator.paginate(BackupVaultName=backup_vault.name): + for recovery_point in page.get("RecoveryPoints", []): + arn = recovery_point.get("RecoveryPointArn") + if arn: + rp = RecoveryPoint( + arn=arn, + id=arn.split(":")[-1], + backup_vault_name=backup_vault.name, + encrypted=recovery_point.get("IsEncrypted", False), + creation_date=recovery_point.get("CreationDate"), + backup_vault_region=backup_vault.region, + region=backup_vault.region, + tags=[], + ) + self.recovery_points.append(rp) except ClientError as error: logger.error( - f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + f"{backup_vault.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) except Exception as error: logger.error( - f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + f"{backup_vault.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) + def _select_recovery_points_for_analysis(self): + self.recovery_points = list( + limit_resources( + sorted( + self.recovery_points, + key=lambda rp: ( + ( + -rp.creation_date.timestamp() + if isinstance(rp.creation_date, datetime) + else 0.0 + ), + rp.region or "", + rp.backup_vault_name or "", + rp.arn or "", + rp.id or "", + ), + ), + self.recovery_point_limit, + ) + ) + class BackupVault(BaseModel): arn: str @@ -256,4 +289,5 @@ class RecoveryPoint(BaseModel): backup_vault_name: str encrypted: bool backup_vault_region: str + creation_date: Optional[datetime] = None tags: Optional[list] = None diff --git a/prowler/providers/aws/services/bedrock/bedrock_model_invocation_logs_encryption_enabled/bedrock_model_invocation_logs_encryption_enabled.py b/prowler/providers/aws/services/bedrock/bedrock_model_invocation_logs_encryption_enabled/bedrock_model_invocation_logs_encryption_enabled.py index c4adddc344..230cfd2295 100644 --- a/prowler/providers/aws/services/bedrock/bedrock_model_invocation_logs_encryption_enabled/bedrock_model_invocation_logs_encryption_enabled.py +++ b/prowler/providers/aws/services/bedrock/bedrock_model_invocation_logs_encryption_enabled/bedrock_model_invocation_logs_encryption_enabled.py @@ -30,12 +30,13 @@ class bedrock_model_invocation_logs_encryption_enabled(Check): s3_encryption = False if logging.cloudwatch_log_group: log_group_arn = f"arn:{logs_client.audited_partition}:logs:{region}:{logs_client.audited_account}:log-group:{logging.cloudwatch_log_group}" + all_log_groups = getattr(logs_client, "all_log_groups", None) or {} if ( - log_group_arn in logs_client.log_groups - and not logs_client.log_groups[log_group_arn].kms_id + log_group_arn in all_log_groups + and not all_log_groups[log_group_arn].kms_id ) or ( - log_group_arn + ":*" in logs_client.log_groups - and not logs_client.log_groups[log_group_arn + ":*"].kms_id + log_group_arn + ":*" in all_log_groups + and not all_log_groups[log_group_arn + ":*"].kms_id ): cloudwatch_encryption = False if not s3_encryption and not cloudwatch_encryption: diff --git a/prowler/providers/aws/services/cloudwatch/cloudwatch_service.py b/prowler/providers/aws/services/cloudwatch/cloudwatch_service.py index 29ac103867..56b7ebe25c 100644 --- a/prowler/providers/aws/services/cloudwatch/cloudwatch_service.py +++ b/prowler/providers/aws/services/cloudwatch/cloudwatch_service.py @@ -6,6 +6,10 @@ from botocore.exceptions import ClientError from pydantic.v1 import BaseModel from prowler.lib.logger import logger +from prowler.lib.resource_limit import ( + get_resource_scan_limit, + limit_resources, +) from prowler.lib.scan_filters.scan_filters import is_resource_filtered from prowler.providers.aws.lib.service.service import AWSService @@ -83,8 +87,19 @@ class Logs(AWSService): # Call AWSService's __init__ super().__init__(__class__.__name__, provider) self.log_group_arn_template = f"arn:{self.audited_partition}:logs:{self.region}:{self.audited_account}:log-group" + # Log groups are listed first, then only the selected subset is enriched + # and exposed for primary log group checks. Keep a complete lightweight + # index for cross-service evidence lookups. + self.all_log_groups = {} self.log_groups = {} + self._log_groups_hydrated = set() + self.log_group_limit = get_resource_scan_limit( + self.audit_config, "max_cloudwatch_log_groups" + ) + # The threshold for number of events to return per log group. + self.events_per_log_group_threshold = 1000 self.__threading_call__(self._describe_log_groups) + self._select_log_groups_for_analysis() self.resource_policies = {} self.__threading_call__(self._describe_resource_policies) self.metric_filters = [] @@ -94,14 +109,27 @@ class Logs(AWSService): "cloudwatch_log_group_no_secrets_in_logs" in provider.audit_metadata.expected_checks ): - self.events_per_log_group_threshold = ( - 1000 # The threshold for number of events to return per log group. - ) - self.__threading_call__(self._get_log_events) + self.__threading_call__(self._get_log_events, self.log_groups.values()) self.__threading_call__( self._list_tags_for_resource, self.log_groups.values() ) + def _select_log_groups_for_analysis(self): + """Select the newest log groups for bounded analysis.""" + if not self.log_groups: + return + self.log_groups = { + log_group.arn: log_group + for log_group in limit_resources( + sorted( + self.log_groups.values(), + key=lambda lg: lg.creation_time or 0, + reverse=True, + ), + self.log_group_limit, + ) + } + def _describe_metric_filters(self, regional_client): logger.info("CloudWatch Logs - Describing metric filters...") try: @@ -118,11 +146,21 @@ class Logs(AWSService): self.metric_filters = [] log_group = None - for lg in self.log_groups.values(): - if lg.name == filter["logGroupName"]: + for lg in (self.all_log_groups or {}).values(): + if ( + lg.name == filter["logGroupName"] + and lg.region == regional_client.region + ): log_group = lg break + if ( + log_group + and log_group.arn in (self.log_groups or {}) + and log_group.arn not in self._log_groups_hydrated + ): + self._list_tags_for_resource(log_group) + self.metric_filters.append( MetricFilter( arn=arn, @@ -156,9 +194,9 @@ class Logs(AWSService): "describe_log_groups" ) for page in describe_log_groups_paginator.paginate(): - for log_group in page["logGroups"]: - if not self.audit_resources or ( - is_resource_filtered(log_group["arn"], self.audit_resources) + for log_group in page.get("logGroups", []): + if not self.audit_resources or is_resource_filtered( + log_group["arn"], self.audit_resources ): never_expire = False kms = log_group.get("kmsKeyId") @@ -168,20 +206,26 @@ class Logs(AWSService): retention_days = 9999 if self.log_groups is None: self.log_groups = {} - self.log_groups[log_group["arn"]] = LogGroup( + if self.all_log_groups is None: + self.all_log_groups = {} + log_group_object = LogGroup( arn=log_group["arn"], name=log_group["logGroupName"], retention_days=retention_days, never_expire=never_expire, kms_id=kms, + creation_time=log_group.get("creationTime"), region=regional_client.region, ) + self.all_log_groups[log_group_object.arn] = log_group_object + self.log_groups[log_group_object.arn] = log_group_object except ClientError as error: if error.response["Error"]["Code"] == "AccessDeniedException": logger.error( f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) if not self.log_groups: + self.all_log_groups = None self.log_groups = None else: logger.error( @@ -192,37 +236,29 @@ class Logs(AWSService): f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - def _get_log_events(self, regional_client): - regional_log_groups = [ - log_group - for log_group in self.log_groups.values() - if log_group.region == regional_client.region - ] - total_log_groups = len(regional_log_groups) + def _get_log_events(self, log_group): + """Retrieve recent log events for a selected log group. + + Args: + log_group: Log group selected for bounded analysis. + """ logger.info( - f"CloudWatch Logs - Retrieving log events for {total_log_groups} log groups in {regional_client.region}..." + f"CloudWatch Logs - Retrieving log events for log group {log_group.name}..." ) try: - for count, log_group in enumerate(regional_log_groups, start=1): - events = regional_client.filter_log_events( - logGroupName=log_group.name, - limit=self.events_per_log_group_threshold, - )["events"] - for event in events: - if event["logStreamName"] not in log_group.log_streams: - log_group.log_streams[event["logStreamName"]] = [] - log_group.log_streams[event["logStreamName"]].append(event) - if count % 10 == 0: - logger.info( - f"CloudWatch Logs - Retrieved log events for {count}/{total_log_groups} log groups in {regional_client.region}..." - ) + regional_client = self.regional_clients[log_group.region] + events = regional_client.filter_log_events( + logGroupName=log_group.name, + limit=self.events_per_log_group_threshold, + )["events"] + for event in events: + if event["logStreamName"] not in log_group.log_streams: + log_group.log_streams[event["logStreamName"]] = [] + log_group.log_streams[event["logStreamName"]].append(event) except Exception as error: logger.error( - f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + f"{log_group.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - logger.info( - f"CloudWatch Logs - Finished retrieving log events in {regional_client.region}..." - ) def _describe_resource_policies(self, regional_client): logger.info("CloudWatch Logs - Describing resource policies...") @@ -257,6 +293,13 @@ class Logs(AWSService): ) def _list_tags_for_resource(self, log_group): + """Hydrate tags for a selected log group once. + + Args: + log_group: Log group selected for tag hydration. + """ + if log_group.arn in self._log_groups_hydrated: + return logger.info(f"CloudWatch Logs - List Tags for Log Group {log_group.name}...") try: regional_client = self.regional_clients[log_group.region] @@ -264,6 +307,7 @@ class Logs(AWSService): resourceArn=log_group.arn )["tags"] log_group.tags = [response] + self._log_groups_hydrated.add(log_group.arn) except ClientError as error: if error.response["Error"]["Code"] == "ResourceNotFoundException": logger.warning( @@ -292,6 +336,7 @@ class LogGroup(BaseModel): retention_days: int never_expire: bool kms_id: Optional[str] + creation_time: Optional[int] = None region: str log_streams: dict[str, list[str]] = ( {} diff --git a/prowler/providers/aws/services/codeartifact/codeartifact_service.py b/prowler/providers/aws/services/codeartifact/codeartifact_service.py index 1465092063..9a13e8fcf7 100644 --- a/prowler/providers/aws/services/codeartifact/codeartifact_service.py +++ b/prowler/providers/aws/services/codeartifact/codeartifact_service.py @@ -1,10 +1,14 @@ from enum import Enum -from typing import Optional +from typing import Iterator, Optional, Tuple from botocore.exceptions import ClientError from pydantic.v1 import BaseModel from prowler.lib.logger import logger +from prowler.lib.resource_limit import ( + get_resource_scan_limit, + iter_limited_paginator_items, +) from prowler.lib.scan_filters.scan_filters import is_resource_filtered from prowler.providers.aws.lib.service.service import AWSService @@ -15,9 +19,18 @@ class CodeArtifact(AWSService): super().__init__(__class__.__name__, provider) # repositories is a dictionary containing all the codeartifact service information self.repositories = {} + # repository ARNs whose selected packages have been listed and memoized + # into repository.packages. + self._packages_listed = set() + self.package_limit = get_resource_scan_limit( + self.audit_config, "max_codeartifact_packages" + ) self.__threading_call__(self._list_repositories) - self.__threading_call__(self._list_packages) - self._list_tags_for_resource() + for _ in self._load_packages_for_analysis(): + pass + self.__threading_call__( + self._list_tags_for_resource, self.repositories.values() + ) def _list_repositories(self, regional_client): logger.info("CodeArtifact - Listing Repositories...") @@ -51,134 +64,146 @@ class CodeArtifact(AWSService): f" {error}" ) - def _list_packages(self, regional_client): - logger.info("CodeArtifact - Listing Packages and retrieving information...") - for repository in self.repositories: - try: - if self.repositories[repository].region == regional_client.region: - list_packages_paginator = regional_client.get_paginator( - "list_packages" + def _iter_repository_packages( + self, repository, limit: Optional[int] = None + ) -> Iterator["Package"]: + """Yield packages for a single repository, hydrating each one lazily. + + Each package requires an extra ``list_package_versions`` call to + resolve its latest version, so producing them lazily lets the resource + limit stop before extra package version calls. + """ + regional_client = self.regional_clients[repository.region] + try: + list_packages_paginator = regional_client.get_paginator("list_packages") + list_packages_parameters = { + "domain": repository.domain_name, + "domainOwner": repository.domain_owner, + "repository": repository.name, + } + for package in iter_limited_paginator_items( + list_packages_paginator, + "packages", + limit, + **list_packages_parameters, + ): + # Package information + package_format = package["format"] + package_namespace = package.get("namespace") + package_name = package["package"] + package_origin_configuration_restrictions_publish = package[ + "originConfiguration" + ]["restrictions"]["publish"] + package_origin_configuration_restrictions_upstream = package[ + "originConfiguration" + ]["restrictions"]["upstream"] + # Get Latest Package Version + list_package_versions_parameters = { + "domain": repository.domain_name, + "domainOwner": repository.domain_owner, + "repository": repository.name, + "format": package_format, + "package": package_name, + "sortBy": "PUBLISHED_TIME", + "maxResults": 1, + } + if package_namespace: + list_package_versions_parameters["namespace"] = package_namespace + latest_version_information = regional_client.list_package_versions( + **list_package_versions_parameters + ) + latest_version = "" + latest_origin_type = "UNKNOWN" + latest_status = "Published" + if latest_version_information.get("versions"): + latest_version = latest_version_information["versions"][0].get( + "version" ) - list_packages_parameters = { - "domain": self.repositories[repository].domain_name, - "domainOwner": self.repositories[repository].domain_owner, - "repository": self.repositories[repository].name, - } - packages = [] - for page in list_packages_paginator.paginate( - **list_packages_parameters - ): - for package in page["packages"]: - # Package information - package_format = package["format"] - package_namespace = package.get("namespace") - package_name = package["package"] - package_origin_configuration_restrictions_publish = package[ - "originConfiguration" - ]["restrictions"]["publish"] - package_origin_configuration_restrictions_upstream = ( - package["originConfiguration"]["restrictions"][ - "upstream" - ] - ) - # Get Latest Package Version - if package_namespace: - latest_version_information = ( - regional_client.list_package_versions( - domain=self.repositories[ - repository - ].domain_name, - domainOwner=self.repositories[ - repository - ].domain_owner, - repository=self.repositories[repository].name, - format=package_format, - namespace=package_namespace, - package=package_name, - sortBy="PUBLISHED_TIME", - maxResults=1, - ) - ) - else: - latest_version_information = ( - regional_client.list_package_versions( - domain=self.repositories[ - repository - ].domain_name, - domainOwner=self.repositories[ - repository - ].domain_owner, - repository=self.repositories[repository].name, - format=package_format, - package=package_name, - sortBy="PUBLISHED_TIME", - maxResults=1, - ) - ) - latest_version = "" - latest_origin_type = "UNKNOWN" - latest_status = "Published" - if latest_version_information.get("versions"): - latest_version = latest_version_information["versions"][ - 0 - ].get("version") - latest_origin_type = ( - latest_version_information["versions"][0] - .get("origin", {}) - .get("originType", "UNKNOWN") - ) - latest_status = latest_version_information["versions"][ - 0 - ].get("status", "Published") - - packages.append( - Package( - name=package_name, - namespace=package_namespace, - format=package_format, - origin_configuration=OriginConfiguration( - restrictions=Restrictions( - publish=package_origin_configuration_restrictions_publish, - upstream=package_origin_configuration_restrictions_upstream, - ) - ), - latest_version=LatestPackageVersion( - version=latest_version, - status=latest_status, - origin=OriginInformation( - origin_type=latest_origin_type - ), - ), - ) - ) - # Save all the packages information - self.repositories[repository].packages = packages - - except ClientError as error: - if error.response["Error"]["Code"] == "ResourceNotFoundException": - logger.warning( - f"{regional_client.region} --" - f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" - f" {error}" + latest_origin_type = ( + latest_version_information["versions"][0] + .get("origin", {}) + .get("originType", "UNKNOWN") + ) + latest_status = latest_version_information["versions"][0].get( + "status", "Published" ) - continue - except Exception as error: - logger.error( - f"{regional_client.region} --" + yield Package( + name=package_name, + namespace=package_namespace, + format=package_format, + origin_configuration=OriginConfiguration( + restrictions=Restrictions( + publish=package_origin_configuration_restrictions_publish, + upstream=package_origin_configuration_restrictions_upstream, + ) + ), + latest_version=LatestPackageVersion( + version=latest_version, + status=latest_status, + origin=OriginInformation(origin_type=latest_origin_type), + ), + ) + + except ClientError as error: + if error.response["Error"]["Code"] == "ResourceNotFoundException": + logger.warning( + f"{repository.region} --" f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" f" {error}" ) + else: + logger.error( + f"{repository.region} --" + f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" + f" {error}" + ) + except Exception as error: + logger.error( + f"{repository.region} --" + f" {error.__class__.__name__}[{error.__traceback__.tb_lineno}]:" + f" {error}" + ) - def _list_tags_for_resource(self): + def _load_packages_for_analysis(self) -> Iterator[Tuple["Repository", "Package"]]: + """Yield the ``(repository, package)`` pairs selected for analysis. + + Package listing stays in the service layer so checks receive only the + selected packages and remain unaware of resource-analysis limits. + """ + yielded = 0 + for repository in list(self.repositories.values()): + if repository.arn in self._packages_listed: + for package in repository.packages: + yield repository, package + yielded += 1 + if self.package_limit and yielded >= self.package_limit: + return + continue + collected = [] + remaining_limit = None + if self.package_limit: + remaining_limit = self.package_limit - yielded + if remaining_limit <= 0: + return + for package in self._iter_repository_packages(repository, remaining_limit): + collected.append(package) + repository.packages = collected + yield repository, package + yielded += 1 + if self.package_limit and yielded >= self.package_limit: + self._packages_listed.add(repository.arn) + return + self._packages_listed.add(repository.arn) + + def _list_tags_for_resource(self, repository): logger.info("CodeArtifact - List Tags...") try: - for repository in self.repositories.values(): - regional_client = self.regional_clients[repository.region] - response = regional_client.list_tags_for_resource( - resourceArn=repository.arn - )["tags"] - repository.tags = response + regional_client = self.regional_clients[repository.region] + response = regional_client.list_tags_for_resource( + resourceArn=repository.arn + )["tags"] + repository.tags = response except Exception as error: logger.error( f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" diff --git a/prowler/providers/aws/services/ec2/ec2_securitygroup_not_used/ec2_securitygroup_not_used.py b/prowler/providers/aws/services/ec2/ec2_securitygroup_not_used/ec2_securitygroup_not_used.py index aa621abdfb..c45693c498 100644 --- a/prowler/providers/aws/services/ec2/ec2_securitygroup_not_used/ec2_securitygroup_not_used.py +++ b/prowler/providers/aws/services/ec2/ec2_securitygroup_not_used/ec2_securitygroup_not_used.py @@ -15,11 +15,10 @@ class ec2_securitygroup_not_used(Check): report.resource_details = security_group.name report.status = "PASS" report.status_extended = f"Security group {security_group.name} ({security_group.id}) it is being used." - sg_in_lambda = False + sg_in_lambda = ( + security_group.id in awslambda_client.security_groups_in_use + ) sg_associated = False - for function in awslambda_client.functions.values(): - if security_group.id in function.security_groups: - sg_in_lambda = True for sg in ec2_client.security_groups.values(): if security_group.id in sg.associated_sgs: sg_associated = True diff --git a/prowler/providers/aws/services/ec2/ec2_service.py b/prowler/providers/aws/services/ec2/ec2_service.py index ccdb9e58fe..45a45e10d5 100644 --- a/prowler/providers/aws/services/ec2/ec2_service.py +++ b/prowler/providers/aws/services/ec2/ec2_service.py @@ -6,6 +6,10 @@ from botocore.client import ClientError from pydantic.v1 import BaseModel from prowler.lib.logger import logger +from prowler.lib.resource_limit import ( + get_resource_scan_limit, + limit_resources, +) from prowler.lib.scan_filters.scan_filters import is_resource_filtered from prowler.providers.aws.lib.service.service import AWSService @@ -26,8 +30,12 @@ class EC2(AWSService): self.snapshots = [] self.volumes_with_snapshots = {} self.regions_with_snapshots = {} + # Snapshots are listed first, then limited after per-region snapshot + # presence is derived and before public status is hydrated. + self.snapshot_limit = get_resource_scan_limit( + self.audit_config, "max_ebs_snapshots" + ) self.__threading_call__(self._describe_snapshots) - self.__threading_call__(self._determine_public_snapshots, self.snapshots) self.network_interfaces = {} self.__threading_call__(self._describe_network_interfaces) self.images = [] @@ -36,6 +44,8 @@ class EC2(AWSService): self.__threading_call__(self._describe_volumes) self.attributes_for_regions = {} self.__threading_call__(self._get_resources_for_regions) + self._select_snapshots_for_analysis() + self.__threading_call__(self._determine_public_snapshots, self.snapshots) self.ebs_encryption_by_default = [] self.__threading_call__(self._get_ebs_encryption_settings) self.elastic_ips = [] @@ -207,6 +217,7 @@ class EC2(AWSService): arn=arn, region=regional_client.region, encrypted=snapshot.get("Encrypted", False), + start_time=snapshot.get("StartTime"), tags=snapshot.get("Tags"), volume=snapshot["VolumeId"], ) @@ -243,6 +254,18 @@ class EC2(AWSService): f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) + def _select_snapshots_for_analysis(self): + self.snapshots = list( + limit_resources( + sorted( + self.snapshots, + key=lambda s: (s.start_time.timestamp() if s.start_time else 0.0), + reverse=True, + ), + self.snapshot_limit, + ) + ) + def _describe_network_interfaces(self, regional_client): try: # Get Network Interfaces with Public IPs @@ -686,6 +709,7 @@ class Snapshot(BaseModel): region: str encrypted: bool public: bool = False + start_time: Optional[datetime] = None tags: Optional[list] = [] volume: Optional[str] diff --git a/prowler/providers/aws/services/ecs/ecs_service.py b/prowler/providers/aws/services/ecs/ecs_service.py index 560125bf58..e14197162f 100644 --- a/prowler/providers/aws/services/ecs/ecs_service.py +++ b/prowler/providers/aws/services/ecs/ecs_service.py @@ -1,9 +1,16 @@ +from datetime import datetime +from itertools import zip_longest from re import sub from typing import Optional from pydantic.v1 import BaseModel from prowler.lib.logger import logger +from prowler.lib.resource_limit import ( + get_resource_scan_limit, + iter_limited_paginator_items, + limit_resources, +) from prowler.lib.scan_filters.scan_filters import is_resource_filtered from prowler.providers.aws.lib.service.service import AWSService @@ -12,40 +19,95 @@ class ECS(AWSService): def __init__(self, provider): # Call AWSService's __init__ super().__init__(__class__.__name__, provider) + # Task definition ARNs are listed first, then only the selected subset + # is described and exposed for checks. self.task_definitions = {} + self._task_definition_arns = None + self._task_definition_arns_by_region = {} + self.task_definition_limit = get_resource_scan_limit( + self.audit_config, "max_ecs_task_definitions" + ) self.services = {} self.clusters = {} self.task_sets = {} - self.__threading_call__(self._list_task_definitions) - self.__threading_call__( - self._describe_task_definition, self.task_definitions.values() - ) + for _ in self._load_task_definitions_for_analysis(): + pass self.__threading_call__(self._list_clusters) self.__threading_call__(self._describe_clusters, self.clusters.values()) self.__threading_call__(self._describe_services, self.clusters.values()) - def _list_task_definitions(self, regional_client): + def _list_task_definition_arns(self) -> list: + """List task definition ARNs newest-first, memoized. + + AWS returns ``list_task_definitions(sort=DESC)`` results per region. + Prowler limits the task definitions it describes and exposes to checks. + """ + if self._task_definition_arns is not None: + return self._task_definition_arns logger.info("ECS - Listing Task Definitions...") + self.__threading_call__(self._list_task_definition_arns_by_region) + arns_by_region = [] + for region in self.regional_clients: + arns_by_region.append(self._task_definition_arns_by_region.get(region, [])) + arns = [] + for task_definition_batch in zip_longest(*arns_by_region): + for task_definition in task_definition_batch: + if task_definition: + arns.append(task_definition) + self._task_definition_arns = arns + return arns + + def _list_task_definition_arns_by_region(self, regional_client): try: list_ecs_paginator = regional_client.get_paginator("list_task_definitions") - for page in list_ecs_paginator.paginate(): - for task_definition in page["taskDefinitionArns"]: - if not self.audit_resources or ( - is_resource_filtered(task_definition, self.audit_resources) - ): - self.task_definitions[task_definition] = TaskDefinition( - # we want the family name without the revision - name=sub(":.*", "", task_definition.split("/")[-1]), - arn=task_definition, - revision=task_definition.split(":")[-1], - region=regional_client.region, - environment_variables=[], - ) + regional_arns = [] + for task_definition in iter_limited_paginator_items( + list_ecs_paginator, + "taskDefinitionArns", + None, + item_filter=lambda task_definition: not self.audit_resources + or is_resource_filtered(task_definition, self.audit_resources), + sort="DESC", + ): + regional_arns.append((task_definition, regional_client.region)) + self._task_definition_arns_by_region[regional_client.region] = regional_arns except Exception as error: logger.error( f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) + def _load_task_definitions_for_analysis(self): + """Yield task definitions lazily, describing each one on demand. + + Resources already fetched are memoized in ``self.task_definitions`` and + reused across checks (checks run sequentially, so no locking is needed). + The configured resource limit bounds ``describe_task_definition`` calls. + """ + task_definitions = [] + for arn, region in limit_resources( + self._list_task_definition_arns(), self.task_definition_limit + ): + task_definition = self.task_definitions.get(arn) + if task_definition is None: + task_definition = TaskDefinition( + # we want the family name without the revision + name=sub(":.*", "", arn.split("/")[-1]), + arn=arn, + revision=arn.split(":")[-1], + region=region, + environment_variables=[], + ) + self.task_definitions[arn] = task_definition + task_definitions.append(task_definition) + + self.__threading_call__(self._describe_task_definition, task_definitions) + + for arn, _ in limit_resources( + self._list_task_definition_arns(), self.task_definition_limit + ): + task_definition = self.task_definitions[arn] + yield task_definition + def _describe_task_definition(self, task_definition): logger.info("ECS - Describing Task Definition...") try: @@ -84,6 +146,9 @@ class ECS(AWSService): ) ) task_definition.pid_mode = response["taskDefinition"].get("pidMode", "") + task_definition.registered_at = response["taskDefinition"].get( + "registeredAt" + ) task_definition.tags = response.get("tags") task_definition.network_mode = response["taskDefinition"].get( "networkMode", "bridge" @@ -208,6 +273,7 @@ class TaskDefinition(BaseModel): region: str container_definitions: list[ContainerDefinition] = [] pid_mode: Optional[str] + registered_at: Optional[datetime] = None tags: Optional[list] = [] network_mode: Optional[str] diff --git a/prowler/providers/aws/services/inspector2/inspector2_is_enabled/inspector2_is_enabled.py b/prowler/providers/aws/services/inspector2/inspector2_is_enabled/inspector2_is_enabled.py index a9f5efbedd..fd414badfa 100644 --- a/prowler/providers/aws/services/inspector2/inspector2_is_enabled/inspector2_is_enabled.py +++ b/prowler/providers/aws/services/inspector2/inspector2_is_enabled/inspector2_is_enabled.py @@ -15,11 +15,10 @@ class inspector2_is_enabled(Check): if inspector.status == "ENABLED": report.status = "PASS" report.status_extended = "Inspector2 is enabled for EC2 instances, ECR container images, Lambda functions and code." - funtions_in_region = False + functions_in_region = ( + inspector.region in awslambda_client.regions_with_functions + ) ec2_in_region = False - for function in awslambda_client.functions.values(): - if function.region == inspector.region: - funtions_in_region = True for instance in ec2_client.instances: if instance == inspector.region: ec2_in_region = True @@ -36,12 +35,12 @@ class inspector2_is_enabled(Check): failed_services.append("ECR") if inspector.lambda_status != "ENABLED" and ( inspector2_client.provider.scan_unused_services - or funtions_in_region + or functions_in_region ): failed_services.append("Lambda") if inspector.lambda_code_status != "ENABLED" and ( inspector2_client.provider.scan_unused_services - or funtions_in_region + or functions_in_region ): failed_services.append("Lambda Code") diff --git a/tests/config/schema/aws_schema_test.py b/tests/config/schema/aws_schema_test.py index 8731e08ba9..787cb4c03e 100644 --- a/tests/config/schema/aws_schema_test.py +++ b/tests/config/schema/aws_schema_test.py @@ -3,6 +3,7 @@ constraint surface (CIDRs, account IDs, port ranges, enums, thresholds).""" import pytest +from prowler.config.scan_config_schema import SCAN_CONFIG_SCHEMA from prowler.config.schema.aws import AWSProviderConfig from prowler.config.schema.validator import validate_provider_config @@ -11,6 +12,52 @@ def _validate(raw): return validate_provider_config("aws", raw, AWSProviderConfig) +RESOURCE_LIMIT_KEYS = [ + "max_scanned_resources_per_service", + "max_ebs_snapshots", + "max_backup_recovery_points", + "max_cloudwatch_log_groups", + "max_lambda_functions", + "max_ecs_task_definitions", + "max_codeartifact_packages", +] + + +class Test_AWS_Resource_Limits: + @pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS) + def test_positive_values_round_trip(self, key): + assert _validate({key: 100}) == {key: 100} + + @pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS) + def test_null_values_round_trip(self, key): + assert _validate({key: None}) == {key: None} + + @pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS) + def test_zero_disable_sentinel_round_trips(self, key): + assert _validate({key: 0}) == {key: 0} + + @pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS) + def test_numeric_strings_are_coerced_to_int(self, key): + assert _validate({key: "100"}) == {key: 100} + + @pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS) + def test_disable_sentinel_minus_one_round_trips(self, key): + assert _validate({key: -1}) == {key: -1} + + @pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS) + @pytest.mark.parametrize("value", [True, False]) + def test_booleans_are_dropped_not_coerced_to_int(self, key, value): + assert _validate({key: value}) == {} + + @pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS) + def test_invalid_strings_are_dropped(self, key): + assert _validate({key: "not-an-int"}) == {} + + @pytest.mark.parametrize("key", RESOURCE_LIMIT_KEYS) + def test_keys_are_exposed_in_scan_config_schema(self, key): + assert key in SCAN_CONFIG_SCHEMA["properties"]["aws"]["properties"] + + class Test_AWS_Threat_Detection_Thresholds: """All threat detection thresholds are documented as fractions in 0..1. The biggest risk of mistyping them is silently disabling the check.""" diff --git a/tests/lib/resource_limit_test.py b/tests/lib/resource_limit_test.py new file mode 100644 index 0000000000..ad2f3a292f --- /dev/null +++ b/tests/lib/resource_limit_test.py @@ -0,0 +1,156 @@ +from prowler.lib.resource_limit import ( + get_resource_scan_limit, + iter_limited_paginator_items, + limit_resources, +) + + +class FakePaginator: + def __init__(self, pages): + self.pages = pages + self.paginate_calls = [] + self.pages_requested = 0 + + def paginate(self, **kwargs): + self.paginate_calls.append(kwargs) + for page in self.pages: + self.pages_requested += 1 + yield page + + +class Test_limit_resources: + def test_no_limit_returns_all_in_order(self): + resources = ["PASS", "FAIL", "PASS"] + + result = list(limit_resources(iter(resources), None)) + + assert result == ["PASS", "FAIL", "PASS"] + + +class Test_iter_limited_paginator_items: + def test_positive_limit_stops_without_page_size(self): + paginator = FakePaginator( + [ + {"Items": [1, 2]}, + {"Items": [3, 4]}, + {"Items": [5]}, + ] + ) + + result = list(iter_limited_paginator_items(paginator, "Items", 3)) + + assert result == [1, 2, 3] + assert paginator.paginate_calls == [{}] + assert paginator.pages_requested == 2 + + def test_absurd_limit_is_not_sent_as_page_size(self): + paginator = FakePaginator([{"Items": [1, 2]}]) + + result = list(iter_limited_paginator_items(paginator, "Items", 200000)) + + assert result == [1, 2] + assert paginator.paginate_calls == [{}] + + def test_operation_parameters_are_forwarded_unchanged(self): + paginator = FakePaginator([{"Snapshots": ["snapshot"]}]) + + result = list( + iter_limited_paginator_items( + paginator, + "Snapshots", + 1, + OwnerIds=["self"], + ) + ) + + assert result == ["snapshot"] + assert paginator.paginate_calls == [{"OwnerIds": ["self"]}] + + def test_item_filter_limits_selected_items_only(self): + paginator = FakePaginator( + [ + {"Items": [{"arn": "skip"}, {"arn": "first"}]}, + {"Items": [{"arn": "second"}, {"arn": "third"}]}, + ] + ) + + result = list( + iter_limited_paginator_items( + paginator, + "Items", + 2, + item_filter=lambda item: item["arn"] != "skip", + ) + ) + + assert result == [{"arn": "first"}, {"arn": "second"}] + assert paginator.pages_requested == 2 + + def test_limit_zero_or_negative_is_unlimited(self): + resources = list(range(5)) + + assert list(limit_resources(iter(resources), 0)) == resources + assert list(limit_resources(iter(resources), -3)) == resources + + def test_positive_limit_stops_after_selected_resources(self): + pulled = [] + + def gen(): + for i in range(1000): + pulled.append(i) + yield i + + result = list(limit_resources(gen(), 100)) + + assert result == list(range(100)) + assert len(pulled) == 100 + + def test_does_not_reorder_or_inspect_resource_status(self): + resources = ["PASS", "FAIL", "PASS", "FAIL"] + + result = list(limit_resources(iter(resources), 3)) + + assert result == ["PASS", "FAIL", "PASS"] + + +class Test_get_resource_scan_limit: + def test_per_service_override_wins(self): + config = { + "max_scanned_resources_per_service": 100, + "max_ecs_task_definitions": 25, + } + assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 25 + + def test_falls_back_to_global_default(self): + config = {"max_scanned_resources_per_service": 50} + assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 50 + + def test_null_per_service_override_falls_back_to_global_default(self): + config = { + "max_scanned_resources_per_service": 50, + "max_ecs_task_definitions": None, + } + + assert get_resource_scan_limit(config, "max_ecs_task_definitions") == 50 + + def test_default_is_unlimited_when_unset(self): + assert get_resource_scan_limit({}, "max_ecs_task_definitions") is None + + def test_null_per_service_override_falls_back_to_unlimited_global_default(self): + config = {"max_ecs_task_definitions": None} + + assert get_resource_scan_limit(config, "max_ecs_task_definitions") is None + + def test_non_positive_means_unlimited(self): + assert ( + get_resource_scan_limit( + {"max_scanned_resources_per_service": 0}, "max_lambda_functions" + ) + is None + ) + assert ( + get_resource_scan_limit( + {"max_lambda_functions": -1}, "max_lambda_functions" + ) + is None + ) diff --git a/tests/providers/aws/services/awslambda/awslambda_service_test.py b/tests/providers/aws/services/awslambda/awslambda_service_test.py index 412c944d8b..302b99d109 100644 --- a/tests/providers/aws/services/awslambda/awslambda_service_test.py +++ b/tests/providers/aws/services/awslambda/awslambda_service_test.py @@ -6,10 +6,16 @@ from re import search from unittest.mock import patch import mock +import pytest from boto3 import client, resource +from botocore.client import ClientError from moto import mock_aws -from prowler.providers.aws.services.awslambda.awslambda_service import AuthType, Lambda +from prowler.providers.aws.services.awslambda.awslambda_service import ( + AuthType, + Function, + Lambda, +) from tests.providers.aws.utils import ( AWS_ACCOUNT_NUMBER, AWS_REGION_EU_WEST_1, @@ -85,6 +91,367 @@ class Test_Lambda_Service: awslambda = Lambda(set_mocked_aws_provider([AWS_REGION_US_EAST_1])) assert awslambda.service == "lambda" + def test_function_limit_selects_latest_functions_for_analysis(self): + awslambda = Lambda.__new__(Lambda) + awslambda.functions = { + "old": Function( + name="old", + arn="old", + security_groups=[], + last_modified="2024-01-01T00:00:00.000+0000", + region=AWS_REGION_EU_WEST_1, + ), + "new": Function( + name="new", + arn="new", + security_groups=[], + last_modified="2024-01-02T00:00:00.000+0000", + region=AWS_REGION_EU_WEST_1, + ), + } + awslambda.function_limit = 1 + + awslambda._select_functions_for_analysis() + + assert list(awslambda.functions) == ["new"] + + def test_function_limit_selects_global_latest_across_regions(self): + class FakePaginator: + def __init__(self, functions): + self.functions = functions + + def paginate(self, **kwargs): + assert "PageSize" not in kwargs + return [{"Functions": self.functions}] + + class FakeLambdaClient: + def __init__(self, region, functions): + self.region = region + self.functions = functions + + def get_paginator(self, name): + assert name == "list_functions" + return FakePaginator(self.functions) + + awslambda = Lambda.__new__(Lambda) + awslambda.functions = {} + awslambda.security_groups_in_use = set() + awslambda.regions_with_functions = set() + awslambda.function_limit = 1 + awslambda.audit_resources = [] + old_client = FakeLambdaClient( + AWS_REGION_EU_WEST_1, + [ + { + "FunctionName": "old", + "FunctionArn": "arn:aws:lambda:eu-west-1:123456789012:function:old", + "LastModified": "2024-01-01T00:00:00.000+0000", + } + ], + ) + new_client = FakeLambdaClient( + AWS_REGION_US_EAST_1, + [ + { + "FunctionName": "new", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:new", + "LastModified": "2024-01-02T00:00:00.000+0000", + } + ], + ) + + awslambda._list_functions(old_client) + awslambda._list_functions(new_client) + awslambda._select_functions_for_analysis() + + assert [function.name for function in awslambda.functions.values()] == ["new"] + + def test_function_limit_keeps_complete_auxiliary_indexes(self): + class FakePaginator: + def __init__(self, functions): + self.functions = functions + + def paginate(self, **kwargs): + assert "PageSize" not in kwargs + return [{"Functions": self.functions}] + + class FakeLambdaClient: + region = AWS_REGION_US_EAST_1 + + def get_paginator(self, name): + assert name == "list_functions" + return FakePaginator( + [ + { + "FunctionName": "old", + "FunctionArn": ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:old" + ), + "LastModified": "2024-01-01T00:00:00.000+0000", + "VpcConfig": {"SecurityGroupIds": ["sg-old"]}, + }, + { + "FunctionName": "new", + "FunctionArn": ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:new" + ), + "LastModified": "2024-01-02T00:00:00.000+0000", + "VpcConfig": {"SecurityGroupIds": ["sg-new"]}, + }, + ] + ) + + awslambda = Lambda.__new__(Lambda) + awslambda.functions = {} + awslambda.security_groups_in_use = set() + awslambda.regions_with_functions = set() + awslambda.function_limit = 1 + awslambda.audit_resources = [] + + awslambda._list_functions(FakeLambdaClient()) + awslambda._select_functions_for_analysis() + + assert [function.name for function in awslambda.functions.values()] == ["new"] + assert awslambda.security_groups_in_use == {"sg-old", "sg-new"} + assert awslambda.regions_with_functions == {AWS_REGION_US_EAST_1} + + def test_list_event_source_mappings_uses_selected_functions_as_api_scope(self): + class FakePaginator: + def __init__(self): + self.paginate_calls = [] + + def paginate(self, **kwargs): + self.paginate_calls.append(kwargs) + function_name = kwargs["FunctionName"] + return [ + { + "EventSourceMappings": [ + { + "UUID": f"{function_name}-mapping", + "FunctionArn": ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:{function_name}:1" + ), + "EventSourceArn": "arn:aws:sqs:queue", + "State": "Enabled", + "BatchSize": 10, + } + ] + } + ] + + class FakeLambdaClient: + region = AWS_REGION_US_EAST_1 + + def __init__(self): + self.paginator = FakePaginator() + + def get_paginator(self, name): + assert name == "list_event_source_mappings" + return self.paginator + + selected_arn = ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:selected" + ) + other_region_arn = ( + f"arn:aws:lambda:{AWS_REGION_EU_WEST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:other-region" + ) + awslambda = Lambda.__new__(Lambda) + awslambda.function_limit = 1 + awslambda.functions = { + selected_arn: Function( + name="selected", + arn=selected_arn, + security_groups=[], + region=AWS_REGION_US_EAST_1, + ), + other_region_arn: Function( + name="other-region", + arn=other_region_arn, + security_groups=[], + region=AWS_REGION_EU_WEST_1, + ), + } + regional_client = FakeLambdaClient() + + awslambda._list_event_source_mappings(regional_client) + + assert regional_client.paginator.paginate_calls == [ + {"FunctionName": "selected"} + ] + assert len(awslambda.functions[selected_arn].event_source_mappings) == 1 + assert ( + awslambda.functions[selected_arn].event_source_mappings[0].uuid + == "selected-mapping" + ) + assert not awslambda.functions[other_region_arn].event_source_mappings + + def test_list_event_source_mappings_keeps_unlimited_regional_api_scope(self): + class FakePaginator: + def __init__(self): + self.paginate_calls = [] + + def paginate(self, **kwargs): + self.paginate_calls.append(kwargs) + return [ + { + "EventSourceMappings": [ + { + "UUID": "selected-mapping", + "FunctionArn": selected_arn, + "EventSourceArn": "arn:aws:sqs:queue", + "State": "Enabled", + } + ] + } + ] + + class FakeLambdaClient: + region = AWS_REGION_US_EAST_1 + + def __init__(self): + self.paginator = FakePaginator() + + def get_paginator(self, name): + assert name == "list_event_source_mappings" + return self.paginator + + selected_arn = ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:selected" + ) + awslambda = Lambda.__new__(Lambda) + awslambda.function_limit = None + awslambda.functions = { + selected_arn: Function( + name="selected", + arn=selected_arn, + security_groups=[], + region=AWS_REGION_US_EAST_1, + ) + } + regional_client = FakeLambdaClient() + + awslambda._list_event_source_mappings(regional_client) + + assert regional_client.paginator.paginate_calls == [{}] + assert len(awslambda.functions[selected_arn].event_source_mappings) == 1 + + def test_list_event_source_mappings_continues_after_invalid_parameter_value(self): + class FakePaginator: + def paginate(self, **kwargs): + function_name = kwargs["FunctionName"] + if function_name == "deleted": + raise ClientError( + { + "Error": { + "Code": "InvalidParameterValueException", + "Message": "Function no longer exists", + } + }, + "ListEventSourceMappings", + ) + return [ + { + "EventSourceMappings": [ + { + "UUID": f"{function_name}-mapping", + "FunctionArn": ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:{function_name}" + ), + "EventSourceArn": "arn:aws:sqs:queue", + "State": "Enabled", + } + ] + } + ] + + class FakeLambdaClient: + region = AWS_REGION_US_EAST_1 + + def get_paginator(self, name): + assert name == "list_event_source_mappings" + return FakePaginator() + + deleted_arn = ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:deleted" + ) + remaining_arn = ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:remaining" + ) + awslambda = Lambda.__new__(Lambda) + awslambda.function_limit = 2 + awslambda.functions = { + deleted_arn: Function( + name="deleted", + arn=deleted_arn, + security_groups=[], + region=AWS_REGION_US_EAST_1, + ), + remaining_arn: Function( + name="remaining", + arn=remaining_arn, + security_groups=[], + region=AWS_REGION_US_EAST_1, + ), + } + + awslambda._list_event_source_mappings(FakeLambdaClient()) + + assert not awslambda.functions[deleted_arn].event_source_mappings + assert len(awslambda.functions[remaining_arn].event_source_mappings) == 1 + assert ( + awslambda.functions[remaining_arn].event_source_mappings[0].uuid + == "remaining-mapping" + ) + + def test_list_event_source_mappings_raises_non_transient_client_error(self): + class FakePaginator: + def paginate(self, **kwargs): + raise ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "Access denied", + } + }, + "ListEventSourceMappings", + ) + + class FakeLambdaClient: + region = AWS_REGION_US_EAST_1 + + def get_paginator(self, name): + assert name == "list_event_source_mappings" + return FakePaginator() + + function_arn = ( + f"arn:aws:lambda:{AWS_REGION_US_EAST_1}:" + f"{AWS_ACCOUNT_NUMBER}:function:selected" + ) + awslambda = Lambda.__new__(Lambda) + awslambda.function_limit = 1 + awslambda.functions = { + function_arn: Function( + name="selected", + arn=function_arn, + security_groups=[], + region=AWS_REGION_US_EAST_1, + ) + } + + with pytest.raises(ClientError) as error: + awslambda._list_event_source_mappings(FakeLambdaClient()) + + assert error.value.response["Error"]["Code"] == "AccessDeniedException" + @mock_aws def test_list_functions(self): # Create IAM Lambda Role @@ -253,3 +620,63 @@ class Test_Lambda_Service: f"{tmp_dir_name}/{files_in_zip[0]}", "r" ) as lambda_code_file: assert lambda_code_file.read() == LAMBDA_FUNCTION_CODE + + @mock_aws + def test_function_limit_exposes_only_selected_functions(self): + lambda_client = client("lambda", region_name=AWS_REGION_US_EAST_1) + iam_client = client("iam", region_name=AWS_REGION_US_EAST_1) + iam_role = iam_client.create_role( + RoleName="test-role", + AssumeRolePolicyDocument="{}", + )["Role"]["Arn"] + for name in ("function-1", "function-2"): + lambda_client.create_function( + FunctionName=name, + Runtime="python3.7", + Role=iam_role, + Handler="lambda_function.lambda_handler", + Code={"ZipFile": create_zip_file().read()}, + PackageType="ZIP", + ) + awslambda = Lambda( + set_mocked_aws_provider( + audited_regions=[AWS_REGION_US_EAST_1], + audit_config={"max_lambda_functions": 1}, + ) + ) + + assert len(awslambda.functions) == 1 + + @mock_aws + def test_get_function_code_fetches_only_selected_functions(self): + lambda_client = client("lambda", region_name=AWS_REGION_US_EAST_1) + iam_client = client("iam", region_name=AWS_REGION_US_EAST_1) + iam_role = iam_client.create_role( + RoleName="test-role", + AssumeRolePolicyDocument="{}", + )["Role"]["Arn"] + for name in ("function-1", "function-2"): + lambda_client.create_function( + FunctionName=name, + Runtime="python3.7", + Role=iam_role, + Handler="lambda_function.lambda_handler", + Code={"ZipFile": create_zip_file().read()}, + PackageType="ZIP", + ) + awslambda = Lambda( + set_mocked_aws_provider( + audited_regions=[AWS_REGION_US_EAST_1], + audit_config={"max_lambda_functions": 1}, + ) + ) + fetched = [] + + def fetch_function_code(function_name, _function_region): + fetched.append(function_name) + return mock.MagicMock() + + awslambda._fetch_function_code = fetch_function_code + + assert len(list(awslambda._get_function_code())) == 1 + assert len(fetched) == 1 diff --git a/tests/providers/aws/services/backup/backup_service_test.py b/tests/providers/aws/services/backup/backup_service_test.py index d4a7be392f..5894a62bda 100644 --- a/tests/providers/aws/services/backup/backup_service_test.py +++ b/tests/providers/aws/services/backup/backup_service_test.py @@ -1,11 +1,16 @@ from datetime import datetime +from types import SimpleNamespace from unittest.mock import patch import botocore from boto3 import client from moto import mock_aws -from prowler.providers.aws.services.backup.backup_service import Backup +from prowler.providers.aws.services.backup.backup_service import ( + Backup, + BackupVault, + RecoveryPoint, +) from tests.providers.aws.utils import ( AWS_ACCOUNT_NUMBER, AWS_REGION_EU_WEST_1, @@ -292,3 +297,248 @@ class TestBackupService: assert backup.recovery_points[0].backup_vault_region == "eu-west-1" assert backup.recovery_points[0].tags == [] assert backup.recovery_points[0].encrypted is True + + def test_recovery_point_limit_bounds_tag_calls_to_selected_points(self): + class FakePaginator: + def paginate(self, **kwargs): + return [ + { + "RecoveryPoints": [ + { + "RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:new", + "IsEncrypted": True, + "CreationDate": datetime(2024, 1, 2), + }, + { + "RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:old", + "IsEncrypted": True, + "CreationDate": datetime(2024, 1, 1), + }, + ] + } + ] + + class FakeBackupClient: + def __init__(self): + self.tag_calls = [] + + def get_paginator(self, name): + assert name == "list_recovery_points_by_backup_vault" + return FakePaginator() + + def list_tags(self, **kwargs): + self.tag_calls.append(kwargs["ResourceArn"]) + return {"Tags": {}} + + regional_client = FakeBackupClient() + backup = Backup.__new__(Backup) + backup.backup_vaults = [ + BackupVault( + arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:vault", + name="vault", + region=AWS_REGION_EU_WEST_1, + encryption="", + recovery_points=2, + locked=False, + ) + ] + backup.recovery_points = [] + backup.recovery_point_limit = 1 + backup.regional_clients = {AWS_REGION_EU_WEST_1: regional_client} + + backup._list_recovery_points() + backup._select_recovery_points_for_analysis() + for recovery_point in backup.recovery_points: + backup._list_tags(recovery_point) + + assert [rp.id for rp in backup.recovery_points] == ["new"] + assert regional_client.tag_calls == [ + "arn:aws:backup:eu-west-1:123456789012:recovery-point:new" + ] + + def test_recovery_point_limit_selects_global_newest_across_vaults(self): + class FakePaginator: + def __init__(self, recovery_points_by_vault): + self.recovery_points_by_vault = recovery_points_by_vault + + def paginate(self, **kwargs): + assert "PageSize" not in kwargs + return [ + { + "RecoveryPoints": self.recovery_points_by_vault[ + kwargs["BackupVaultName"] + ] + } + ] + + class FakeBackupClient: + def __init__(self, recovery_points_by_vault): + self.recovery_points_by_vault = recovery_points_by_vault + + def get_paginator(self, name): + assert name == "list_recovery_points_by_backup_vault" + return FakePaginator(self.recovery_points_by_vault) + + backup = Backup.__new__(Backup) + backup.recovery_point_limit = 1 + backup.recovery_points = [] + backup.backup_vaults = [ + BackupVault( + arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:old-vault", + name="old-vault", + region=AWS_REGION_EU_WEST_1, + encryption="", + recovery_points=1, + locked=False, + ), + BackupVault( + arn="arn:aws:backup:eu-west-1:123456789012:backup-vault:new-vault", + name="new-vault", + region=AWS_REGION_EU_WEST_1, + encryption="", + recovery_points=1, + locked=False, + ), + ] + backup.regional_clients = { + AWS_REGION_EU_WEST_1: FakeBackupClient( + { + "old-vault": [ + { + "RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:old", + "IsEncrypted": True, + "CreationDate": datetime(2024, 1, 1), + } + ], + "new-vault": [ + { + "RecoveryPointArn": "arn:aws:backup:eu-west-1:123456789012:recovery-point:new", + "IsEncrypted": True, + "CreationDate": datetime(2024, 1, 2), + } + ], + } + ) + } + + backup._list_recovery_points() + backup._select_recovery_points_for_analysis() + + assert [rp.id for rp in backup.recovery_points] == ["new"] + + def test_recovery_point_limit_exposes_only_selected_resources(self): + backup = Backup.__new__(Backup) + backup.recovery_point_limit = 2 + backup.recovery_points = [] + backup.backup_vaults = [ + BackupVault( + arn="arn", + name="vault", + region="eu-west-1", + encryption="", + recovery_points=3, + locked=False, + ) + ] + + class Paginator: + def paginate(self, **_kwargs): + return [ + { + "RecoveryPoints": [ + { + "RecoveryPointArn": f"arn:aws:backup:eu-west-1:123456789012:recovery-point:{i}", + "IsEncrypted": True, + } + for i in range(3) + ] + } + ] + + backup.regional_clients = { + "eu-west-1": SimpleNamespace(get_paginator=lambda _: Paginator()) + } + tagged = [] + + def list_tags(recovery_point): + tagged.append(recovery_point.arn) + + backup._list_tags = list_tags + + backup._list_recovery_points() + backup._select_recovery_points_for_analysis() + for recovery_point in backup.recovery_points: + backup._list_tags(recovery_point) + + assert len(backup.recovery_points) == 2 + assert len(tagged) == 2 + + def test_recovery_point_limit_uses_deterministic_tie_breaker(self): + backup = Backup.__new__(Backup) + backup.recovery_point_limit = 2 + backup.recovery_points = [ + RecoveryPoint( + arn="arn:aws:backup:us-east-1:123456789012:recovery-point:z", + id="z", + region="us-east-1", + backup_vault_name="vault-b", + encrypted=True, + backup_vault_region="us-east-1", + ), + RecoveryPoint( + arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:b", + id="b", + region="eu-west-1", + backup_vault_name="vault-b", + encrypted=True, + backup_vault_region="eu-west-1", + ), + RecoveryPoint( + arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:a", + id="a", + region="eu-west-1", + backup_vault_name="vault-a", + encrypted=True, + backup_vault_region="eu-west-1", + ), + ] + + backup._select_recovery_points_for_analysis() + + assert [rp.id for rp in backup.recovery_points] == ["a", "b"] + + def test_recovery_point_limit_keeps_newest_before_tie_breaker(self): + backup = Backup.__new__(Backup) + backup.recovery_point_limit = 2 + backup.recovery_points = [ + RecoveryPoint( + arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:older-a", + id="older-a", + region="eu-west-1", + backup_vault_name="vault-a", + encrypted=True, + backup_vault_region="eu-west-1", + creation_date=datetime(2024, 1, 1), + ), + RecoveryPoint( + arn="arn:aws:backup:us-east-1:123456789012:recovery-point:newer-z", + id="newer-z", + region="us-east-1", + backup_vault_name="vault-z", + encrypted=True, + backup_vault_region="us-east-1", + creation_date=datetime(2024, 1, 2), + ), + RecoveryPoint( + arn="arn:aws:backup:eu-west-1:123456789012:recovery-point:missing-date", + id="missing-date", + region="eu-west-1", + backup_vault_name="vault-a", + encrypted=True, + backup_vault_region="eu-west-1", + ), + ] + + backup._select_recovery_points_for_analysis() + + assert [rp.id for rp in backup.recovery_points] == ["newer-z", "older-a"] diff --git a/tests/providers/aws/services/bedrock/bedrock_model_invocation_logs_encryption_enabled/bedrock_model_invocation_logs_encryption_enabled_test.py b/tests/providers/aws/services/bedrock/bedrock_model_invocation_logs_encryption_enabled/bedrock_model_invocation_logs_encryption_enabled_test.py index 4630ad25a0..8c45a9d56b 100644 --- a/tests/providers/aws/services/bedrock/bedrock_model_invocation_logs_encryption_enabled/bedrock_model_invocation_logs_encryption_enabled_test.py +++ b/tests/providers/aws/services/bedrock/bedrock_model_invocation_logs_encryption_enabled/bedrock_model_invocation_logs_encryption_enabled_test.py @@ -227,6 +227,72 @@ class Test_bedrock_model_invocation_logs_encryption_enabled: assert result[0].region == AWS_REGION_US_EAST_1 assert result[0].resource_tags == [] + def test_cloudwatch_logging_uses_complete_log_group_index(self): + from prowler.providers.aws.services.bedrock.bedrock_service import ( + LoggingConfiguration, + ) + from prowler.providers.aws.services.cloudwatch.cloudwatch_service import ( + LogGroup, + ) + + bedrock_client = mock.MagicMock() + bedrock_client.logging_configurations = { + AWS_REGION_US_EAST_1: LoggingConfiguration( + enabled=True, + cloudwatch_log_group="Test", + ) + } + bedrock_client._get_model_invocation_logging_arn_template.return_value = ( + "arn:aws:bedrock:us-east-1:123456789012:model-invocation-logging" + ) + + logs_client = mock.MagicMock() + logs_client.audited_partition = "aws" + logs_client.audited_account = "123456789012" + logs_client.log_groups = {} + logs_client.all_log_groups = { + "arn:aws:logs:us-east-1:123456789012:log-group:Test:*": LogGroup( + arn="arn:aws:logs:us-east-1:123456789012:log-group:Test:*", + name="Test", + retention_days=30, + never_expire=False, + kms_id=None, + region=AWS_REGION_US_EAST_1, + ) + } + + s3_client = mock.MagicMock() + s3_client.audited_partition = "aws" + s3_client.buckets = {} + + with ( + mock.patch( + "prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.bedrock_client", + new=bedrock_client, + ), + mock.patch( + "prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.logs_client", + new=logs_client, + ), + mock.patch( + "prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled.s3_client", + new=s3_client, + ), + ): + from prowler.providers.aws.services.bedrock.bedrock_model_invocation_logs_encryption_enabled.bedrock_model_invocation_logs_encryption_enabled import ( + bedrock_model_invocation_logs_encryption_enabled, + ) + + check = bedrock_model_invocation_logs_encryption_enabled() + result = check.execute() + + assert len(result) == 1 + assert result[0].status == "FAIL" + assert ( + result[0].status_extended + == "Bedrock Model Invocation logs are not encrypted in CloudWatch Log Group: Test." + ) + @mock_aws def test_s3_and_cloudwatch_logging_encrypted(self): logs_client = client("logs", region_name=AWS_REGION_US_EAST_1) diff --git a/tests/providers/aws/services/cloudwatch/cloudwatch_service_test.py b/tests/providers/aws/services/cloudwatch/cloudwatch_service_test.py index 33d4bde7d7..3d2d53ffa0 100644 --- a/tests/providers/aws/services/cloudwatch/cloudwatch_service_test.py +++ b/tests/providers/aws/services/cloudwatch/cloudwatch_service_test.py @@ -4,6 +4,7 @@ from moto import mock_aws from prowler.providers.aws.services.cloudwatch.cloudwatch_service import ( CloudWatch, + LogGroup, Logs, ) from prowler.providers.aws.services.cloudwatch.lib.metric_filters import ( @@ -188,7 +189,9 @@ class Test_CloudWatch_Service: arn = f"arn:aws:logs:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:log-group:/log-group/test:*" logs = Logs(aws_provider) assert len(logs.log_groups) == 1 + assert len(logs.all_log_groups) == 1 assert arn in logs.log_groups + assert arn in logs.all_log_groups assert logs.log_groups[arn].name == "/log-group/test" assert logs.log_groups[arn].retention_days == 400 assert logs.log_groups[arn].kms_id == "test_kms_id" @@ -212,7 +215,9 @@ class Test_CloudWatch_Service: arn = f"arn:aws:logs:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:log-group:/log-group/test:*" logs = Logs(aws_provider) assert len(logs.log_groups) == 1 + assert len(logs.all_log_groups) == 1 assert arn in logs.log_groups + assert arn in logs.all_log_groups assert logs.log_groups[arn].name == "/log-group/test" assert logs.log_groups[arn].never_expire # Since it never expires we don't use the retention_days @@ -221,6 +226,190 @@ class Test_CloudWatch_Service: assert logs.log_groups[arn].region == AWS_REGION_US_EAST_1 assert logs.log_groups[arn].tags == [{}] + def test_log_group_limit_exposes_only_selected_resources(self): + class FakeLogsClient: + def __init__(self): + self.filter_calls = [] + + def filter_log_events(self, **kwargs): + self.filter_calls.append(kwargs["logGroupName"]) + return {"events": []} + + regional_client = FakeLogsClient() + logs = Logs.__new__(Logs) + logs.log_group_limit = 1 + logs._log_groups_hydrated = set() + logs.regional_clients = {AWS_REGION_US_EAST_1: regional_client} + logs.events_per_log_group_threshold = 1000 + logs.log_groups = { + f"arn:{i}": LogGroup( + arn=f"arn:{i}", + name=f"log-{i}", + retention_days=30, + never_expire=False, + kms_id=None, + creation_time=i, + region=AWS_REGION_US_EAST_1, + ) + for i in range(3) + } + tagged = [] + + def list_tags(log_group): + tagged.append(log_group.arn) + + logs._list_tags_for_resource = list_tags + + logs._select_log_groups_for_analysis() + for log_group in logs.log_groups.values(): + logs._list_tags_for_resource(log_group) + logs._get_log_events(log_group) + + assert list(logs.log_groups) == ["arn:2"] + assert tagged == ["arn:2"] + assert regional_client.filter_calls == ["log-2"] + + def test_log_group_limit_selects_global_newest_across_regions(self): + class FakePaginator: + def __init__(self, log_groups): + self.log_groups = log_groups + + def paginate(self, **kwargs): + assert "PageSize" not in kwargs + return [{"logGroups": self.log_groups}] + + class FakeLogsClient: + def __init__(self, region, log_groups): + self.region = region + self.log_groups = log_groups + + def get_paginator(self, name): + assert name == "describe_log_groups" + return FakePaginator(self.log_groups) + + logs = Logs.__new__(Logs) + logs.all_log_groups = {} + logs.log_groups = {} + logs.log_group_limit = 1 + logs.audit_resources = [] + + logs._describe_log_groups( + FakeLogsClient( + "eu-west-1", + [ + { + "arn": "arn:aws:logs:eu-west-1:123456789012:log-group:old:*", + "logGroupName": "old", + "creationTime": 1, + } + ], + ) + ) + logs._describe_log_groups( + FakeLogsClient( + AWS_REGION_US_EAST_1, + [ + { + "arn": "arn:aws:logs:us-east-1:123456789012:log-group:new:*", + "logGroupName": "new", + "creationTime": 2, + } + ], + ) + ) + logs._select_log_groups_for_analysis() + + assert [log_group.name for log_group in logs.log_groups.values()] == ["new"] + assert [log_group.name for log_group in logs.all_log_groups.values()] == [ + "old", + "new", + ] + + def test_metric_filters_use_complete_log_group_index(self): + class FakePaginator: + def paginate(self): + return [ + { + "metricFilters": [ + { + "filterName": "test-filter", + "filterPattern": "test-pattern", + "logGroupName": "old", + "metricTransformations": [ + {"metricName": "test-metric"} + ], + } + ] + } + ] + + class FakeLogsClient: + region = AWS_REGION_US_EAST_1 + + def get_paginator(self, name): + assert name == "describe_metric_filters" + return FakePaginator() + + logs = Logs.__new__(Logs) + old_log_group = LogGroup( + arn="arn:old", + name="old", + retention_days=30, + never_expire=False, + kms_id=None, + creation_time=1, + region=AWS_REGION_US_EAST_1, + ) + logs.audited_partition = "aws" + logs.audited_account = AWS_ACCOUNT_NUMBER + logs.audit_resources = [] + logs.metric_filters = [] + logs.log_groups = {} + logs.all_log_groups = {old_log_group.arn: old_log_group} + logs._log_groups_hydrated = set() + logs._list_tags_for_resource = lambda log_group: None + + logs._describe_metric_filters(FakeLogsClient()) + + assert len(logs.metric_filters) == 1 + assert logs.metric_filters[0].log_group == old_log_group + + def test_log_group_collection_recovers_all_log_groups_after_access_denied(self): + class FakePaginator: + def paginate(self): + return [ + { + "logGroups": [ + { + "arn": "arn:aws:logs:us-east-1:123456789012:log-group:success:*", + "logGroupName": "success", + "creationTime": 1, + } + ] + } + ] + + class FakeLogsClient: + region = AWS_REGION_US_EAST_1 + + def get_paginator(self, name): + assert name == "describe_log_groups" + return FakePaginator() + + logs = Logs.__new__(Logs) + logs.all_log_groups = None + logs.log_groups = None + logs.audit_resources = [] + + logs._describe_log_groups(FakeLogsClient()) + + assert list(logs.all_log_groups) == [ + "arn:aws:logs:us-east-1:123456789012:log-group:success:*" + ] + assert list(logs.log_groups) == [ + "arn:aws:logs:us-east-1:123456789012:log-group:success:*" + ] + class Test_build_metric_filter_pattern: @pytest.mark.parametrize("bad_operator", ["==", "~=", "<", "<>", ">=", ""]) diff --git a/tests/providers/aws/services/codeartifact/codeartifact_service_test.py b/tests/providers/aws/services/codeartifact/codeartifact_service_test.py index 99325dd2ea..50af85b767 100644 --- a/tests/providers/aws/services/codeartifact/codeartifact_service_test.py +++ b/tests/providers/aws/services/codeartifact/codeartifact_service_test.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import patch import botocore @@ -6,6 +7,7 @@ from prowler.providers.aws.services.codeartifact.codeartifact_service import ( CodeArtifact, LatestPackageVersionStatus, OriginInformationValues, + Repository, RestrictionValues, ) from tests.providers.aws.utils import ( @@ -208,6 +210,104 @@ class Test_CodeArtifact_Service: == OriginInformationValues.INTERNAL ) + def test_package_limit_bounds_package_version_lookups_to_selected_packages(self): + class FakePaginator: + def paginate(self, **kwargs): + return [ + { + "packages": [ + { + "format": "pypi", + "package": "first-package", + "originConfiguration": { + "restrictions": { + "publish": "ALLOW", + "upstream": "ALLOW", + } + }, + }, + { + "format": "pypi", + "package": "second-package", + "originConfiguration": { + "restrictions": { + "publish": "ALLOW", + "upstream": "ALLOW", + } + }, + }, + ] + } + ] + + class FakeCodeArtifactClient: + def __init__(self): + self.version_calls = [] + + def get_paginator(self, name): + assert name == "list_packages" + return FakePaginator() + + def list_package_versions(self, **kwargs): + self.version_calls.append(kwargs["package"]) + return { + "versions": [ + { + "version": "1.0.0", + "status": "Published", + "origin": {"originType": "INTERNAL"}, + } + ] + } + + regional_client = FakeCodeArtifactClient() + codeartifact = CodeArtifact.__new__(CodeArtifact) + codeartifact.repositories = { + TEST_REPOSITORY_ARN: Repository( + name="test-repository", + arn=TEST_REPOSITORY_ARN, + domain_name="test-domain", + domain_owner=AWS_ACCOUNT_NUMBER, + region=AWS_REGION_EU_WEST_1, + ) + } + codeartifact._packages_listed = set() + codeartifact.package_limit = 1 + codeartifact.regional_clients = {AWS_REGION_EU_WEST_1: regional_client} + + pairs = list(codeartifact._load_packages_for_analysis()) + + assert [package.name for _, package in pairs] == ["first-package"] + assert regional_client.version_calls == ["first-package"] + + def test_package_limit_exposes_only_selected_packages(self): + codeartifact = CodeArtifact.__new__(CodeArtifact) + codeartifact.package_limit = 2 + codeartifact._packages_listed = set() + repository = Repository( + name="repository", + arn="repo", + domain_name="domain", + domain_owner=AWS_ACCOUNT_NUMBER, + region=AWS_REGION_EU_WEST_1, + ) + codeartifact.repositories = {repository.arn: repository} + enriched = [] + + def iter_repository_packages(repository, limit=None): + for index in range(3): + if limit is not None and index >= limit: + return + enriched.append(index) + yield SimpleNamespace(name=f"package-{index}") + + codeartifact._iter_repository_packages = iter_repository_packages + + packages = list(codeartifact._load_packages_for_analysis()) + + assert [package.name for _, package in packages] == ["package-0", "package-1"] + assert enriched == [0, 1] + def mock_make_api_call_no_namespace(self, operation_name, kwarg): """Mock for packages without a namespace to exercise the else branch""" diff --git a/tests/providers/aws/services/ec2/ec2_securitygroup_not_used/ec2_securitygroup_not_used_test.py b/tests/providers/aws/services/ec2/ec2_securitygroup_not_used/ec2_securitygroup_not_used_test.py index d1b77b9f8b..b959f4b257 100644 --- a/tests/providers/aws/services/ec2/ec2_securitygroup_not_used/ec2_securitygroup_not_used_test.py +++ b/tests/providers/aws/services/ec2/ec2_securitygroup_not_used/ec2_securitygroup_not_used_test.py @@ -14,6 +14,57 @@ EXAMPLE_AMI_ID = "ami-12c6146b" class Test_ec2_securitygroup_not_used: + def test_ec2_sg_used_by_lambda_outside_selected_analysis_limit(self): + from prowler.providers.aws.services.ec2.ec2_service import SecurityGroup + + sg_id = "sg-limited-out" + sg_name = "lambda-sg" + security_group = SecurityGroup( + name=sg_name, + region=AWS_REGION_US_EAST_1, + arn=f"arn:aws:ec2:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:security-group/{sg_id}", + id=sg_id, + vpc_id="vpc-test", + associated_sgs=[], + network_interfaces=[], + ingress_rules=[], + egress_rules=[], + tags=[], + ) + ec2_client = mock.MagicMock() + ec2_client.security_groups = {security_group.arn: security_group} + awslambda_client = mock.MagicMock() + awslambda_client.functions = {} + awslambda_client.security_groups_in_use = {sg_id} + aws_provider = set_mocked_aws_provider() + + with ( + mock.patch( + "prowler.providers.common.provider.Provider.get_global_provider", + return_value=aws_provider, + ), + mock.patch( + "prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used.ec2_client", + new=ec2_client, + ), + mock.patch( + "prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used.awslambda_client", + new=awslambda_client, + ), + ): + from prowler.providers.aws.services.ec2.ec2_securitygroup_not_used.ec2_securitygroup_not_used import ( + ec2_securitygroup_not_used, + ) + + result = ec2_securitygroup_not_used().execute() + + assert len(result) == 1 + assert result[0].status == "PASS" + assert ( + result[0].status_extended + == f"Security group {sg_name} ({sg_id}) it is being used." + ) + @mock_aws def test_ec2_default_sgs(self): # Create EC2 Mocked Resources diff --git a/tests/providers/aws/services/ec2/ec2_service_test.py b/tests/providers/aws/services/ec2/ec2_service_test.py index aca200dd61..1302952e95 100644 --- a/tests/providers/aws/services/ec2/ec2_service_test.py +++ b/tests/providers/aws/services/ec2/ec2_service_test.py @@ -11,7 +11,7 @@ from freezegun import freeze_time from moto import mock_aws from prowler.config.config import encoding_format_utf_8 -from prowler.providers.aws.services.ec2.ec2_service import EC2 +from prowler.providers.aws.services.ec2.ec2_service import EC2, Snapshot from tests.providers.aws.utils import ( AWS_ACCOUNT_NUMBER, AWS_REGION_EU_WEST_1, @@ -103,6 +103,99 @@ class Test_EC2_Service: ec2 = EC2(aws_provider) assert ec2.audited_account == AWS_ACCOUNT_NUMBER + def test_snapshot_limit_bounds_public_attribute_calls_to_latest_selected(self): + class FakeEC2Client: + def __init__(self): + self.calls = [] + + def describe_snapshot_attribute(self, **kwargs): + self.calls.append(kwargs["SnapshotId"]) + return {"CreateVolumePermissions": []} + + regional_client = FakeEC2Client() + ec2 = EC2.__new__(EC2) + ec2.snapshots = [ + Snapshot( + id="snap-old", + arn="arn:aws:ec2:eu-west-1:123456789012:snapshot/snap-old", + region=AWS_REGION_EU_WEST_1, + encrypted=True, + start_time=datetime(2024, 1, 1), + volume="vol-old", + ), + Snapshot( + id="snap-new", + arn="arn:aws:ec2:eu-west-1:123456789012:snapshot/snap-new", + region=AWS_REGION_EU_WEST_1, + encrypted=True, + start_time=datetime(2024, 1, 2), + volume="vol-new", + ), + ] + ec2.snapshot_limit = 1 + ec2.regional_clients = {AWS_REGION_EU_WEST_1: regional_client} + + ec2._select_snapshots_for_analysis() + for snapshot in ec2.snapshots: + ec2._determine_public_snapshots(snapshot) + + assert [snapshot.id for snapshot in ec2.snapshots] == ["snap-new"] + assert regional_client.calls == ["snap-new"] + + def test_snapshot_limit_preserves_volume_index_and_selects_global_latest(self): + class FakePaginator: + def __init__(self, snapshots): + self.snapshots = snapshots + + def paginate(self, **kwargs): + assert "PageSize" not in kwargs + return [{"Snapshots": self.snapshots}] + + class FakeEC2Client: + def __init__(self, region, snapshots): + self.region = region + self.snapshots = snapshots + + def get_paginator(self, name): + assert name == "describe_snapshots" + return FakePaginator(self.snapshots) + + ec2 = EC2.__new__(EC2) + ec2.snapshots = [] + ec2.volumes_with_snapshots = {} + ec2.regions_with_snapshots = {} + ec2.snapshot_limit = 1 + ec2.audit_resources = [] + ec2.audited_partition = "aws" + ec2.audited_account = AWS_ACCOUNT_NUMBER + old_client = FakeEC2Client( + AWS_REGION_EU_WEST_1, + [ + { + "SnapshotId": "snap-old", + "VolumeId": "vol-old", + "StartTime": datetime(2024, 1, 1), + } + ], + ) + new_client = FakeEC2Client( + AWS_REGION_US_EAST_1, + [ + { + "SnapshotId": "snap-new", + "VolumeId": "vol-new", + "StartTime": datetime(2024, 1, 2), + } + ], + ) + + ec2._describe_snapshots(old_client) + ec2._describe_snapshots(new_client) + ec2._select_snapshots_for_analysis() + + assert ec2.volumes_with_snapshots == {"vol-old": True, "vol-new": True} + assert [snapshot.id for snapshot in ec2.snapshots] == ["snap-new"] + # Test EC2 Describe Instances @mock_aws @freeze_time(MOCK_DATETIME) @@ -346,6 +439,24 @@ class Test_EC2_Service: assert not snapshot.encrypted assert snapshot.public + @mock_aws + def test_snapshot_limit_exposes_only_selected_snapshots(self): + ec2_client = client("ec2", region_name=AWS_REGION_US_EAST_1) + ec2_resource = resource("ec2", region_name=AWS_REGION_US_EAST_1) + volume_id = ec2_resource.create_volume( + AvailabilityZone="us-east-1a", + Size=80, + VolumeType="gp2", + ).id + for _ in range(3): + ec2_client.create_snapshot(VolumeId=volume_id) + aws_provider = set_mocked_aws_provider( + [AWS_REGION_US_EAST_1], audit_config={"max_ebs_snapshots": 1} + ) + ec2 = EC2(aws_provider) + + assert len(ec2.snapshots) == 1 + # Test EC2 Instance User Data @mock_aws def test_get_instance_user_data(self): diff --git a/tests/providers/aws/services/ecs/ecs_service_test.py b/tests/providers/aws/services/ecs/ecs_service_test.py index b472e94b88..0427113867 100644 --- a/tests/providers/aws/services/ecs/ecs_service_test.py +++ b/tests/providers/aws/services/ecs/ecs_service_test.py @@ -3,7 +3,11 @@ from unittest.mock import patch import botocore from prowler.providers.aws.services.ecs.ecs_service import ECS -from tests.providers.aws.utils import AWS_REGION_EU_WEST_1, set_mocked_aws_provider +from tests.providers.aws.utils import ( + AWS_REGION_EU_WEST_1, + AWS_REGION_US_EAST_1, + set_mocked_aws_provider, +) make_api_call = botocore.client.BaseClient._make_api_call @@ -115,6 +119,23 @@ def mock_generate_regional_clients(provider, service): return {AWS_REGION_EU_WEST_1: regional_client} +def mock_generate_multi_region_clients(provider, service): + eu_west_1_client = provider._session.current_session.client( + service, region_name=AWS_REGION_EU_WEST_1 + ) + eu_west_1_client.region = AWS_REGION_EU_WEST_1 + + us_east_1_client = provider._session.current_session.client( + service, region_name=AWS_REGION_US_EAST_1 + ) + us_east_1_client.region = AWS_REGION_US_EAST_1 + + return { + AWS_REGION_EU_WEST_1: eu_west_1_client, + AWS_REGION_US_EAST_1: us_east_1_client, + } + + @patch( "prowler.providers.aws.aws_provider.AwsProvider.generate_regional_clients", new=mock_generate_regional_clients, @@ -139,7 +160,6 @@ class Test_ECS_Service: ecs = ECS(aws_provider) assert ecs.session.__class__.__name__ == "Session" - # Test list ECS task definitions @patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call) def test_list_task_definitions(self): aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1]) @@ -201,6 +221,169 @@ class Test_ECS_Service: .readonly_rootfilesystem ) + def test_task_definitions_are_loaded_once_for_analysis(self): + describe_calls = [] + list_calls = [] + + def counting_make_api_call(self, operation_name, kwarg): + if operation_name == "ListTaskDefinitions": + list_calls.append(kwarg) + return { + "taskDefinitionArns": [ + f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}" + for i in (3, 2, 1) + ] + } + if operation_name == "DescribeTaskDefinition": + describe_calls.append(kwarg["taskDefinition"]) + return { + "taskDefinition": { + "containerDefinitions": [], + "networkMode": "bridge", + "pidMode": "", + "tags": [], + } + } + return make_api_call(self, operation_name, kwarg) + + with patch( + "botocore.client.BaseClient._make_api_call", new=counting_make_api_call + ): + aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1]) + ecs = ECS(aws_provider) + + assert [td.revision for td in ecs.task_definitions.values()] == [ + "3", + "2", + "1", + ] + assert list_calls == [{"sort": "DESC"}] + assert len(describe_calls) == 3 + + def test_task_definition_limit_exposes_only_selected_resources(self): + describe_calls = [] + + def counting_make_api_call(self, operation_name, kwarg): + if operation_name == "ListTaskDefinitions": + return { + "taskDefinitionArns": [ + f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}" + for i in (3, 2, 1) + ] + } + if operation_name == "DescribeTaskDefinition": + describe_calls.append(kwarg["taskDefinition"]) + return { + "taskDefinition": { + "containerDefinitions": [], + "networkMode": "bridge", + "pidMode": "", + "tags": [], + } + } + return make_api_call(self, operation_name, kwarg) + + with patch( + "botocore.client.BaseClient._make_api_call", new=counting_make_api_call + ): + aws_provider = set_mocked_aws_provider( + [AWS_REGION_EU_WEST_1], audit_config={"max_ecs_task_definitions": 2} + ) + ecs = ECS(aws_provider) + + assert [td.revision for td in ecs.task_definitions.values()] == ["3", "2"] + assert len(describe_calls) == 2 + + def test_task_definition_limit_bounds_describe_calls(self): + describe_calls = [] + + def counting_make_api_call(self, operation_name, kwarg): + if operation_name == "ListTaskDefinitions": + return { + "taskDefinitionArns": [ + f"arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:{i}" + for i in (3, 2, 1) + ] + } + if operation_name == "DescribeTaskDefinition": + describe_calls.append(kwarg["taskDefinition"]) + return { + "taskDefinition": { + "containerDefinitions": [], + "networkMode": "bridge", + "pidMode": "", + "tags": [], + } + } + return mock_make_api_call(self, operation_name, kwarg) + + with patch( + "botocore.client.BaseClient._make_api_call", new=counting_make_api_call + ): + aws_provider = set_mocked_aws_provider( + [AWS_REGION_EU_WEST_1], audit_config={"max_ecs_task_definitions": 1} + ) + ecs = ECS(aws_provider) + + assert [td.revision for td in ecs.task_definitions.values()] == ["3"] + assert describe_calls == [ + "arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:3" + ] + + def test_task_definition_limit_does_not_starve_later_regions(self): + describe_calls = [] + + def counting_make_api_call(self, operation_name, kwarg): + region = self.meta.region_name + if operation_name == "ListTaskDefinitions": + task_definition_revisions = { + AWS_REGION_EU_WEST_1: (3, 2, 1), + AWS_REGION_US_EAST_1: (9,), + }[region] + return { + "taskDefinitionArns": [ + f"arn:aws:ecs:{region}:123456789012:task-definition/fam:{revision}" + for revision in task_definition_revisions + ] + } + if operation_name == "DescribeTaskDefinition": + describe_calls.append(kwarg["taskDefinition"]) + return { + "taskDefinition": { + "containerDefinitions": [], + "networkMode": "bridge", + "pidMode": "", + "tags": [], + } + } + if operation_name == "ListClusters": + return {"clusterArns": []} + return mock_make_api_call(self, operation_name, kwarg) + + with ( + patch( + "prowler.providers.aws.aws_provider.AwsProvider.generate_regional_clients", + new=mock_generate_multi_region_clients, + ), + patch( + "botocore.client.BaseClient._make_api_call", new=counting_make_api_call + ), + ): + aws_provider = set_mocked_aws_provider( + [AWS_REGION_EU_WEST_1, AWS_REGION_US_EAST_1], + audit_config={"max_ecs_task_definitions": 2}, + ) + ecs = ECS(aws_provider) + + assert [td.region for td in ecs.task_definitions.values()] == [ + AWS_REGION_EU_WEST_1, + AWS_REGION_US_EAST_1, + ] + assert set(describe_calls) == { + "arn:aws:ecs:eu-west-1:123456789012:task-definition/fam:3", + "arn:aws:ecs:us-east-1:123456789012:task-definition/fam:9", + } + # Test list ECS clusters @patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call) def test_list_clusters(self): diff --git a/tests/providers/aws/services/inspector2/inspector2_is_enabled/inspector2_is_enabled_test.py b/tests/providers/aws/services/inspector2/inspector2_is_enabled/inspector2_is_enabled_test.py index 2ea7702cd3..66eb2ceaa3 100644 --- a/tests/providers/aws/services/inspector2/inspector2_is_enabled/inspector2_is_enabled_test.py +++ b/tests/providers/aws/services/inspector2/inspector2_is_enabled/inspector2_is_enabled_test.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest import mock from prowler.providers.aws.services.inspector2.inspector2_service import Inspector @@ -13,6 +14,65 @@ FINDING_ARN = ( class Test_inspector2_is_enabled: + def test_lambda_disabled_with_region_hidden_by_function_analysis_limit(self): + inspector2_client = mock.MagicMock() + inspector2_client.provider = SimpleNamespace(scan_unused_services=False) + inspector2_client.inspectors = [ + Inspector( + id=AWS_ACCOUNT_NUMBER, + arn=f"arn:aws:inspector2:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:inspector2", + status="ENABLED", + ec2_status="ENABLED", + ecr_status="ENABLED", + lambda_status="DISABLED", + lambda_code_status="ENABLED", + region=AWS_REGION_EU_WEST_1, + ) + ] + awslambda_client = mock.MagicMock() + awslambda_client.functions = {} + awslambda_client.regions_with_functions = {AWS_REGION_EU_WEST_1} + ec2_client = mock.MagicMock() + ec2_client.instances = [] + ecr_client = mock.MagicMock() + ecr_client.registries = {AWS_REGION_EU_WEST_1: SimpleNamespace(repositories=[])} + aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1]) + + with ( + mock.patch( + "prowler.providers.common.provider.Provider.get_global_provider", + return_value=aws_provider, + ), + mock.patch( + "prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.inspector2_client", + new=inspector2_client, + ), + mock.patch( + "prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.awslambda_client", + new=awslambda_client, + ), + mock.patch( + "prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.ec2_client", + new=ec2_client, + ), + mock.patch( + "prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled.ecr_client", + new=ecr_client, + ), + ): + from prowler.providers.aws.services.inspector2.inspector2_is_enabled.inspector2_is_enabled import ( + inspector2_is_enabled, + ) + + result = inspector2_is_enabled().execute() + + assert len(result) == 1 + assert result[0].status == "FAIL" + assert ( + result[0].status_extended + == "Inspector2 is not enabled for the following services: Lambda." + ) + def test_inspector2_disabled(self): # Mock the inspector2 client inspector2_client = mock.MagicMock