refactor(aws): Refactor provider (#4808)

Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
This commit is contained in:
Pedro Martín
2024-08-23 15:19:05 +02:00
committed by GitHub
parent 0f79312c33
commit 2fda2388bb
7 changed files with 305 additions and 239 deletions

View File

@@ -1,6 +1,6 @@
import os
import pathlib
from argparse import ArgumentTypeError, Namespace
from argparse import ArgumentTypeError
from datetime import datetime
from re import fullmatch
@@ -64,34 +64,57 @@ class AwsProvider(Provider):
# TODO: this is not optional, enforce for all providers
audit_metadata: Audit_Metadata
def __init__(self, arguments: Namespace):
def __init__(
self,
retries_max_attempts: int = 3,
role_arn: str = None,
session_duration: int = None,
external_id: str = None,
role_session_name: str = None,
mfa: bool = None,
profile: str = None,
regions: set = set(),
organizations_role_arn: str = None,
scan_unused_services: bool = None,
resource_tags: list[str] = [],
resource_arn: list[str] = [],
config_file: str = None,
fixer_config: str = None,
):
"""
Initializes the AWS provider.
Arguments:
- retries_max_attempts: The maximum number of retries for the AWS client.
- role_arn: The ARN of the IAM role to assume.
- session_duration: The duration of the session in seconds.
- external_id: The external ID to use when assuming the IAM role.
- role_session_name: The name of the session when assuming the IAM role.
- mfa: A boolean indicating whether MFA is enabled.
- profile: The name of the AWS CLI profile to use.
- regions: A set of regions to audit.
- organizations_role_arn: The ARN of the AWS Organizations IAM role to assume.
- scan_unused_services: A boolean indicating whether to scan unused services.
- resource_tags: A list of tags to filter the resources to audit.
- resource_arn: A list of ARNs of the resources to audit.
- config_file: The path to the configuration file.
- fixer_config: The path to the fixer configuration
Raises:
- ArgumentTypeError: If the input MFA ARN is invalid.
- ArgumentTypeError: If the input session duration is invalid.
- ArgumentTypeError: If the input external ID is invalid.
- ArgumentTypeError: If the input role session name is invalid.
"""
logger.info("Initializing AWS provider ...")
######## Parse Arguments
# Session
aws_retries_max_attempts = getattr(arguments, "aws_retries_max_attempts", None)
# Assume Role
input_role = getattr(arguments, "role", None)
input_session_duration = getattr(arguments, "session_duration", None)
input_external_id = getattr(arguments, "external_id", None)
input_role_session_name = getattr(arguments, "role_session_name", None)
# MFA Configuration (false by default)
input_mfa = getattr(arguments, "mfa", None)
input_profile = getattr(arguments, "profile", None)
input_regions = set(getattr(arguments, "region", []) or [])
organizations_role_arn = getattr(arguments, "organizations_role", None)
# Set if unused services must be scanned
scan_unused_services = getattr(arguments, "scan_unused_services", None)
########
######## AWS Session
logger.info("Generating original session ...")
# Configure the initial AWS Session using the local credentials: profile or environment variables
aws_session = self.setup_session(input_mfa, input_profile)
session_config = self.set_session_config(aws_retries_max_attempts)
aws_session = self.setup_session(mfa, profile)
session_config = self.set_session_config(retries_max_attempts)
# Current session and the original session points to the same session object until we get a new one, if needed
self._session = AWSSession(
current_session=aws_session,
@@ -104,10 +127,10 @@ class AwsProvider(Provider):
# After the session is created, validate it
logger.info("Validating credentials ...")
sts_region = get_aws_region_for_sts(
self.session.current_session.region_name, input_regions
self.session.current_session.region_name, regions
)
# Use test_connection to validate the credentials
# Validate the credentials
caller_identity = self.validate_credentials(
session=self.session.current_session,
aws_region=sts_region,
@@ -123,23 +146,23 @@ class AwsProvider(Provider):
# Set identity
self._identity = self.set_identity(
caller_identity=caller_identity,
input_profile=input_profile,
input_regions=input_regions,
profile=profile,
regions=regions,
profile_region=profile_region,
)
########
######## AWS Session with Assume Role (if needed)
if input_role:
if role_arn:
# Validate the input role
valid_role_arn = parse_iam_credentials_arn(input_role)
valid_role_arn = parse_iam_credentials_arn(role_arn)
# Set assume IAM Role information
assumed_role_information = AWSAssumeRoleInfo(
role_arn=valid_role_arn,
session_duration=input_session_duration,
external_id=input_external_id,
mfa_enabled=input_mfa,
role_session_name=input_role_session_name,
session_duration=session_duration,
external_id=external_id,
mfa_enabled=mfa,
role_session_name=role_session_name,
sts_region=sts_region,
)
# Assume the IAM Role
@@ -181,10 +204,10 @@ class AwsProvider(Provider):
# Set assume IAM Role information
organizations_assumed_role_information = AWSAssumeRoleInfo(
role_arn=valid_role_arn,
session_duration=input_session_duration,
external_id=input_external_id,
mfa_enabled=input_mfa,
role_session_name=input_role_session_name,
session_duration=session_duration,
external_id=external_id,
mfa_enabled=mfa,
role_session_name=role_session_name,
sts_region=sts_region,
)
@@ -219,12 +242,12 @@ class AwsProvider(Provider):
########
# Parse Scan Tags
if getattr(arguments, "resource_tags", None):
self._audit_resources = self.get_tagged_resources(arguments.resource_tags)
if resource_tags:
self._audit_resources = self.get_tagged_resources(resource_tags)
# Parse Input Resource ARNs
if getattr(arguments, "resource_arn", None):
self._audit_resources = arguments.resource_arn
if resource_arn:
self._audit_resources = resource_arn
# Get Enabled Regions
self._enabled_regions = self.get_aws_enabled_regions(
@@ -237,14 +260,12 @@ class AwsProvider(Provider):
# TODO: move this to the providers, pending for AWS, GCP, AZURE and K8s
# Audit Config
self._audit_config = {}
if hasattr(arguments, "config_file"):
self._audit_config = load_and_validate_config_file(
self._type, arguments.config_file
)
if config_file:
self._audit_config = load_and_validate_config_file(self._type, config_file)
self._fixer_config = {}
if hasattr(arguments, "fixer_config"):
if fixer_config:
self._fixer_config = load_and_validate_fixer_config_file(
self._type, arguments.fixer_config
self._type, fixer_config
)
@property
@@ -388,8 +409,8 @@ class AwsProvider(Provider):
def set_identity(
self,
caller_identity: AWSCallerIdentity,
input_profile: str,
input_regions: set,
profile: str,
regions: set,
profile_region: str,
) -> AWSIdentityInfo:
logger.info(f"Original AWS Caller Identity UserId: {caller_identity.user_id}")
@@ -402,19 +423,19 @@ class AwsProvider(Provider):
user_id=caller_identity.user_id,
partition=partition,
identity_arn=caller_identity.arn.arn,
profile=input_profile,
profile=profile,
profile_region=profile_region,
audited_regions=input_regions,
audited_regions=regions,
)
@staticmethod
def setup_session(
input_mfa: bool = False,
input_profile: str = None,
mfa: bool = False,
profile: str = None,
) -> Session:
try:
logger.info("Creating original session ...")
if input_mfa:
if mfa:
mfa_info = AwsProvider.input_role_mfa_token_and_code()
# TODO: validate MFA ARN here
get_session_token_arguments = {
@@ -433,11 +454,11 @@ class AwsProvider(Provider):
aws_session_token=session_credentials["Credentials"][
"SessionToken"
],
profile_name=input_profile,
profile_name=profile,
)
else:
return Session(
profile_name=input_profile,
profile_name=profile,
)
except Exception as error:
logger.critical(
@@ -691,12 +712,12 @@ class AwsProvider(Provider):
audited_regions.add(region)
return audited_regions
def get_tagged_resources(self, input_resource_tags: list[str]) -> list[str]:
def get_tagged_resources(self, resource_tags: list[str]) -> list[str]:
"""
Returns a list of the resources that are going to be scanned based on the given input tags.
Parameters:
- input_resource_tags: A list of strings representing the tags to filter the resources. Each string should be in the format "key=value".
- resource_tags: A list of strings representing the tags to filter the resources. Each string should be in the format "key=value".
Returns:
- A list of strings representing the ARNs (Amazon Resource Names) of the tagged resources.
@@ -707,16 +728,16 @@ class AwsProvider(Provider):
- The method paginates through the results of the 'get_resources' operation to retrieve all the tagged resources.
Example usage:
input_resource_tags = ["Environment=Production", "Owner=John Doe"]
tagged_resources = get_tagged_resources(input_resource_tags)
resource_tags = ["Environment=Production", "Owner=John Doe"]
tagged_resources = get_tagged_resources(resource_tags)
"""
try:
resource_tags = []
resource_tags_values = []
tagged_resources = []
for tag in input_resource_tags:
for tag in resource_tags:
key = tag.split("=")[0]
value = tag.split("=")[1]
resource_tags.append({"Key": key, "Values": [value]})
resource_tags_values.append({"Key": key, "Values": [value]})
# Get Resources with resource_tags for all regions
for regional_client in self.generate_regional_clients(
"resourcegroupstaggingapi"
@@ -726,7 +747,7 @@ class AwsProvider(Provider):
"get_resources"
)
for page in get_resources_paginator.paginate(
TagFilters=resource_tags
TagFilters=resource_tags_values
):
for resource in page["ResourceTagMappingList"]:
tagged_resources.append(resource["ResourceARN"])
@@ -779,7 +800,7 @@ class AwsProvider(Provider):
mfa_TOTP = input("Enter MFA code: ")
return AWSMFAInfo(arn=mfa_ARN, totp=mfa_TOTP)
def set_session_config(self, aws_retries_max_attempts: int) -> Config:
def set_session_config(self, retries_max_attempts: int) -> Config:
"""
set_session_config returns a botocore Config object with the Prowler user agent and the default retrier configuration if nothing is passed as argument
"""
@@ -788,11 +809,11 @@ class AwsProvider(Provider):
retries={"max_attempts": 3, "mode": "standard"},
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
)
if aws_retries_max_attempts:
if retries_max_attempts:
# Create the new config
config = Config(
retries={
"max_attempts": aws_retries_max_attempts,
"max_attempts": retries_max_attempts,
"mode": "standard",
},
)
@@ -1102,9 +1123,9 @@ def get_aws_available_regions() -> set:
# TODO: This can be moved to another class since it doesn't need self
def get_aws_region_for_sts(session_region: str, input_regions: set[str]) -> str:
def get_aws_region_for_sts(session_region: str, regions: set[str]) -> str:
# If there is no region passed with -f/--region/--filter-region
if input_regions is None or len(input_regions) == 0:
if regions is None or len(regions) == 0:
# If you have a region configured in your AWS config or credentials file
if session_region is not None:
aws_region = session_region
@@ -1114,7 +1135,7 @@ def get_aws_region_for_sts(session_region: str, input_regions: set[str]) -> str:
aws_region = AWS_STS_GLOBAL_ENDPOINT_REGION
else:
# Get the first region passed to the -f/--region
aws_region = list(input_regions)[0]
aws_region = list(regions)[0]
return aws_region

View File

@@ -190,7 +190,7 @@ def validate_arguments(arguments: Namespace) -> tuple[bool, str]:
return (True, "")
def validate_bucket(bucket_name):
def validate_bucket(bucket_name: str) -> str:
"""validate_bucket validates that the input bucket_name is valid"""
if search("(?!(^xn--|.+-s3alias$))^[a-z0-9][a-z0-9-]{1,61}[a-z0-9]$", bucket_name):
return bucket_name

View File

@@ -180,7 +180,24 @@ class Provider(ABC):
)
if not isinstance(Provider._global, provider_class):
if "Kubernetes" in provider_class_name:
if "aws" in provider_class_name.lower():
global_provider = provider_class(
arguments.aws_retries_max_attempts,
arguments.role,
arguments.session_duration,
arguments.external_id,
arguments.role_session_name,
arguments.mfa,
arguments.profile,
set(arguments.region) if arguments.region else None,
arguments.organizations_role,
arguments.scan_unused_services,
arguments.resource_tag,
arguments.resource_arn,
arguments.config_file,
arguments.fixer_config,
)
elif "Kubernetes" in provider_class_name:
global_provider = provider_class(
arguments.kubeconfig_file,
arguments.context,

View File

@@ -2,7 +2,6 @@ import json
import os
import pathlib
import traceback
from argparse import Namespace
from importlib.machinery import FileFinder
from logging import DEBUG, ERROR
from pkgutil import ModuleInfo
@@ -475,8 +474,7 @@ class TestCheck:
},
]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
for test in test_cases:
check_folder = test["input"]["path"]
provider = test["input"]["provider"]

View File

@@ -1064,7 +1064,7 @@ class Test_Parser:
self.parser.parse(command)
assert ex.type == SystemExit
def test_aws_parser_aws_retries_max_attempts(self):
def test_aws_parser_retries_max_attempts(self):
argument = "--aws-retries-max-attempts"
max_retries = "10"
command = [prowler_command, argument, max_retries]

View File

@@ -7,6 +7,7 @@ from datetime import datetime, timedelta
from json import dumps
from os import rmdir
from re import search
from unittest import mock
import botocore
import botocore.exceptions
@@ -246,10 +247,12 @@ def mock_recover_checks_from_aws_provider_cognito_service(*_):
class TestAWSProvider:
@mock_aws
def test_aws_provider_default(self):
arguments = Namespace()
arguments.mfa = False
arguments.scan_unused_services = True
aws_provider = AwsProvider(arguments)
mfa = False
scan_unused_services = True
aws_provider = AwsProvider(
mfa=mfa,
scan_unused_services=scan_unused_services,
)
assert aws_provider.type == "aws"
assert aws_provider.scan_unused_services is True
@@ -267,8 +270,7 @@ class TestAWSProvider:
],
)
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
assert isinstance(aws_provider.organizations_metadata, AWSOrganizationsInfo)
assert aws_provider.organizations_metadata.account_email == "master@example.com"
@@ -285,8 +287,8 @@ class TestAWSProvider:
@mock_aws
def test_aws_provider_organizations_none_organizations_metadata(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
assert isinstance(aws_provider.organizations_metadata, AWSOrganizationsInfo)
assert aws_provider.organizations_metadata.account_email == ""
@@ -346,10 +348,12 @@ class TestAWSProvider:
],
)
arguments = Namespace()
arguments.organizations_role = organizations_role["Arn"]
arguments.session_duration = 900
aws_provider = AwsProvider(arguments)
organizations_role = organizations_role["Arn"]
session_duration = 900
aws_provider = AwsProvider(
organizations_role_arn=organizations_role,
session_duration=session_duration,
)
assert isinstance(aws_provider.organizations_metadata, AWSOrganizationsInfo)
assert aws_provider.organizations_metadata.account_email == "master@example.com"
@@ -366,8 +370,7 @@ class TestAWSProvider:
@mock_aws
def test_aws_provider_session_with_mfa(self):
arguments = Namespace()
arguments.mfa = True
mfa = True
with patch(
"prowler.providers.aws.aws_provider.AwsProvider.input_role_mfa_token_and_code",
@@ -377,7 +380,7 @@ class TestAWSProvider:
),
):
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider(mfa=mfa)
assert aws_provider.type == "aws"
assert aws_provider.scan_unused_services is None
@@ -391,9 +394,7 @@ class TestAWSProvider:
@mock_aws
def test_aws_provider_get_output_mapping(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
assert aws_provider.get_output_mapping == {
"auth_method": "identity.profile",
@@ -410,13 +411,12 @@ class TestAWSProvider:
@mock_aws
def test_aws_provider_assume_role_with_mfa(self):
# Variables
arguments = Namespace()
arguments.mfa = True
mfa = True
role_name = "test-role"
arguments.role = f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:role/{role_name}"
arguments.session_duration = 900
arguments.role_session_name = "ProwlerAssessmentSession"
arguments.external_id = "test-external-id"
role_arn = f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:role/{role_name}"
session_duration = 900
role_session_name = "ProwlerAssessmentSession"
external_id = "test-external-id"
with patch(
"prowler.providers.aws.aws_provider.AwsProvider.input_role_mfa_token_and_code",
@@ -425,7 +425,13 @@ class TestAWSProvider:
totp="111111",
),
):
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider(
mfa=mfa,
role_arn=role_arn,
session_duration=session_duration,
role_session_name=role_session_name,
external_id=external_id,
)
assert (
aws_provider.session.current_session.region_name == AWS_REGION_US_EAST_1
)
@@ -436,11 +442,11 @@ class TestAWSProvider:
aws_provider._assumed_role_configuration.info, AWSAssumeRoleInfo
)
assert aws_provider._assumed_role_configuration.info == AWSAssumeRoleInfo(
role_arn=ARN(arn=arguments.role),
session_duration=arguments.session_duration,
external_id=arguments.external_id,
role_arn=ARN(arn=role_arn),
session_duration=session_duration,
external_id=external_id,
mfa_enabled=True, # <- MFA configuration
role_session_name=arguments.role_session_name,
role_session_name=role_session_name,
sts_region=AWS_REGION_US_EAST_1,
)
@@ -464,16 +470,20 @@ class TestAWSProvider:
@mock_aws
def test_aws_provider_assume_role_without_mfa(self):
# Variables
arguments = Namespace()
arguments.mfa = False
mfa = False
role_name = "test-role"
arguments.role = (
role_arn = (
f"arn:{AWS_COMMERCIAL_PARTITION}:iam::{AWS_ACCOUNT_NUMBER}:role/{role_name}"
)
arguments.session_duration = 900
arguments.role_session_name = "ProwlerAssessmentSession"
session_duration = 900
role_session_name = "ProwlerAssessmentSession"
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider(
mfa=mfa,
role_arn=role_arn,
session_duration=session_duration,
role_session_name=role_session_name,
)
assert aws_provider.session.current_session.region_name == AWS_REGION_US_EAST_1
assert aws_provider.identity.account == AWS_ACCOUNT_NUMBER
assert aws_provider.identity.account_arn == AWS_ACCOUNT_ARN
@@ -482,11 +492,11 @@ class TestAWSProvider:
aws_provider._assumed_role_configuration.info, AWSAssumeRoleInfo
)
assert aws_provider._assumed_role_configuration.info == AWSAssumeRoleInfo(
role_arn=ARN(arn=arguments.role),
session_duration=arguments.session_duration,
role_arn=ARN(arn=role_arn),
session_duration=session_duration,
external_id=None,
mfa_enabled=False, # <- MFA configuration
role_session_name=arguments.role_session_name,
role_session_name=role_session_name,
sts_region=AWS_REGION_US_EAST_1,
)
@@ -513,16 +523,20 @@ class TestAWSProvider:
monkeypatch.setenv("AWS_DEFAULT_REGION", AWS_REGION_GOV_CLOUD_US_EAST_1)
# Variables
arguments = Namespace()
arguments.mfa = False
mfa = False
role_name = "test-role"
arguments.role = (
role_arn = (
f"arn:{AWS_GOV_CLOUD_PARTITION}:iam::{AWS_ACCOUNT_NUMBER}:role/{role_name}"
)
arguments.session_duration = 900
arguments.role_session_name = "ProwlerAssessmentSession"
session_duration = 900
role_session_name = "ProwlerAssessmentSession"
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider(
mfa=mfa,
role_arn=role_arn,
session_duration=session_duration,
role_session_name=role_session_name,
)
assert (
aws_provider.session.current_session.region_name
== AWS_REGION_GOV_CLOUD_US_EAST_1
@@ -534,11 +548,11 @@ class TestAWSProvider:
aws_provider._assumed_role_configuration.info, AWSAssumeRoleInfo
)
assert aws_provider._assumed_role_configuration.info == AWSAssumeRoleInfo(
role_arn=ARN(arn=arguments.role),
session_duration=arguments.session_duration,
role_arn=ARN(arn=role_arn),
session_duration=session_duration,
external_id=None,
mfa_enabled=False, # <- MFA configuration
role_session_name=arguments.role_session_name,
role_session_name=role_session_name,
sts_region=AWS_REGION_GOV_CLOUD_US_EAST_1,
)
@@ -565,14 +579,15 @@ class TestAWSProvider:
aws:
test_key: value"""
config_file = tempfile.NamedTemporaryFile(delete=False)
config_file.write(bytes(config, encoding="raw_unicode_escape"))
config_file.close()
arguments = Namespace()
arguments.config_file = config_file.name
aws_provider = AwsProvider(arguments)
config_file_input = tempfile.NamedTemporaryFile(delete=False)
config_file_input.write(bytes(config, encoding="raw_unicode_escape"))
config_file_input.close()
config_file_input = config_file_input.name
aws_provider = AwsProvider(
config_file=config_file_input,
)
os.remove(config_file.name)
os.remove(config_file_input)
assert aws_provider.audit_config == {"test_key": "value"}
@@ -604,8 +619,7 @@ aws:
with open(mutelist_file.name, "w") as mutelist_file:
mutelist_file.write(json.dumps(mutelist, indent=4))
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider.mutelist = mutelist_file.name
@@ -617,8 +631,7 @@ aws:
@mock_aws
def test_aws_provider_mutelist_none(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
with patch(
"prowler.providers.aws.aws_provider.get_default_mute_file_path",
@@ -672,8 +685,7 @@ aws:
)
)
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider.mutelist = mutelist_bucket_object_uri
os.remove(mutelist_file.name)
@@ -707,8 +719,7 @@ aws:
}
}
lambda_mutelist_path = f"arn:aws:lambda:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:function:lambda-mutelist"
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
with patch(
"prowler.providers.aws.lib.mutelist.mutelist.AWSMutelist.get_mutelist_file_from_lambda",
@@ -745,8 +756,7 @@ aws:
}
}
dynamodb_mutelist_path = f"arn:aws:dynamodb:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:table/mutelist-dynamo"
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
with patch(
"prowler.providers.aws.lib.mutelist.mutelist.AWSMutelist.get_mutelist_file_from_dynamodb",
@@ -760,24 +770,21 @@ aws:
@mock_aws
def test_empty_input_regions_in_arguments(self):
arguments = Namespace()
arguments.region = None
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider(regions=None)
assert isinstance(aws_provider, AwsProvider)
@mock_aws
def test_generate_regional_clients_all_enabled_regions(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
response = aws_provider.generate_regional_clients("ec2")
assert len(response.keys()) == 29
@mock_aws
def test_generate_regional_clients_with_enabled_regions(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
enabled_regions = [AWS_REGION_EU_WEST_1]
aws_provider._enabled_regions = enabled_regions
@@ -787,9 +794,10 @@ aws:
@mock_aws
def test_generate_regional_clients_with_enabled_regions_and_input_regions(self):
arguments = Namespace()
arguments.region = [AWS_REGION_EU_WEST_1, AWS_REGION_US_EAST_1]
aws_provider = AwsProvider(arguments)
region = [AWS_REGION_EU_WEST_1, AWS_REGION_US_EAST_1]
aws_provider = AwsProvider(
regions=region,
)
enabled_regions = [AWS_REGION_EU_WEST_1]
aws_provider._enabled_regions = enabled_regions
@@ -800,9 +808,10 @@ aws:
@mock_aws
def test_generate_regional_clients_cn_partition(self):
arguments = Namespace()
arguments.region = [AWS_REGION_CN_NORTH_1, AWS_REGION_CN_NORTHWEST_1]
aws_provider = AwsProvider(arguments)
region = [AWS_REGION_CN_NORTH_1, AWS_REGION_CN_NORTHWEST_1]
aws_provider = AwsProvider(
regions=region,
)
response = aws_provider.generate_regional_clients("ec2")
assert AWS_REGION_CN_NORTH_1 in response.keys()
@@ -810,9 +819,10 @@ aws:
@mock_aws
def test_generate_regional_clients_cn_partition_not_present_service(self):
arguments = Namespace()
arguments.region = ["cn-northwest-1", "cn-north-1"]
aws_provider = AwsProvider(arguments)
region = ["cn-northwest-1", "cn-north-1"]
aws_provider = AwsProvider(
regions=region,
)
response = aws_provider.generate_regional_clients("shield")
@@ -820,75 +830,76 @@ aws:
@mock_aws
def test_get_default_region(self):
arguments = Namespace()
arguments.region = [AWS_REGION_EU_WEST_1]
aws_provider = AwsProvider(arguments)
region = [AWS_REGION_EU_WEST_1]
aws_provider = AwsProvider(
regions=region,
)
aws_provider._identity.profile_region = AWS_REGION_EU_WEST_1
assert aws_provider.get_default_region("ec2") == AWS_REGION_EU_WEST_1
@mock_aws
def test_get_default_region_profile_region_not_audited(self):
arguments = Namespace()
arguments.region = [AWS_REGION_EU_WEST_1]
aws_provider = AwsProvider(arguments)
region = [AWS_REGION_EU_WEST_1]
aws_provider = AwsProvider(
regions=region,
)
aws_provider._identity.profile_region = AWS_REGION_US_EAST_2
assert aws_provider.get_default_region("ec2") == AWS_REGION_EU_WEST_1
@mock_aws
def test_get_default_region_non_profile_region(self):
arguments = Namespace()
arguments.region = [AWS_REGION_EU_WEST_1]
aws_provider = AwsProvider(arguments)
region = [AWS_REGION_EU_WEST_1]
aws_provider = AwsProvider(
regions=region,
)
aws_provider._identity.profile_region = None
assert aws_provider.get_default_region("ec2") == AWS_REGION_EU_WEST_1
@mock_aws
def test_get_default_region_non_profile_or_audited_region(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._identity.profile_region = None
assert aws_provider.get_default_region("ec2") == AWS_REGION_US_EAST_1
@mock_aws
def test_get_default_region_profile_region_not_present_in_service(self):
arguments = Namespace()
arguments.region = [AWS_REGION_EU_WEST_1]
aws_provider = AwsProvider(arguments)
region = [AWS_REGION_EU_WEST_1]
aws_provider = AwsProvider(
regions=region,
)
aws_provider._identity.profile_region = "non-existent-region"
assert aws_provider.get_default_region("ec2") == AWS_REGION_EU_WEST_1
@mock_aws
def test_aws_gov_get_global_region(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._identity.partition = AWS_GOV_CLOUD_PARTITION
assert aws_provider.get_global_region() == AWS_REGION_GOV_CLOUD_US_EAST_1
@mock_aws
def test_aws_cn_get_global_region(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._identity.partition = AWS_CHINA_PARTITION
assert aws_provider.get_global_region() == AWS_REGION_CN_NORTH_1
@mock_aws
def test_aws_iso_get_global_region(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._identity.partition = AWS_ISO_PARTITION
assert aws_provider.get_global_region() == AWS_REGION_ISO_GLOBAL
@mock_aws
def test_get_available_aws_service_regions_with_us_east_1_audited(self):
arguments = Namespace()
arguments.region = [AWS_REGION_US_EAST_1]
aws_provider = AwsProvider(arguments)
region = [AWS_REGION_US_EAST_1]
aws_provider = AwsProvider(
regions=region,
)
with patch(
"prowler.providers.aws.aws_provider.parse_json_file",
@@ -926,8 +937,8 @@ aws:
@mock_aws
def test_get_available_aws_service_regions_with_all_regions_audited(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
with patch(
"prowler.providers.aws.aws_provider.parse_json_file",
@@ -994,10 +1005,12 @@ aws:
)
# Through the AWS provider
arguments = Namespace()
arguments.region = [AWS_REGION_EU_CENTRAL_1]
arguments.resource_tags = ["ami=test"]
aws_provider = AwsProvider(arguments)
region = [AWS_REGION_EU_CENTRAL_1]
resource_tags = ["ami=test"]
aws_provider = AwsProvider(
regions=region,
resource_tags=resource_tags,
)
tagged_resources = aws_provider.audit_resources
assert len(tagged_resources) == 2
@@ -1012,9 +1025,10 @@ aws:
@mock_aws
def test_aws_provider_resource_tags(self):
arguments = Namespace()
arguments.resource_arn = [AWS_ACCOUNT_ARN]
aws_provider = AwsProvider(arguments)
resource_arn = [AWS_ACCOUNT_ARN]
aws_provider = AwsProvider(
resource_arn=resource_arn,
)
assert aws_provider.audit_resources == [AWS_ACCOUNT_ARN]
@@ -1032,7 +1046,7 @@ aws:
arguments.unix_timestamp = False
arguments.send_sh_only_fails = True
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
# This is needed since the output_options requires to get the global provider to get the audit config
with patch(
"prowler.providers.common.provider.Provider.get_global_provider",
@@ -1081,7 +1095,7 @@ aws:
arguments.unix_timestamp = False
arguments.send_sh_only_fails = True
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
# This is needed since the output_options requires to get the global provider to get the audit config
with patch(
"prowler.providers.common.provider.Provider.get_global_provider",
@@ -1247,16 +1261,33 @@ aws:
assert connection.is_connected
assert connection.error is None
@mock_aws
def test_test_connection_without_credentials(self, monkeypatch):
monkeypatch.delenv("AWS_ACCESS_KEY_ID")
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
def test_test_connection_without_credentials(self):
with mock.patch("boto3.Session.get_credentials", return_value=None), mock.patch(
"botocore.session.Session.get_scoped_config", return_value={}
), mock.patch(
"botocore.credentials.EnvProvider.load", return_value=None
), mock.patch(
"botocore.credentials.SharedCredentialProvider.load", return_value=None
), mock.patch(
"botocore.credentials.InstanceMetadataProvider.load", return_value=None
), mock.patch.dict(
"os.environ",
{
"AWS_ACCESS_KEY_ID": "",
"AWS_SECRET_ACCESS_KEY": "",
"AWS_SESSION_TOKEN": "",
"AWS_PROFILE": "",
},
clear=True,
):
with raises(botocore.exceptions.NoCredentialsError) as exception:
AwsProvider.test_connection()
with raises(botocore.exceptions.NoCredentialsError) as exception:
AwsProvider.test_connection(
profile=None
) # No profile to avoid ProfileNotFound error
assert exception.type == botocore.exceptions.NoCredentialsError
assert exception.value.args[0] == "Unable to locate credentials"
assert exception.type == botocore.exceptions.NoCredentialsError
assert "Unable to locate credentials" in str(exception.value)
@mock_aws
def test_test_connection_with_role_from_env(self, monkeypatch):
@@ -1394,8 +1425,8 @@ aws:
"elb_internet_facing",
"elb_logging_enabled",
]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:elasticloadbalancing:us-east-1:{AWS_ACCOUNT_NUMBER}:loadbalancer/test"
]
@@ -1415,8 +1446,8 @@ aws:
"efs_have_backup_enabled",
"efs_not_publicly_accessible",
]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:elasticfilesystem:us-east-1:{AWS_ACCOUNT_NUMBER}:file-system/fs-01234567"
]
@@ -1435,8 +1466,8 @@ aws:
"awslambda_function_no_secrets_in_code",
"awslambda_function_url_cors_policy",
]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
"arn:aws:lambda:us-east-1:123456789:function:test-lambda"
]
@@ -1456,8 +1487,8 @@ aws:
"iam_customer_attached_policy_no_administrative_privileges",
"iam_password_policy_minimum_length_14",
]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:user/user-name"
]
@@ -1478,8 +1509,8 @@ aws:
"s3_bucket_acl_prohibited",
"s3_bucket_policy_public_write_access",
]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = ["arn:aws:s3:::bucket-name"]
recovered_checks = aws_provider.get_checks_from_input_arn()
@@ -1496,8 +1527,8 @@ aws:
"cloudwatch_changes_to_network_gateways_alarm_configured",
"cloudwatch_changes_to_network_route_tables_alarm_configured",
]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:logs:us-east-1:{AWS_ACCOUNT_NUMBER}:destination:testDestination"
]
@@ -1512,8 +1543,8 @@ aws:
)
def test_get_checks_from_input_arn_cognito(self):
expected_checks = []
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:cognito-idp:us-east-1:{AWS_ACCOUNT_NUMBER}:userpool/test"
]
@@ -1528,8 +1559,8 @@ aws:
)
def test_get_checks_from_input_arn_ec2_security_group(self):
expected_checks = ["ec2_securitygroup_allow_ingress_from_internet_to_any_port"]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:ec2:us-east-1:{AWS_ACCOUNT_NUMBER}:security-group/sg-1111111111"
]
@@ -1544,8 +1575,8 @@ aws:
)
def test_get_checks_from_input_arn_ec2_acl(self):
expected_checks = ["ec2_networkacl_allow_ingress_any_port"]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:ec2:us-west-2:{AWS_ACCOUNT_NUMBER}:network-acl/acl-1"
]
@@ -1560,8 +1591,8 @@ aws:
)
def test_get_checks_from_input_arn_rds_snapshots(self):
expected_checks = ["rds_snapshots_public_access"]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:rds:us-east-2:{AWS_ACCOUNT_NUMBER}:snapshot:rds:snapshot-1",
]
@@ -1576,8 +1607,8 @@ aws:
)
def test_get_checks_from_input_arn_ec2_ami(self):
expected_checks = ["ec2_ami_public"]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:ec2:us-west-2:{AWS_ACCOUNT_NUMBER}:image/ami-1"
]
@@ -1595,8 +1626,8 @@ aws:
"arn:aws:apigateway:us-east-2::/restapis/api-id/stages/stage-name",
]
expected_regions = {"us-east-1", "eu-west-1", "us-east-2"}
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
recovered_regions = aws_provider.get_regions_from_audit_resources(
audit_resources
)
@@ -1605,8 +1636,8 @@ aws:
@mock_aws
def test_get_regions_from_audit_resources_without_regions(self):
audit_resources = ["arn:aws:s3:::bucket-name"]
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
recovered_regions = aws_provider.get_regions_from_audit_resources(
audit_resources
)
@@ -1681,8 +1712,8 @@ aws:
@mock_aws
def test_set_session_config_default(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
session_config = aws_provider.set_session_config(None)
assert session_config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
@@ -1690,8 +1721,8 @@ aws:
@mock_aws
def test_set_session_config_10_max_attempts(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
session_config = aws_provider.set_session_config(10)
assert session_config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
@@ -1703,8 +1734,8 @@ aws:
new=mock_recover_checks_from_aws_provider_ec2_service,
)
def test_get_checks_to_execute_by_audit_resources(self):
arguments = Namespace()
aws_provider = AwsProvider(arguments)
aws_provider = AwsProvider()
aws_provider._audit_resources = [
f"arn:aws:ec2:us-west-2:{AWS_ACCOUNT_NUMBER}:network-acl/acl-1"
]
@@ -1745,10 +1776,8 @@ aws:
@mock_aws
def test_refresh_credentials_before_expiration(self):
role_arn = create_role(AWS_REGION_EU_WEST_1)
arguments = Namespace()
arguments.role = role_arn
arguments.session_duration = 900
aws_provider = AwsProvider(arguments)
session_duration = 900
aws_provider = AwsProvider(role_arn=role_arn, session_duration=session_duration)
current_credentials = (
aws_provider._assumed_role_configuration.credentials.__dict__
@@ -1768,10 +1797,8 @@ aws:
def test_refresh_credentials_after_expiration(self):
role_arn = create_role(AWS_REGION_EU_WEST_1)
session_duration_in_seconds = 900
arguments = Namespace()
arguments.role = role_arn
arguments.session_duration = session_duration_in_seconds
aws_provider = AwsProvider(arguments)
session_duration = session_duration_in_seconds
aws_provider = AwsProvider(role_arn=role_arn, session_duration=session_duration)
# Manually expire credentials
aws_provider._assumed_role_configuration.credentials.expiration = datetime.now(

View File

@@ -118,7 +118,10 @@ def set_mocked_aws_provider(
arguments = set_default_provider_arguments(arguments, status)
# AWS Provider
provider = AwsProvider(arguments)
provider = AwsProvider()
# Set output options
provider.output_options = arguments, {}
# Output options