fix(neo4j): lazy load driver (#9868)

Co-authored-by: Josema Camacho <josema@prowler.com>
This commit is contained in:
Pepe Fagoaga
2026-01-23 06:36:47 +01:00
committed by GitHub
parent babf18ffea
commit 29133f2d7e
4 changed files with 309 additions and 27 deletions

View File

@@ -10,6 +10,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Deleting providers don't try to delete a `None` Neo4j database when an Attack Paths scan is scheduled [(#9858)](https://github.com/prowler-cloud/prowler/pull/9858)
- Use replica database for reading Findings to add them to the Attack Paths graph [(#9861)](https://github.com/prowler-cloud/prowler/pull/9861)
- Attack paths findings loading query to use streaming generator for O(batch_size) memory instead of O(total_findings) [(#9862)](https://github.com/prowler-cloud/prowler/pull/9862)
- Lazy load Neo4j driver [(#9868)](https://github.com/prowler-cloud/prowler/pull/9868)
- Use `Findings.all_objects` to avoid the `ActiveProviderPartitionedManager` [(#9869)](https://github.com/prowler-cloud/prowler/pull/9869)
## [1.18.0] (Prowler v5.17.0)

View File

@@ -1,4 +1,3 @@
import atexit
import logging
import os
import sys
@@ -31,7 +30,6 @@ class ApiConfig(AppConfig):
def ready(self):
from api import schema_extensions # noqa: F401
from api import signals # noqa: F401
from api.attack_paths import database as graph_database
from api.compliance import load_prowler_compliance
# Generate required cryptographic keys if not present, but only if:
@@ -43,29 +41,8 @@ class ApiConfig(AppConfig):
):
self._ensure_crypto_keys()
# Commands that don't need Neo4j
SKIP_NEO4J_DJANGO_COMMANDS = [
"migrate",
"makemigrations",
"check",
"help",
"showmigrations",
"check_and_fix_socialaccount_sites_migration",
]
# Skip Neo4j initialization during tests or specific Django commands
if getattr(settings, "TESTING", False) or (
len(sys.argv) > 1
and "manage.py" in sys.argv[0]
and sys.argv[1] in SKIP_NEO4J_DJANGO_COMMANDS
):
logger.info(
"Skipping Neo4j initialization because of the current Django command or testing"
)
else:
graph_database.init_driver()
atexit.register(graph_database.close_driver)
# Neo4j driver is initialized lazily on first use (see api.attack_paths.database)
# This avoids connection attempts during regular scans that don't need graph database
load_prowler_compliance()

View File

@@ -1,13 +1,12 @@
import atexit
import logging
import threading
from contextlib import contextmanager
from typing import Iterator
from uuid import UUID
import neo4j
import neo4j.exceptions
from django.conf import settings
from api.attack_paths.retryable_session import RetryableSession
@@ -51,6 +50,9 @@ def init_driver() -> neo4j.Driver:
)
_driver.verify_connectivity()
# Register cleanup handler (only runs once since we're inside the _driver is None block)
atexit.register(close_driver)
return _driver

View File

@@ -0,0 +1,302 @@
"""
Tests for Neo4j database lazy initialization.
The Neo4j driver should only connect when actually needed (lazy initialization),
not at Django app startup. This allows regular scans to run without Neo4j.
"""
import threading
from unittest.mock import MagicMock, patch
import pytest
class TestLazyInitialization:
"""Test that Neo4j driver is initialized lazily on first use."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
import api.attack_paths.database as db_module
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
def test_driver_not_initialized_at_import(self):
"""Driver should be None after module import (no eager connection)."""
import api.attack_paths.database as db_module
assert db_module._driver is None
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_init_driver_creates_connection_on_first_call(
self, mock_driver_factory, mock_settings
):
"""init_driver() should create connection only when called."""
import api.attack_paths.database as db_module
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
assert db_module._driver is None
result = db_module.init_driver()
mock_driver_factory.assert_called_once()
mock_driver.verify_connectivity.assert_called_once()
assert result is mock_driver
assert db_module._driver is mock_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_init_driver_returns_cached_driver_on_subsequent_calls(
self, mock_driver_factory, mock_settings
):
"""Subsequent calls should return cached driver without reconnecting."""
import api.attack_paths.database as db_module
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
first_result = db_module.init_driver()
second_result = db_module.init_driver()
third_result = db_module.init_driver()
# Only one connection attempt
assert mock_driver_factory.call_count == 1
assert mock_driver.verify_connectivity.call_count == 1
# All calls return same instance
assert first_result is second_result is third_result
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_get_driver_delegates_to_init_driver(
self, mock_driver_factory, mock_settings
):
"""get_driver() should use init_driver() for lazy initialization."""
import api.attack_paths.database as db_module
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
result = db_module.get_driver()
assert result is mock_driver
mock_driver_factory.assert_called_once()
class TestAtexitRegistration:
"""Test that atexit cleanup handler is registered correctly."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
import api.attack_paths.database as db_module
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.atexit.register")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_atexit_registered_on_first_init(
self, mock_driver_factory, mock_atexit_register, mock_settings
):
"""atexit.register should be called on first initialization."""
import api.attack_paths.database as db_module
mock_driver_factory.return_value = MagicMock()
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
db_module.init_driver()
mock_atexit_register.assert_called_once_with(db_module.close_driver)
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.atexit.register")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_atexit_registered_only_once(
self, mock_driver_factory, mock_atexit_register, mock_settings
):
"""atexit.register should only be called once across multiple inits.
The double-checked locking on _driver ensures the atexit registration
block only executes once (when _driver is first created).
"""
import api.attack_paths.database as db_module
mock_driver_factory.return_value = MagicMock()
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
db_module.init_driver()
db_module.init_driver()
db_module.init_driver()
# Only registered once because subsequent calls hit the fast path
assert mock_atexit_register.call_count == 1
class TestCloseDriver:
"""Test driver cleanup functionality."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
import api.attack_paths.database as db_module
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
def test_close_driver_closes_and_clears_driver(self):
"""close_driver() should close the driver and set it to None."""
import api.attack_paths.database as db_module
mock_driver = MagicMock()
db_module._driver = mock_driver
db_module.close_driver()
mock_driver.close.assert_called_once()
assert db_module._driver is None
def test_close_driver_handles_none_driver(self):
"""close_driver() should handle case where driver is None."""
import api.attack_paths.database as db_module
db_module._driver = None
# Should not raise
db_module.close_driver()
assert db_module._driver is None
def test_close_driver_clears_driver_even_on_close_error(self):
"""Driver should be cleared even if close() raises an exception."""
import api.attack_paths.database as db_module
mock_driver = MagicMock()
mock_driver.close.side_effect = Exception("Connection error")
db_module._driver = mock_driver
with pytest.raises(Exception, match="Connection error"):
db_module.close_driver()
# Driver should still be cleared
assert db_module._driver is None
class TestThreadSafety:
"""Test thread-safe initialization."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
import api.attack_paths.database as db_module
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_concurrent_init_creates_single_driver(
self, mock_driver_factory, mock_settings
):
"""Multiple threads calling init_driver() should create only one driver."""
import api.attack_paths.database as db_module
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
results = []
errors = []
def call_init():
try:
result = db_module.init_driver()
results.append(result)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=call_init) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors, f"Threads raised errors: {errors}"
# Only one driver created
assert mock_driver_factory.call_count == 1
# All threads got the same driver instance
assert all(r is mock_driver for r in results)
assert len(results) == 10