diff --git a/api/src/backend/api/attack_paths/database.py b/api/src/backend/api/attack_paths/database.py index eaa9da3713..1c1c64e30d 100644 --- a/api/src/backend/api/attack_paths/database.py +++ b/api/src/backend/api/attack_paths/database.py @@ -42,7 +42,12 @@ def init_driver() -> neo4j.Driver: config = settings.DATABASES["neo4j"] _driver = neo4j.GraphDatabase.driver( - uri, auth=(config["USER"], config["PASSWORD"]) + uri, + auth=(config["USER"], config["PASSWORD"]), + keep_alive=True, + max_connection_lifetime=7200, + connection_acquisition_timeout=120, + max_connection_pool_size=50, ) _driver.verify_connectivity() @@ -71,7 +76,6 @@ def get_session(database: str | None = None) -> Iterator[RetryableSession]: try: session_wrapper = RetryableSession( session_factory=lambda: get_driver().session(database=database), - close_driver=close_driver, # Just to avoid circular imports max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES, ) yield session_wrapper diff --git a/api/src/backend/api/attack_paths/retryable_session.py b/api/src/backend/api/attack_paths/retryable_session.py index 79bf383fff..026751a616 100644 --- a/api/src/backend/api/attack_paths/retryable_session.py +++ b/api/src/backend/api/attack_paths/retryable_session.py @@ -17,11 +17,9 @@ class RetryableSession: def __init__( self, session_factory: Callable[[], neo4j.Session], - close_driver: Callable[[], None], # Just to avoid circular imports max_retries: int, ) -> None: self._session_factory = session_factory - self._close_driver = close_driver self._max_retries = max(0, max_retries) self._session = self._session_factory() @@ -58,7 +56,7 @@ class RetryableSession: def _call_with_retry(self, method_name: str, *args: Any, **kwargs: Any) -> Any: attempt = 0 - last_exc: neo4j.exceptions.ServiceUnavailable | None = None + last_exc: Exception | None = None while attempt <= self._max_retries: try: @@ -66,7 +64,8 @@ class RetryableSession: return method(*args, **kwargs) except ( - neo4j.exceptions.ServiceUnavailable + neo4j.exceptions.ServiceUnavailable, + ConnectionResetError, ) as exc: # pragma: no cover - depends on infra last_exc = exc attempt += 1 @@ -75,7 +74,7 @@ class RetryableSession: raise logger.warning( - f"Neo4j session {method_name} failed with ServiceUnavailable ({attempt}/{self._max_retries} attempts). Retrying..." + f"Neo4j session {method_name} failed with {type(exc).__name__} ({attempt}/{self._max_retries} attempts). Retrying..." ) self._refresh_session() @@ -83,7 +82,10 @@ class RetryableSession: def _refresh_session(self) -> None: if self._session is not None: - self._session.close() + try: + self._session.close() + except Exception: + # Best-effort close; failures just mean we open a new session below + pass - self._close_driver() self._session = self._session_factory()