From a46ea6a4472bb242718f542e65192f3da8f23c8d Mon Sep 17 00:00:00 2001 From: William Leung <61426712+lshw54@users.noreply.github.com> Date: Sat, 8 Jun 2024 00:51:01 +0800 Subject: [PATCH] fix(config/html): handle encoding issues and improve error handling in config and HTML file loading functions (#4203) Co-authored-by: Sergio --- prowler/config/config.py | 80 +++++++++++++++++++++++--------- prowler/lib/outputs/html/html.py | 29 ++++++++---- tests/config/config_test.py | 21 ++++++--- 3 files changed, 91 insertions(+), 39 deletions(-) diff --git a/prowler/config/config.py b/prowler/config/config.py index 399a633a6d..73fb4e7936 100644 --- a/prowler/config/config.py +++ b/prowler/config/config.py @@ -1,6 +1,5 @@ import os import pathlib -import sys from datetime import datetime, timezone from os import getcwd @@ -99,52 +98,87 @@ def check_current_version(): def load_and_validate_config_file(provider: str, config_file_path: str) -> dict: """ - load_and_validate_config_file reads the Prowler config file in YAML format from the default location or the file passed with the --config-file flag + Reads the Prowler config file in YAML format from the default location or the file passed with the --config-file flag. + + Args: + provider (str): The provider name (e.g., 'aws', 'gcp', 'azure', 'kubernetes'). + config_file_path (str): The path to the configuration file. + + Returns: + dict: The configuration dictionary for the specified provider. """ try: - with open(config_file_path) as f: - config = {} + with open(config_file_path, "r", encoding="utf-8") as f: config_file = yaml.safe_load(f) - # Not to introduce a breaking change we have to allow the old format config file without any provider keys - # and a new format with a key for each provider to include their configuration values within - # Check if the new format is passed - if ( - "aws" in config_file - or "gcp" in config_file - or "azure" in config_file - or "kubernetes" in config_file - ): + # Not to introduce a breaking change, allow the old format config file without any provider keys + # and a new format with a key for each provider to include their configuration values within. + if any(key in config_file for key in ["aws", "gcp", "azure", "kubernetes"]): config = config_file.get(provider, {}) else: config = config_file if config_file else {} - # Not to break Azure, K8s and GCP does not support neither use the old config format + # Not to break Azure, K8s and GCP does not support or use the old config format if provider in ["azure", "gcp", "kubernetes"]: config = {} return config - except Exception as error: - logger.critical( + except FileNotFoundError as error: + logger.error( f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" ) - sys.exit(1) + except yaml.YAMLError as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" + ) + except UnicodeDecodeError as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" + ) + except Exception as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" + ) + + return {} def load_and_validate_fixer_config_file( provider: str, fixer_config_file_path: str ) -> dict: """ - load_and_validate_fixer_config_file reads the Prowler fixer config file in YAML format from the default location or the file passed with the --fixer-config flag + Reads the Prowler fixer config file in YAML format from the default location or the file passed with the --fixer-config flag. + + Args: + provider (str): The provider name (e.g., 'aws', 'gcp', 'azure', 'kubernetes'). + fixer_config_file_path (str): The path to the fixer configuration file. + + Returns: + dict: The fixer configuration dictionary for the specified provider. + + Raises: + SystemExit: If there is an error reading or parsing the fixer configuration file. """ try: - with open(fixer_config_file_path) as f: + with open(fixer_config_file_path, "r", encoding="utf-8") as f: fixer_config_file = yaml.safe_load(f) - return fixer_config_file.get(provider, {}) - except Exception as error: - logger.critical( + except FileNotFoundError as error: + logger.error( f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" ) - sys.exit(1) + except yaml.YAMLError as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" + ) + except UnicodeDecodeError as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" + ) + except Exception as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" + ) + + return {} diff --git a/prowler/lib/outputs/html/html.py b/prowler/lib/outputs/html/html.py index abeb94e7ca..ecb6ef5465 100644 --- a/prowler/lib/outputs/html/html.py +++ b/prowler/lib/outputs/html/html.py @@ -173,33 +173,42 @@ def fill_html(file_descriptor, finding): def fill_html_overview_statistics(stats, output_filename, output_directory): try: filename = f"{output_directory}/{output_filename}{html_file_suffix}" - # Read file + + # Read file if path.isfile(filename): - with open(filename, "r") as file: + with open(filename, "r", encoding="utf-8") as file: filedata = file.read() # Replace statistics # TOTAL_FINDINGS filedata = filedata.replace( - "TOTAL_FINDINGS", str(stats.get("findings_count")) + "TOTAL_FINDINGS", str(stats.get("findings_count", 0)) ) # TOTAL_RESOURCES filedata = filedata.replace( - "TOTAL_RESOURCES", str(stats.get("resources_count")) + "TOTAL_RESOURCES", str(stats.get("resources_count", 0)) ) # TOTAL_PASS - filedata = filedata.replace("TOTAL_PASS", str(stats.get("total_pass"))) + filedata = filedata.replace("TOTAL_PASS", str(stats.get("total_pass", 0))) # TOTAL_FAIL - filedata = filedata.replace("TOTAL_FAIL", str(stats.get("total_fail"))) + filedata = filedata.replace("TOTAL_FAIL", str(stats.get("total_fail", 0))) + # Write file - with open(filename, "w") as file: + with open(filename, "w", encoding="utf-8") as file: file.write(filedata) - except Exception as error: - logger.critical( + except FileNotFoundError as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" + ) + except UnicodeDecodeError as error: + logger.error( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" + ) + except Exception as error: + logger.error( f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" ) - sys.exit(1) def add_html_footer(output_filename, output_directory): diff --git a/tests/config/config_test.py b/tests/config/config_test.py index 7d33bcca69..b04074ef0b 100644 --- a/tests/config/config_test.py +++ b/tests/config/config_test.py @@ -1,3 +1,4 @@ +import logging import os import pathlib from unittest import mock @@ -386,12 +387,16 @@ class Test_Config: assert load_and_validate_config_file("azure", config_test_file) == {} assert load_and_validate_config_file("kubernetes", config_test_file) == {} - def test_load_and_validate_config_file_invalid_config_file_path(self): + def test_load_and_validate_config_file_invalid_config_file_path(self, caplog): provider = "aws" config_file_path = "invalid/path/to/fixer_config.yaml" - with pytest.raises(SystemExit): - load_and_validate_config_file(provider, config_file_path) + with caplog.at_level(logging.ERROR): + result = load_and_validate_config_file(provider, config_file_path) + assert "FileNotFoundError" in caplog.text + assert result == {} + + assert pytest is not None def test_load_and_validate_fixer_config_aws(self): path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) @@ -421,9 +426,13 @@ class Test_Config: assert load_and_validate_fixer_config_file(provider, config_test_file) == {} - def test_load_and_validate_fixer_config_invalid_fixer_config_path(self): + def test_load_and_validate_fixer_config_invalid_fixer_config_path(self, caplog): provider = "aws" fixer_config_path = "invalid/path/to/fixer_config.yaml" - with pytest.raises(SystemExit): - load_and_validate_fixer_config_file(provider, fixer_config_path) + with caplog.at_level(logging.ERROR): + result = load_and_validate_fixer_config_file(provider, fixer_config_path) + assert "FileNotFoundError" in caplog.text + assert result == {} + + assert pytest is not None