mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
d27ec7d62e
Co-authored-by: Daniel Barranquero <danielbo2001@gmail.com>
241 lines
9.2 KiB
Python
241 lines
9.2 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
from prowler.providers.azure.services.postgresql.postgresql_service import (
|
|
EntraIdAdmin,
|
|
Firewall,
|
|
PostgreSQL,
|
|
Server,
|
|
)
|
|
from tests.providers.azure.azure_fixtures import (
|
|
AZURE_SUBSCRIPTION_ID,
|
|
set_mocked_azure_provider,
|
|
)
|
|
|
|
|
|
def mock_sqlserver_get_postgresql_flexible_servers(_):
|
|
firewall = Firewall(
|
|
id="id",
|
|
name="name",
|
|
start_ip="start_ip",
|
|
end_ip="end_ip",
|
|
)
|
|
return {
|
|
AZURE_SUBSCRIPTION_ID: [
|
|
Server(
|
|
id="id",
|
|
name="name",
|
|
resource_group="resource_group",
|
|
location="location",
|
|
require_secure_transport="ON",
|
|
active_directory_auth="ENABLED",
|
|
entra_id_admins=[
|
|
EntraIdAdmin(
|
|
object_id="11111111-1111-1111-1111-111111111111",
|
|
principal_name="Test Admin User",
|
|
principal_type="User",
|
|
tenant_id="22222222-2222-2222-2222-222222222222",
|
|
)
|
|
],
|
|
log_checkpoints="ON",
|
|
log_connections="ON",
|
|
log_disconnections="ON",
|
|
connection_throttling="ON",
|
|
log_retention_days="3",
|
|
firewall=[firewall],
|
|
)
|
|
]
|
|
}
|
|
|
|
|
|
@patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.PostgreSQL._get_flexible_servers",
|
|
new=mock_sqlserver_get_postgresql_flexible_servers,
|
|
)
|
|
class Test_SqlServer_Service:
|
|
def test_get_client(self):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID].__class__.__name__
|
|
== "PostgreSQLManagementClient"
|
|
)
|
|
|
|
def test_get_sql_servers(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].__class__.__name__
|
|
== "Server"
|
|
)
|
|
assert postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].id == "id"
|
|
assert postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].name == "name"
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].location == "location"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].resource_group
|
|
== "resource_group"
|
|
)
|
|
|
|
def test_get_resource_group(self):
|
|
id = "/subscriptions/subscription/resourceGroups/resource_group/providers/Microsoft.DBforPostgreSQL/flexibleServers/server"
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
assert postgresql._get_resource_group(id) == "resource_group"
|
|
|
|
def test_get_require_secure_transport(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][
|
|
0
|
|
].require_secure_transport
|
|
== "ON"
|
|
)
|
|
|
|
def test_get_log_checkpoints(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].log_checkpoints == "ON"
|
|
)
|
|
|
|
def test_get_log_connections(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].log_connections == "ON"
|
|
)
|
|
|
|
def test_get_log_disconnections(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].log_disconnections
|
|
== "ON"
|
|
)
|
|
|
|
def test_get_connection_throttling(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].connection_throttling
|
|
== "ON"
|
|
)
|
|
|
|
def test_get_connection_throttling_missing_parameter_returns_none(self):
|
|
# PostgreSQL v18 removed the "connection_throttle.enable" parameter; the
|
|
# service must degrade gracefully (quiet None) instead of raising and
|
|
# aborting the whole subscription's server inventory.
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
mock_client = MagicMock()
|
|
mock_client.configurations.get.side_effect = Exception(
|
|
"The configuration 'connection_throttle.enable' does not exist for "
|
|
"server version 18."
|
|
)
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID] = mock_client
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.logger"
|
|
) as mock_logger:
|
|
result = postgresql._get_connection_throttling(
|
|
AZURE_SUBSCRIPTION_ID, "resource_group", "server_name"
|
|
)
|
|
assert result is None
|
|
mock_logger.error.assert_not_called()
|
|
|
|
def test_get_connection_throttling_unexpected_error_logs_error(self):
|
|
# Any other failure (permissions, throttling, transient API errors) must
|
|
# still be logged as an error, while keeping the scan resilient (None).
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
mock_client = MagicMock()
|
|
mock_client.configurations.get.side_effect = Exception(
|
|
"Some unexpected failure"
|
|
)
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID] = mock_client
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.logger"
|
|
) as mock_logger:
|
|
result = postgresql._get_connection_throttling(
|
|
AZURE_SUBSCRIPTION_ID, "resource_group", "server_name"
|
|
)
|
|
assert result is None
|
|
mock_logger.error.assert_called_once()
|
|
|
|
def test_get_log_retention_days(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].log_retention_days
|
|
== "3"
|
|
)
|
|
|
|
def test_get_active_directory_auth(self):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgresql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].active_directory_auth
|
|
== "ENABLED"
|
|
)
|
|
|
|
def test_get_entra_id_admins(self):
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
admins = postgresql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].entra_id_admins
|
|
assert isinstance(admins, list)
|
|
assert len(admins) == 1
|
|
assert isinstance(admins[0], EntraIdAdmin)
|
|
assert admins[0].principal_name == "Test Admin User"
|
|
assert admins[0].object_id == "11111111-1111-1111-1111-111111111111"
|
|
|
|
def test_get_entra_id_admins_aad_not_enabled_logs_warning(self):
|
|
# A server using PostgreSQL authentication only (Entra/Azure AD auth
|
|
# disabled) is an expected state; it should be logged as a warning, not
|
|
# an error, and return an empty admin list.
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
mock_client = MagicMock()
|
|
mock_client.administrators.list_by_server.side_effect = Exception(
|
|
"Azure AD authentication is not enabled for the given server"
|
|
)
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID] = mock_client
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.logger"
|
|
) as mock_logger:
|
|
result = postgresql._get_entra_id_admins(
|
|
AZURE_SUBSCRIPTION_ID, "resource_group", "server_name"
|
|
)
|
|
assert result == []
|
|
mock_logger.warning.assert_called_once()
|
|
mock_logger.error.assert_not_called()
|
|
|
|
def test_get_entra_id_admins_unexpected_error_logs_error(self):
|
|
# Any other failure (permissions, throttling, transient API errors) is a
|
|
# genuine problem and must still be logged as an error.
|
|
postgresql = PostgreSQL(set_mocked_azure_provider())
|
|
mock_client = MagicMock()
|
|
mock_client.administrators.list_by_server.side_effect = Exception(
|
|
"Some unexpected failure"
|
|
)
|
|
postgresql.clients[AZURE_SUBSCRIPTION_ID] = mock_client
|
|
with patch(
|
|
"prowler.providers.azure.services.postgresql.postgresql_service.logger"
|
|
) as mock_logger:
|
|
result = postgresql._get_entra_id_admins(
|
|
AZURE_SUBSCRIPTION_ID, "resource_group", "server_name"
|
|
)
|
|
assert result == []
|
|
mock_logger.error.assert_called_once()
|
|
mock_logger.warning.assert_not_called()
|
|
|
|
def test_get_firewall(self):
|
|
postgesql = PostgreSQL(set_mocked_azure_provider())
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0]
|
|
.firewall[0]
|
|
.__class__.__name__
|
|
== "Firewall"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].firewall[0].id == "id"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].firewall[0].name
|
|
== "name"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].firewall[0].start_ip
|
|
== "start_ip"
|
|
)
|
|
assert (
|
|
postgesql.flexible_servers[AZURE_SUBSCRIPTION_ID][0].firewall[0].end_ip
|
|
== "end_ip"
|
|
)
|