mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
feat(api): make Attack Paths sink selectable between Neo4j and Neptune (#11524)
This commit is contained in:
@@ -1,623 +1,174 @@
|
||||
"""
|
||||
Tests for Neo4j database lazy initialization.
|
||||
"""Tests for the attack-paths database facade.
|
||||
|
||||
The Neo4j driver is created on first use for every process type; app startup
|
||||
never contacts Neo4j. These tests validate the database module behavior itself.
|
||||
After the Neptune port, `api.attack_paths.database` is a thin routing shim
|
||||
over `api.attack_paths.ingest` (cartography temp DB, always Neo4j) and
|
||||
`api.attack_paths.sink` (configurable Neo4j or Neptune). The facade's
|
||||
contract is routing by database-name prefix and the public exception
|
||||
hierarchy; sink-internal behavior is exercised in `test_sink.py`.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import api.attack_paths.database as db_module
|
||||
import neo4j
|
||||
import neo4j.exceptions
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLazyInitialization:
|
||||
"""Test that Neo4j driver is initialized lazily on first use."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
"""Reset module-level singleton state before each test."""
|
||||
original_driver = db_module._driver
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
|
||||
def test_driver_not_initialized_at_import(self):
|
||||
"""Driver should be None after module import (no eager connection)."""
|
||||
assert db_module._driver is None
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_init_driver_creates_connection_on_first_call(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""init_driver() should create connection only when called."""
|
||||
mock_driver = MagicMock()
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
assert db_module._driver is None
|
||||
|
||||
result = db_module.init_driver()
|
||||
|
||||
mock_driver_factory.assert_called_once()
|
||||
mock_driver.verify_connectivity.assert_called_once()
|
||||
assert result is mock_driver
|
||||
assert db_module._driver is mock_driver
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_init_driver_leaves_driver_none_when_verify_fails(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""A failed verify_connectivity() must not publish or leak the driver."""
|
||||
mock_driver = MagicMock()
|
||||
mock_driver.verify_connectivity.side_effect = (
|
||||
neo4j.exceptions.ServiceUnavailable("down")
|
||||
)
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(neo4j.exceptions.ServiceUnavailable):
|
||||
db_module.init_driver()
|
||||
|
||||
assert db_module._driver is None
|
||||
mock_driver.close.assert_called_once()
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_init_driver_returns_cached_driver_on_subsequent_calls(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""Subsequent calls should return cached driver without reconnecting."""
|
||||
mock_driver = MagicMock()
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
first_result = db_module.init_driver()
|
||||
second_result = db_module.init_driver()
|
||||
third_result = db_module.init_driver()
|
||||
|
||||
# Only one connection attempt
|
||||
assert mock_driver_factory.call_count == 1
|
||||
assert mock_driver.verify_connectivity.call_count == 1
|
||||
|
||||
# All calls return same instance
|
||||
assert first_result is second_result is third_result
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_get_driver_delegates_to_init_driver(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""get_driver() should use init_driver() for lazy initialization."""
|
||||
mock_driver = MagicMock()
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
result = db_module.get_driver()
|
||||
|
||||
assert result is mock_driver
|
||||
mock_driver_factory.assert_called_once()
|
||||
|
||||
|
||||
class TestConnectionAcquisitionTimeout:
|
||||
"""Test that the connection acquisition timeout is configurable."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
original_driver = db_module._driver
|
||||
original_acq_timeout = db_module.CONN_ACQUISITION_TIMEOUT
|
||||
original_conn_timeout = db_module.CONNECTION_TIMEOUT
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
db_module.CONN_ACQUISITION_TIMEOUT = original_acq_timeout
|
||||
db_module.CONNECTION_TIMEOUT = original_conn_timeout
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_driver_receives_configured_timeout(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""init_driver() should pass the configured timeouts to the neo4j driver."""
|
||||
mock_driver_factory.return_value = MagicMock()
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
db_module.CONN_ACQUISITION_TIMEOUT = 42
|
||||
db_module.CONNECTION_TIMEOUT = 7
|
||||
|
||||
db_module.init_driver()
|
||||
|
||||
_, kwargs = mock_driver_factory.call_args
|
||||
assert kwargs["connection_acquisition_timeout"] == 42
|
||||
assert kwargs["connection_timeout"] == 7
|
||||
|
||||
|
||||
class TestAtexitRegistration:
|
||||
"""Test that atexit cleanup handler is registered correctly."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
"""Reset module-level singleton state before each test."""
|
||||
original_driver = db_module._driver
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.atexit.register")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_atexit_registered_on_first_init(
|
||||
self, mock_driver_factory, mock_atexit_register, mock_settings
|
||||
):
|
||||
"""atexit.register should be called on first initialization."""
|
||||
mock_driver_factory.return_value = MagicMock()
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
db_module.init_driver()
|
||||
|
||||
mock_atexit_register.assert_called_once_with(db_module.close_driver)
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.atexit.register")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_atexit_registered_only_once(
|
||||
self, mock_driver_factory, mock_atexit_register, mock_settings
|
||||
):
|
||||
"""atexit.register should only be called once across multiple inits.
|
||||
|
||||
The double-checked locking on _driver ensures the atexit registration
|
||||
block only executes once (when _driver is first created).
|
||||
"""
|
||||
mock_driver_factory.return_value = MagicMock()
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
db_module.init_driver()
|
||||
db_module.init_driver()
|
||||
db_module.init_driver()
|
||||
|
||||
# Only registered once because subsequent calls hit the fast path
|
||||
assert mock_atexit_register.call_count == 1
|
||||
|
||||
|
||||
class TestCloseDriver:
|
||||
"""Test driver cleanup functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
"""Reset module-level singleton state before each test."""
|
||||
original_driver = db_module._driver
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
|
||||
def test_close_driver_closes_and_clears_driver(self):
|
||||
"""close_driver() should close the driver and set it to None."""
|
||||
mock_driver = MagicMock()
|
||||
db_module._driver = mock_driver
|
||||
|
||||
db_module.close_driver()
|
||||
|
||||
mock_driver.close.assert_called_once()
|
||||
assert db_module._driver is None
|
||||
|
||||
def test_close_driver_handles_none_driver(self):
|
||||
"""close_driver() should handle case where driver is None."""
|
||||
db_module._driver = None
|
||||
|
||||
# Should not raise
|
||||
db_module.close_driver()
|
||||
|
||||
assert db_module._driver is None
|
||||
|
||||
def test_close_driver_clears_driver_even_on_close_error(self):
|
||||
"""Driver should be cleared even if close() raises an exception."""
|
||||
mock_driver = MagicMock()
|
||||
mock_driver.close.side_effect = Exception("Connection error")
|
||||
db_module._driver = mock_driver
|
||||
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
db_module.close_driver()
|
||||
|
||||
# Driver should still be cleared
|
||||
assert db_module._driver is None
|
||||
|
||||
|
||||
class TestExecuteReadQuery:
|
||||
"""Test read query execution helper."""
|
||||
|
||||
def test_execute_read_query_calls_read_session_and_returns_result(self):
|
||||
tx = MagicMock()
|
||||
expected_graph = MagicMock()
|
||||
run_result = MagicMock()
|
||||
run_result.graph.return_value = expected_graph
|
||||
tx.run.return_value = run_result
|
||||
|
||||
session = MagicMock()
|
||||
|
||||
def execute_read_side_effect(fn):
|
||||
return fn(tx)
|
||||
|
||||
session.execute_read.side_effect = execute_read_side_effect
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
) as mock_get_session:
|
||||
result = db_module.execute_read_query(
|
||||
"db-tenant-test-tenant-id",
|
||||
"MATCH (n) RETURN n",
|
||||
{"provider_uid": "123"},
|
||||
)
|
||||
|
||||
mock_get_session.assert_called_once_with(
|
||||
"db-tenant-test-tenant-id",
|
||||
default_access_mode=neo4j.READ_ACCESS,
|
||||
)
|
||||
session.execute_read.assert_called_once()
|
||||
tx.run.assert_called_once_with(
|
||||
"MATCH (n) RETURN n",
|
||||
{"provider_uid": "123"},
|
||||
timeout=db_module.READ_QUERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
run_result.graph.assert_called_once_with()
|
||||
assert result is expected_graph
|
||||
|
||||
def test_execute_read_query_defaults_parameters_to_empty_dict(self):
|
||||
tx = MagicMock()
|
||||
run_result = MagicMock()
|
||||
run_result.graph.return_value = MagicMock()
|
||||
tx.run.return_value = run_result
|
||||
|
||||
session = MagicMock()
|
||||
session.execute_read.side_effect = lambda fn: fn(tx)
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
db_module.execute_read_query(
|
||||
"db-tenant-test-tenant-id",
|
||||
"MATCH (n) RETURN n",
|
||||
)
|
||||
|
||||
tx.run.assert_called_once_with(
|
||||
"MATCH (n) RETURN n",
|
||||
{},
|
||||
timeout=db_module.READ_QUERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
run_result.graph.assert_called_once_with()
|
||||
|
||||
|
||||
class TestGetSessionReadOnly:
|
||||
"""Test that get_session translates Neo4j read-mode errors."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
original_driver = db_module._driver
|
||||
db_module._driver = None
|
||||
yield
|
||||
db_module._driver = original_driver
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"neo4j_code",
|
||||
[
|
||||
"Neo.ClientError.Statement.AccessMode",
|
||||
"Neo.ClientError.Procedure.ProcedureNotFound",
|
||||
],
|
||||
)
|
||||
def test_get_session_raises_write_query_not_allowed(self, neo4j_code):
|
||||
"""Read-mode Neo4j errors should raise `WriteQueryNotAllowedException`."""
|
||||
mock_session = MagicMock()
|
||||
neo4j_error = neo4j.exceptions.Neo4jError._hydrate_neo4j(
|
||||
code=neo4j_code,
|
||||
message="Write operations are not allowed",
|
||||
)
|
||||
mock_session.run.side_effect = neo4j_error
|
||||
|
||||
mock_driver = MagicMock()
|
||||
mock_driver.session.return_value = mock_session
|
||||
db_module._driver = mock_driver
|
||||
|
||||
with pytest.raises(db_module.WriteQueryNotAllowedException):
|
||||
with db_module.get_session(
|
||||
default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
session.run("CREATE (n) RETURN n")
|
||||
|
||||
def test_get_session_raises_generic_exception_for_other_errors(self):
|
||||
"""Non-read-mode Neo4j errors should raise GraphDatabaseQueryException."""
|
||||
mock_session = MagicMock()
|
||||
neo4j_error = neo4j.exceptions.Neo4jError._hydrate_neo4j(
|
||||
code="Neo.ClientError.Statement.SyntaxError",
|
||||
message="Invalid syntax",
|
||||
)
|
||||
mock_session.run.side_effect = neo4j_error
|
||||
|
||||
mock_driver = MagicMock()
|
||||
mock_driver.session.return_value = mock_session
|
||||
db_module._driver = mock_driver
|
||||
|
||||
with pytest.raises(db_module.GraphDatabaseQueryException):
|
||||
with db_module.get_session(
|
||||
default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
session.run("INVALID CYPHER")
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Test thread-safe initialization."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
"""Reset module-level singleton state before each test."""
|
||||
original_driver = db_module._driver
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_concurrent_init_creates_single_driver(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""Multiple threads calling init_driver() should create only one driver."""
|
||||
mock_driver = MagicMock()
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def call_init():
|
||||
try:
|
||||
result = db_module.init_driver()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=call_init) for _ in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors, f"Threads raised errors: {errors}"
|
||||
|
||||
# Only one driver created
|
||||
assert mock_driver_factory.call_count == 1
|
||||
|
||||
# All threads got the same driver instance
|
||||
assert all(r is mock_driver for r in results)
|
||||
assert len(results) == 10
|
||||
|
||||
|
||||
class TestHasProviderData:
|
||||
"""Test has_provider_data helper for checking provider nodes in Neo4j."""
|
||||
|
||||
def test_returns_true_when_nodes_exist(self):
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.single.return_value = MagicMock() # non-None record
|
||||
mock_session.run.return_value = mock_result
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = mock_session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is True
|
||||
|
||||
mock_session.run.assert_called_once()
|
||||
|
||||
def test_returns_false_when_no_nodes(self):
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.single.return_value = None
|
||||
mock_session.run.return_value = mock_result
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = mock_session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is False
|
||||
|
||||
def test_returns_false_when_database_not_found(self):
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
|
||||
message="Database does not exist",
|
||||
code="Neo.ClientError.Database.DatabaseNotFound",
|
||||
class TestDatabaseNameHelper:
|
||||
def test_tenant_name_lowercases_uuid(self):
|
||||
assert (
|
||||
db_module.get_database_name("ABC-123", temporary=False)
|
||||
== "db-tenant-abc-123"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
assert (
|
||||
db_module.has_provider_data("db-tenant-gone", "provider-123") is False
|
||||
)
|
||||
|
||||
def test_raises_on_other_errors(self):
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
|
||||
message="Connection refused",
|
||||
code="Neo.TransientError.General.UnknownError",
|
||||
def test_temporary_name_uses_tmp_scan_prefix(self):
|
||||
assert (
|
||||
db_module.get_database_name("XYZ-789", temporary=True)
|
||||
== "db-tmp-scan-xyz-789"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
with pytest.raises(db_module.GraphDatabaseQueryException):
|
||||
db_module.has_provider_data("db-tenant-abc", "provider-123")
|
||||
|
||||
class TestExceptionHierarchy:
|
||||
"""`tasks/` and `api/v1/views.py` import these from the facade."""
|
||||
|
||||
class TestDropSubgraph:
|
||||
"""Test drop_subgraph two-phase batched deletion of a provider's graph."""
|
||||
|
||||
@staticmethod
|
||||
def _result(count):
|
||||
result = MagicMock()
|
||||
result.single.return_value.get.return_value = count
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _session_ctx(session):
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = session
|
||||
ctx.__exit__.return_value = False
|
||||
return ctx
|
||||
|
||||
def test_deletes_relationships_then_nodes_in_batches(self):
|
||||
session = MagicMock()
|
||||
# Phase 1 (relationships): one full batch then empty.
|
||||
# Phase 2 (nodes): one full batch then empty.
|
||||
session.run.side_effect = [
|
||||
self._result(1000),
|
||||
self._result(0),
|
||||
self._result(1000),
|
||||
self._result(0),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=self._session_ctx(session),
|
||||
):
|
||||
deleted = db_module.drop_subgraph("db-tenant-abc", "provider-123")
|
||||
|
||||
# Only phase-2 node counts contribute to the return value.
|
||||
assert deleted == 1000
|
||||
assert session.run.call_count == 4
|
||||
|
||||
queries = [call.args[0] for call in session.run.call_args_list]
|
||||
|
||||
# Regression guard: the memory blow-up was caused by DETACH DELETE.
|
||||
assert all("DETACH DELETE" not in query for query in queries)
|
||||
|
||||
rel_queries = [query for query in queries if "DELETE r" in query]
|
||||
node_queries = [query for query in queries if "DELETE n" in query]
|
||||
assert rel_queries and node_queries
|
||||
# DISTINCT avoids double-counting relationships matched from both ends.
|
||||
assert all("DISTINCT r" in query for query in rel_queries)
|
||||
|
||||
# Relationships must be fully drained before nodes are deleted.
|
||||
first_node = next(i for i, q in enumerate(queries) if "DELETE n" in q)
|
||||
last_rel = max(i for i, q in enumerate(queries) if "DELETE r" in q)
|
||||
assert last_rel < first_node
|
||||
|
||||
def test_returns_zero_when_database_not_found(self):
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
|
||||
message="Database does not exist",
|
||||
code="Neo.ClientError.Database.DatabaseNotFound",
|
||||
def test_write_query_is_graph_database_exception(self):
|
||||
assert issubclass(
|
||||
db_module.WriteQueryNotAllowedException,
|
||||
db_module.GraphDatabaseQueryException,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
assert db_module.drop_subgraph("db-tenant-gone", "provider-123") == 0
|
||||
|
||||
def test_raises_on_other_errors(self):
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
|
||||
message="Connection refused",
|
||||
code="Neo.TransientError.General.UnknownError",
|
||||
def test_client_statement_is_graph_database_exception(self):
|
||||
assert issubclass(
|
||||
db_module.ClientStatementException, db_module.GraphDatabaseQueryException
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
with pytest.raises(db_module.GraphDatabaseQueryException):
|
||||
db_module.drop_subgraph("db-tenant-abc", "provider-123")
|
||||
def test_exception_str_includes_code_when_set(self):
|
||||
exc = db_module.GraphDatabaseQueryException(
|
||||
message="boom", code="Neo.ClientError.X.Y"
|
||||
)
|
||||
assert str(exc) == "Neo.ClientError.X.Y: boom"
|
||||
|
||||
def test_exception_str_falls_back_to_message_without_code(self):
|
||||
exc = db_module.GraphDatabaseQueryException(message="boom")
|
||||
assert str(exc) == "boom"
|
||||
|
||||
|
||||
class TestExecuteReadQueryRoutes:
|
||||
def test_execute_read_query_delegates_to_sink(self, sink_backend_stub):
|
||||
sink_backend_stub.execute_read_query.return_value = "graph"
|
||||
|
||||
result = db_module.execute_read_query(
|
||||
"db-tenant-abc", "MATCH (n) RETURN n", {"provider_uid": "123"}
|
||||
)
|
||||
|
||||
sink_backend_stub.execute_read_query.assert_called_once_with(
|
||||
"db-tenant-abc", "MATCH (n) RETURN n", {"provider_uid": "123"}
|
||||
)
|
||||
assert result == "graph"
|
||||
|
||||
def test_execute_read_query_defaults_parameters_to_none(self, sink_backend_stub):
|
||||
db_module.execute_read_query("db-tenant-abc", "MATCH (n) RETURN n")
|
||||
|
||||
sink_backend_stub.execute_read_query.assert_called_once_with(
|
||||
"db-tenant-abc", "MATCH (n) RETURN n", None
|
||||
)
|
||||
|
||||
|
||||
class TestSinkOperationsDelegation:
|
||||
def test_has_provider_data_delegates_to_sink(self, sink_backend_stub):
|
||||
sink_backend_stub.has_provider_data.return_value = True
|
||||
|
||||
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is True
|
||||
sink_backend_stub.has_provider_data.assert_called_once_with(
|
||||
"db-tenant-abc", "provider-123"
|
||||
)
|
||||
|
||||
def test_drop_subgraph_delegates_to_sink(self, sink_backend_stub):
|
||||
sink_backend_stub.drop_subgraph.return_value = 42
|
||||
|
||||
assert db_module.drop_subgraph("db-tenant-abc", "provider-123") == 42
|
||||
sink_backend_stub.drop_subgraph.assert_called_once_with(
|
||||
"db-tenant-abc", "provider-123"
|
||||
)
|
||||
|
||||
|
||||
class TestRoutingByDatabasePrefix:
|
||||
"""`db-tmp-scan-*` and `None` route to ingest; everything else to sink."""
|
||||
|
||||
def test_create_database_routes_temp_to_ingest(self, sink_backend_stub):
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
db_module.create_database("db-tmp-scan-uuid-1")
|
||||
|
||||
mock_ingest.create_database.assert_called_once_with("db-tmp-scan-uuid-1")
|
||||
sink_backend_stub.create_database.assert_not_called()
|
||||
|
||||
def test_create_database_routes_tenant_to_sink(self, sink_backend_stub):
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
db_module.create_database("db-tenant-abc")
|
||||
|
||||
sink_backend_stub.create_database.assert_called_once_with("db-tenant-abc")
|
||||
mock_ingest.create_database.assert_not_called()
|
||||
|
||||
def test_drop_database_routes_temp_to_ingest(self, sink_backend_stub):
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
db_module.drop_database("db-tmp-scan-uuid-1")
|
||||
|
||||
mock_ingest.drop_database.assert_called_once_with("db-tmp-scan-uuid-1")
|
||||
sink_backend_stub.drop_database.assert_not_called()
|
||||
|
||||
def test_drop_database_routes_tenant_to_sink(self, sink_backend_stub):
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
db_module.drop_database("db-tenant-abc")
|
||||
|
||||
sink_backend_stub.drop_database.assert_called_once_with("db-tenant-abc")
|
||||
mock_ingest.drop_database.assert_not_called()
|
||||
|
||||
def test_clear_cache_routes_temp_to_ingest(self, sink_backend_stub):
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
db_module.clear_cache("db-tmp-scan-uuid-1")
|
||||
|
||||
mock_ingest.clear_cache.assert_called_once_with("db-tmp-scan-uuid-1")
|
||||
sink_backend_stub.clear_cache.assert_not_called()
|
||||
|
||||
def test_clear_cache_routes_tenant_to_sink(self, sink_backend_stub):
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
db_module.clear_cache("db-tenant-abc")
|
||||
|
||||
sink_backend_stub.clear_cache.assert_called_once_with("db-tenant-abc")
|
||||
mock_ingest.clear_cache.assert_not_called()
|
||||
|
||||
def test_get_session_routes_temp_to_ingest(self, sink_backend_stub):
|
||||
sentinel = MagicMock()
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
mock_ingest.get_session.return_value = sentinel
|
||||
|
||||
result = db_module.get_session("db-tmp-scan-uuid-1")
|
||||
|
||||
assert result is sentinel
|
||||
mock_ingest.get_session.assert_called_once()
|
||||
sink_backend_stub.get_session.assert_not_called()
|
||||
|
||||
def test_get_session_routes_none_to_ingest(self, sink_backend_stub):
|
||||
sentinel = MagicMock()
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
mock_ingest.get_session.return_value = sentinel
|
||||
|
||||
result = db_module.get_session(None)
|
||||
|
||||
assert result is sentinel
|
||||
sink_backend_stub.get_session.assert_not_called()
|
||||
|
||||
def test_get_ingest_uri_delegates_to_ingest(self, sink_backend_stub):
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
mock_ingest.get_uri.return_value = "bolt://neo4j:7687"
|
||||
|
||||
assert db_module.get_ingest_uri() == "bolt://neo4j:7687"
|
||||
|
||||
mock_ingest.get_uri.assert_called_once_with()
|
||||
|
||||
def test_get_session_routes_tenant_to_sink(self, sink_backend_stub):
|
||||
sentinel = MagicMock()
|
||||
sink_backend_stub.get_session.return_value = sentinel
|
||||
with patch("api.attack_paths.database.ingest") as mock_ingest:
|
||||
result = db_module.get_session("db-tenant-abc")
|
||||
|
||||
assert result is sentinel
|
||||
mock_ingest.get_session.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user