fix(boto3): pass config to clients (#10944)

This commit is contained in:
Pepe Fagoaga
2026-04-30 14:11:29 +02:00
committed by GitHub
parent e821e07d7d
commit 36b8aa1b79
17 changed files with 153 additions and 76 deletions
+1
View File
@@ -22,6 +22,7 @@ All notable changes to the **Prowler SDK** are documented in this file.
### 🐞 Fixed
- AWS SDK test isolation: autouse `mock_aws` fixture and leak detector in `conftest.py` to prevent tests from hitting real AWS endpoints, with idempotent organization setup for tests calling `set_mocked_aws_provider` multiple times [(#10605)](https://github.com/prowler-cloud/prowler/pull/10605)
- AWS `boto` user agent extra is now applied to every client [(#10944)](https://github.com/prowler-cloud/prowler/pull/10944)
### 🔐 Security
+37 -18
View File
@@ -25,8 +25,8 @@ from prowler.lib.utils.utils import open_file, parse_json_file, print_boxes
from prowler.providers.aws.config import (
AWS_REGION_US_EAST_1,
AWS_STS_GLOBAL_ENDPOINT_REGION,
BOTO3_USER_AGENT_EXTRA,
ROLE_SESSION_NAME,
get_default_session_config,
)
from prowler.providers.aws.exceptions.exceptions import (
AWSAccessKeyIDInvalidError,
@@ -227,14 +227,15 @@ class AwsProvider(Provider):
# TODO: Use AwsSetUpSession ?????
# Configure the initial AWS Session using the local credentials: profile or environment variables
session_config = self.set_session_config(retries_max_attempts)
aws_session = self.setup_session(
mfa=mfa,
profile=profile,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
session_config=session_config,
)
session_config = self.set_session_config(retries_max_attempts)
# Current session and the original session points to the same session object until we get a new one, if needed
self._session = AWSSession(
current_session=aws_session,
@@ -630,6 +631,7 @@ class AwsProvider(Provider):
aws_access_key_id: str = None,
aws_secret_access_key: str = None,
aws_session_token: Optional[str] = None,
session_config: Optional[Config] = None,
) -> Session:
"""
setup_session sets up an AWS session using the provided credentials.
@@ -640,6 +642,9 @@ class AwsProvider(Provider):
- aws_access_key_id: The AWS access key ID.
- aws_secret_access_key: The AWS secret access key.
- aws_session_token: The AWS session token, optional.
- session_config: Botocore Config applied as the session's default
client config so every client created from the session inherits
the Prowler user agent and retry settings.
Returns:
- Session: The AWS session.
@@ -650,6 +655,9 @@ class AwsProvider(Provider):
try:
logger.debug("Creating original session ...")
if session_config is None:
session_config = AwsProvider.set_session_config(None)
session_arguments = {}
if profile:
session_arguments["profile_name"] = profile
@@ -661,6 +669,7 @@ class AwsProvider(Provider):
if mfa:
session = Session(**session_arguments)
session._session.set_default_client_config(session_config)
sts_client = session.client("sts")
# TODO: pass values from the input
@@ -673,7 +682,7 @@ class AwsProvider(Provider):
session_credentials = sts_client.get_session_token(
**get_session_token_arguments
)
return Session(
mfa_session = Session(
aws_access_key_id=session_credentials["Credentials"]["AccessKeyId"],
aws_secret_access_key=session_credentials["Credentials"][
"SecretAccessKey"
@@ -682,8 +691,12 @@ class AwsProvider(Provider):
"SessionToken"
],
)
mfa_session._session.set_default_client_config(session_config)
return mfa_session
else:
return Session(**session_arguments)
session = Session(**session_arguments)
session._session.set_default_client_config(session_config)
return session
except Exception as error:
logger.critical(
f"AWSSetUpSessionError[{error.__traceback__.tb_lineno}]: {error}"
@@ -698,6 +711,7 @@ class AwsProvider(Provider):
identity: AWSIdentityInfo,
assumed_role_configuration: AWSAssumeRoleConfiguration,
session: AWSSession,
session_config: Optional[Config] = None,
) -> Session:
"""
Sets up an assumed session using the provided assumed role credentials.
@@ -742,6 +756,13 @@ class AwsProvider(Provider):
assumed_session = BotocoreSession()
assumed_session._credentials = assumed_refreshable_credentials
assumed_session.set_config_variable("region", identity.profile_region)
if session_config is None:
session_config = (
session.session_config
if session is not None
else AwsProvider.set_session_config(None)
)
assumed_session.set_default_client_config(session_config)
return Session(
profile_name=identity.profile,
botocore_session=assumed_session,
@@ -870,7 +891,7 @@ class AwsProvider(Provider):
for region in enabled_regions:
regional_client = self._session.current_session.client(
service, region_name=region, config=self._session.session_config
service, region_name=region
)
regional_client.region = region
regional_clients[region] = regional_client
@@ -1140,21 +1161,16 @@ class AwsProvider(Provider):
Returns:
- Config: The botocore Config object
"""
# Set the maximum retries for the standard retrier config
default_session_config = Config(
retries={"max_attempts": 3, "mode": "standard"},
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
)
default_session_config = get_default_session_config()
if retries_max_attempts:
# Create the new config
config = Config(
retries={
"max_attempts": retries_max_attempts,
"mode": "standard",
},
default_session_config = default_session_config.merge(
Config(
retries={
"max_attempts": retries_max_attempts,
"mode": "standard",
},
)
)
# Merge the new configuration
default_session_config = default_session_config.merge(config)
return default_session_config
@@ -1425,6 +1441,9 @@ class AwsProvider(Provider):
region_name=aws_region,
profile_name=profile,
)
session._session.set_default_client_config(
AwsProvider.set_session_config(None)
)
caller_identity = AwsProvider.validate_credentials(session, aws_region)
# Do an extra validation if the AWS account ID is provided
+9
View File
@@ -1,6 +1,15 @@
import os
from botocore.config import Config
AWS_STS_GLOBAL_ENDPOINT_REGION = "us-east-1"
AWS_REGION_US_EAST_1 = "us-east-1"
BOTO3_USER_AGENT_EXTRA = os.getenv("PROWLER_AWS_BOTO3_USER_AGENT_EXTRA", "APN_1826889")
ROLE_SESSION_NAME = "ProwlerAssessmentSession"
def get_default_session_config() -> Config:
return Config(
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
retries={"max_attempts": 3, "mode": "standard"},
)
@@ -56,9 +56,7 @@ def quick_inventory(provider: AwsProvider, args):
try:
# Scan IAM only once
if not iam_was_scanned:
global_resources.extend(
get_iam_resources(provider.session.current_session)
)
global_resources.extend(get_iam_resources(provider))
iam_was_scanned = True
# Get regional S3 buckets since none-tagged buckets are not supported by the resourcegroupstaggingapi
@@ -312,8 +310,8 @@ def create_output(resources: list, provider: AwsProvider, args):
if args.output_bucket:
output_bucket = args.output_bucket
bucket_session = provider.session.current_session
# Check if -D was input
elif args.output_bucket_no_assume:
# The outer condition guarantees -D was input when -B was not
else:
output_bucket = args.output_bucket_no_assume
bucket_session = provider.session.original_session
@@ -375,9 +373,9 @@ def get_regional_buckets(provider: AwsProvider, region: str) -> list:
return regional_buckets
def get_iam_resources(session) -> list:
def get_iam_resources(provider: AwsProvider) -> list:
iam_resources = []
iam_client = session.client("iam")
iam_client = provider.session.current_session.client("iam")
try:
get_roles_paginator = iam_client.get_paginator("list_roles")
for page in get_roles_paginator.paginate():
+11 -2
View File
@@ -111,6 +111,13 @@ class S3:
- None
"""
if session:
# Preserve the caller's existing default config (and the
# retries_max_attempts already baked into it) instead of clobbering
# it with a freshly built one.
if session._session.get_default_client_config() is None:
session._session.set_default_client_config(
AwsProvider.set_session_config(retries_max_attempts)
)
self._session = session.client(__class__.__name__.lower())
else:
aws_setup_session = AwsSetUpSession(
@@ -127,8 +134,7 @@ class S3:
regions=regions,
)
self._session = aws_setup_session._session.current_session.client(
__class__.__name__.lower(),
config=aws_setup_session._session.session_config,
__class__.__name__.lower()
)
self._bucket_name = bucket_name
@@ -313,6 +319,9 @@ class S3:
region_name=aws_region,
profile_name=profile,
)
session._session.set_default_client_config(
AwsProvider.set_session_config(None)
)
s3_client = session.client(__class__.__name__.lower())
if "s3://" in bucket_name:
bucket_name = bucket_name.removeprefix("s3://")
@@ -148,6 +148,13 @@ class SecurityHub:
regions=regions,
)
self._session = aws_setup_session._session.current_session
# Only install the Prowler default config when the caller-supplied
# session does not already carry one — overwriting would drop the
# provider's retries_max_attempts value.
if aws_session and self._session._session.get_default_client_config() is None:
self._session._session.set_default_client_config(
AwsProvider.set_session_config(retries_max_attempts)
)
self._aws_account_id = aws_account_id
if not aws_partition:
aws_partition = AwsProvider.validate_credentials(
@@ -235,7 +242,7 @@ class SecurityHub:
Args:
region (str): AWS region to check.
session (Session): AWS session object.
session (Session): AWS session object. Expected to carry the Prowler default client config.
aws_account_id (str): AWS account ID.
aws_partition (str): AWS partition.
@@ -540,6 +547,9 @@ class SecurityHub:
region_name=aws_region,
profile_name=profile,
)
session._session.set_default_client_config(
AwsProvider.set_session_config(None)
)
all_regions = AwsProvider.get_available_aws_service_regions(
service="securityhub", partition=aws_partition
+8 -2
View File
@@ -32,7 +32,13 @@ class AWSService:
def is_failed_check(cls, check_id, arn):
return (check_id.split(".")[-1], arn) in cls.failed_checks
def __init__(self, service: str, provider: AwsProvider, global_service=False):
def __init__(
self,
service: str,
provider: AwsProvider,
global_service=False,
region: str = None,
):
# Audit Information
# Do we need to store the whole provider?
self.provider = provider
@@ -61,7 +67,7 @@ class AWSService:
# Get a single region and client if the service needs it (e.g. AWS Global Service)
# We cannot include this within an else because some services needs both the regional_clients
# and a single client like S3
self.region = provider.get_default_region(
self.region = region or provider.get_default_region(
self.service, global_service=global_service
)
self.client = self.session.client(self.service, self.region)
@@ -73,15 +73,15 @@ class AwsSetUpSession:
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
# Setup the AWS session
session_config = AwsProvider.set_session_config(retries_max_attempts)
aws_session = AwsProvider.setup_session(
mfa=mfa,
profile=profile,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
session_config=session_config,
)
session_config = AwsProvider.set_session_config(retries_max_attempts)
self._session = AWSSession(
current_session=aws_session,
session_config=session_config,
@@ -9,15 +9,13 @@ from prowler.providers.aws.lib.service.service import AWSService
class GlobalAccelerator(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# Global Accelerator is a global service that supports endpoints in multiple AWS Regions
# but you must specify the US West (Oregon) Region to create, update, or otherwise work with accelerators.
# That is, for example, specify --region us-west-2 on AWS CLI commands.
region = "us-west-2" if provider.identity.partition == "aws" else None
super().__init__(__class__.__name__, provider, region=region)
self.accelerators = {}
if self.audited_partition == "aws":
# Global Accelerator is a global service that supports endpoints in multiple AWS Regions
# but you must specify the US West (Oregon) Region to create, update, or otherwise work with accelerators.
# That is, for example, specify --region us-west-2 on AWS CLI commands.
self.region = "us-west-2"
self.client = self.session.client(self.service, self.region)
self._list_accelerators()
self.__threading_call__(self._list_tags, self.accelerators.values())
@@ -176,14 +176,12 @@ class RecordSet(BaseModel):
class Route53Domains(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# Route53Domains is a global service that supports endpoints in multiple AWS Regions
# but you must specify the US East (N. Virginia) Region to create, update, or otherwise work with domains.
region = "us-east-1" if provider.identity.partition == "aws" else None
super().__init__(__class__.__name__, provider, region=region)
self.domains = {}
if self.audited_partition == "aws":
# Route53Domains is a global service that supports endpoints in multiple AWS Regions
# but you must specify the US East (N. Virginia) Region to create, update, or otherwise work with domains.
self.region = "us-east-1"
self.client = self.session.client(self.service, self.region)
self._list_domains()
self._get_domain_detail()
self._list_tags_for_domain()
@@ -9,20 +9,20 @@ from prowler.providers.aws.lib.service.service import AWSService
class TrustedAdvisor(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__("support", provider)
# Support API is not available in China Partition
# But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html
partition = provider.identity.partition
if partition == "aws":
support_region = "us-east-1"
elif partition == "aws-cn":
support_region = None
else:
support_region = "us-gov-west-1"
super().__init__("support", provider, region=support_region)
self.account_arn_template = f"arn:{self.audited_partition}:trusted-advisor:{self.region}:{self.audited_account}:account"
self.checks = []
self.premium_support = PremiumSupport(enabled=False)
# Support API is not available in China Partition
# But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html
if self.audited_partition != "aws-cn":
if self.audited_partition == "aws":
support_region = "us-east-1"
else:
support_region = "us-gov-west-1"
self.client = self.session.client(self.service, region_name=support_region)
self.client.region = support_region
self._describe_services()
if getattr(self.premium_support, "enabled", False):
self._describe_trusted_advisor_checks()
@@ -34,13 +34,13 @@ class TrustedAdvisor(AWSService):
for check in self.client.describe_trusted_advisor_checks(language="en").get(
"checks", []
):
check_arn = f"arn:{self.audited_partition}:trusted-advisor:{self.client.region}:{self.audited_account}:check/{check['id']}"
check_arn = f"arn:{self.audited_partition}:trusted-advisor:{self.region}:{self.audited_account}:check/{check['id']}"
self.checks.append(
Check(
id=check["id"],
name=check["name"],
arn=check_arn,
region=self.client.region,
region=self.region,
)
)
except ClientError as error:
@@ -50,22 +50,22 @@ class TrustedAdvisor(AWSService):
== "Amazon Web Services Premium Support Subscription is required to use this service."
):
logger.warning(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
logger.error(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _describe_trusted_advisor_check_result(self):
logger.info("TrustedAdvisor - Describing Check Result...")
try:
for check in self.checks:
if check.region == self.client.region:
if check.region == self.region:
try:
response = self.client.describe_trusted_advisor_check_result(
checkId=check.id
@@ -78,11 +78,11 @@ class TrustedAdvisor(AWSService):
== "InvalidParameterValueException"
):
logger.warning(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _describe_services(self):
@@ -9,15 +9,13 @@ from prowler.providers.aws.lib.service.service import AWSService
class WAF(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__("waf", provider)
# AWS WAF is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL and any resources used in the web ACL, such as rule groups, IP sets, and regex pattern sets.
region = "us-east-1" if provider.identity.partition == "aws" else None
super().__init__("waf", provider, region=region)
self.rules = {}
self.rule_groups = {}
self.web_acls = {}
if self.audited_partition == "aws":
# AWS WAF is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL and any resources used in the web ACL, such as rule groups, IP sets, and regex pattern sets.
self.region = "us-east-1"
self.client = self.session.client(self.service, self.region)
self._list_rules()
self.__threading_call__(self._get_rule, self.rules.values())
self._list_rule_groups()
@@ -11,13 +11,11 @@ from prowler.providers.aws.lib.service.service import AWSService
class WAFv2(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# AWS WAFv2 is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL.
region = "us-east-1" if provider.identity.partition == "aws" else None
super().__init__(__class__.__name__, provider, region=region)
self.web_acls = {}
if self.audited_partition == "aws":
# AWS WAFv2 is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL.
self.region = "us-east-1"
self.client = self.session.client(self.service, self.region)
self._list_web_acls_global()
self.__threading_call__(self._list_web_acls_regional)
self.__threading_call__(self._get_web_acl, self.web_acls.values())
+7
View File
@@ -21,6 +21,7 @@ from prowler.providers.aws.config import (
AWS_STS_GLOBAL_ENDPOINT_REGION,
BOTO3_USER_AGENT_EXTRA,
ROLE_SESSION_NAME,
get_default_session_config,
)
from prowler.providers.aws.exceptions.exceptions import (
AWSArgumentTypeValidationError,
@@ -2242,6 +2243,12 @@ aws:
assert session_config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
assert session_config.retries == {"max_attempts": 10, "mode": "standard"}
def test_get_default_session_config(self):
config = get_default_session_config()
assert config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
assert config.retries == {"max_attempts": 3, "mode": "standard"}
@mock_aws
@patch(
"prowler.lib.check.utils.recover_checks_from_provider",
@@ -4,6 +4,8 @@ import boto3
from botocore.exceptions import ClientError
from moto import mock_aws
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.config import BOTO3_USER_AGENT_EXTRA
from prowler.providers.aws.lib.organizations.organizations import (
_get_ou_metadata,
get_organizations_metadata,
@@ -222,6 +224,20 @@ class Test_AWS_Organizations:
assert tags == {}
assert ou_metadata == {}
def test_get_organizations_metadata_uses_user_agent_extra(self):
real_session = boto3.Session()
real_session._session.set_default_client_config(
AwsProvider.set_session_config(None)
)
wrapper = MagicMock(wraps=real_session)
get_organizations_metadata("123456789012", wrapper)
wrapper.client.assert_called_once()
default_config = real_session._session.get_default_client_config()
assert default_config is not None
assert BOTO3_USER_AGENT_EXTRA in default_config.user_agent_extra
def test_parse_organizations_metadata_with_empty_ou_metadata(self):
tags = {"Tags": []}
metadata = {
@@ -1,5 +1,6 @@
from mock import patch
from prowler.providers.aws.config import BOTO3_USER_AGENT_EXTRA
from prowler.providers.aws.lib.service.service import AWSService
from tests.providers.aws.utils import (
AWS_ACCOUNT_ARN,
@@ -189,6 +190,15 @@ class TestAWSService:
== f"arn:{service.audited_partition}:{service_name}::{AWS_ACCOUNT_NUMBER}:bucket/unknown"
)
def test_AWSService_clients_carry_user_agent_extra(self):
provider = set_mocked_aws_provider()
service = AWSService("s3", provider)
ad_hoc_client = service.session.client("ec2", AWS_REGION_US_EAST_1)
assert BOTO3_USER_AGENT_EXTRA in service.client._client_config.user_agent_extra
assert BOTO3_USER_AGENT_EXTRA in ad_hoc_client._client_config.user_agent_extra
def test_AWSService_get_unknown_arn_resource_type_set_region(self):
service_name = "s3"
provider = set_mocked_aws_provider()
+3 -3
View File
@@ -2,7 +2,6 @@ from argparse import Namespace
from json import dumps
from boto3 import client, session
from botocore.config import Config
from moto import mock_aws
from prowler.config.config import (
@@ -133,10 +132,11 @@ def set_mocked_aws_provider(
provider = AwsProvider()
# Mock Session
provider._session.session_config = None
session_config = AwsProvider.set_session_config(None)
provider._session.session_config = session_config
provider._session.original_session = original_session
provider._session.current_session = audit_session
provider._session.session_config = Config()
audit_session._session.set_default_client_config(session_config)
# Mock Identity
provider._identity.account = audited_account
provider._identity.account_arn = audited_account_arn