chore(aws): Working outputs (#3488)

This commit is contained in:
Pepe Fagoaga
2024-03-04 17:17:20 +01:00
committed by GitHub
parent 5df9fd881c
commit 086148819c
26 changed files with 267 additions and 418 deletions

View File

@@ -27,7 +27,7 @@ lint: ## Lint Code
@echo "Running black... "
black --check .
@echo "Running pylint..."
pylint --disable=W,C,R,E -j 0 providers lib util config
pylint --disable=W,C,R,E -j 0 prowler util
##@ PyPI
pypi-clean: ## Delete the distribution files

View File

@@ -44,11 +44,7 @@ from prowler.providers.aws.lib.security_hub.security_hub import (
resolve_security_hub_previous_findings,
verify_security_hub_integration_enabled_per_region,
)
from prowler.providers.common.clean import clean_provider_local_output_directories
from prowler.providers.common.common import (
get_global_provider,
set_global_provider_object,
)
from prowler.providers.common.common import set_global_provider_object
from prowler.providers.common.outputs import set_provider_output_options
@@ -147,8 +143,7 @@ def prowler():
sys.exit()
# Provider to scan
set_global_provider_object(args)
global_provider = get_global_provider()
global_provider = set_global_provider_object(args)
# Print Provider Credentials
if not args.only_logs:
@@ -249,11 +244,11 @@ def prowler():
args.output_bucket or args.output_bucket_no_assume
):
output_bucket = args.output_bucket
bucket_session = audit_info.audit_session
bucket_session = global_provider.session.current_session
# Check if -D was input
if args.output_bucket_no_assume:
output_bucket = args.output_bucket_no_assume
bucket_session = audit_info.original_session
bucket_session = global_provider.session.original_session
send_to_s3_bucket(
audit_output_options.output_filename,
args.output_directory,
@@ -271,27 +266,27 @@ def prowler():
aws_security_enabled_regions = []
security_hub_regions = (
global_provider.get_available_aws_service_regions("securityhub")
if not audit_info.audited_regions
else audit_info.audited_regions
if not global_provider.identity.audited_regions
else global_provider.identity.audited_regions
)
for region in security_hub_regions:
# Save the regions where AWS Security Hub is enabled
if verify_security_hub_integration_enabled_per_region(
audit_info.audited_partition,
global_provider.identity.partition,
region,
audit_info.audit_session,
audit_info.audited_account,
global_provider.session.current_session,
global_provider.identity.account,
):
aws_security_enabled_regions.append(region)
# Prepare the findings to be sent to Security Hub
security_hub_findings_per_region = prepare_security_hub_findings(
findings, audit_info, audit_output_options, aws_security_enabled_regions
findings, provider, audit_output_options, aws_security_enabled_regions
)
# Send the findings to Security Hub
findings_sent_to_security_hub = batch_send_to_security_hub(
security_hub_findings_per_region, audit_info.audit_session
security_hub_findings_per_region, provider.session.current_session
)
print(
@@ -305,7 +300,7 @@ def prowler():
)
findings_archived_in_security_hub = resolve_security_hub_previous_findings(
security_hub_findings_per_region,
audit_info,
provider,
)
print(
f"{Style.BRIGHT}{Fore.GREEN}\n{findings_archived_in_security_hub} findings archived in AWS Security Hub!{Style.RESET_ALL}"
@@ -315,9 +310,8 @@ def prowler():
if not args.only_logs:
display_summary_table(
findings,
audit_info,
global_provider,
audit_output_options,
provider,
)
if findings:
@@ -344,9 +338,6 @@ def prowler():
if checks_folder:
remove_custom_checks_module(checks_folder, provider)
# clean local directories
clean_provider_local_output_directories(args)
# If there are failed findings exit code 3, except if -z is input
if not args.ignore_exit_code_3 and stats["total_fail"] > 0:
sys.exit(3)

View File

@@ -577,7 +577,7 @@ def execute(
)
# Report the check's findings
report(check_findings, audit_output_options, global_provider.identity)
report(check_findings, audit_output_options, global_provider)
if os.environ.get("PROWLER_REPORT_LIB_PATH"):
try:
@@ -587,7 +587,7 @@ def execute(
custom_report_interface = getattr(outputs_module, "report")
custom_report_interface(
check_findings, audit_output_options, global_provider.identity
check_findings, audit_output_options, global_provider
)
except Exception:
sys.exit(1)

View File

@@ -141,6 +141,7 @@ class Mitre_Requirement(BaseModel):
# Base Compliance Model
# TODO: move this to compliance folder
class Compliance_Requirement(BaseModel):
"""Compliance_Requirement holds the base model for every requirement within a compliance framework"""
@@ -151,9 +152,10 @@ class Compliance_Requirement(BaseModel):
Union[
CIS_Requirement_Attribute,
ENS_Requirement_Attribute,
Generic_Compliance_Requirement_Attribute,
ISO27001_2013_Requirement_Attribute,
AWS_Well_Architected_Requirement_Attribute,
# Generic_Compliance_Requirement_Attribute must be the last one since it is the fallback for generic compliance framework
Generic_Compliance_Requirement_Attribute,
]
]
Checks: list[str]

View File

@@ -1,6 +1,7 @@
from csv import DictWriter
from prowler.config.config import timestamp
from prowler.lib.logger import logger
from prowler.lib.outputs.models import (
Check_Output_CSV_AWS_Well_Architected,
generate_csv_fields,
@@ -9,47 +10,52 @@ from prowler.lib.utils.utils import outputs_unix_timestamp
def write_compliance_row_aws_well_architected_framework(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
):
compliance_output = compliance.Framework
if compliance.Version != "":
compliance_output += "_" + compliance.Version
if compliance.Provider != "":
compliance_output += "_" + compliance.Provider
compliance_output = compliance_output.lower().replace("-", "_")
csv_header = generate_csv_fields(Check_Output_CSV_AWS_Well_Architected)
csv_writer = DictWriter(
file_descriptors[compliance_output],
fieldnames=csv_header,
delimiter=";",
)
for requirement in compliance.Requirements:
requirement_description = requirement.Description
requirement_id = requirement.Id
for attribute in requirement.Attributes:
compliance_row = Check_Output_CSV_AWS_Well_Architected(
Provider=finding.check_metadata.Provider,
Description=compliance.Description,
AccountId=audit_info.audited_account,
Region=finding.region,
AssessmentDate=outputs_unix_timestamp(
output_options.unix_timestamp, timestamp
),
Requirements_Id=requirement_id,
Requirements_Description=requirement_description,
Requirements_Attributes_Name=attribute.Name,
Requirements_Attributes_WellArchitectedQuestionId=attribute.WellArchitectedQuestionId,
Requirements_Attributes_WellArchitectedPracticeId=attribute.WellArchitectedPracticeId,
Requirements_Attributes_Section=attribute.Section,
Requirements_Attributes_SubSection=attribute.SubSection,
Requirements_Attributes_LevelOfRisk=attribute.LevelOfRisk,
Requirements_Attributes_AssessmentMethod=attribute.AssessmentMethod,
Requirements_Attributes_Description=attribute.Description,
Requirements_Attributes_ImplementationGuidanceUrl=attribute.ImplementationGuidanceUrl,
Status=finding.status,
StatusExtended=finding.status_extended,
ResourceId=finding.resource_id,
CheckId=finding.check_metadata.CheckID,
)
try:
compliance_output = compliance.Framework
if compliance.Version != "":
compliance_output += "_" + compliance.Version
if compliance.Provider != "":
compliance_output += "_" + compliance.Provider
compliance_output = compliance_output.lower().replace("-", "_")
csv_header = generate_csv_fields(Check_Output_CSV_AWS_Well_Architected)
csv_writer = DictWriter(
file_descriptors[compliance_output],
fieldnames=csv_header,
delimiter=";",
)
for requirement in compliance.Requirements:
requirement_description = requirement.Description
requirement_id = requirement.Id
for attribute in requirement.Attributes:
compliance_row = Check_Output_CSV_AWS_Well_Architected(
Provider=finding.check_metadata.Provider,
Description=compliance.Description,
AccountId=provider.identity.account,
Region=finding.region,
AssessmentDate=outputs_unix_timestamp(
output_options.unix_timestamp, timestamp
),
Requirements_Id=requirement_id,
Requirements_Description=requirement_description,
Requirements_Attributes_Name=attribute.Name,
Requirements_Attributes_WellArchitectedQuestionId=attribute.WellArchitectedQuestionId,
Requirements_Attributes_WellArchitectedPracticeId=attribute.WellArchitectedPracticeId,
Requirements_Attributes_Section=attribute.Section,
Requirements_Attributes_SubSection=attribute.SubSection,
Requirements_Attributes_LevelOfRisk=attribute.LevelOfRisk,
Requirements_Attributes_AssessmentMethod=attribute.AssessmentMethod,
Requirements_Attributes_Description=attribute.Description,
Requirements_Attributes_ImplementationGuidanceUrl=attribute.ImplementationGuidanceUrl,
Status=finding.status,
StatusExtended=finding.status_extended,
ResourceId=finding.resource_id,
CheckId=finding.check_metadata.CheckID,
)
csv_writer.writerow(compliance_row.__dict__)
csv_writer.writerow(compliance_row.__dict__)
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

View File

@@ -8,7 +8,7 @@ def write_compliance_row_cis(
finding,
compliance,
output_options,
audit_info,
provider,
input_compliance_frameworks,
):
compliance_output = "cis_" + compliance.Version + "_" + compliance.Provider.lower()
@@ -24,7 +24,7 @@ def write_compliance_row_cis(
requirement,
attribute,
output_options,
audit_info,
provider,
)
elif compliance.Provider == "GCP":
(compliance_row, csv_header) = generate_compliance_row_cis_gcp(

View File

@@ -4,12 +4,12 @@ from prowler.lib.utils.utils import outputs_unix_timestamp
def generate_compliance_row_cis_aws(
finding, compliance, requirement, attribute, output_options, audit_info
finding, compliance, requirement, attribute, output_options, provider
):
compliance_row = Check_Output_CSV_AWS_CIS(
Provider=finding.check_metadata.Provider,
Description=compliance.Description,
AccountId=audit_info.audited_account,
AccountId=provider.identity.account,
Region=finding.region,
AssessmentDate=outputs_unix_timestamp(output_options.unix_timestamp, timestamp),
Requirements_Id=requirement.Id,

View File

@@ -79,7 +79,7 @@ def get_check_compliance_frameworks_in_input(
def fill_compliance(
output_options, finding, audit_info, file_descriptors, input_compliance_frameworks
output_options, finding, provider, file_descriptors, input_compliance_frameworks
):
try:
# We have to retrieve all the check's compliance requirements and get the ones matching with the input ones
@@ -92,7 +92,7 @@ def fill_compliance(
for compliance in check_compliances:
if compliance.Framework == "ENS" and compliance.Version == "RD2022":
write_compliance_row_ens_rd2022_aws(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
)
elif compliance.Framework == "CIS":
@@ -101,7 +101,7 @@ def fill_compliance(
finding,
compliance,
output_options,
audit_info,
provider,
input_compliance_frameworks,
)
@@ -110,7 +110,7 @@ def fill_compliance(
and compliance.Provider == "AWS"
):
write_compliance_row_aws_well_architected_framework(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
)
elif (
@@ -119,7 +119,7 @@ def fill_compliance(
and compliance.Provider == "AWS"
):
write_compliance_row_iso27001_2013_aws(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
)
elif (
@@ -128,12 +128,12 @@ def fill_compliance(
and compliance.Provider == "AWS"
):
write_compliance_row_mitre_attack_aws(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
)
else:
write_compliance_row_generic(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
)
except Exception as error:

View File

@@ -6,7 +6,7 @@ from prowler.lib.utils.utils import outputs_unix_timestamp
def write_compliance_row_ens_rd2022_aws(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
):
compliance_output = "ens_rd2022_aws"
csv_header = generate_csv_fields(Check_Output_CSV_ENS_RD2022)
@@ -22,7 +22,7 @@ def write_compliance_row_ens_rd2022_aws(
compliance_row = Check_Output_CSV_ENS_RD2022(
Provider=finding.check_metadata.Provider,
Description=compliance.Description,
AccountId=audit_info.audited_account,
AccountId=provider.identity.account,
Region=finding.region,
AssessmentDate=outputs_unix_timestamp(
output_options.unix_timestamp, timestamp

View File

@@ -9,7 +9,7 @@ from prowler.lib.utils.utils import outputs_unix_timestamp
def write_compliance_row_generic(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
):
compliance_output = compliance.Framework
if compliance.Version != "":
@@ -31,7 +31,7 @@ def write_compliance_row_generic(
compliance_row = Check_Output_CSV_Generic_Compliance(
Provider=finding.check_metadata.Provider,
Description=compliance.Description,
AccountId=audit_info.audited_account,
AccountId=provider.identity.account,
Region=finding.region,
AssessmentDate=outputs_unix_timestamp(
output_options.unix_timestamp, timestamp

View File

@@ -9,7 +9,7 @@ from prowler.lib.utils.utils import outputs_unix_timestamp
def write_compliance_row_iso27001_2013_aws(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
):
compliance_output = compliance.Framework
if compliance.Version != "":
@@ -32,7 +32,7 @@ def write_compliance_row_iso27001_2013_aws(
compliance_row = Check_Output_CSV_AWS_ISO27001_2013(
Provider=finding.check_metadata.Provider,
Description=compliance.Description,
AccountId=audit_info.audited_account,
AccountId=provider.identity.account,
Region=finding.region,
AssessmentDate=outputs_unix_timestamp(
output_options.unix_timestamp, timestamp

View File

@@ -10,7 +10,7 @@ from prowler.lib.utils.utils import outputs_unix_timestamp
def write_compliance_row_mitre_attack_aws(
file_descriptors, finding, compliance, output_options, audit_info
file_descriptors, finding, compliance, output_options, provider
):
compliance_output = compliance.Framework
if compliance.Version != "":
@@ -41,7 +41,7 @@ def write_compliance_row_mitre_attack_aws(
compliance_row = Check_Output_MITRE_ATTACK(
Provider=finding.check_metadata.Provider,
Description=compliance.Description,
AccountId=audit_info.audited_account,
AccountId=provider.identity.account,
Region=finding.region,
AssessmentDate=outputs_unix_timestamp(
output_options.unix_timestamp, timestamp

View File

@@ -22,10 +22,7 @@ from prowler.lib.outputs.models import (
generate_csv_fields,
)
from prowler.lib.utils.utils import file_exists, open_file
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
from prowler.providers.azure.lib.audit_info.models import Azure_Audit_Info
from prowler.providers.common.outputs import get_provider_output_model
from prowler.providers.gcp.lib.audit_info.models import GCP_Audit_Info
def initialize_file_descriptor(
@@ -66,20 +63,18 @@ def initialize_file_descriptor(
return file_descriptor
def fill_file_descriptors(output_modes, output_directory, output_filename, audit_info):
def fill_file_descriptors(output_modes, output_directory, output_filename, provider):
try:
file_descriptors = {}
if output_modes:
for output_mode in output_modes:
if output_mode == "csv":
filename = f"{output_directory}/{output_filename}{csv_file_suffix}"
output_model = get_provider_output_model(
audit_info.__class__.__name__
)
output_model = get_provider_output_model(provider.type)
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
provider,
output_model,
)
file_descriptors.update({output_mode: file_descriptor})
@@ -87,7 +82,7 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
elif output_mode == "json":
filename = f"{output_directory}/{output_filename}{json_file_suffix}"
file_descriptor = initialize_file_descriptor(
filename, output_mode, audit_info
filename, output_mode, provider
)
file_descriptors.update({output_mode: file_descriptor})
@@ -96,30 +91,30 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
f"{output_directory}/{output_filename}{json_ocsf_file_suffix}"
)
file_descriptor = initialize_file_descriptor(
filename, output_mode, audit_info
filename, output_mode, provider
)
file_descriptors.update({output_mode: file_descriptor})
elif output_mode == "html":
filename = f"{output_directory}/{output_filename}{html_file_suffix}"
file_descriptor = initialize_file_descriptor(
filename, output_mode, audit_info
filename, output_mode, provider
)
file_descriptors.update({output_mode: file_descriptor})
elif isinstance(audit_info, GCP_Audit_Info):
elif provider.type == "gcp":
if output_mode == "cis_2.0_gcp":
filename = f"{output_directory}/compliance/{output_filename}_cis_2.0_gcp{csv_file_suffix}"
file_descriptor = initialize_file_descriptor(
filename, output_mode, audit_info, Check_Output_CSV_GCP_CIS
filename, output_mode, provider, Check_Output_CSV_GCP_CIS
)
file_descriptors.update({output_mode: file_descriptor})
elif isinstance(audit_info, AWS_Audit_Info):
elif provider.type == "aws":
if output_mode == "json-asff":
filename = f"{output_directory}/{output_filename}{json_asff_file_suffix}"
file_descriptor = initialize_file_descriptor(
filename, output_mode, audit_info
filename, output_mode, provider
)
file_descriptors.update({output_mode: file_descriptor})
@@ -128,7 +123,7 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
provider,
Check_Output_CSV_ENS_RD2022,
)
file_descriptors.update({output_mode: file_descriptor})
@@ -136,14 +131,14 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
elif output_mode == "cis_1.5_aws":
filename = f"{output_directory}/compliance/{output_filename}_cis_1.5_aws{csv_file_suffix}"
file_descriptor = initialize_file_descriptor(
filename, output_mode, audit_info, Check_Output_CSV_AWS_CIS
filename, output_mode, provider, Check_Output_CSV_AWS_CIS
)
file_descriptors.update({output_mode: file_descriptor})
elif output_mode == "cis_1.4_aws":
filename = f"{output_directory}/compliance/{output_filename}_cis_1.4_aws{csv_file_suffix}"
file_descriptor = initialize_file_descriptor(
filename, output_mode, audit_info, Check_Output_CSV_AWS_CIS
filename, output_mode, provider, Check_Output_CSV_AWS_CIS
)
file_descriptors.update({output_mode: file_descriptor})
@@ -155,7 +150,7 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
provider,
Check_Output_CSV_AWS_Well_Architected,
)
file_descriptors.update({output_mode: file_descriptor})
@@ -168,7 +163,7 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
provider,
Check_Output_CSV_AWS_Well_Architected,
)
file_descriptors.update({output_mode: file_descriptor})
@@ -178,7 +173,7 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
provider,
Check_Output_CSV_AWS_ISO27001_2013,
)
file_descriptors.update({output_mode: file_descriptor})
@@ -188,7 +183,7 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
provider,
Check_Output_MITRE_ATTACK,
)
file_descriptors.update({output_mode: file_descriptor})
@@ -196,22 +191,16 @@ def fill_file_descriptors(output_modes, output_directory, output_filename, audit
else:
# Generic Compliance framework
if (
isinstance(audit_info, AWS_Audit_Info)
provider.type == "aws"
and "aws" in output_mode
or (
isinstance(audit_info, Azure_Audit_Info)
and "azure" in output_mode
)
or (
isinstance(audit_info, GCP_Audit_Info)
and "gcp" in output_mode
)
or (provider.type == "azure" and "azure" in output_mode)
or (provider.type == "gcp" and "gcp" in output_mode)
):
filename = f"{output_directory}/compliance/{output_filename}_{output_mode}{csv_file_suffix}"
file_descriptor = initialize_file_descriptor(
filename,
output_mode,
audit_info,
provider,
Check_Output_CSV_Generic_Compliance,
)
file_descriptors.update({output_mode: file_descriptor})

View File

@@ -18,13 +18,9 @@ from prowler.lib.outputs.models import (
unroll_tags,
)
from prowler.lib.utils.utils import open_file
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
from prowler.providers.azure.lib.audit_info.models import Azure_Audit_Info
from prowler.providers.gcp.lib.audit_info.models import GCP_Audit_Info
from prowler.providers.kubernetes.lib.audit_info.models import Kubernetes_Audit_Info
def add_html_header(file_descriptor, audit_info):
def add_html_header(file_descriptor, provider):
try:
file_descriptor.write(
"""
@@ -113,7 +109,7 @@ def add_html_header(file_descriptor, audit_info):
</ul>
</div>
</div> """
+ get_assessment_summary(audit_info)
+ get_assessment_summary(provider)
+ """
<div class="col-md-2">
<div class="card">
@@ -336,18 +332,21 @@ def add_html_footer(output_filename, output_directory):
sys.exit(1)
def get_aws_html_assessment_summary(audit_info):
def get_aws_html_assessment_summary(provider):
try:
if isinstance(audit_info, AWS_Audit_Info):
if provider.type == "aws":
profile = (
audit_info.profile if audit_info.profile is not None else "default"
provider.identity.profile
if provider.identity.profile is not None
else "default"
)
if isinstance(audit_info.audited_regions, list):
audited_regions = " ".join(audit_info.audited_regions)
elif not audit_info.audited_regions:
if isinstance(provider.identity.audited_regions, list):
audited_regions = " ".join(provider.identity.audited_regions)
elif not provider.identity.audited_regions:
audited_regions = "All Regions"
# TODO: why this fallback?
else:
audited_regions = ", ".join(audit_info.audited_regions)
audited_regions = ", ".join(provider.identity.audited_regions)
return (
"""
<div class="col-md-2">
@@ -358,7 +357,7 @@ def get_aws_html_assessment_summary(audit_info):
<ul class="list-group list-group-flush">
<li class="list-group-item">
<b>AWS Account:</b> """
+ audit_info.audited_account
+ provider.identity.account
+ """
</li>
<li class="list-group-item">
@@ -382,12 +381,12 @@ def get_aws_html_assessment_summary(audit_info):
<ul class="list-group list-group-flush">
<li class="list-group-item">
<b>User Id:</b> """
+ audit_info.audited_user_id
+ provider.identity.user_id
+ """
</li>
<li class="list-group-item">
<b>Caller Identity ARN:</b> """
+ audit_info.audited_identity_arn
+ provider.identity.identity_arn
+ """
</li>
</ul>
@@ -403,21 +402,21 @@ def get_aws_html_assessment_summary(audit_info):
sys.exit(1)
def get_azure_html_assessment_summary(audit_info):
def get_azure_html_assessment_summary(provider):
try:
if isinstance(audit_info, Azure_Audit_Info):
if provider.type == "azure":
printed_subscriptions = []
for key, value in audit_info.identity.subscriptions.items():
for key, value in provider.identity.subscriptions.items():
intermediate = f"{key} : {value}"
printed_subscriptions.append(intermediate)
# check if identity is str(coming from SP) or dict(coming from browser or)
if isinstance(audit_info.identity.identity_id, dict):
html_identity = audit_info.identity.identity_id.get(
if isinstance(provider.identity.identity_id, dict):
html_identity = provider.identity.identity_id.get(
"userPrincipalName", "Identity not found"
)
else:
html_identity = audit_info.identity.identity_id
html_identity = provider.identity.identity_id
return (
"""
<div class="col-md-2">
@@ -428,12 +427,12 @@ def get_azure_html_assessment_summary(audit_info):
<ul class="list-group list-group-flush">
<li class="list-group-item">
<b>Azure Tenant IDs:</b> """
+ " ".join(audit_info.identity.tenant_ids)
+ " ".join(provider.identity.tenant_ids)
+ """
</li>
<li class="list-group-item">
<b>Azure Tenant Domain:</b> """
+ audit_info.identity.domain
+ provider.identity.domain
+ """
</li>
<li class="list-group-item">
@@ -452,7 +451,7 @@ def get_azure_html_assessment_summary(audit_info):
<ul class="list-group list-group-flush">
<li class="list-group-item">
<b>Azure Identity Type:</b> """
+ audit_info.identity.identity_type
+ provider.identity.identity_type
+ """
</li>
<li class="list-group-item">
@@ -472,14 +471,14 @@ def get_azure_html_assessment_summary(audit_info):
sys.exit(1)
def get_gcp_html_assessment_summary(audit_info):
def get_gcp_html_assessment_summary(provider):
try:
if isinstance(audit_info, GCP_Audit_Info):
if provider.type == "gcp":
try:
getattr(audit_info.credentials, "_service_account_email")
getattr(provider.credentials, "_service_account_email")
profile = (
audit_info.credentials._service_account_email
if audit_info.credentials._service_account_email is not None
provider.credentials._service_account_email
if provider.credentials._service_account_email is not None
else "default"
)
except AttributeError:
@@ -494,7 +493,7 @@ def get_gcp_html_assessment_summary(audit_info):
<ul class="list-group list-group-flush">
<li class="list-group-item">
<b>GCP Project IDs:</b> """
+ ", ".join(audit_info.project_ids)
+ ", ".join(provider.project_ids)
+ """
</li>
</ul>
@@ -523,9 +522,9 @@ def get_gcp_html_assessment_summary(audit_info):
sys.exit(1)
def get_kubernetes_html_assessment_summary(audit_info):
def get_kubernetes_html_assessment_summary(provider):
try:
if isinstance(audit_info, Kubernetes_Audit_Info):
if provider.type == "kubernetes":
return (
"""
<div class="col-md-2">
@@ -536,7 +535,7 @@ def get_kubernetes_html_assessment_summary(audit_info):
<ul class="list-group list-group-flush">
<li class="list-group-item">
<b>Kubernetes Context:</b> """
+ audit_info.context["name"]
+ provider.context["name"]
+ """
</li>
</ul>
@@ -550,12 +549,12 @@ def get_kubernetes_html_assessment_summary(audit_info):
<ul class="list-group list-group-flush">
<li class="list-group-item">
<b>Kubernetes Cluster:</b> """
+ audit_info.context["context"]["cluster"]
+ provider.context["context"]["cluster"]
+ """
</li>
<li class="list-group-item">
<b>Kubernetes User:</b> """
+ audit_info.context["context"]["user"]
+ provider.context["context"]["user"]
+ """
</li>
</ul>
@@ -570,26 +569,18 @@ def get_kubernetes_html_assessment_summary(audit_info):
sys.exit(1)
def get_assessment_summary(audit_info):
def get_assessment_summary(provider):
"""
get_assessment_summary gets the HTML assessment summary for the provider
"""
try:
# This is based in the Provider_Audit_Info class
# It is not pretty but useful
# AWS_Audit_Info --> aws
# GCP_Audit_Info --> gcp
# Azure_Audit_Info --> azure
# Kubernetes_Audit_Info --> kubernetes
provider = audit_info.__class__.__name__.split("_")[0].lower()
# Dynamically get the Provider quick inventory handler
provider_html_assessment_summary_function = (
f"get_{provider}_html_assessment_summary"
f"get_{provider.type}_html_assessment_summary"
)
return getattr(
importlib.import_module(__name__), provider_html_assessment_summary_function
)(audit_info)
)(provider)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"

View File

@@ -31,10 +31,9 @@ from prowler.lib.outputs.models import (
unroll_dict_to_list,
)
from prowler.lib.utils.utils import hash_sha512, open_file, outputs_unix_timestamp
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
def fill_json_asff(finding_output, audit_info, finding, output_options):
def fill_json_asff(finding_output, provider, finding, output_options):
try:
# Check if there are no resources in the finding
if finding.resource_arn == "":
@@ -43,13 +42,13 @@ def fill_json_asff(finding_output, audit_info, finding, output_options):
finding.resource_arn = finding.resource_id
# The following line cannot be changed because it is the format we use to generate unique findings for AWS Security Hub
# If changed some findings could be lost because the unique identifier will be different
finding_output.Id = f"prowler-{finding.check_metadata.CheckID}-{audit_info.audited_account}-{finding.region}-{hash_sha512(finding.resource_id)}"
finding_output.ProductArn = f"arn:{audit_info.audited_partition}:securityhub:{finding.region}::product/prowler/prowler"
finding_output.Id = f"prowler-{finding.check_metadata.CheckID}-{provider.identity.account}-{finding.region}-{hash_sha512(finding.resource_id)}"
finding_output.ProductArn = f"arn:{provider.identity.partition}:securityhub:{finding.region}::product/prowler/prowler"
finding_output.ProductFields = ProductFields(
ProviderVersion=prowler_version, ProwlerResourceName=finding.resource_arn
)
finding_output.GeneratorId = "prowler-" + finding.check_metadata.CheckID
finding_output.AwsAccountId = audit_info.audited_account
finding_output.AwsAccountId = provider.identity.account
finding_output.Types = finding.check_metadata.CheckType
finding_output.FirstObservedAt = finding_output.UpdatedAt = (
finding_output.CreatedAt
@@ -69,7 +68,7 @@ def fill_json_asff(finding_output, audit_info, finding, output_options):
Resource(
Id=finding.resource_arn,
Type=finding.check_metadata.ResourceType,
Partition=audit_info.audited_partition,
Partition=provider.identity.partition,
Region=finding.region,
Tags=resource_tags,
)
@@ -144,7 +143,7 @@ def generate_json_asff_resource_tags(tags):
)
def fill_json_ocsf(audit_info, finding, output_options) -> Check_Output_JSON_OCSF:
def fill_json_ocsf(provider, finding, output_options) -> Check_Output_JSON_OCSF:
try:
resource_region = ""
resource_name = ""
@@ -157,20 +156,22 @@ def fill_json_ocsf(audit_info, finding, output_options) -> Check_Output_JSON_OCS
account = None
org = None
profile = ""
if isinstance(audit_info, AWS_Audit_Info):
if provider.type == "aws":
profile = (
audit_info.profile if audit_info.profile is not None else "default"
provider.identity.profile
if provider.identity.profile is not None
else "default"
)
if (
hasattr(audit_info, "organizations_metadata")
and audit_info.organizations_metadata
hasattr(provider, "organizations_metadata")
and provider.organizations_metadata
):
aws_account_name = audit_info.organizations_metadata.account_details_name
aws_org_uid = audit_info.organizations_metadata.account_details_org
aws_account_name = provider.organizations_metadata.account_details_name
aws_org_uid = provider.organizations_metadata.account_details_org
if finding.check_metadata.Provider == "aws":
account = Account(
name=aws_account_name,
uid=audit_info.audited_account,
uid=provider.identity.account,
)
org = Organization(
name=aws_org_uid,
@@ -179,15 +180,15 @@ def fill_json_ocsf(audit_info, finding, output_options) -> Check_Output_JSON_OCS
resource_region = finding.region
resource_name = finding.resource_id
resource_uid = finding.resource_arn
finding_uid = f"prowler-{finding.check_metadata.Provider}-{finding.check_metadata.CheckID}-{audit_info.audited_account}-{finding.region}-{finding.resource_id}"
finding_uid = f"prowler-{finding.check_metadata.Provider}-{finding.check_metadata.CheckID}-{provider.identity.account}-{finding.region}-{finding.resource_id}"
elif finding.check_metadata.Provider == "azure":
account = Account(
name=finding.subscription,
uid=finding.subscription,
)
org = Organization(
name=audit_info.identity.domain,
uid=audit_info.identity.domain,
name=provider.identity.domain,
uid=provider.identity.domain,
)
resource_name = finding.resource_name
resource_uid = finding.resource_id

View File

@@ -13,7 +13,7 @@ from prowler.lib.utils.utils import outputs_unix_timestamp
from prowler.providers.aws.lib.audit_info.models import AWSOrganizationsInfo
def get_check_compliance(finding, provider, output_options) -> dict:
def get_check_compliance(finding, provider_type, output_options) -> dict:
"""get_check_compliance returns a map with the compliance framework as key and the requirements where the finding's check is present.
Example:
@@ -33,7 +33,7 @@ def get_check_compliance(finding, provider, output_options) -> dict:
compliance_fw = compliance.Framework
if compliance.Version:
compliance_fw = f"{compliance_fw}-{compliance.Version}"
if compliance.Provider == provider.upper():
if compliance.Provider == provider_type.upper():
if compliance_fw not in check_compliance:
check_compliance[compliance_fw] = []
for requirement in compliance.Requirements:
@@ -46,86 +46,86 @@ def get_check_compliance(finding, provider, output_options) -> dict:
sys.exit(1)
def generate_provider_output_csv(
provider: str, finding, audit_info, mode: str, fd, output_options
):
def generate_provider_output_csv(provider, finding, mode: str, fd, output_options):
"""
set_provider_output_options configures automatically the outputs based on the selected provider and returns the Provider_Output_Options object.
"""
try:
# Dynamically load the Provider_Output_Options class
finding_output_model = f"{provider.capitalize()}_Check_Output_{mode.upper()}"
finding_output_model = (
f"{provider.type.capitalize()}_Check_Output_{mode.upper()}"
)
output_model = getattr(importlib.import_module(__name__), finding_output_model)
# Fill common data among providers
data = fill_common_data_csv(finding, output_options.unix_timestamp)
if provider == "azure":
if provider.type == "azure":
data["resource_id"] = finding.resource_id
data["resource_name"] = finding.resource_name
data["subscription"] = finding.subscription
data["tenant_domain"] = audit_info.identity.domain
data["tenant_domain"] = provider.identity.domain
data["finding_unique_id"] = (
f"prowler-{provider}-{finding.check_metadata.CheckID}-{finding.subscription}-{finding.resource_id}"
f"prowler-{provider.type}-{finding.check_metadata.CheckID}-{finding.subscription}-{finding.resource_id}"
)
data["compliance"] = unroll_dict(
get_check_compliance(finding, provider, output_options)
get_check_compliance(finding, provider.type, output_options)
)
finding_output = output_model(**data)
if provider == "gcp":
if provider.type == "gcp":
data["resource_id"] = finding.resource_id
data["resource_name"] = finding.resource_name
data["project_id"] = finding.project_id
data["location"] = finding.location.lower()
data["finding_unique_id"] = (
f"prowler-{provider}-{finding.check_metadata.CheckID}-{finding.project_id}-{finding.resource_id}"
f"prowler-{provider.type}-{finding.check_metadata.CheckID}-{finding.project_id}-{finding.resource_id}"
)
data["compliance"] = unroll_dict(
get_check_compliance(finding, provider, output_options)
get_check_compliance(finding, provider.type, output_options)
)
finding_output = output_model(**data)
if provider == "kubernetes":
if provider.type == "kubernetes":
data["resource_id"] = finding.resource_id
data["resource_name"] = finding.resource_name
data["namespace"] = finding.namespace
data["finding_unique_id"] = (
f"prowler-{provider}-{finding.check_metadata.CheckID}-{finding.namespace}-{finding.resource_id}"
f"prowler-{provider.type}-{finding.check_metadata.CheckID}-{finding.namespace}-{finding.resource_id}"
)
data["compliance"] = unroll_dict(
get_check_compliance(finding, provider, output_options)
get_check_compliance(finding, provider.type, output_options)
)
finding_output = output_model(**data)
if provider == "aws":
data["profile"] = audit_info.profile
data["account_id"] = audit_info.audited_account
if provider.type == "aws":
data["profile"] = provider.identity.profile
data["account_id"] = provider.identity.account
data["region"] = finding.region
data["resource_id"] = finding.resource_id
data["resource_arn"] = finding.resource_arn
data["finding_unique_id"] = (
f"prowler-{provider}-{finding.check_metadata.CheckID}-{audit_info.audited_account}-{finding.region}-{finding.resource_id}"
f"prowler-{provider.type}-{finding.check_metadata.CheckID}-{provider.identity.account}-{finding.region}-{finding.resource_id}"
)
data["compliance"] = unroll_dict(
get_check_compliance(finding, provider, output_options)
get_check_compliance(finding, provider.type, output_options)
)
finding_output = output_model(**data)
if audit_info.organizations_metadata:
if provider.organizations_metadata:
finding_output.account_name = (
audit_info.organizations_metadata.account_details_name
provider.organizations_metadata.account_details_name
)
finding_output.account_email = (
audit_info.organizations_metadata.account_details_email
provider.organizations_metadata.account_details_email
)
finding_output.account_arn = (
audit_info.organizations_metadata.account_details_arn
provider.organizations_metadata.account_details_arn
)
finding_output.account_org = (
audit_info.organizations_metadata.account_details_org
provider.organizations_metadata.account_details_org
)
finding_output.account_tags = (
audit_info.organizations_metadata.account_details_tags
provider.organizations_metadata.account_details_tags
)
csv_writer = DictWriter(
@@ -379,15 +379,15 @@ class Kubernetes_Check_Output_CSV(Check_Output_CSV):
resource_name: str = ""
def generate_provider_output_json(
provider: str, finding, audit_info, mode: str, output_options
):
def generate_provider_output_json(provider, finding, mode: str, output_options):
"""
generate_provider_output_json configures automatically the outputs based on the selected provider and returns the Check_Output_JSON object.
"""
try:
# Dynamically load the Provider_Output_Options class for the JSON format
finding_output_model = f"{provider.capitalize()}_Check_Output_{mode.upper()}"
finding_output_model = (
f"{provider.type.capitalize()}_Check_Output_{mode.upper()}"
)
output_model = getattr(importlib.import_module(__name__), finding_output_model)
# Instantiate the class for the cloud provider
finding_output = output_model(**finding.check_metadata.dict())
@@ -400,13 +400,13 @@ def generate_provider_output_json(
finding_output.ResourceDetails = finding.resource_details
if provider == "azure":
finding_output.Tenant_Domain = audit_info.identity.domain
finding_output.Tenant_Domain = provider.identity.domain
finding_output.Subscription = finding.subscription
finding_output.ResourceId = finding.resource_id
finding_output.ResourceName = finding.resource_name
finding_output.FindingUniqueId = f"prowler-{provider}-{finding.check_metadata.CheckID}-{finding.subscription}-{finding.resource_id}"
finding_output.FindingUniqueId = f"prowler-{provider.type}-{finding.check_metadata.CheckID}-{finding.subscription}-{finding.resource_id}"
finding_output.Compliance = get_check_compliance(
finding, provider, output_options
finding, provider.type, output_options
)
if provider == "gcp":
@@ -414,26 +414,26 @@ def generate_provider_output_json(
finding_output.Location = finding.location.lower()
finding_output.ResourceId = finding.resource_id
finding_output.ResourceName = finding.resource_name
finding_output.FindingUniqueId = f"prowler-{provider}-{finding.check_metadata.CheckID}-{finding.project_id}-{finding.resource_id}"
finding_output.FindingUniqueId = f"prowler-{provider.type}-{finding.check_metadata.CheckID}-{finding.project_id}-{finding.resource_id}"
finding_output.Compliance = get_check_compliance(
finding, provider, output_options
finding, provider.type, output_options
)
if provider == "aws":
finding_output.Profile = audit_info.profile
finding_output.AccountId = audit_info.audited_account
finding_output.Profile = provider.identity.profile
finding_output.AccountId = provider.identity.account
finding_output.Region = finding.region
finding_output.ResourceId = finding.resource_id
finding_output.ResourceArn = finding.resource_arn
finding_output.ResourceTags = parse_json_tags(finding.resource_tags)
finding_output.FindingUniqueId = f"prowler-{provider}-{finding.check_metadata.CheckID}-{audit_info.audited_account}-{finding.region}-{finding.resource_id}"
finding_output.FindingUniqueId = f"prowler-{provider.type}-{finding.check_metadata.CheckID}-{provider.identity.account}-{finding.region}-{finding.resource_id}"
finding_output.Compliance = get_check_compliance(
finding, provider, output_options
finding, provider.type, output_options
)
if audit_info.organizations_metadata:
if provider.organizations_metadata:
finding_output.OrganizationsInfo = (
audit_info.organizations_metadata.__dict__
provider.organizations_metadata.__dict__
)
except Exception as error:

View File

@@ -16,8 +16,6 @@ from prowler.lib.outputs.models import (
generate_provider_output_csv,
generate_provider_output_json,
)
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
from prowler.providers.azure.lib.audit_info.models import Azure_Audit_Info
def stdout_report(finding, color, verbose, status):
@@ -36,26 +34,25 @@ def stdout_report(finding, color, verbose, status):
)
def report(check_findings, output_options, audit_info):
def report(check_findings, output_options, provider):
try:
file_descriptors = {}
if check_findings:
# TO-DO Generic Function
if isinstance(audit_info, AWS_Audit_Info):
if provider.type == "aws":
check_findings.sort(key=lambda x: x.region)
if isinstance(audit_info, Azure_Audit_Info):
if provider.type == "azure":
check_findings.sort(key=lambda x: x.subscription)
# Generate the required output files
if output_options.output_modes:
# if isinstance(audit_info, AWS_Audit_Info):
# We have to create the required output files
file_descriptors = fill_file_descriptors(
output_options.output_modes,
output_options.output_directory,
output_options.output_filename,
audit_info,
provider,
)
for finding in check_findings:
@@ -80,14 +77,14 @@ def report(check_findings, output_options, audit_info):
fill_compliance(
output_options,
finding,
audit_info,
provider,
file_descriptors,
input_compliance_frameworks,
)
add_manual_controls(
output_options,
audit_info,
provider,
file_descriptors,
input_compliance_frameworks,
)
@@ -97,7 +94,7 @@ def report(check_findings, output_options, audit_info):
if "json-asff" in file_descriptors:
finding_output = Check_Output_JSON_ASFF()
fill_json_asff(
finding_output, audit_info, finding, output_options
finding_output, provider, finding, output_options
)
json.dump(
@@ -114,9 +111,8 @@ def report(check_findings, output_options, audit_info):
if "csv" in file_descriptors:
csv_writer, finding_output = generate_provider_output_csv(
finding.check_metadata.Provider,
provider,
finding,
audit_info,
"csv",
file_descriptors["csv"],
output_options,
@@ -125,9 +121,8 @@ def report(check_findings, output_options, audit_info):
if "json" in file_descriptors:
finding_output = generate_provider_output_json(
finding.check_metadata.Provider,
provider,
finding,
audit_info,
"json",
output_options,
)
@@ -140,7 +135,7 @@ def report(check_findings, output_options, audit_info):
if "json-ocsf" in file_descriptors:
finding_output = fill_json_ocsf(
audit_info, finding, output_options
provider, finding, output_options
)
json.dump(

View File

@@ -16,32 +16,31 @@ from prowler.providers.common.outputs import Provider_Output_Options
def display_summary_table(
findings: list,
audit_info,
provider,
output_options: Provider_Output_Options,
provider: str,
):
output_directory = output_options.output_directory
output_filename = output_options.output_filename
try:
if provider == "aws":
if provider.type == "aws":
entity_type = "Account"
audited_entities = audit_info.audited_account
elif provider == "azure":
audited_entities = provider.identity.account
elif provider.type == "azure":
if (
audit_info.identity.domain
provider.identity.domain
!= "Unknown tenant domain (missing AAD permissions)"
):
entity_type = "Tenant Domain"
audited_entities = audit_info.identity.domain
audited_entities = provider.identity.domain
else:
entity_type = "Tenant ID/s"
audited_entities = " ".join(audit_info.identity.tenant_ids)
elif provider == "gcp":
audited_entities = " ".join(provider.identity.tenant_ids)
elif provider.type == "gcp":
entity_type = "Project ID/s"
audited_entities = ", ".join(audit_info.project_ids)
elif provider == "kubernetes":
audited_entities = ", ".join(provider.project_ids)
elif provider.type == "kubernetes":
entity_type = "Context"
audited_entities = audit_info.context["name"]
audited_entities = provider.context["name"]
if findings:
current = {
@@ -110,7 +109,7 @@ def display_summary_table(
)
if provider == "azure":
print(
f"\nSubscriptions scanned: {Fore.YELLOW}{' '.join(audit_info.identity.subscriptions.keys())}{Style.RESET_ALL}"
f"\nSubscriptions scanned: {Fore.YELLOW}{' '.join(provider.identity.subscriptions.keys())}{Style.RESET_ALL}"
)
print(tabulate(findings_table, headers="keys", tablefmt="rounded_grid"))
print(

View File

@@ -517,7 +517,6 @@ Caller Identity ARN: {Fore.YELLOW}[{self._identity.identity_arn}]{Style.RESET_AL
regions = json_regions
return regions
# Remove if not needed
def get_checks_from_input_arn(self) -> set:
"""
get_checks_from_input_arn gets the list of checks from the input arns

View File

@@ -33,6 +33,7 @@ class AWSOrganizationsInfo:
@dataclass
# TODO: remove this class once is not used
class AWS_Audit_Info:
original_session: session.Session
audit_session: session.Session

View File

@@ -5,14 +5,13 @@ from prowler.config.config import timestamp_utc
from prowler.lib.logger import logger
from prowler.lib.outputs.json import fill_json_asff
from prowler.lib.outputs.models import Check_Output_JSON_ASFF
from prowler.providers.aws.lib.audit_info.models import AWS_Audit_Info
SECURITY_HUB_INTEGRATION_NAME = "prowler/prowler"
SECURITY_HUB_MAX_BATCH = 100
def prepare_security_hub_findings(
findings: list, audit_info: AWS_Audit_Info, output_options, enabled_regions: list
findings: list, provider, output_options, enabled_regions: list
) -> dict:
security_hub_findings_per_region = {}
@@ -42,7 +41,7 @@ def prepare_security_hub_findings(
# Format the finding in the JSON ASFF format
finding_json_asff = fill_json_asff(
Check_Output_JSON_ASFF(), audit_info, finding, output_options
Check_Output_JSON_ASFF(), provider, finding, output_options
)
# Include that finding within their region in the JSON format
@@ -137,7 +136,7 @@ def batch_send_to_security_hub(
# Move previous Security Hub check findings to ARCHIVED (as prowler didn't re-detect them)
def resolve_security_hub_previous_findings(
security_hub_findings_per_region: dict, audit_info: AWS_Audit_Info
security_hub_findings_per_region: dict, provider
) -> list:
"""
resolve_security_hub_previous_findings archives all the findings that does not appear in the current execution
@@ -152,14 +151,14 @@ def resolve_security_hub_previous_findings(
for finding in current_findings:
current_findings_ids.append(finding["Id"])
# Get findings of that region
security_hub_client = audit_info.audit_session.client(
security_hub_client = provider.session.current_session.client(
"securityhub", region_name=region
)
findings_filter = {
"ProductName": [{"Value": "Prowler", "Comparison": "EQUALS"}],
"RecordState": [{"Value": "ACTIVE", "Comparison": "EQUALS"}],
"AwsAccountId": [
{"Value": audit_info.audited_account, "Comparison": "EQUALS"}
{"Value": provider.identity.account, "Comparison": "EQUALS"}
],
"Region": [{"Value": region, "Comparison": "EQUALS"}],
}

View File

@@ -3,10 +3,7 @@ from argparse import Namespace
from importlib import import_module
from prowler.lib.logger import logger
from prowler.providers.common.common import (
get_available_providers,
providers_prowler_lib_path,
)
from prowler.providers.common.common import get_available_providers, providers_path
provider_arguments_lib_path = "lib.arguments.arguments"
validate_provider_arguments_function = "validate_arguments"
@@ -21,7 +18,7 @@ def init_providers_parser(self):
try:
getattr(
import_module(
f"{providers_prowler_lib_path}.{provider}.{provider_arguments_lib_path}"
f"{providers_path}.{provider}.{provider_arguments_lib_path}"
),
init_provider_arguments_function,
)(self)
@@ -38,7 +35,7 @@ def validate_provider_arguments(arguments: Namespace) -> tuple[bool, str]:
# Provider function must be located at prowler.providers.<provider>.lib.arguments.arguments.validate_arguments
return getattr(
import_module(
f"{providers_prowler_lib_path}.{arguments.provider}.{provider_arguments_lib_path}"
f"{providers_path}.{arguments.provider}.{provider_arguments_lib_path}"
),
validate_provider_arguments_function,
)(arguments)

View File

@@ -1,32 +0,0 @@
import importlib
import sys
from shutil import rmtree
from prowler.config.config import default_output_directory
from prowler.lib.logger import logger
def clean_provider_local_output_directories(args):
"""
clean_provider_local_output_directories cleans deletes local custom dirs when output is sent to remote provider storage
"""
try:
# import provider cleaning function
provider_clean_function = f"clean_{args.provider}_local_output_directories"
getattr(importlib.import_module(__name__), provider_clean_function)(args)
except AttributeError as attribute_exception:
logger.info(
f"Cleaning local output directories not initialized for provider {args.provider}: {attribute_exception}"
)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
sys.exit(1)
def clean_aws_local_output_directories(args):
"""clean_aws_provider_local_output_directories deletes local custom dirs when output is sent to remote provider storage for aws provider"""
if args.output_bucket or args.output_bucket_no_assume:
if args.output_directory != default_output_directory:
rmtree(args.output_directory)

View File

@@ -1,51 +1,50 @@
import sys
from importlib import import_module
from typing import Any
from pkgutil import iter_modules
from prowler.lib.logger import logger
providers_prowler_lib_path = "prowler.providers"
providers_path = "prowler.providers"
global_provider = None
def get_available_providers() -> list[str]:
"""get_available_providers returns a list of the available providers"""
providers = []
for _, provider, _ in iter_modules([providers_path.replace(".", "/")]):
if provider != "common":
providers.append(provider)
return providers
def get_global_provider():
return global_provider
def set_provider(provider, arguments) -> Any:
provider_class_name = f"{provider.capitalize()}Provider"
import_module_path = f"prowler.providers.{provider}.azure_provider_testing"
provider_instance = getattr(import_module(import_module_path), provider_class_name)(
arguments
)
return provider_instance
def get_available_providers() -> list[str]:
"""get_available_providers returns a list of the available providers"""
providers_list = import_module(f"{providers_prowler_lib_path}")
providers = [
provider
for provider in providers_list.__dict__
if not (provider.startswith("__") or provider.startswith("common"))
]
return providers
# TODO: rename to set_global_provider
def set_global_provider_object(arguments):
try:
global global_provider
# make here dynamic import
common_import_path = (
f"prowler.providers.{arguments.provider}.{arguments.provider}_provider"
provider_class_path = (
f"{providers_path}.{arguments.provider}.{arguments.provider}_provider"
)
provider_class = f"{arguments.provider.capitalize()}Provider"
global_provider = getattr(import_module(common_import_path), provider_class)(
arguments
provider_class_name = f"{arguments.provider.capitalize()}Provider"
provider_class = getattr(
import_module(provider_class_path), provider_class_name
)
if not isinstance(global_provider, provider_class):
global_provider = provider_class(arguments)
return global_provider
except TypeError as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
sys.exit(1)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
sys.exit(1)

View File

@@ -29,13 +29,12 @@ def set_provider_output_options(
return provider_output_options
def get_provider_output_model(audit_info_class_name):
def get_provider_output_model(provider_type):
"""
get_provider_output_model returns the model _Check_Output_CSV for each provider
get_provider_output_model returns the model <provider>_Check_Output_CSV for each provider
"""
# from AWS_Audit_Info -> AWS -> aws -> Aws
output_provider = audit_info_class_name.split("_", 1)[0].lower().capitalize()
output_provider_model_name = f"{output_provider}_Check_Output_CSV"
# TODO: classes should be AwsCheckOutputCSV
output_provider_model_name = f"{provider_type.capitalize()}_Check_Output_CSV"
output_provider_models_path = "prowler.lib.outputs.models"
output_provider_model = getattr(
importlib.import_module(output_provider_models_path), output_provider_model_name

View File

@@ -1,87 +0,0 @@
import importlib
import logging
import tempfile
from argparse import Namespace
from os import path
from mock import patch
from prowler.providers.common.clean import clean_provider_local_output_directories
class Test_Common_Clean:
def set_provider_input_args(self, provider):
set_args_function = f"set_{provider}_input_args"
args = getattr(
getattr(importlib.import_module(__name__), __class__.__name__),
set_args_function,
)(self)
return args
def set_aws_input_args(self):
args = Namespace()
args.provider = "aws"
args.output_bucket = "test-bucket"
args.output_bucket_no_assume = None
return args
def set_azure_input_args(self):
args = Namespace()
args.provider = "azure"
return args
def test_clean_provider_local_output_directories_non_initialized(self, caplog):
provider = "azure"
input_args = self.set_provider_input_args(provider)
caplog.set_level(logging.INFO)
clean_provider_local_output_directories(input_args)
assert (
f"Cleaning local output directories not initialized for provider {provider}:"
in caplog.text
)
def test_clean_aws_local_output_directories_non_default_dir_output_bucket(self):
provider = "aws"
input_args = self.set_provider_input_args(provider)
with tempfile.TemporaryDirectory() as temp_dir:
input_args.output_directory = temp_dir
clean_provider_local_output_directories(input_args)
assert not path.exists(input_args.output_directory)
def test_clean_aws_local_output_directories_non_default_dir_output_bucket_no_assume(
self,
):
provider = "aws"
input_args = self.set_provider_input_args(provider)
input_args.output_bucket = None
input_args.output_bucket_no_assume = "test"
with tempfile.TemporaryDirectory() as temp_dir:
input_args.output_directory = temp_dir
clean_provider_local_output_directories(input_args)
assert not path.exists(input_args.output_directory)
def test_clean_aws_local_output_directories_default_dir_output_bucket(self):
provider = "aws"
input_args = self.set_provider_input_args(provider)
with tempfile.TemporaryDirectory() as temp_dir:
with patch(
"prowler.providers.common.clean.default_output_directory", new=temp_dir
):
input_args.output_directory = temp_dir
clean_provider_local_output_directories(input_args)
assert path.exists(input_args.output_directory)
def test_clean_aws_local_output_directories_default_dir_output_bucket_no_assume(
self,
):
provider = "aws"
input_args = self.set_provider_input_args(provider)
input_args.output_bucket_no_assume = "test"
input_args.ouput_bucket = None
with tempfile.TemporaryDirectory() as temp_dir:
with patch(
"prowler.providers.common.clean.default_output_directory", new=temp_dir
):
input_args.output_directory = temp_dir
clean_provider_local_output_directories(input_args)
assert path.exists(input_args.output_directory)