mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-05-06 08:47:18 +00:00
fix(boto3): pass config to clients (#10944)
This commit is contained in:
@@ -22,6 +22,7 @@ All notable changes to the **Prowler SDK** are documented in this file.
|
||||
### 🐞 Fixed
|
||||
|
||||
- AWS SDK test isolation: autouse `mock_aws` fixture and leak detector in `conftest.py` to prevent tests from hitting real AWS endpoints, with idempotent organization setup for tests calling `set_mocked_aws_provider` multiple times [(#10605)](https://github.com/prowler-cloud/prowler/pull/10605)
|
||||
- AWS `boto` user agent extra is now applied to every client [(#10944)](https://github.com/prowler-cloud/prowler/pull/10944)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ from prowler.lib.utils.utils import open_file, parse_json_file, print_boxes
|
||||
from prowler.providers.aws.config import (
|
||||
AWS_REGION_US_EAST_1,
|
||||
AWS_STS_GLOBAL_ENDPOINT_REGION,
|
||||
BOTO3_USER_AGENT_EXTRA,
|
||||
ROLE_SESSION_NAME,
|
||||
get_default_session_config,
|
||||
)
|
||||
from prowler.providers.aws.exceptions.exceptions import (
|
||||
AWSAccessKeyIDInvalidError,
|
||||
@@ -227,14 +227,15 @@ class AwsProvider(Provider):
|
||||
|
||||
# TODO: Use AwsSetUpSession ?????
|
||||
# Configure the initial AWS Session using the local credentials: profile or environment variables
|
||||
session_config = self.set_session_config(retries_max_attempts)
|
||||
aws_session = self.setup_session(
|
||||
mfa=mfa,
|
||||
profile=profile,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
session_config=session_config,
|
||||
)
|
||||
session_config = self.set_session_config(retries_max_attempts)
|
||||
# Current session and the original session points to the same session object until we get a new one, if needed
|
||||
self._session = AWSSession(
|
||||
current_session=aws_session,
|
||||
@@ -630,6 +631,7 @@ class AwsProvider(Provider):
|
||||
aws_access_key_id: str = None,
|
||||
aws_secret_access_key: str = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
session_config: Optional[Config] = None,
|
||||
) -> Session:
|
||||
"""
|
||||
setup_session sets up an AWS session using the provided credentials.
|
||||
@@ -640,6 +642,9 @@ class AwsProvider(Provider):
|
||||
- aws_access_key_id: The AWS access key ID.
|
||||
- aws_secret_access_key: The AWS secret access key.
|
||||
- aws_session_token: The AWS session token, optional.
|
||||
- session_config: Botocore Config applied as the session's default
|
||||
client config so every client created from the session inherits
|
||||
the Prowler user agent and retry settings.
|
||||
|
||||
Returns:
|
||||
- Session: The AWS session.
|
||||
@@ -650,6 +655,9 @@ class AwsProvider(Provider):
|
||||
try:
|
||||
logger.debug("Creating original session ...")
|
||||
|
||||
if session_config is None:
|
||||
session_config = AwsProvider.set_session_config(None)
|
||||
|
||||
session_arguments = {}
|
||||
if profile:
|
||||
session_arguments["profile_name"] = profile
|
||||
@@ -661,6 +669,7 @@ class AwsProvider(Provider):
|
||||
|
||||
if mfa:
|
||||
session = Session(**session_arguments)
|
||||
session._session.set_default_client_config(session_config)
|
||||
sts_client = session.client("sts")
|
||||
|
||||
# TODO: pass values from the input
|
||||
@@ -673,7 +682,7 @@ class AwsProvider(Provider):
|
||||
session_credentials = sts_client.get_session_token(
|
||||
**get_session_token_arguments
|
||||
)
|
||||
return Session(
|
||||
mfa_session = Session(
|
||||
aws_access_key_id=session_credentials["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=session_credentials["Credentials"][
|
||||
"SecretAccessKey"
|
||||
@@ -682,8 +691,12 @@ class AwsProvider(Provider):
|
||||
"SessionToken"
|
||||
],
|
||||
)
|
||||
mfa_session._session.set_default_client_config(session_config)
|
||||
return mfa_session
|
||||
else:
|
||||
return Session(**session_arguments)
|
||||
session = Session(**session_arguments)
|
||||
session._session.set_default_client_config(session_config)
|
||||
return session
|
||||
except Exception as error:
|
||||
logger.critical(
|
||||
f"AWSSetUpSessionError[{error.__traceback__.tb_lineno}]: {error}"
|
||||
@@ -698,6 +711,7 @@ class AwsProvider(Provider):
|
||||
identity: AWSIdentityInfo,
|
||||
assumed_role_configuration: AWSAssumeRoleConfiguration,
|
||||
session: AWSSession,
|
||||
session_config: Optional[Config] = None,
|
||||
) -> Session:
|
||||
"""
|
||||
Sets up an assumed session using the provided assumed role credentials.
|
||||
@@ -742,6 +756,13 @@ class AwsProvider(Provider):
|
||||
assumed_session = BotocoreSession()
|
||||
assumed_session._credentials = assumed_refreshable_credentials
|
||||
assumed_session.set_config_variable("region", identity.profile_region)
|
||||
if session_config is None:
|
||||
session_config = (
|
||||
session.session_config
|
||||
if session is not None
|
||||
else AwsProvider.set_session_config(None)
|
||||
)
|
||||
assumed_session.set_default_client_config(session_config)
|
||||
return Session(
|
||||
profile_name=identity.profile,
|
||||
botocore_session=assumed_session,
|
||||
@@ -870,7 +891,7 @@ class AwsProvider(Provider):
|
||||
|
||||
for region in enabled_regions:
|
||||
regional_client = self._session.current_session.client(
|
||||
service, region_name=region, config=self._session.session_config
|
||||
service, region_name=region
|
||||
)
|
||||
regional_client.region = region
|
||||
regional_clients[region] = regional_client
|
||||
@@ -1140,21 +1161,16 @@ class AwsProvider(Provider):
|
||||
Returns:
|
||||
- Config: The botocore Config object
|
||||
"""
|
||||
# Set the maximum retries for the standard retrier config
|
||||
default_session_config = Config(
|
||||
retries={"max_attempts": 3, "mode": "standard"},
|
||||
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
|
||||
)
|
||||
default_session_config = get_default_session_config()
|
||||
if retries_max_attempts:
|
||||
# Create the new config
|
||||
config = Config(
|
||||
retries={
|
||||
"max_attempts": retries_max_attempts,
|
||||
"mode": "standard",
|
||||
},
|
||||
default_session_config = default_session_config.merge(
|
||||
Config(
|
||||
retries={
|
||||
"max_attempts": retries_max_attempts,
|
||||
"mode": "standard",
|
||||
},
|
||||
)
|
||||
)
|
||||
# Merge the new configuration
|
||||
default_session_config = default_session_config.merge(config)
|
||||
|
||||
return default_session_config
|
||||
|
||||
@@ -1425,6 +1441,9 @@ class AwsProvider(Provider):
|
||||
region_name=aws_region,
|
||||
profile_name=profile,
|
||||
)
|
||||
session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(None)
|
||||
)
|
||||
|
||||
caller_identity = AwsProvider.validate_credentials(session, aws_region)
|
||||
# Do an extra validation if the AWS account ID is provided
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
import os
|
||||
|
||||
from botocore.config import Config
|
||||
|
||||
AWS_STS_GLOBAL_ENDPOINT_REGION = "us-east-1"
|
||||
AWS_REGION_US_EAST_1 = "us-east-1"
|
||||
BOTO3_USER_AGENT_EXTRA = os.getenv("PROWLER_AWS_BOTO3_USER_AGENT_EXTRA", "APN_1826889")
|
||||
ROLE_SESSION_NAME = "ProwlerAssessmentSession"
|
||||
|
||||
|
||||
def get_default_session_config() -> Config:
|
||||
return Config(
|
||||
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
|
||||
retries={"max_attempts": 3, "mode": "standard"},
|
||||
)
|
||||
|
||||
@@ -56,9 +56,7 @@ def quick_inventory(provider: AwsProvider, args):
|
||||
try:
|
||||
# Scan IAM only once
|
||||
if not iam_was_scanned:
|
||||
global_resources.extend(
|
||||
get_iam_resources(provider.session.current_session)
|
||||
)
|
||||
global_resources.extend(get_iam_resources(provider))
|
||||
iam_was_scanned = True
|
||||
|
||||
# Get regional S3 buckets since none-tagged buckets are not supported by the resourcegroupstaggingapi
|
||||
@@ -312,8 +310,8 @@ def create_output(resources: list, provider: AwsProvider, args):
|
||||
if args.output_bucket:
|
||||
output_bucket = args.output_bucket
|
||||
bucket_session = provider.session.current_session
|
||||
# Check if -D was input
|
||||
elif args.output_bucket_no_assume:
|
||||
# The outer condition guarantees -D was input when -B was not
|
||||
else:
|
||||
output_bucket = args.output_bucket_no_assume
|
||||
bucket_session = provider.session.original_session
|
||||
|
||||
@@ -375,9 +373,9 @@ def get_regional_buckets(provider: AwsProvider, region: str) -> list:
|
||||
return regional_buckets
|
||||
|
||||
|
||||
def get_iam_resources(session) -> list:
|
||||
def get_iam_resources(provider: AwsProvider) -> list:
|
||||
iam_resources = []
|
||||
iam_client = session.client("iam")
|
||||
iam_client = provider.session.current_session.client("iam")
|
||||
try:
|
||||
get_roles_paginator = iam_client.get_paginator("list_roles")
|
||||
for page in get_roles_paginator.paginate():
|
||||
|
||||
@@ -111,6 +111,13 @@ class S3:
|
||||
- None
|
||||
"""
|
||||
if session:
|
||||
# Preserve the caller's existing default config (and the
|
||||
# retries_max_attempts already baked into it) instead of clobbering
|
||||
# it with a freshly built one.
|
||||
if session._session.get_default_client_config() is None:
|
||||
session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(retries_max_attempts)
|
||||
)
|
||||
self._session = session.client(__class__.__name__.lower())
|
||||
else:
|
||||
aws_setup_session = AwsSetUpSession(
|
||||
@@ -127,8 +134,7 @@ class S3:
|
||||
regions=regions,
|
||||
)
|
||||
self._session = aws_setup_session._session.current_session.client(
|
||||
__class__.__name__.lower(),
|
||||
config=aws_setup_session._session.session_config,
|
||||
__class__.__name__.lower()
|
||||
)
|
||||
|
||||
self._bucket_name = bucket_name
|
||||
@@ -313,6 +319,9 @@ class S3:
|
||||
region_name=aws_region,
|
||||
profile_name=profile,
|
||||
)
|
||||
session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(None)
|
||||
)
|
||||
s3_client = session.client(__class__.__name__.lower())
|
||||
if "s3://" in bucket_name:
|
||||
bucket_name = bucket_name.removeprefix("s3://")
|
||||
|
||||
@@ -148,6 +148,13 @@ class SecurityHub:
|
||||
regions=regions,
|
||||
)
|
||||
self._session = aws_setup_session._session.current_session
|
||||
# Only install the Prowler default config when the caller-supplied
|
||||
# session does not already carry one — overwriting would drop the
|
||||
# provider's retries_max_attempts value.
|
||||
if aws_session and self._session._session.get_default_client_config() is None:
|
||||
self._session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(retries_max_attempts)
|
||||
)
|
||||
self._aws_account_id = aws_account_id
|
||||
if not aws_partition:
|
||||
aws_partition = AwsProvider.validate_credentials(
|
||||
@@ -235,7 +242,7 @@ class SecurityHub:
|
||||
|
||||
Args:
|
||||
region (str): AWS region to check.
|
||||
session (Session): AWS session object.
|
||||
session (Session): AWS session object. Expected to carry the Prowler default client config.
|
||||
aws_account_id (str): AWS account ID.
|
||||
aws_partition (str): AWS partition.
|
||||
|
||||
@@ -540,6 +547,9 @@ class SecurityHub:
|
||||
region_name=aws_region,
|
||||
profile_name=profile,
|
||||
)
|
||||
session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(None)
|
||||
)
|
||||
|
||||
all_regions = AwsProvider.get_available_aws_service_regions(
|
||||
service="securityhub", partition=aws_partition
|
||||
|
||||
@@ -32,7 +32,13 @@ class AWSService:
|
||||
def is_failed_check(cls, check_id, arn):
|
||||
return (check_id.split(".")[-1], arn) in cls.failed_checks
|
||||
|
||||
def __init__(self, service: str, provider: AwsProvider, global_service=False):
|
||||
def __init__(
|
||||
self,
|
||||
service: str,
|
||||
provider: AwsProvider,
|
||||
global_service=False,
|
||||
region: str = None,
|
||||
):
|
||||
# Audit Information
|
||||
# Do we need to store the whole provider?
|
||||
self.provider = provider
|
||||
@@ -61,7 +67,7 @@ class AWSService:
|
||||
# Get a single region and client if the service needs it (e.g. AWS Global Service)
|
||||
# We cannot include this within an else because some services needs both the regional_clients
|
||||
# and a single client like S3
|
||||
self.region = provider.get_default_region(
|
||||
self.region = region or provider.get_default_region(
|
||||
self.service, global_service=global_service
|
||||
)
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
|
||||
@@ -73,15 +73,15 @@ class AwsSetUpSession:
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
# Setup the AWS session
|
||||
session_config = AwsProvider.set_session_config(retries_max_attempts)
|
||||
aws_session = AwsProvider.setup_session(
|
||||
mfa=mfa,
|
||||
profile=profile,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
session_config=session_config,
|
||||
)
|
||||
session_config = AwsProvider.set_session_config(retries_max_attempts)
|
||||
self._session = AWSSession(
|
||||
current_session=aws_session,
|
||||
session_config=session_config,
|
||||
|
||||
@@ -9,15 +9,13 @@ from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
class GlobalAccelerator(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__(__class__.__name__, provider)
|
||||
# Global Accelerator is a global service that supports endpoints in multiple AWS Regions
|
||||
# but you must specify the US West (Oregon) Region to create, update, or otherwise work with accelerators.
|
||||
# That is, for example, specify --region us-west-2 on AWS CLI commands.
|
||||
region = "us-west-2" if provider.identity.partition == "aws" else None
|
||||
super().__init__(__class__.__name__, provider, region=region)
|
||||
self.accelerators = {}
|
||||
if self.audited_partition == "aws":
|
||||
# Global Accelerator is a global service that supports endpoints in multiple AWS Regions
|
||||
# but you must specify the US West (Oregon) Region to create, update, or otherwise work with accelerators.
|
||||
# That is, for example, specify --region us-west-2 on AWS CLI commands.
|
||||
self.region = "us-west-2"
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
self._list_accelerators()
|
||||
self.__threading_call__(self._list_tags, self.accelerators.values())
|
||||
|
||||
|
||||
@@ -176,14 +176,12 @@ class RecordSet(BaseModel):
|
||||
|
||||
class Route53Domains(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__(__class__.__name__, provider)
|
||||
# Route53Domains is a global service that supports endpoints in multiple AWS Regions
|
||||
# but you must specify the US East (N. Virginia) Region to create, update, or otherwise work with domains.
|
||||
region = "us-east-1" if provider.identity.partition == "aws" else None
|
||||
super().__init__(__class__.__name__, provider, region=region)
|
||||
self.domains = {}
|
||||
if self.audited_partition == "aws":
|
||||
# Route53Domains is a global service that supports endpoints in multiple AWS Regions
|
||||
# but you must specify the US East (N. Virginia) Region to create, update, or otherwise work with domains.
|
||||
self.region = "us-east-1"
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
self._list_domains()
|
||||
self._get_domain_detail()
|
||||
self._list_tags_for_domain()
|
||||
|
||||
@@ -9,20 +9,20 @@ from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
class TrustedAdvisor(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__("support", provider)
|
||||
# Support API is not available in China Partition
|
||||
# But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html
|
||||
partition = provider.identity.partition
|
||||
if partition == "aws":
|
||||
support_region = "us-east-1"
|
||||
elif partition == "aws-cn":
|
||||
support_region = None
|
||||
else:
|
||||
support_region = "us-gov-west-1"
|
||||
super().__init__("support", provider, region=support_region)
|
||||
self.account_arn_template = f"arn:{self.audited_partition}:trusted-advisor:{self.region}:{self.audited_account}:account"
|
||||
self.checks = []
|
||||
self.premium_support = PremiumSupport(enabled=False)
|
||||
# Support API is not available in China Partition
|
||||
# But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html
|
||||
if self.audited_partition != "aws-cn":
|
||||
if self.audited_partition == "aws":
|
||||
support_region = "us-east-1"
|
||||
else:
|
||||
support_region = "us-gov-west-1"
|
||||
self.client = self.session.client(self.service, region_name=support_region)
|
||||
self.client.region = support_region
|
||||
self._describe_services()
|
||||
if getattr(self.premium_support, "enabled", False):
|
||||
self._describe_trusted_advisor_checks()
|
||||
@@ -34,13 +34,13 @@ class TrustedAdvisor(AWSService):
|
||||
for check in self.client.describe_trusted_advisor_checks(language="en").get(
|
||||
"checks", []
|
||||
):
|
||||
check_arn = f"arn:{self.audited_partition}:trusted-advisor:{self.client.region}:{self.audited_account}:check/{check['id']}"
|
||||
check_arn = f"arn:{self.audited_partition}:trusted-advisor:{self.region}:{self.audited_account}:check/{check['id']}"
|
||||
self.checks.append(
|
||||
Check(
|
||||
id=check["id"],
|
||||
name=check["name"],
|
||||
arn=check_arn,
|
||||
region=self.client.region,
|
||||
region=self.region,
|
||||
)
|
||||
)
|
||||
except ClientError as error:
|
||||
@@ -50,22 +50,22 @@ class TrustedAdvisor(AWSService):
|
||||
== "Amazon Web Services Premium Support Subscription is required to use this service."
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _describe_trusted_advisor_check_result(self):
|
||||
logger.info("TrustedAdvisor - Describing Check Result...")
|
||||
try:
|
||||
for check in self.checks:
|
||||
if check.region == self.client.region:
|
||||
if check.region == self.region:
|
||||
try:
|
||||
response = self.client.describe_trusted_advisor_check_result(
|
||||
checkId=check.id
|
||||
@@ -78,11 +78,11 @@ class TrustedAdvisor(AWSService):
|
||||
== "InvalidParameterValueException"
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _describe_services(self):
|
||||
|
||||
@@ -9,15 +9,13 @@ from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
class WAF(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__("waf", provider)
|
||||
# AWS WAF is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL and any resources used in the web ACL, such as rule groups, IP sets, and regex pattern sets.
|
||||
region = "us-east-1" if provider.identity.partition == "aws" else None
|
||||
super().__init__("waf", provider, region=region)
|
||||
self.rules = {}
|
||||
self.rule_groups = {}
|
||||
self.web_acls = {}
|
||||
if self.audited_partition == "aws":
|
||||
# AWS WAF is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL and any resources used in the web ACL, such as rule groups, IP sets, and regex pattern sets.
|
||||
self.region = "us-east-1"
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
self._list_rules()
|
||||
self.__threading_call__(self._get_rule, self.rules.values())
|
||||
self._list_rule_groups()
|
||||
|
||||
@@ -11,13 +11,11 @@ from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
class WAFv2(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__(__class__.__name__, provider)
|
||||
# AWS WAFv2 is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL.
|
||||
region = "us-east-1" if provider.identity.partition == "aws" else None
|
||||
super().__init__(__class__.__name__, provider, region=region)
|
||||
self.web_acls = {}
|
||||
if self.audited_partition == "aws":
|
||||
# AWS WAFv2 is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL.
|
||||
self.region = "us-east-1"
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
self._list_web_acls_global()
|
||||
self.__threading_call__(self._list_web_acls_regional)
|
||||
self.__threading_call__(self._get_web_acl, self.web_acls.values())
|
||||
|
||||
@@ -21,6 +21,7 @@ from prowler.providers.aws.config import (
|
||||
AWS_STS_GLOBAL_ENDPOINT_REGION,
|
||||
BOTO3_USER_AGENT_EXTRA,
|
||||
ROLE_SESSION_NAME,
|
||||
get_default_session_config,
|
||||
)
|
||||
from prowler.providers.aws.exceptions.exceptions import (
|
||||
AWSArgumentTypeValidationError,
|
||||
@@ -2242,6 +2243,12 @@ aws:
|
||||
assert session_config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
|
||||
assert session_config.retries == {"max_attempts": 10, "mode": "standard"}
|
||||
|
||||
def test_get_default_session_config(self):
|
||||
config = get_default_session_config()
|
||||
|
||||
assert config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
|
||||
assert config.retries == {"max_attempts": 3, "mode": "standard"}
|
||||
|
||||
@mock_aws
|
||||
@patch(
|
||||
"prowler.lib.check.utils.recover_checks_from_provider",
|
||||
|
||||
@@ -4,6 +4,8 @@ import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from prowler.providers.aws.aws_provider import AwsProvider
|
||||
from prowler.providers.aws.config import BOTO3_USER_AGENT_EXTRA
|
||||
from prowler.providers.aws.lib.organizations.organizations import (
|
||||
_get_ou_metadata,
|
||||
get_organizations_metadata,
|
||||
@@ -222,6 +224,20 @@ class Test_AWS_Organizations:
|
||||
assert tags == {}
|
||||
assert ou_metadata == {}
|
||||
|
||||
def test_get_organizations_metadata_uses_user_agent_extra(self):
|
||||
real_session = boto3.Session()
|
||||
real_session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(None)
|
||||
)
|
||||
wrapper = MagicMock(wraps=real_session)
|
||||
|
||||
get_organizations_metadata("123456789012", wrapper)
|
||||
|
||||
wrapper.client.assert_called_once()
|
||||
default_config = real_session._session.get_default_client_config()
|
||||
assert default_config is not None
|
||||
assert BOTO3_USER_AGENT_EXTRA in default_config.user_agent_extra
|
||||
|
||||
def test_parse_organizations_metadata_with_empty_ou_metadata(self):
|
||||
tags = {"Tags": []}
|
||||
metadata = {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from mock import patch
|
||||
|
||||
from prowler.providers.aws.config import BOTO3_USER_AGENT_EXTRA
|
||||
from prowler.providers.aws.lib.service.service import AWSService
|
||||
from tests.providers.aws.utils import (
|
||||
AWS_ACCOUNT_ARN,
|
||||
@@ -189,6 +190,15 @@ class TestAWSService:
|
||||
== f"arn:{service.audited_partition}:{service_name}::{AWS_ACCOUNT_NUMBER}:bucket/unknown"
|
||||
)
|
||||
|
||||
def test_AWSService_clients_carry_user_agent_extra(self):
|
||||
provider = set_mocked_aws_provider()
|
||||
|
||||
service = AWSService("s3", provider)
|
||||
ad_hoc_client = service.session.client("ec2", AWS_REGION_US_EAST_1)
|
||||
|
||||
assert BOTO3_USER_AGENT_EXTRA in service.client._client_config.user_agent_extra
|
||||
assert BOTO3_USER_AGENT_EXTRA in ad_hoc_client._client_config.user_agent_extra
|
||||
|
||||
def test_AWSService_get_unknown_arn_resource_type_set_region(self):
|
||||
service_name = "s3"
|
||||
provider = set_mocked_aws_provider()
|
||||
|
||||
@@ -2,7 +2,6 @@ from argparse import Namespace
|
||||
from json import dumps
|
||||
|
||||
from boto3 import client, session
|
||||
from botocore.config import Config
|
||||
from moto import mock_aws
|
||||
|
||||
from prowler.config.config import (
|
||||
@@ -133,10 +132,11 @@ def set_mocked_aws_provider(
|
||||
provider = AwsProvider()
|
||||
|
||||
# Mock Session
|
||||
provider._session.session_config = None
|
||||
session_config = AwsProvider.set_session_config(None)
|
||||
provider._session.session_config = session_config
|
||||
provider._session.original_session = original_session
|
||||
provider._session.current_session = audit_session
|
||||
provider._session.session_config = Config()
|
||||
audit_session._session.set_default_client_config(session_config)
|
||||
# Mock Identity
|
||||
provider._identity.account = audited_account
|
||||
provider._identity.account_arn = audited_account_arn
|
||||
|
||||
Reference in New Issue
Block a user