diff --git a/api/src/backend/tasks/jobs/scan.py b/api/src/backend/tasks/jobs/scan.py index 684fbfd759..500d9201c8 100644 --- a/api/src/backend/tasks/jobs/scan.py +++ b/api/src/backend/tasks/jobs/scan.py @@ -26,7 +26,7 @@ from api.models import ( StateChoices, ) from api.models import StatusChoices as FindingStatus -from api.utils import initialize_prowler_provider +from api.utils import initialize_prowler_provider, return_prowler_provider from api.v1.serializers import ScanTaskSerializer from prowler.lib.outputs.finding import Finding as ProwlerFinding from prowler.lib.scan.scan import Scan as ProwlerScan @@ -149,7 +149,8 @@ def perform_prowler_scan( provider_instance.save() # If the provider is not connected, raise an exception outside the transaction. - # If raised within the transaction, the transaction will be rolled back and the provider will not be marked as not connected. + # If raised within the transaction, the transaction will be rolled back and the provider will not be marked + # as not connected. if exc: raise exc @@ -526,7 +527,7 @@ def create_compliance_requirements(tenant_id: str, scan_id: str): with rls_transaction(tenant_id): scan_instance = Scan.objects.get(pk=scan_id) provider_instance = scan_instance.provider - prowler_provider = initialize_prowler_provider(provider_instance) + prowler_provider = return_prowler_provider(provider_instance) # Get check status data by region from findings check_status_by_region = {} diff --git a/api/src/backend/tasks/tests/test_scan.py b/api/src/backend/tasks/tests/test_scan.py index e4ec0d4d41..9577e66a09 100644 --- a/api/src/backend/tasks/tests/test_scan.py +++ b/api/src/backend/tasks/tests/test_scan.py @@ -399,9 +399,7 @@ class TestCreateComplianceRequirements: ): with ( patch("api.db_utils.rls_transaction"), - patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + patch("tasks.jobs.scan.return_prowler_provider") as mock_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -427,9 +425,7 @@ class TestCreateComplianceRequirements: "us-east-1", "us-west-2", ] - mock_initialize_prowler_provider.return_value = ( - mock_prowler_provider_instance - ) + mock_prowler_provider.return_value = mock_prowler_provider_instance mock_compliance_template.__getitem__.return_value = { "cis_1.4_aws": { @@ -512,9 +508,7 @@ class TestCreateComplianceRequirements: ): with ( patch("api.db_utils.rls_transaction"), - patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + patch("tasks.jobs.scan.return_prowler_provider") as mock_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -557,9 +551,7 @@ class TestCreateComplianceRequirements: "us-east-1", "us-west-2", ] - mock_initialize_prowler_provider.return_value = ( - mock_prowler_provider_instance - ) + mock_prowler_provider.return_value = mock_prowler_provider_instance mock_compliance_template.__getitem__.return_value = { "test_compliance": { @@ -607,9 +599,7 @@ class TestCreateComplianceRequirements: ): with ( patch("api.db_utils.rls_transaction"), - patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + patch("tasks.jobs.scan.return_prowler_provider") as mock_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -641,9 +631,7 @@ class TestCreateComplianceRequirements: mock_prowler_provider_instance.get_regions.side_effect = AttributeError( "No get_regions method" ) - mock_initialize_prowler_provider.return_value = ( - mock_prowler_provider_instance - ) + mock_prowler_provider.return_value = mock_prowler_provider_instance mock_compliance_template.__getitem__.return_value = { "kubernetes_cis": { @@ -676,9 +664,7 @@ class TestCreateComplianceRequirements: ): with ( patch("api.db_utils.rls_transaction"), - patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + patch("tasks.jobs.scan.return_prowler_provider") as mock_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -704,9 +690,7 @@ class TestCreateComplianceRequirements: mock_prowler_provider_instance = MagicMock() mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"] - mock_initialize_prowler_provider.return_value = ( - mock_prowler_provider_instance - ) + mock_prowler_provider.return_value = mock_prowler_provider_instance mock_compliance_template.__getitem__.return_value = { "cis_1.4_aws": { @@ -743,9 +727,7 @@ class TestCreateComplianceRequirements: ): with ( patch("api.db_utils.rls_transaction"), - patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + patch("tasks.jobs.scan.return_prowler_provider") as mock_prowler_provider, ): tenant = tenants_fixture[0] scan = scans_fixture[0] @@ -759,7 +741,7 @@ class TestCreateComplianceRequirements: tenant_id = str(tenant.id) scan_id = str(scan.id) - mock_initialize_prowler_provider.side_effect = Exception( + mock_prowler_provider.side_effect = Exception( "Provider initialization failed" ) @@ -774,9 +756,7 @@ class TestCreateComplianceRequirements: ): with ( patch("api.db_utils.rls_transaction"), - patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + patch("tasks.jobs.scan.return_prowler_provider") as mock_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -800,9 +780,7 @@ class TestCreateComplianceRequirements: mock_prowler_provider_instance = MagicMock() mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"] - mock_initialize_prowler_provider.return_value = ( - mock_prowler_provider_instance - ) + mock_prowler_provider.return_value = mock_prowler_provider_instance mock_compliance_template.__getitem__.return_value = {} @@ -821,8 +799,8 @@ class TestCreateComplianceRequirements: with ( patch("api.db_utils.rls_transaction"), patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + "tasks.jobs.scan.return_prowler_provider" + ) as mock_return_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -862,9 +840,7 @@ class TestCreateComplianceRequirements: mock_prowler_provider_instance = MagicMock() mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"] - mock_initialize_prowler_provider.return_value = ( - mock_prowler_provider_instance - ) + mock_return_prowler_provider.return_value = mock_prowler_provider_instance mock_compliance_template.__getitem__.return_value = { "cis_1.4_aws": { @@ -898,8 +874,8 @@ class TestCreateComplianceRequirements: with ( patch("api.db_utils.rls_transaction"), patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + "tasks.jobs.scan.return_prowler_provider" + ) as mock_return_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -911,7 +887,6 @@ class TestCreateComplianceRequirements: ): tenant = tenants_fixture[0] scan = scans_fixture[0] - providers_fixture[0] mock_findings_filter.return_value = [] @@ -921,7 +896,7 @@ class TestCreateComplianceRequirements: "us-west-2", "eu-west-1", ] - mock_initialize_prowler_provider.return_value = mock_prowler_provider + mock_return_prowler_provider.return_value = mock_prowler_provider mock_compliance_template.__getitem__.return_value = { "test_compliance": { @@ -990,8 +965,8 @@ class TestCreateComplianceRequirements: with ( patch("api.db_utils.rls_transaction"), patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + "tasks.jobs.scan.return_prowler_provider" + ) as mock_return_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -1009,7 +984,7 @@ class TestCreateComplianceRequirements: mock_prowler_provider = MagicMock() mock_prowler_provider.get_regions.return_value = ["us-east-1", "us-west-2"] - mock_initialize_prowler_provider.return_value = mock_prowler_provider + mock_return_prowler_provider.return_value = mock_prowler_provider mock_compliance_template.__getitem__.return_value = { "test_compliance": { @@ -1077,8 +1052,8 @@ class TestCreateComplianceRequirements: with ( patch("api.db_utils.rls_transaction"), patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, + "tasks.jobs.scan.return_prowler_provider" + ) as mock_return_prowler_provider, patch( "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" ) as mock_compliance_template, @@ -1090,13 +1065,12 @@ class TestCreateComplianceRequirements: ): tenant = tenants_fixture[0] scan = scans_fixture[0] - providers_fixture[0] mock_findings_filter.return_value = [] mock_prowler_provider = MagicMock() mock_prowler_provider.get_regions.return_value = ["us-east-1", "us-west-2"] - mock_initialize_prowler_provider.return_value = mock_prowler_provider + mock_return_prowler_provider.return_value = mock_prowler_provider mock_compliance_template.__getitem__.return_value = { "test_compliance": {