mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
537c3ea71e
Signed-off-by: Legin-ML <leginml2004@gmail.com>
466 lines
19 KiB
Python
466 lines
19 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
|
|
|
|
from prowler.providers.azure.services.postgresql.postgresql_service import (
|
|
EntraIdAdmin,
|
|
Firewall,
|
|
PostgreSQL,
|
|
Server,
|
|
)
|
|
from tests.providers.azure.azure_fixtures import (
|
|
AZURE_SUBSCRIPTION_ID,
|
|
RESOURCE_GROUP,
|
|
RESOURCE_GROUP_LIST,
|
|
set_mocked_azure_provider,
|
|
)
|
|
|
|
|
|
def mock_sqlserver_get_postgresql_flexible_servers(_):
|
|
firewall = Firewall(
|
|
id="id",
|
|
name="name",
|
|
start_ip="start_ip",
|
|
end_ip="end_ip",
|
|
)
|
|
return {
|
|
AZURE_SUBSCRIPTION_ID: [
|
|
Server(
|
|
id="id",
|
|
name="name",
|
|
resource_group="resource_group",
|
|
location="location",
|
|
require_secure_transport="ON",
|
|
active_directory_auth="ENABLED",
|
|
entra_id_admins=[
|
|
EntraIdAdmin(
|
|
object_id="11111111-1111-1111-1111-111111111111",
|
|
principal_name="Test Admin User",
|
|
principal_type="User",
|
|
tenant_id="22222222-2222-2222-2222-222222222222",
|
|
)
|
|
],
|
|
log_checkpoints="ON",
|
|
log_connections="ON",
|
|
log_disconnections="ON",
|
|
connection_throttling="ON",
|
|
log_retention_days="3",
|
|
firewall=[firewall],
|
|
)
|
|
]
|
|
}
|
|
|
|
|
|
@patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.PostgreSQL._get_flexible_servers",
|
|
new=mock_sqlserver_get_postgresql_flexible_servers,
|
|
)
|
|
class Test_SqlServer_Service:
|
|
def test_get_client(self):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID].__class__.__name__
|
|
== "PostgreSQLManagementClient"
|
|
)
|
|
|
|
def test_get_sql_servers(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].__class__.__name__
|
|
== "Server"
|
|
)
|
|
assert postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].id == "id"
|
|
assert postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].name == "name"
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].location == "location"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].resource_group
|
|
== "resource_group"
|
|
)
|
|
|
|
def test_get_resource_group(self):
|
|
id = "/subscriptions/subscription/resourceGroups/resource_group/providers/Microsoft.DBforPostgreSQL/flexibleServers/server"
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
assert postgresql._get_resource_group(id) == "resource_group"
|
|
|
|
def test_get_require_secure_transport(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][
|
|
0
|
|
].require_secure_transport
|
|
== "ON"
|
|
)
|
|
|
|
def test_get_log_checkpoints(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].log_checkpoints == "ON"
|
|
)
|
|
|
|
def test_get_log_connections(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].log_connections == "ON"
|
|
)
|
|
|
|
def test_get_log_disconnections(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].log_disconnections
|
|
== "ON"
|
|
)
|
|
|
|
def test_get_connection_throttling(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].connection_throttling
|
|
== "ON"
|
|
)
|
|
|
|
def test_get_connection_throttling_missing_parameter_returns_none(self):
|
|
# PostgreSQL v18 removed the "connection_throttle.enable" parameter; when
|
|
# it is genuinely absent the Azure SDK raises ResourceNotFoundError, and
|
|
# the service treats that as "not enabled" (quiet None) instead of
|
|
# aborting the whole subscription's server inventory.
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
mock_client = MagicMock()
|
|
mock_client.configurations.get.side_effect = ResourceNotFoundError(
|
|
"The configuration 'connection_throttle.enable' does not exist for "
|
|
"server version 18."
|
|
)
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID] = mock_client
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.logger"
|
|
) as mock_logger:
|
|
result = postgresql._get_connection_throttling(
|
|
AZURE_SUBSCRIPTION_ID, "resource_group", "server_name"
|
|
)
|
|
assert result is None
|
|
mock_logger.error.assert_not_called()
|
|
|
|
def test_get_connection_throttling_unexpected_error_propagates(self):
|
|
# Any other failure (permissions, throttling, transient API errors) must
|
|
# NOT be swallowed into None: that would make the downstream check report
|
|
# the server as having throttling disabled, hiding a collection failure
|
|
# as a security finding. The error propagates so the per-server handler
|
|
# in _get_flexible_servers can record it as a collection failure.
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
mock_client = MagicMock()
|
|
mock_client.configurations.get.side_effect = HttpResponseError(
|
|
"(AuthorizationFailed) permission denied"
|
|
)
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID] = mock_client
|
|
with pytest.raises(HttpResponseError):
|
|
postgresql._get_connection_throttling(
|
|
AZURE_SUBSCRIPTION_ID, "resource_group", "server_name"
|
|
)
|
|
|
|
def test_get_log_retention_days(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].log_retention_days
|
|
== "3"
|
|
)
|
|
|
|
def test_get_active_directory_auth(self):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgresql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].active_directory_auth
|
|
== "ENABLED"
|
|
)
|
|
|
|
def test_get_entra_id_admins(self):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
admins = postgresql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].entra_id_admins
|
|
assert isinstance(admins, list)
|
|
assert len(admins) == 1
|
|
assert isinstance(admins[0], EntraIdAdmin)
|
|
assert admins[0].principal_name == "Test Admin User"
|
|
assert admins[0].object_id == "11111111-1111-1111-1111-111111111111"
|
|
|
|
def test_get_entra_id_admins_aad_not_enabled_logs_warning(self):
|
|
# A server using PostgreSQL authentication only (Entra/Azure AD auth
|
|
# disabled) is an expected state; it should be logged as a warning, not
|
|
# an error, and return an empty admin list.
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
mock_client = MagicMock()
|
|
mock_client.administrators.list_by_server.side_effect = Exception(
|
|
"Azure AD authentication is not enabled for the given server"
|
|
)
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID] = mock_client
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.logger"
|
|
) as mock_logger:
|
|
result = postgresql._get_entra_id_admins(
|
|
AZURE_SUBSCRIPTION_ID, "resource_group", "server_name"
|
|
)
|
|
assert result == []
|
|
mock_logger.warning.assert_called_once()
|
|
mock_logger.error.assert_not_called()
|
|
|
|
def test_get_entra_id_admins_unexpected_error_logs_error(self):
|
|
# Any other failure (permissions, throttling, transient API errors) is a
|
|
# genuine problem and must still be logged as an error.
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
mock_client = MagicMock()
|
|
mock_client.administrators.list_by_server.side_effect = Exception(
|
|
"Some unexpected failure"
|
|
)
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID] = mock_client
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.logger"
|
|
) as mock_logger:
|
|
result = postgresql._get_entra_id_admins(
|
|
AZURE_SUBSCRIPTION_ID, "resource_group", "server_name"
|
|
)
|
|
assert result == []
|
|
mock_logger.error.assert_called_once()
|
|
mock_logger.warning.assert_not_called()
|
|
|
|
def test_get_firewall(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0]
|
|
.firewall[0]
|
|
.__class__.__name__
|
|
== "Firewall"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].firewall[0].id == "id"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].firewall[0].name
|
|
== "name"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].firewall[0].start_ip
|
|
== "start_ip"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].firewall[0].end_ip
|
|
== "end_ip"
|
|
)
|
|
|
|
|
|
class Test_PostgreSQL_get_flexible_servers:
|
|
def test_get_flexible_servers_no_resource_groups(self):
|
|
mock_client = MagicMock()
|
|
mock_client.servers.list.return_value = []
|
|
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.PostgreSQL._get_flexible_servers",
|
|
return_value={},
|
|
):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
|
|
postgresql.clients = {AZURE_SUBSCRIPTION_ID: mock_client}
|
|
postgresql.resource_groups = None
|
|
|
|
result = postgresql._get_flexible_servers()
|
|
|
|
mock_client.servers.list.assert_called_once()
|
|
mock_client.servers.list_by_resource_group.assert_not_called()
|
|
assert AZURE_SUBSCRIPTION_ID in result
|
|
|
|
def test_get_flexible_servers_with_resource_group(self):
|
|
mock_client = MagicMock()
|
|
mock_client.servers.list_by_resource_group.return_value = []
|
|
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.PostgreSQL._get_flexible_servers",
|
|
return_value={},
|
|
):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
|
|
postgresql.clients = {AZURE_SUBSCRIPTION_ID: mock_client}
|
|
postgresql.resource_groups = {AZURE_SUBSCRIPTION_ID: [RESOURCE_GROUP]}
|
|
|
|
result = postgresql._get_flexible_servers()
|
|
|
|
mock_client.servers.list_by_resource_group.assert_called_once_with(
|
|
resource_group_name=RESOURCE_GROUP
|
|
)
|
|
mock_client.servers.list.assert_not_called()
|
|
assert AZURE_SUBSCRIPTION_ID in result
|
|
|
|
def test_get_flexible_servers_empty_resource_group_for_subscription(self):
|
|
mock_client = MagicMock()
|
|
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.PostgreSQL._get_flexible_servers",
|
|
return_value={},
|
|
):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
|
|
postgresql.clients = {AZURE_SUBSCRIPTION_ID: mock_client}
|
|
postgresql.resource_groups = {AZURE_SUBSCRIPTION_ID: []}
|
|
|
|
result = postgresql._get_flexible_servers()
|
|
|
|
mock_client.servers.list_by_resource_group.assert_not_called()
|
|
mock_client.servers.list.assert_not_called()
|
|
assert result[AZURE_SUBSCRIPTION_ID] == []
|
|
|
|
def test_get_flexible_servers_with_multiple_resource_groups(self):
|
|
mock_client = MagicMock()
|
|
mock_client.servers.list_by_resource_group.return_value = []
|
|
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.PostgreSQL._get_flexible_servers",
|
|
return_value={},
|
|
):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
|
|
postgresql.clients = {AZURE_SUBSCRIPTION_ID: mock_client}
|
|
postgresql.resource_groups = {AZURE_SUBSCRIPTION_ID: RESOURCE_GROUP_LIST}
|
|
|
|
result = postgresql._get_flexible_servers()
|
|
|
|
assert mock_client.servers.list_by_resource_group.call_count == 2
|
|
assert AZURE_SUBSCRIPTION_ID in result
|
|
|
|
def test_get_flexible_servers_with_mixed_case_resource_group(self):
|
|
mock_client = MagicMock()
|
|
mock_client.servers.list_by_resource_group.return_value = []
|
|
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.PostgreSQL._get_flexible_servers",
|
|
return_value={},
|
|
):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
|
|
postgresql.clients = {AZURE_SUBSCRIPTION_ID: mock_client}
|
|
postgresql.resource_groups = {AZURE_SUBSCRIPTION_ID: ["RG"]}
|
|
|
|
postgresql._get_flexible_servers()
|
|
|
|
mock_client.servers.list_by_resource_group.assert_called_once_with(
|
|
resource_group_name="RG"
|
|
)
|
|
|
|
|
|
def _make_server(name):
|
|
server = MagicMock()
|
|
server.id = (
|
|
f"/subscriptions/{AZURE_SUBSCRIPTION_ID}/resourceGroups/rg/providers/"
|
|
f"Microsoft.DBforPostgreSQL/flexibleServers/{name}"
|
|
)
|
|
server.name = name
|
|
return server
|
|
|
|
|
|
class Test_PostgreSQL_Service_Resilience:
|
|
"""Collecting one flexible server must never abort collection of the rest of
|
|
the subscription (regression: a missing/failing per-server configuration
|
|
lookup silently dropped every remaining server)."""
|
|
|
|
def _build_service_with_client(self, mock_client):
|
|
# Skip the real network call during construction, then run the real
|
|
# collection against the mocked management client.
|
|
with patch.object(PostgreSQL, "_get_flexible_servers", return_value={}):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
postgresql.clients = {AZURE_SUBSCRIPTION_ID: mock_client}
|
|
return postgresql
|
|
|
|
def test_missing_connection_throttle_config_still_collects_server(self):
|
|
# The "connection_throttle.enable" parameter was removed in PostgreSQL
|
|
# 16+, so the lookup raises ConfigurationNotExists on newer servers.
|
|
dev = _make_server("dev")
|
|
prd = _make_server("prd")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.servers.list.return_value = [dev, prd]
|
|
server_details = MagicMock()
|
|
server_details.location = "westeurope"
|
|
mock_client.servers.get.return_value = server_details
|
|
mock_client.administrators.list_by_server.return_value = []
|
|
mock_client.firewall_rules.list_by_server.return_value = []
|
|
|
|
def configurations_get(resource_group, server_name, key):
|
|
if key == "connection_throttle.enable" and server_name == "prd":
|
|
# Azure raises ResourceNotFoundError (ConfigurationNotExists)
|
|
# when the parameter does not exist on the server.
|
|
raise ResourceNotFoundError(
|
|
"(ConfigurationNotExists) The configuration "
|
|
"'connection_throttle.enable' does not exist for prd server "
|
|
"version 18."
|
|
)
|
|
return MagicMock(value="ON")
|
|
|
|
mock_client.configurations.get.side_effect = configurations_get
|
|
|
|
postgresql = self._build_service_with_client(mock_client)
|
|
servers = postgresql._get_flexible_servers()
|
|
|
|
names = sorted(server.name for server in servers[AZURE_SUBSCRIPTION_ID])
|
|
assert names == ["dev", "prd"]
|
|
prd_server = next(s for s in servers[AZURE_SUBSCRIPTION_ID] if s.name == "prd")
|
|
assert prd_server.connection_throttling is None
|
|
dev_server = next(s for s in servers[AZURE_SUBSCRIPTION_ID] if s.name == "dev")
|
|
assert dev_server.connection_throttling == "ON"
|
|
|
|
def test_unexpected_throttling_error_is_not_silently_collected(self):
|
|
# An unexpected failure reading "connection_throttle.enable" (e.g. a
|
|
# permission, throttling, or transient SDK error) must NOT be turned
|
|
# into connection_throttling=None: that would make the downstream check
|
|
# report the server as having throttling disabled, hiding a collection
|
|
# failure as a security finding. Only ResourceNotFoundError (the
|
|
# parameter genuinely missing) is treated as "not enabled"; anything
|
|
# else isolates to that server, which is dropped rather than fabricated.
|
|
ok = _make_server("ok")
|
|
denied = _make_server("denied")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.servers.list.return_value = [ok, denied]
|
|
server_details = MagicMock()
|
|
server_details.location = "westeurope"
|
|
mock_client.servers.get.return_value = server_details
|
|
mock_client.administrators.list_by_server.return_value = []
|
|
mock_client.firewall_rules.list_by_server.return_value = []
|
|
|
|
def configurations_get(resource_group, server_name, key):
|
|
if key == "connection_throttle.enable" and server_name == "denied":
|
|
raise HttpResponseError("(AuthorizationFailed) permission denied")
|
|
return MagicMock(value="ON")
|
|
|
|
mock_client.configurations.get.side_effect = configurations_get
|
|
|
|
postgresql = self._build_service_with_client(mock_client)
|
|
servers = postgresql._get_flexible_servers()
|
|
|
|
collected = servers[AZURE_SUBSCRIPTION_ID]
|
|
# The server whose throttling lookup failed unexpectedly is dropped,
|
|
# not collected with a fabricated connection_throttling=None.
|
|
assert [server.name for server in collected] == ["ok"]
|
|
assert all(server.connection_throttling is not None for server in collected)
|
|
|
|
def test_one_server_hard_failure_does_not_drop_others(self):
|
|
# A failure unrelated to a guarded getter (here, fetching the server
|
|
# details) must isolate to that server, not the whole subscription.
|
|
ok = _make_server("ok")
|
|
broken = _make_server("broken")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.servers.list.return_value = [broken, ok]
|
|
mock_client.administrators.list_by_server.return_value = []
|
|
mock_client.firewall_rules.list_by_server.return_value = []
|
|
mock_client.configurations.get.return_value = MagicMock(value="ON")
|
|
|
|
def servers_get(resource_group, server_name):
|
|
if server_name == "broken":
|
|
raise Exception("boom: transient failure fetching server details")
|
|
details = MagicMock()
|
|
details.location = "westeurope"
|
|
return details
|
|
|
|
mock_client.servers.get.side_effect = servers_get
|
|
|
|
postgresql = self._build_service_with_client(mock_client)
|
|
servers = postgresql._get_flexible_servers()
|
|
|
|
names = [server.name for server in servers[AZURE_SUBSCRIPTION_ID]]
|
|
assert names == ["ok"]
|