fix(config/html): handle encoding issues and improve error handling in config and HTML file loading functions (#4203)

Co-authored-by: Sergio <sergio@prowler.com>
This commit is contained in:
William Leung
2024-06-08 00:51:01 +08:00
committed by Sergio
parent 66199ee722
commit a46ea6a447
3 changed files with 91 additions and 39 deletions

View File

@@ -1,6 +1,5 @@
import os
import pathlib
import sys
from datetime import datetime, timezone
from os import getcwd
@@ -99,52 +98,87 @@ def check_current_version():
def load_and_validate_config_file(provider: str, config_file_path: str) -> dict:
"""
load_and_validate_config_file reads the Prowler config file in YAML format from the default location or the file passed with the --config-file flag
Reads the Prowler config file in YAML format from the default location or the file passed with the --config-file flag.
Args:
provider (str): The provider name (e.g., 'aws', 'gcp', 'azure', 'kubernetes').
config_file_path (str): The path to the configuration file.
Returns:
dict: The configuration dictionary for the specified provider.
"""
try:
with open(config_file_path) as f:
config = {}
with open(config_file_path, "r", encoding="utf-8") as f:
config_file = yaml.safe_load(f)
# Not to introduce a breaking change we have to allow the old format config file without any provider keys
# and a new format with a key for each provider to include their configuration values within
# Check if the new format is passed
if (
"aws" in config_file
or "gcp" in config_file
or "azure" in config_file
or "kubernetes" in config_file
):
# Not to introduce a breaking change, allow the old format config file without any provider keys
# and a new format with a key for each provider to include their configuration values within.
if any(key in config_file for key in ["aws", "gcp", "azure", "kubernetes"]):
config = config_file.get(provider, {})
else:
config = config_file if config_file else {}
# Not to break Azure, K8s and GCP does not support neither use the old config format
# Not to break Azure, K8s and GCP does not support or use the old config format
if provider in ["azure", "gcp", "kubernetes"]:
config = {}
return config
except Exception as error:
logger.critical(
except FileNotFoundError as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
sys.exit(1)
except yaml.YAMLError as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
except UnicodeDecodeError as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
return {}
def load_and_validate_fixer_config_file(
provider: str, fixer_config_file_path: str
) -> dict:
"""
load_and_validate_fixer_config_file reads the Prowler fixer config file in YAML format from the default location or the file passed with the --fixer-config flag
Reads the Prowler fixer config file in YAML format from the default location or the file passed with the --fixer-config flag.
Args:
provider (str): The provider name (e.g., 'aws', 'gcp', 'azure', 'kubernetes').
fixer_config_file_path (str): The path to the fixer configuration file.
Returns:
dict: The fixer configuration dictionary for the specified provider.
Raises:
SystemExit: If there is an error reading or parsing the fixer configuration file.
"""
try:
with open(fixer_config_file_path) as f:
with open(fixer_config_file_path, "r", encoding="utf-8") as f:
fixer_config_file = yaml.safe_load(f)
return fixer_config_file.get(provider, {})
except Exception as error:
logger.critical(
except FileNotFoundError as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
sys.exit(1)
except yaml.YAMLError as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
except UnicodeDecodeError as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
return {}

View File

@@ -173,33 +173,42 @@ def fill_html(file_descriptor, finding):
def fill_html_overview_statistics(stats, output_filename, output_directory):
try:
filename = f"{output_directory}/{output_filename}{html_file_suffix}"
# Read file
# Read file
if path.isfile(filename):
with open(filename, "r") as file:
with open(filename, "r", encoding="utf-8") as file:
filedata = file.read()
# Replace statistics
# TOTAL_FINDINGS
filedata = filedata.replace(
"TOTAL_FINDINGS", str(stats.get("findings_count"))
"TOTAL_FINDINGS", str(stats.get("findings_count", 0))
)
# TOTAL_RESOURCES
filedata = filedata.replace(
"TOTAL_RESOURCES", str(stats.get("resources_count"))
"TOTAL_RESOURCES", str(stats.get("resources_count", 0))
)
# TOTAL_PASS
filedata = filedata.replace("TOTAL_PASS", str(stats.get("total_pass")))
filedata = filedata.replace("TOTAL_PASS", str(stats.get("total_pass", 0)))
# TOTAL_FAIL
filedata = filedata.replace("TOTAL_FAIL", str(stats.get("total_fail")))
filedata = filedata.replace("TOTAL_FAIL", str(stats.get("total_fail", 0)))
# Write file
with open(filename, "w") as file:
with open(filename, "w", encoding="utf-8") as file:
file.write(filedata)
except Exception as error:
logger.critical(
except FileNotFoundError as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
except UnicodeDecodeError as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
sys.exit(1)
def add_html_footer(output_filename, output_directory):

View File

@@ -1,3 +1,4 @@
import logging
import os
import pathlib
from unittest import mock
@@ -386,12 +387,16 @@ class Test_Config:
assert load_and_validate_config_file("azure", config_test_file) == {}
assert load_and_validate_config_file("kubernetes", config_test_file) == {}
def test_load_and_validate_config_file_invalid_config_file_path(self):
def test_load_and_validate_config_file_invalid_config_file_path(self, caplog):
provider = "aws"
config_file_path = "invalid/path/to/fixer_config.yaml"
with pytest.raises(SystemExit):
load_and_validate_config_file(provider, config_file_path)
with caplog.at_level(logging.ERROR):
result = load_and_validate_config_file(provider, config_file_path)
assert "FileNotFoundError" in caplog.text
assert result == {}
assert pytest is not None
def test_load_and_validate_fixer_config_aws(self):
path = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
@@ -421,9 +426,13 @@ class Test_Config:
assert load_and_validate_fixer_config_file(provider, config_test_file) == {}
def test_load_and_validate_fixer_config_invalid_fixer_config_path(self):
def test_load_and_validate_fixer_config_invalid_fixer_config_path(self, caplog):
provider = "aws"
fixer_config_path = "invalid/path/to/fixer_config.yaml"
with pytest.raises(SystemExit):
load_and_validate_fixer_config_file(provider, fixer_config_path)
with caplog.at_level(logging.ERROR):
result = load_and_validate_fixer_config_file(provider, fixer_config_path)
assert "FileNotFoundError" in caplog.text
assert result == {}
assert pytest is not None