feat(aws): Add static credentials authentication (#5360)

This commit is contained in:
Pepe Fagoaga
2024-10-10 17:47:05 +02:00
committed by GitHub
parent 9f2de7d2f9
commit cad99c5e0f
4 changed files with 211 additions and 12 deletions
+1 -1
View File
@@ -85,7 +85,7 @@ repos:
# For running trufflehog in docker, use the following entry instead:
# entry: bash -c 'docker run -v "$(pwd):/workdir" -i --rm trufflesecurity/trufflehog:latest git file:///workdir --only-verified --fail'
language: system
stages: ["commit", "push"]
stages: ["pre-commit", "pre-push"]
- id: bandit
name: bandit
+83 -11
View File
@@ -2,8 +2,8 @@ import os
import pathlib
from datetime import datetime
from re import fullmatch
from typing import Optional
from boto3 import client
from boto3.session import Session
from botocore.config import Config
from botocore.credentials import RefreshableCredentials
@@ -24,6 +24,7 @@ from prowler.providers.aws.config import (
ROLE_SESSION_NAME,
)
from prowler.providers.aws.exceptions.exceptions import (
AWSAccessKeyIDInvalid,
AWSArgumentTypeValidationError,
AWSAssumeRoleError,
AWSClientError,
@@ -35,6 +36,7 @@ from prowler.providers.aws.exceptions.exceptions import (
AWSIAMRoleARNServiceNotIAMnorSTS,
AWSNoCredentialsError,
AWSProfileNotFoundError,
AWSSecretAccessKeyInvalid,
AWSSetUpSessionError,
)
from prowler.providers.aws.lib.arn.arn import parse_iam_credentials_arn
@@ -86,6 +88,9 @@ class AwsProvider(Provider):
resource_arn: list[str] = [],
audit_config: dict = {},
fixer_config: dict = {},
aws_access_key_id: str = None,
aws_secret_access_key: str = None,
aws_session_token: Optional[str] = None,
):
"""
Initializes the AWS provider.
@@ -105,6 +110,9 @@ class AwsProvider(Provider):
- resource_arn: A list of ARNs of the resources to audit.
- audit_config: The audit configuration.
- fixer_config: The fixer configuration.
- aws_access_key_id: The AWS access key ID.
- aws_secret_access_key: The AWS secret access key.
- aws_session_token: The AWS session token, optional.
Raises:
- ArgumentTypeError: If the input MFA ARN is invalid.
@@ -119,7 +127,13 @@ class AwsProvider(Provider):
logger.info("Generating original session ...")
# Configure the initial AWS Session using the local credentials: profile or environment variables
aws_session = self.setup_session(mfa, profile)
aws_session = self.setup_session(
mfa=mfa,
profile=profile,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
session_config = self.set_session_config(retries_max_attempts)
# Current session and the original session points to the same session object until we get a new one, if needed
self._session = AWSSession(
@@ -423,17 +437,33 @@ class AwsProvider(Provider):
def setup_session(
mfa: bool = False,
profile: str = None,
aws_access_key_id: str = None,
aws_secret_access_key: str = None,
aws_session_token: Optional[str] = None,
) -> Session:
try:
logger.info("Creating original session ...")
logger.debug("Creating original session ...")
session_arguments = {}
if profile:
session_arguments["profile_name"] = profile
elif aws_access_key_id and aws_secret_access_key:
session_arguments["aws_access_key_id"] = aws_access_key_id
session_arguments["aws_secret_access_key"] = aws_secret_access_key
if aws_session_token:
session_arguments["aws_session_token"] = aws_session_token
if mfa:
session = Session(**session_arguments)
sts_client = session.client("sts")
# TODO: pass values from the input
mfa_info = AwsProvider.input_role_mfa_token_and_code()
# TODO: validate MFA ARN here
get_session_token_arguments = {
"SerialNumber": mfa_info.arn,
"TokenCode": mfa_info.totp,
}
sts_client = client("sts")
session_credentials = sts_client.get_session_token(
**get_session_token_arguments
)
@@ -445,12 +475,9 @@ class AwsProvider(Provider):
aws_session_token=session_credentials["Credentials"][
"SessionToken"
],
profile_name=profile,
)
else:
return Session(
profile_name=profile,
)
return Session(**session_arguments)
except Exception as error:
logger.critical(
f"AWSSetUpSessionError[{error.__traceback__.tb_lineno}]: {error}"
@@ -939,6 +966,26 @@ class AwsProvider(Provider):
arn=ARN(caller_identity.get("Arn")),
region=aws_region,
)
except ClientError as client_error:
logger.error(
f"{client_error.__class__.__name__}[{client_error.__traceback__.tb_lineno}]: {client_error}"
)
if client_error.response["Error"]["Code"] == "InvalidClientTokenId":
raise AWSAccessKeyIDInvalid(
original_exception=client_error,
file=pathlib.Path(__file__).name,
)
elif client_error.response["Error"]["Code"] == "SignatureDoesNotMatch":
raise AWSSecretAccessKeyInvalid(
original_exception=client_error,
file=pathlib.Path(__file__).name,
)
else:
raise AWSClientError(
original_exception=client_error,
file=pathlib.Path(__file__).name,
)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -955,6 +1002,9 @@ class AwsProvider(Provider):
external_id: str = None,
mfa_enabled: bool = False,
raise_on_exception: bool = True,
aws_access_key_id: str = None,
aws_secret_access_key: str = None,
aws_session_token: Optional[str] = None,
) -> Connection:
"""
Test the connection to AWS with one of the Boto3 credentials methods.
@@ -968,6 +1018,9 @@ class AwsProvider(Provider):
external_id (str): The external ID to use when assuming the role.
mfa_enabled (bool): Whether MFA (Multi-Factor Authentication) is enabled.
raise_on_exception (bool): Whether to raise an exception if an error occurs.
aws_access_key_id (str): The AWS access key ID to use for the session.
aws_secret_access_key (str): The AWS secret access key to use for the session.
aws_session_token (str): The AWS session token to use for the session. Optional.
Returns:
Connection: An object tha contains the result of the test connection operation.
@@ -994,9 +1047,17 @@ class AwsProvider(Provider):
Connection(is_connected=False, Error=ProfileNotFound('The config profile (not-found) could not be found'))
>>> AwsProvider.test_connection(raise_on_exception=False))
Connection(is_connected=False, Error=NoCredentialsError('Unable to locate credentials'))
>>> AwsProvider.test_connection(aws_access_key_id="XXXXXXXX", aws_secret_access_key="XXXXXXXX", raise_on_exception=False))
Connection(is_connected=True, Error=None))
"""
try:
session = AwsProvider.setup_session(mfa_enabled, profile)
session = AwsProvider.setup_session(
mfa=mfa_enabled,
profile=profile,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
if role_arn:
session_duration = validate_session_duration(session_duration)
@@ -1021,8 +1082,7 @@ class AwsProvider(Provider):
profile_name=profile,
)
sts_client = AwsProvider.create_sts_session(session, aws_region)
_ = sts_client.get_caller_identity()
_ = AwsProvider.validate_credentials(session, aws_region)
return Connection(
is_connected=True,
)
@@ -1113,6 +1173,18 @@ class AwsProvider(Provider):
) from no_credentials_error
return Connection(error=no_credentials_error)
except AWSAccessKeyIDInvalid as access_key_id_invalid_error:
logger.error(str(access_key_id_invalid_error))
if raise_on_exception:
raise access_key_id_invalid_error
return Connection(error=access_key_id_invalid_error)
except AWSSecretAccessKeyInvalid as secret_access_key_invalid_error:
logger.error(str(secret_access_key_invalid_error))
if raise_on_exception:
raise secret_access_key_invalid_error
return Connection(error=secret_access_key_invalid_error)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -57,6 +57,14 @@ class AWSBaseException(ProwlerException):
"message": "AWS assume role error",
"remediation": "Check the AWS assume role configuration and ensure it is properly set up, please visit https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/aws/role-assumption/ and https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_terms-and-concepts.html#iam-term-role-session-name",
},
(1915, "AWSAccessKeyIDInvalid"): {
"message": "AWS Access Key ID or Session Token is invalid",
"remediation": "Check your AWS Access Key ID or Session Token and ensure it is valid.",
},
(1916, "AWSSecretAccessKeyInvalid"): {
"message": "AWS Secret Access Key is invalid",
"remediation": "Check your AWS Secret Access Key and signing method and ensure it is valid.",
},
}
def __init__(self, code, file=None, original_exception=None, message=None):
@@ -175,3 +183,17 @@ class AWSAssumeRoleError(AWSBaseException):
super().__init__(
1914, file=file, original_exception=original_exception, message=message
)
class AWSAccessKeyIDInvalid(AWSCredentialsError):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
1915, file=file, original_exception=original_exception, message=message
)
class AWSSecretAccessKeyInvalid(AWSCredentialsError):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
1916, file=file, original_exception=original_exception, message=message
)
+105
View File
@@ -260,6 +260,72 @@ class TestAWSProvider:
assert aws_provider.audit_config == {}
assert aws_provider.session.current_session.region_name == AWS_REGION_US_EAST_1
@mock_aws
def test_aws_provider_with_static_credentials(self):
# Create a mock IAM user
iam_client = client("iam", region_name=AWS_REGION_EU_WEST_1)
username = "test-user"
iam_user = iam_client.create_user(UserName=username)["User"]
# Create a mock IAM access keys
access_key = iam_client.create_access_key(UserName=iam_user["UserName"])[
"AccessKey"
]
credentials = {
"aws_access_key_id": access_key["AccessKeyId"],
"aws_secret_access_key": access_key["SecretAccessKey"],
}
aws_provider = AwsProvider(**credentials)
assert aws_provider.type == "aws"
# Session
assert aws_provider.session.current_session.region_name == AWS_REGION_US_EAST_1
assert aws_provider.session.current_session.profile_name == "default"
assert aws_provider.session.original_session.region_name == AWS_REGION_US_EAST_1
assert aws_provider.session.original_session.profile_name == "default"
# Identity
assert aws_provider.identity.account == AWS_ACCOUNT_NUMBER
assert aws_provider.identity.account_arn == AWS_ACCOUNT_ARN
assert (
aws_provider.identity.identity_arn
== f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:user/{username}"
)
assert aws_provider.identity.partition == AWS_COMMERCIAL_PARTITION
assert aws_provider.identity.profile is None
assert aws_provider.identity.profile_region == AWS_REGION_US_EAST_1
@mock_aws
def test_aws_provider_with_session_credentials(self):
sts_client = client("sts", region_name=AWS_REGION_EU_WEST_1)
session_token = sts_client.get_session_token()
session_credentials = {
"aws_access_key_id": session_token["Credentials"]["AccessKeyId"],
"aws_secret_access_key": session_token["Credentials"]["SecretAccessKey"],
"aws_session_token": session_token["Credentials"]["SessionToken"],
}
aws_provider = AwsProvider(**session_credentials)
assert aws_provider.type == "aws"
# Session
assert aws_provider.session.current_session.region_name == AWS_REGION_US_EAST_1
assert aws_provider.session.current_session.profile_name == "default"
assert aws_provider.session.original_session.region_name == AWS_REGION_US_EAST_1
assert aws_provider.session.original_session.profile_name == "default"
# Identity
assert aws_provider.identity.account == AWS_ACCOUNT_NUMBER
assert aws_provider.identity.account_arn == AWS_ACCOUNT_ARN
# moto is the default user created by moto
assert (
aws_provider.identity.identity_arn
== f"arn:aws:sts::{AWS_ACCOUNT_NUMBER}:user/moto"
)
assert aws_provider.identity.partition == AWS_COMMERCIAL_PARTITION
assert aws_provider.identity.profile is None
assert aws_provider.identity.profile_region == AWS_REGION_US_EAST_1
@mock_aws
def test_aws_provider_organizations_delegated_administrator(self):
organizations_client = client("organizations", region_name=AWS_REGION_EU_WEST_1)
@@ -1293,6 +1359,45 @@ aws:
== "[1912] AWS IAM Role ARN resource type is invalid"
)
@mock_aws
def test_test_connection_with_static_credentials(self):
# Create a mock IAM user
iam_client = client("iam", region_name=AWS_REGION_EU_WEST_1)
username = "test-user"
iam_user = iam_client.create_user(UserName=username)["User"]
# Create a mock IAM access keys
access_key = iam_client.create_access_key(UserName=iam_user["UserName"])[
"AccessKey"
]
credentials = {
"aws_access_key_id": access_key["AccessKeyId"],
"aws_secret_access_key": access_key["SecretAccessKey"],
}
connection = AwsProvider.test_connection(**credentials)
assert isinstance(connection, Connection)
assert connection.is_connected
assert connection.error is None
@mock_aws
def test_test_connection_with_session_credentials(self):
sts_client = client("sts", region_name=AWS_REGION_EU_WEST_1)
session_token = sts_client.get_session_token()
session_credentials = {
"aws_access_key_id": session_token["Credentials"]["AccessKeyId"],
"aws_secret_access_key": session_token["Credentials"]["SecretAccessKey"],
"aws_session_token": session_token["Credentials"]["SessionToken"],
}
connection = AwsProvider.test_connection(**session_credentials)
assert isinstance(connection, Connection)
assert connection.is_connected
assert connection.error is None
@mock_aws
def test_create_sts_session(self):
current_session = session.Session()