import asyncio from unittest.mock import AsyncMock, patch from uuid import uuid4 import pytest from azure.core.credentials import AccessToken from azure.core.exceptions import HttpResponseError from azure.identity import DefaultAzureCredential from mock import MagicMock from prowler.config.config import ( default_config_file_path, default_fixer_config_file_path, load_and_validate_config_file, ) from prowler.providers.azure.azure_provider import AzureProvider from prowler.providers.azure.exceptions.exceptions import ( AzureBrowserAuthNoTenantIDError, AzureHTTPResponseError, AzureInvalidProviderIdError, AzureNoAuthenticationMethodError, AzureTenantIDNoBrowserAuthError, ) from prowler.providers.azure.models import AzureIdentityInfo, AzureRegionConfig from prowler.providers.common.models import Connection class TestAzureProvider: def test_azure_provider(self): subscription_id = None tenant_id = None # We need to set exactly one auth method az_cli_auth = True sp_env_auth = None browser_auth = None managed_identity_auth = None client_id = None client_secret = None fixer_config = load_and_validate_config_file( "azure", default_fixer_config_file_path ) azure_region = "AzureCloud" with ( patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_identity", return_value=AzureIdentityInfo(), ), patch( "prowler.providers.azure.azure_provider.AzureProvider.get_locations", return_value={}, ), ): azure_provider = AzureProvider( az_cli_auth, sp_env_auth, browser_auth, managed_identity_auth, tenant_id, azure_region, subscription_id, config_path=default_config_file_path, fixer_config=fixer_config, client_id=client_id, client_secret=client_secret, ) assert azure_provider.region_config == AzureRegionConfig( name="AzureCloud", authority=None, base_url="https://management.azure.com", credential_scopes=["https://management.azure.com/.default"], ) assert isinstance(azure_provider.session, DefaultAzureCredential) assert azure_provider.identity == AzureIdentityInfo( identity_id="", identity_type="", tenant_ids=[], tenant_domain="Unknown tenant domain (missing AAD permissions)", subscriptions={}, locations={}, ) assert azure_provider.audit_config == { "shodan_api_key": None, "php_latest_version": "8.2", "python_latest_version": "3.12", "java_latest_version": "17", "recommended_minimal_tls_versions": ["1.2", "1.3"], "recommended_smb_channel_encryption_algorithms": ["AES-256-GCM"], "vm_backup_min_daily_retention_days": 7, "desired_vm_sku_sizes": [ "Standard_A8_v2", "Standard_DS3_v2", "Standard_D4s_v3", ], "defender_attack_path_minimal_risk_level": "High", "apim_threat_detection_llm_jacking_threshold": 0.1, "apim_threat_detection_llm_jacking_minutes": 1440, "apim_threat_detection_llm_jacking_actions": [ "ImageGenerations_Create", "ChatCompletions_Create", "Completions_Create", "Embeddings_Create", "FineTuning_Jobs_Create", "Models_List", "Deployments_List", "Deployments_Get", "Deployments_Create", "Deployments_Delete", "Messages_Create", "Claude_Create", "GenerateContent", "GenerateText", "GenerateImage", "Llama_Create", "CodeLlama_Create", "Gemini_Generate", "Claude_Generate", "Llama_Generate", ], } def test_azure_provider_not_auth_methods(self): subscription_id = None tenant_id = None # We need to set exactly one auth method az_cli_auth = None sp_env_auth = None browser_auth = None managed_identity_auth = None config_file = default_config_file_path fixer_config = default_fixer_config_file_path azure_region = "AzureCloud" with ( patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_identity", return_value=AzureIdentityInfo(), ), patch( "prowler.providers.azure.azure_provider.AzureProvider.get_locations", return_value={}, ), ): with pytest.raises(AzureNoAuthenticationMethodError) as exception: _ = AzureProvider( az_cli_auth, sp_env_auth, browser_auth, managed_identity_auth, tenant_id, azure_region, subscription_id, config_file, fixer_config, ) assert exception.type == AzureNoAuthenticationMethodError assert ( "Azure provider requires at least one authentication method set: [--az-cli-auth | --sp-env-auth | --browser-auth | --managed-identity-auth]" in exception.value.args[0] ) def test_azure_provider_browser_auth_but_not_tenant_id(self): subscription_id = None tenant_id = None # We need to set exactly one auth method az_cli_auth = None sp_env_auth = None browser_auth = True managed_identity_auth = None config_file = default_config_file_path fixer_config = default_fixer_config_file_path azure_region = "AzureCloud" with ( patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_identity", return_value=AzureIdentityInfo(), ), patch( "prowler.providers.azure.azure_provider.AzureProvider.get_locations", return_value={}, ), ): with pytest.raises(AzureBrowserAuthNoTenantIDError) as exception: _ = AzureProvider( az_cli_auth, sp_env_auth, browser_auth, managed_identity_auth, tenant_id, azure_region, subscription_id, config_file, fixer_config, ) assert exception.type == AzureBrowserAuthNoTenantIDError assert ( exception.value.args[0] == "[2004] Azure Tenant ID (--tenant-id) is required for browser authentication mode" ) def test_azure_provider_not_browser_auth_but_tenant_id(self): subscription_id = None tenant_id = "test-tenant-id" # We need to set exactly one auth method az_cli_auth = None sp_env_auth = None browser_auth = False managed_identity_auth = None config_file = default_config_file_path fixer_config = default_fixer_config_file_path azure_region = "AzureCloud" with ( patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_identity", return_value=AzureIdentityInfo(), ), patch( "prowler.providers.azure.azure_provider.AzureProvider.get_locations", return_value={}, ), ): with pytest.raises(AzureTenantIDNoBrowserAuthError) as exception: _ = AzureProvider( az_cli_auth, sp_env_auth, browser_auth, managed_identity_auth, tenant_id, azure_region, subscription_id, config_file, fixer_config, ) assert exception.type == AzureTenantIDNoBrowserAuthError assert ( exception.value.args[0] == "[2005] Azure Tenant ID (--tenant-id) is required for browser authentication mode" ) def test_test_connection_browser_auth(self): with ( patch( "prowler.providers.azure.azure_provider.DefaultAzureCredential" ) as mock_default_credential, patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_session" ) as mock_setup_session, patch( "prowler.providers.azure.azure_provider.SubscriptionClient" ) as mock_resource_client, ): # Mock the return value of DefaultAzureCredential mock_credentials = MagicMock() mock_credentials.get_token.return_value = AccessToken( token="fake_token", expires_on=9999999999 ) mock_default_credential.return_value = mock_credentials # Mock setup_session to return a mocked session object mock_session = MagicMock() mock_setup_session.return_value = mock_session # Mock ResourceManagementClient to avoid real API calls mock_client = MagicMock() mock_resource_client.return_value = mock_client test_connection = AzureProvider.test_connection( browser_auth=True, tenant_id=str(uuid4()), region="AzureCloud", raise_on_exception=False, ) assert isinstance(test_connection, Connection) assert test_connection.is_connected assert test_connection.error is None def test_test_connection_tenant_id_client_id_client_secret(self): with ( patch( "prowler.providers.azure.azure_provider.DefaultAzureCredential" ) as mock_default_credential, patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_session" ) as mock_setup_session, patch( "prowler.providers.azure.azure_provider.SubscriptionClient" ) as mock_resource_client, patch( "prowler.providers.azure.azure_provider.AzureProvider.validate_static_credentials" ) as mock_validate_static_credentials, ): # Mock the return value of DefaultAzureCredential mock_credentials = MagicMock() mock_credentials.get_token.return_value = AccessToken( token="fake_token", expires_on=9999999999 ) mock_default_credential.return_value = { "client_id": str(uuid4()), "client_secret": str(uuid4()), "tenant_id": str(uuid4()), } # Mock setup_session to return a mocked session object mock_session = MagicMock() mock_setup_session.return_value = mock_session # Mock ValidateStaticCredentials to avoid real API calls mock_validate_static_credentials.return_value = None # Mock ResourceManagementClient to avoid real API calls mock_client = MagicMock() mock_resource_client.return_value = mock_client test_connection = AzureProvider.test_connection( browser_auth=False, tenant_id=str(uuid4()), region="AzureCloud", raise_on_exception=False, client_id=str(uuid4()), client_secret=str(uuid4()), ) assert isinstance(test_connection, Connection) assert test_connection.is_connected assert test_connection.error is None def test_test_connection_provider_validation(self): with ( patch( "prowler.providers.azure.azure_provider.DefaultAzureCredential" ) as mock_default_credential, patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_session" ) as mock_setup_session, patch( "prowler.providers.azure.azure_provider.SubscriptionClient" ) as mock_resource_client, patch( "prowler.providers.azure.azure_provider.AzureProvider.validate_static_credentials" ) as mock_validate_static_credentials, ): # Mock the return value of DefaultAzureCredential mock_default_credential.return_value = { "client_id": str(uuid4()), "client_secret": str(uuid4()), "tenant_id": str(uuid4()), } # Mock setup_session to return a mocked session object mock_session = MagicMock() mock_setup_session.return_value = mock_session # Mock ValidateStaticCredentials to avoid real API calls mock_validate_static_credentials.return_value = None # Mock ResourceManagementClient to avoid real API calls mock_subscription = MagicMock() mock_subscription.subscription_id = "test_provider_id" mock_return_value = MagicMock() mock_return_value.subscriptions.list.return_value = [mock_subscription] mock_resource_client.return_value = mock_return_value test_connection = AzureProvider.test_connection( browser_auth=False, tenant_id=str(uuid4()), region="AzureCloud", raise_on_exception=False, client_id=str(uuid4()), client_secret=str(uuid4()), provider_id="test_provider_id", ) assert isinstance(test_connection, Connection) assert test_connection.is_connected assert test_connection.error is None def test_test_connection_provider_validation_error(self): with ( patch( "prowler.providers.azure.azure_provider.DefaultAzureCredential" ) as mock_default_credential, patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_session" ) as mock_setup_session, patch( "prowler.providers.azure.azure_provider.SubscriptionClient" ) as mock_resource_client, patch( "prowler.providers.azure.azure_provider.AzureProvider.validate_static_credentials" ) as mock_validate_static_credentials, ): # Mock the return value of DefaultAzureCredential mock_default_credential.return_value = { "client_id": str(uuid4()), "client_secret": str(uuid4()), "tenant_id": str(uuid4()), } # Mock setup_session to return a mocked session object mock_session = MagicMock() mock_setup_session.return_value = mock_session # Mock ValidateStaticCredentials to avoid real API calls mock_validate_static_credentials.return_value = None # Mock ResourceManagementClient to avoid real API calls mock_subscription = MagicMock() mock_subscription.subscription_id = "test_invalid_provider_id" mock_return_value = MagicMock() mock_return_value.subscriptions.list.return_value = [mock_subscription] mock_resource_client.return_value = mock_return_value test_connection = AzureProvider.test_connection( browser_auth=False, tenant_id=str(uuid4()), region="AzureCloud", raise_on_exception=False, client_id=str(uuid4()), client_secret=str(uuid4()), provider_id="test_provider_id", ) assert test_connection.error is not None assert isinstance(test_connection.error, AzureInvalidProviderIdError) assert ( "The provided credentials are not valid for the specified Azure subscription." in test_connection.error.args[0] ) def test_test_connection_with_ClientAuthenticationError(self): tenant_id = str(uuid4()) error_message = ( "Authentication failed: Unable to get authority configuration for " f"https://login.microsoftonline.com/{tenant_id}." ) with ( patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_session" ) as mock_setup_session, patch( "prowler.providers.azure.azure_provider.SubscriptionClient" ) as mock_subscription_client, pytest.raises(AzureHTTPResponseError) as exception, ): mock_setup_session.return_value = MagicMock() mock_client = MagicMock() mock_client.subscriptions = MagicMock() mock_client.subscriptions.list.side_effect = HttpResponseError( message=error_message ) mock_subscription_client.return_value = mock_client AzureProvider.test_connection( browser_auth=True, tenant_id=tenant_id, region="AzureCloud", ) assert exception.type == AzureHTTPResponseError assert exception.value.args[0] == ( f"[2010] Error in HTTP response from Azure - {error_message}" ) def test_test_connection_without_any_method(self): with pytest.raises(AzureNoAuthenticationMethodError) as exception: AzureProvider.test_connection() assert exception.type == AzureNoAuthenticationMethodError assert ( "[2003] Azure provider requires at least one authentication method set: [--az-cli-auth | --sp-env-auth | --browser-auth | --managed-identity-auth]" in exception.value.args[0] ) def test_test_connection_with_httpresponseerror(self): with ( patch( "prowler.providers.azure.azure_provider.AzureProvider.get_locations", return_value={}, ), patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_session" ) as mock_setup_session, ): mock_setup_session.side_effect = AzureHTTPResponseError( file="test_file", original_exception="Simulated HttpResponseError" ) with pytest.raises(AzureHTTPResponseError) as exception: AzureProvider.test_connection( az_cli_auth=True, raise_on_exception=True, ) assert exception.type == AzureHTTPResponseError assert ( exception.value.args[0] == "[2010] Error in HTTP response from Azure - Simulated HttpResponseError" ) def test_test_connection_with_exception(self): with patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_session" ) as mock_setup_session: mock_setup_session.side_effect = Exception("Simulated Exception") with pytest.raises(Exception) as exception: AzureProvider.test_connection( sp_env_auth=True, raise_on_exception=True, ) assert exception.type is Exception assert exception.value.args[0] == "Simulated Exception" @pytest.mark.parametrize( "subscription_ids, expected_regions", [ (None, {"region1", "region2", "region3"}), (["sub1", "sub2"], {"region1", "region2", "region3"}), ("sub1", {"region1", "region2"}), ("not_exists", set()), ], ) @patch("prowler.providers.azure.azure_provider.AzureProvider.get_locations") @patch( "prowler.providers.azure.azure_provider.AzureProvider.__init__", return_value=None, ) def test_get_regions( self, azure_provider_init_mock, # noqa: F841 azure_get_locations_mock, subscription_ids, expected_regions, ): azure_get_locations_mock.return_value = { "sub1": ["region1", "region2"], "sub2": ["region2", "region3"], } azure_provider = AzureProvider() regions = azure_provider.get_regions(subscription_ids=subscription_ids) assert regions == expected_regions class TestAzureProviderValidateResourceGroups: @patch( "prowler.providers.azure.azure_provider.AzureProvider.__init__", return_value=None, ) def _make_provider(self, _mock_init, subscriptions=None): provider = AzureProvider() provider._identity = MagicMock() provider._identity.subscriptions = subscriptions or {str(uuid4()): "Sub"} provider._session = MagicMock() provider._region_config = MagicMock() return provider @patch("prowler.providers.azure.azure_provider.ResourceManagementClient") def test_validate_resource_groups_exact_match(self, mock_rm_client): provider = self._make_provider() sub_name = list(provider._identity.subscriptions.keys())[0] mock_rg = MagicMock() mock_rg.name = "mygroup" mock_resource_groups = MagicMock() mock_resource_groups.list.return_value = [mock_rg] mock_rm_client.return_value.resource_groups = mock_resource_groups result = provider.validate_resource_groups(["mygroup"]) assert result[sub_name] == ["mygroup"] @patch("prowler.providers.azure.azure_provider.ResourceManagementClient") def test_validate_resource_groups_mixed_case(self, mock_rm_client): provider = self._make_provider() sub_name = list(provider._identity.subscriptions.keys())[0] mock_rg = MagicMock() mock_rg.name = "MyGroup" mock_resource_groups = MagicMock() mock_resource_groups.list.return_value = [mock_rg] mock_rm_client.return_value.resource_groups = mock_resource_groups result = provider.validate_resource_groups(["mygroup"]) assert result[sub_name] == ["MyGroup"] mock_resource_groups.list.assert_called_once() @patch("prowler.providers.azure.azure_provider.ResourceManagementClient") def test_validate_resource_groups_multiple_rgs(self, mock_rm_client): provider = self._make_provider() sub_name = list(provider._identity.subscriptions.keys())[0] rg1, rg2 = MagicMock(), MagicMock() rg1.name = "rg1" rg2.name = "rg2" mock_resource_groups = MagicMock() mock_resource_groups.list.return_value = [rg1, rg2] mock_rm_client.return_value.resource_groups = mock_resource_groups result = provider.validate_resource_groups(["rg1", "rg2"]) assert set(result[sub_name]) == {"rg1", "rg2"} @patch("prowler.providers.azure.azure_provider.ResourceManagementClient") def test_validate_resource_groups_not_found(self, mock_rm_client): provider = self._make_provider() sub_name = list(provider._identity.subscriptions.keys())[0] mock_rg = MagicMock() mock_rg.name = "existing" mock_resource_groups = MagicMock() mock_resource_groups.list.return_value = [mock_rg] mock_rm_client.return_value.resource_groups = mock_resource_groups result = provider.validate_resource_groups(["nonexistent"]) assert result[sub_name] == [] def test_validate_resource_groups_empty_input(self): provider = self._make_provider() result = provider.validate_resource_groups([]) assert result == {} @patch("prowler.providers.azure.azure_provider.ResourceManagementClient") def test_validate_resource_groups_strips_whitespace(self, mock_rm_client): provider = self._make_provider() sub_name = list(provider._identity.subscriptions.keys())[0] mock_rg = MagicMock() mock_rg.name = "rg-prod" mock_resource_groups = MagicMock() mock_resource_groups.list.return_value = [mock_rg] mock_rm_client.return_value.resource_groups = mock_resource_groups result = provider.validate_resource_groups([" rg-prod "]) assert result[sub_name] == ["rg-prod"] class TestAzureProviderSetupIdentitySubscriptions: """Regression tests ensuring identity.subscriptions preserves every subscription even when multiple Azure subscriptions share the same display_name (which is permitted by Azure).""" @staticmethod def _mock_subscription(display_name, subscription_id): mock_subscription = MagicMock() mock_subscription.display_name = display_name mock_subscription.subscription_id = subscription_id return mock_subscription @staticmethod def _build_subscriptions_client_mock(list_result=None, get_map=None): """Construct a fully explicit SubscriptionClient mock so the tests do not depend on MagicMock auto-attribute behavior, which makes the suite sensitive to shared state across test files.""" subscriptions_operations = MagicMock() subscriptions_operations.list = MagicMock(return_value=list_result or []) if get_map is not None: subscriptions_operations.get = MagicMock( side_effect=lambda subscription_id: get_map[subscription_id] ) else: subscriptions_operations.get = MagicMock() tenants_operations = MagicMock() tenants_operations.list = MagicMock(return_value=[]) client_instance = MagicMock() client_instance.subscriptions = subscriptions_operations client_instance.tenants = tenants_operations client_class = MagicMock(return_value=client_instance) return client_class @staticmethod def _build_provider(): """Create an AzureProvider instance ready to invoke setup_identity with auth flags left False so the AAD lookup branches are skipped and the test focuses on the subscription resolution logic.""" with patch.object(AzureProvider, "__init__", return_value=None): azure_provider = AzureProvider() azure_provider._session = MagicMock() azure_provider._region_config = AzureRegionConfig( name="AzureCloud", authority=None, base_url="https://management.azure.com", credential_scopes=["https://management.azure.com/.default"], ) return azure_provider def test_setup_identity_auto_discovery_preserves_unique_display_names(self): first_id = str(uuid4()) second_id = str(uuid4()) client_class = self._build_subscriptions_client_mock( list_result=[ self._mock_subscription("Unique Name One", first_id), self._mock_subscription("Unique Name Two", second_id), ] ) with patch( "prowler.providers.azure.azure_provider.SubscriptionClient", client_class, ): azure_provider = self._build_provider() identity = azure_provider.setup_identity( az_cli_auth=False, sp_env_auth=False, browser_auth=False, managed_identity_auth=False, subscription_ids=[], client_id=None, ) assert identity.subscriptions == { first_id: "Unique Name One", second_id: "Unique Name Two", } def test_setup_identity_auto_discovery_preserves_duplicate_display_names( self, ): shared_name = "Shared Display Name" first_id = str(uuid4()) second_id = str(uuid4()) client_class = self._build_subscriptions_client_mock( list_result=[ self._mock_subscription(shared_name, first_id), self._mock_subscription(shared_name, second_id), ] ) with patch( "prowler.providers.azure.azure_provider.SubscriptionClient", client_class, ): azure_provider = self._build_provider() identity = azure_provider.setup_identity( az_cli_auth=False, sp_env_auth=False, browser_auth=False, managed_identity_auth=False, subscription_ids=[], client_id=None, ) assert identity.subscriptions == { first_id: shared_name, second_id: shared_name, } def test_setup_identity_filtered_preserves_unique_display_names(self): first_id = str(uuid4()) second_id = str(uuid4()) client_class = self._build_subscriptions_client_mock( get_map={ first_id: self._mock_subscription("Unique Name One", first_id), second_id: self._mock_subscription("Unique Name Two", second_id), } ) with patch( "prowler.providers.azure.azure_provider.SubscriptionClient", client_class, ): azure_provider = self._build_provider() identity = azure_provider.setup_identity( az_cli_auth=False, sp_env_auth=False, browser_auth=False, managed_identity_auth=False, subscription_ids=[first_id, second_id], client_id=None, ) assert identity.subscriptions == { first_id: "Unique Name One", second_id: "Unique Name Two", } def test_setup_identity_filtered_preserves_duplicate_display_names(self): shared_name = "Shared Display Name" first_id = str(uuid4()) second_id = str(uuid4()) client_class = self._build_subscriptions_client_mock( get_map={ first_id: self._mock_subscription(shared_name, first_id), second_id: self._mock_subscription(shared_name, second_id), } ) with patch( "prowler.providers.azure.azure_provider.SubscriptionClient", client_class, ): azure_provider = self._build_provider() identity = azure_provider.setup_identity( az_cli_auth=False, sp_env_auth=False, browser_auth=False, managed_identity_auth=False, subscription_ids=[first_id, second_id], client_id=None, ) assert identity.subscriptions == { first_id: shared_name, second_id: shared_name, } class TestAzureProviderSovereignCloudSupport: """Sovereign-cloud authentication coverage across AzureCloud, AzureChinaCloud and AzureUSGovernment for every authentication code path Prowler exposes. Pinned to issue #8425.""" REGION_CASES = [ ( "AzureCloud", None, "https://management.azure.com", ["https://management.azure.com/.default"], "https://graph.microsoft.com/.default", "https://api.loganalytics.io", "login.microsoftonline.com", ), ( "AzureChinaCloud", "login.chinacloudapi.cn", "https://management.chinacloudapi.cn", ["https://management.chinacloudapi.cn/.default"], "https://microsoftgraph.chinacloudapi.cn/.default", "https://api.loganalytics.azure.cn", "login.chinacloudapi.cn", ), ( "AzureUSGovernment", "login.microsoftonline.us", "https://management.usgovcloudapi.net", ["https://management.usgovcloudapi.net/.default"], "https://graph.microsoft.us/.default", "https://api.loganalytics.us", "login.microsoftonline.us", ), ] @pytest.mark.parametrize( "region,authority,base_url,credential_scopes,graph_scope,logs_endpoint,_login_endpoint", REGION_CASES, ) def test_setup_region_config_per_cloud( self, region, authority, base_url, credential_scopes, graph_scope, logs_endpoint, _login_endpoint, ): config = AzureProvider.setup_region_config(region) # graph_host mirrors graph_scope without the `/.default` suffix; we # derive it here to avoid threading a separate parameter through every # parametrized test in this class. expected_graph_host = graph_scope.removesuffix("/.default") assert config == AzureRegionConfig( name=region, authority=authority, base_url=base_url, credential_scopes=credential_scopes, graph_host=expected_graph_host, graph_scope=graph_scope, logs_endpoint=logs_endpoint, ) @pytest.mark.parametrize( "region,authority,_base_url,_credential_scopes,_graph_scope,_logs_endpoint,_login_endpoint", REGION_CASES, ) def test_setup_session_static_credentials_passes_authority( self, region, authority, _base_url, _credential_scopes, _graph_scope, _logs_endpoint, _login_endpoint, ): with patch( "prowler.providers.azure.azure_provider.ClientSecretCredential" ) as mock_client_secret_credential: azure_credentials = { "tenant_id": str(uuid4()), "client_id": str(uuid4()), "client_secret": "fake-secret-value", } region_config = AzureProvider.setup_region_config(region) AzureProvider.setup_session( az_cli_auth=False, sp_env_auth=False, browser_auth=False, managed_identity_auth=False, tenant_id=azure_credentials["tenant_id"], azure_credentials=azure_credentials, region_config=region_config, ) mock_client_secret_credential.assert_called_once_with( tenant_id=azure_credentials["tenant_id"], client_id=azure_credentials["client_id"], client_secret=azure_credentials["client_secret"], authority=authority, ) @pytest.mark.parametrize( "region,authority,_base_url,_credential_scopes,_graph_scope,_logs_endpoint,_login_endpoint", REGION_CASES, ) def test_setup_session_browser_auth_passes_authority( self, region, authority, _base_url, _credential_scopes, _graph_scope, _logs_endpoint, _login_endpoint, ): with patch( "prowler.providers.azure.azure_provider.InteractiveBrowserCredential" ) as mock_interactive_browser_credential: tenant_id = str(uuid4()) region_config = AzureProvider.setup_region_config(region) AzureProvider.setup_session( az_cli_auth=False, sp_env_auth=False, browser_auth=True, managed_identity_auth=False, tenant_id=tenant_id, azure_credentials=None, region_config=region_config, ) mock_interactive_browser_credential.assert_called_once_with( tenant_id=tenant_id, authority=authority, ) @pytest.mark.parametrize( "region,authority,_base_url,_credential_scopes,_graph_scope,_logs_endpoint,_login_endpoint", REGION_CASES, ) def test_setup_session_default_credential_passes_authority( self, region, authority, _base_url, _credential_scopes, _graph_scope, _logs_endpoint, _login_endpoint, ): with patch( "prowler.providers.azure.azure_provider.DefaultAzureCredential" ) as mock_default_credential: region_config = AzureProvider.setup_region_config(region) AzureProvider.setup_session( az_cli_auth=True, sp_env_auth=False, browser_auth=False, managed_identity_auth=False, tenant_id=None, azure_credentials=None, region_config=region_config, ) _, called_kwargs = mock_default_credential.call_args assert called_kwargs["authority"] == authority assert called_kwargs["exclude_cli_credential"] is False assert called_kwargs["exclude_environment_credential"] is True assert called_kwargs["exclude_managed_identity_credential"] is True @pytest.mark.parametrize( "region,_authority,_base_url,_credential_scopes,graph_scope,_logs_endpoint,login_endpoint", REGION_CASES, ) def test_verify_client_uses_per_cloud_endpoints( self, region, _authority, _base_url, _credential_scopes, graph_scope, _logs_endpoint, login_endpoint, ): tenant_id = str(uuid4()) client_id = str(uuid4()) client_secret = "fake-secret" region_config = AzureProvider.setup_region_config(region) with patch("prowler.providers.azure.azure_provider.requests.post") as mock_post: mock_post.return_value = MagicMock() mock_post.return_value.json.return_value = {"access_token": "fake-token"} AzureProvider.verify_client( tenant_id, client_id, client_secret, region_config ) mock_post.assert_called_once() args, kwargs = mock_post.call_args assert args[0] == ( f"https://{login_endpoint}/{tenant_id}/oauth2/v2.0/token" ) assert kwargs["data"]["scope"] == graph_scope assert kwargs["data"]["client_id"] == client_id assert kwargs["data"]["client_secret"] == client_secret @pytest.mark.parametrize( "region,_authority,base_url,credential_scopes,_graph_scope,_logs_endpoint,_login_endpoint", REGION_CASES, ) def test_test_connection_passes_base_url_to_subscription_client( self, region, _authority, base_url, credential_scopes, _graph_scope, _logs_endpoint, _login_endpoint, ): subscription_client_instance = MagicMock() subscription_client_instance.subscriptions = MagicMock() subscription_client_instance.subscriptions.list = MagicMock(return_value=[]) subscription_client_class = MagicMock(return_value=subscription_client_instance) with ( patch( "prowler.providers.azure.azure_provider.AzureProvider.setup_session" ) as mock_setup_session, patch( "prowler.providers.azure.azure_provider.SubscriptionClient", subscription_client_class, ), ): mock_setup_session.return_value = MagicMock() AzureProvider.test_connection( az_cli_auth=True, region=region, raise_on_exception=False, ) subscription_client_class.assert_called_once() _, kwargs = subscription_client_class.call_args assert kwargs["base_url"] == base_url assert kwargs["credential_scopes"] == credential_scopes @pytest.mark.parametrize( "region,_authority,base_url,credential_scopes,_graph_scope,_logs_endpoint,_login_endpoint", REGION_CASES, ) def test_get_locations_passes_base_url_to_subscription_client( self, region, _authority, base_url, credential_scopes, _graph_scope, _logs_endpoint, _login_endpoint, ): subscription_client_instance = MagicMock() subscription_client_instance.subscriptions = MagicMock() subscription_client_instance.subscriptions.list_locations = MagicMock( return_value=[] ) subscription_client_class = MagicMock(return_value=subscription_client_instance) with ( patch.object(AzureProvider, "__init__", return_value=None), patch( "prowler.providers.azure.azure_provider.SubscriptionClient", subscription_client_class, ), ): azure_provider = AzureProvider() azure_provider._session = MagicMock() azure_provider._region_config = AzureProvider.setup_region_config(region) azure_provider._identity = AzureIdentityInfo(subscriptions={}) azure_provider.get_locations() subscription_client_class.assert_called_once() _, kwargs = subscription_client_class.call_args assert kwargs["base_url"] == base_url assert kwargs["credential_scopes"] == credential_scopes class TestAzureProviderSetupIdentityEventLoop: """Regression for the Celery worker scenario where asyncio.get_event_loop() raised "There is no current event loop in thread 'MainThread'." on Python 3.12. setup_identity now uses asyncio.run(), which creates its own loop and must work without a pre-existing one in the current thread.""" @staticmethod def _mock_subscription(display_name, subscription_id): mock_subscription = MagicMock() mock_subscription.display_name = display_name mock_subscription.subscription_id = subscription_id return mock_subscription @staticmethod def _build_subscriptions_client_mock(subscriptions): subscriptions_operations = MagicMock() subscriptions_operations.list = MagicMock(return_value=subscriptions) subscriptions_operations.get = MagicMock() tenants_operations = MagicMock() tenants_operations.list = MagicMock(return_value=[]) client_instance = MagicMock() client_instance.subscriptions = subscriptions_operations client_instance.tenants = tenants_operations return MagicMock(return_value=client_instance) @staticmethod def _build_provider(): with patch.object(AzureProvider, "__init__", return_value=None): azure_provider = AzureProvider() azure_provider._session = MagicMock() azure_provider._region_config = AzureRegionConfig( name="AzureCloud", authority=None, base_url="https://management.azure.com", credential_scopes=["https://management.azure.com/.default"], ) return azure_provider def test_setup_identity_succeeds_without_active_event_loop(self): sub_id = str(uuid4()) subscriptions_client = self._build_subscriptions_client_mock( [self._mock_subscription("Sub", sub_id)] ) graph_client = MagicMock() graph_client.domains.get = AsyncMock(return_value=MagicMock(value=[])) graph_client.me.get = AsyncMock(return_value=None) # Simulate the Celery worker state: no event loop registered for the # current thread. Before the fix this combination triggered # `RuntimeError: There is no current event loop in thread 'MainThread'.` # on Python 3.12 from asyncio.get_event_loop(). asyncio.set_event_loop(None) try: with ( patch( "prowler.providers.azure.azure_provider.GraphServiceClient", return_value=graph_client, ), patch( "prowler.providers.azure.azure_provider.SubscriptionClient", subscriptions_client, ), ): azure_provider = self._build_provider() identity = azure_provider.setup_identity( az_cli_auth=False, sp_env_auth=True, browser_auth=False, managed_identity_auth=False, subscription_ids=[], client_id="00000000-0000-0000-0000-000000000000", ) finally: # Re-arm a loop for sibling tests that may rely on the default. asyncio.set_event_loop(asyncio.new_event_loop()) assert isinstance(identity, AzureIdentityInfo) assert identity.subscriptions == {sub_id: "Sub"} graph_client.domains.get.assert_awaited_once()