chore: merge master into Lighthouse branch

This commit is contained in:
alejandrobailo
2026-07-02 17:28:47 +02:00
696 changed files with 71811 additions and 6348 deletions
+1 -1
View File
@@ -157,7 +157,7 @@ SENTRY_RELEASE=local
# REO_DEV_CLIENT_ID=
#### Prowler release version ####
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.32.0
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.33.0
# Social login credentials
SOCIAL_GOOGLE_OAUTH_CALLBACK_URL="${AUTH_URL}/api/auth/callback/google"
+2 -2
View File
@@ -46,7 +46,7 @@ runs:
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
LATEST_COMMIT=$(curl -sf \
LATEST_COMMIT=$(curl -sf --retry 3 --retry-all-errors --retry-delay 2 --retry-max-time 60 \
-H "Authorization: Bearer ${GITHUB_TOKEN}" \
-H "Accept: application/vnd.github+json" \
"https://api.github.com/repos/prowler-cloud/prowler/commits/master" \
@@ -66,7 +66,7 @@ runs:
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
LATEST_COMMIT=$(curl -sf \
LATEST_COMMIT=$(curl -sf --retry 3 --retry-all-errors --retry-delay 2 --retry-max-time 60 \
-H "Authorization: Bearer ${GITHUB_TOKEN}" \
-H "Accept: application/vnd.github+json" \
"https://api.github.com/repos/prowler-cloud/prowler/commits/master" \
+2 -2
View File
@@ -63,7 +63,7 @@ runs:
exit-code: '0'
scanners: 'vuln'
timeout: '5m'
version: 'v0.71.0'
version: 'v0.71.2'
- name: Run Trivy vulnerability scan (SARIF)
if: inputs.upload-sarif == 'true' && github.event_name == 'push'
@@ -76,7 +76,7 @@ runs:
exit-code: '0'
scanners: 'vuln'
timeout: '5m'
version: 'v0.71.0'
version: 'v0.71.2'
- name: Upload Trivy results to GitHub Security tab
if: inputs.upload-sarif == 'true' && github.event_name == 'push'
+24 -22
View File
@@ -30,17 +30,18 @@ updates:
# - "pip"
# - "component/api"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"
open-pull-requests-limit: 25
target-branch: master
labels:
- "dependencies"
- "github_actions"
cooldown:
default-days: 7
# Dependabot version updates disabled - migrated to Renovate - 2026/07/02
# - package-ecosystem: "github-actions"
# directory: "/"
# schedule:
# interval: "monthly"
# open-pull-requests-limit: 25
# target-branch: master
# labels:
# - "dependencies"
# - "github_actions"
# cooldown:
# default-days: 7
# Dependabot Updates are temporary disabled - 2025/03/19
# - package-ecosystem: "npm"
@@ -54,17 +55,18 @@ updates:
# - "npm"
# - "component/ui"
- package-ecosystem: "docker"
directory: "/"
schedule:
interval: "monthly"
open-pull-requests-limit: 25
target-branch: master
labels:
- "dependencies"
- "docker"
cooldown:
default-days: 7
# Dependabot version updates disabled - migrated to Renovate - 2026/07/02
# - package-ecosystem: "docker"
# directory: "/"
# schedule:
# interval: "monthly"
# open-pull-requests-limit: 25
# target-branch: master
# labels:
# - "dependencies"
# - "docker"
# cooldown:
# default-days: 7
# - package-ecosystem: "pre-commit"
# directory: "/"
+3 -3
View File
@@ -38,7 +38,7 @@
"schedule": [
"* 22-23,0-5 1 * *"
],
"enabled": false
"enabled": true
},
{
"description": "Minors: 8th of every 3 months, Madrid overnight window (22:00-06:00)",
@@ -48,7 +48,7 @@
"schedule": [
"* 22-23,0-5 8 */3 *"
],
"enabled": false
"enabled": true
},
{
"description": "Majors: 15th of every 3 months, Madrid overnight window",
@@ -58,7 +58,7 @@
"schedule": [
"* 22-23,0-5 15 */3 *"
],
"enabled": false
"enabled": true
},
{
"description": "GitHub Actions - single grouped PR, no changelog, scope=ci",
+1 -25
View File
@@ -215,7 +215,7 @@ jobs:
- name: Install regctl
if: always()
uses: regclient/actions/regctl-installer@da9319db8e44e8b062b3a147e1dfb2f574d41a03 # main
uses: regclient/actions/regctl-installer@9a2d4216180dbb3e2dccfa60d2dd4afd98e42ec5 # main
- name: Cleanup intermediate architecture tags
if: always()
@@ -272,27 +272,3 @@ jobs:
payload-file-path: "./.github/scripts/slack-messages/container-release-completed.json"
step-outcome: ${{ steps.outcome.outputs.outcome }}
update-ts: ${{ needs.notify-release-started.outputs.message-ts }}
trigger-deployment:
needs: [setup, container-build-push]
if: always() && github.event_name == 'push' && needs.setup.result == 'success' && needs.container-build-push.result == 'success'
runs-on: ubuntu-latest
timeout-minutes: 5
permissions:
contents: read
steps:
- name: Harden Runner
uses: step-security/harden-runner@ab7a9404c0f3da075243ca237b5fac12c98deaa5 # v2.19.3
with:
egress-policy: block
allowed-endpoints: >
api.github.com:443
- name: Trigger API deployment
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
with:
token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
repository: ${{ secrets.CLOUD_DISPATCH }}
event-type: api-prowler-deployment
client-payload: '{"sha": "${{ github.sha }}", "short_sha": "${{ needs.setup.outputs.short-sha }}"}'
+2 -2
View File
@@ -48,7 +48,7 @@ jobs:
services:
postgres:
image: postgres:17@sha256:2cd82735a36356842d5eb1ef80db3ae8f1154172f0f653db48fde079b2a0b7f7
image: postgres:17@sha256:5c855ad7b85e68e48a62f34662853f38b57c1c1d80f3a927ab58034fd6d31c5e
env:
POSTGRES_HOST: ${{ env.POSTGRES_HOST }}
POSTGRES_PORT: ${{ env.POSTGRES_PORT }}
@@ -63,7 +63,7 @@ jobs:
--health-timeout 5s
--health-retries 5
valkey:
image: valkey/valkey:7-alpine3.19
image: valkey/valkey:7-alpine3.19@sha256:4054fe7fc607b9326ac7c4691ed26e9670d2ff17a9fb28c2577adecf928acbcc
env:
VALKEY_HOST: ${{ env.VALKEY_HOST }}
VALKEY_PORT: ${{ env.VALKEY_PORT }}
+5 -4
View File
@@ -29,10 +29,11 @@ jobs:
with:
# We can't block as Trufflehog needs to verify secrets against vendors
egress-policy: audit
# allowed-endpoints: >
# github.com:443
# ghcr.io:443
# pkg-containers.githubusercontent.com:443
allowed-endpoints: >
github.com:443
ghcr.io:443
pkg-containers.githubusercontent.com:443
www.formbucket.com:443
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
+1 -25
View File
@@ -206,7 +206,7 @@ jobs:
- name: Install regctl
if: always()
uses: regclient/actions/regctl-installer@da9319db8e44e8b062b3a147e1dfb2f574d41a03 # main
uses: regclient/actions/regctl-installer@9a2d4216180dbb3e2dccfa60d2dd4afd98e42ec5 # main
- name: Cleanup intermediate architecture tags
if: always()
@@ -263,27 +263,3 @@ jobs:
payload-file-path: "./.github/scripts/slack-messages/container-release-completed.json"
step-outcome: ${{ steps.outcome.outputs.outcome }}
update-ts: ${{ needs.notify-release-started.outputs.message-ts }}
trigger-deployment:
needs: [setup, container-build-push]
if: always() && github.event_name == 'push' && needs.setup.result == 'success' && needs.container-build-push.result == 'success'
runs-on: ubuntu-latest
timeout-minutes: 5
permissions:
contents: read
steps:
- name: Harden Runner
uses: step-security/harden-runner@ab7a9404c0f3da075243ca237b5fac12c98deaa5 # v2.19.3
with:
egress-policy: block
allowed-endpoints: >
api.github.com:443
- name: Trigger MCP deployment
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
with:
token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
repository: ${{ secrets.CLOUD_DISPATCH }}
event-type: mcp-prowler-deployment
client-payload: '{"sha": "${{ github.sha }}", "short_sha": "${{ needs.setup.outputs.short-sha }}"}'
+62 -26
View File
@@ -37,8 +37,7 @@ jobs:
with:
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 1
# zizmor: ignore[artipacked]
persist-credentials: true # Required by tj-actions/changed-files to fetch PR branch
persist-credentials: false # No write token in the untrusted PR-head tree; public repo so base fetch/changed-files work unauthenticated
- name: Fetch PR base ref for tj-actions/changed-files
env:
@@ -50,6 +49,8 @@ jobs:
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: '**'
safe_output: false # Raw paths (list read via env var, injection-safe); default escaping backslash-quotes chars like () and breaks the -f test
separator: "\n" # Newline-delimited so the reader tolerates spaces and glob chars in paths
- name: Check for conflict markers
id: conflict-check
@@ -59,19 +60,18 @@ jobs:
CONFLICT_FILES=""
HAS_CONFLICTS=false
# Check each changed file for conflict markers
for file in ${STEPS_CHANGED_FILES_OUTPUTS_ALL_CHANGED_FILES}; do
if [ -f "$file" ]; then
echo "Checking file: $file"
# Read newline-delimited paths so spaces/globs neither word-split nor glob-expand
while IFS= read -r file; do
[ -n "$file" ] || continue
[ -f "$file" ] || continue
echo "Checking file: $file"
# Look for conflict markers (more precise regex)
if grep -qE '^(<<<<<<<|=======|>>>>>>>)' "$file" 2>/dev/null; then
echo "Conflict markers found in: $file"
CONFLICT_FILES="${CONFLICT_FILES}- \`${file}\`"$'\n'
HAS_CONFLICTS=true
fi
if grep -qE '^(<<<<<<<|=======|>>>>>>>)' "$file" 2>/dev/null; then
echo "Conflict markers found in: $file"
CONFLICT_FILES="${CONFLICT_FILES}- \`${file}\`"$'\n'
HAS_CONFLICTS=true
fi
done
done <<< "$STEPS_CHANGED_FILES_OUTPUTS_ALL_CHANGED_FILES"
if [ "$HAS_CONFLICTS" = true ]; then
echo "has_conflicts=true" >> $GITHUB_OUTPUT
@@ -88,18 +88,49 @@ jobs:
env:
STEPS_CHANGED_FILES_OUTPUTS_ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
- name: Check base-branch mergeability
id: merge-check
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
REPO: ${{ github.repository }}
run: |
MERGEABLE=null
# GitHub computes mergeability async, so .mergeable is null until ready; poll until resolved
for attempt in 1 2 3 4 5; do
MERGEABLE=$(gh api "repos/${REPO}/pulls/${PR_NUMBER}" --jq '.mergeable')
if [ "$MERGEABLE" != "null" ]; then
break
fi
echo "Attempt ${attempt}: mergeability not computed yet, retrying..."
sleep 3
done
# Keep 'unknown' distinct from 'clean' so we never assert a clean merge we could not confirm
case "$MERGEABLE" in
false) STATUS=conflict; echo "PR branch cannot be merged cleanly into its base branch" ;;
true) STATUS=clean; echo "PR branch merges cleanly into its base branch" ;;
*) STATUS=unknown; echo "::warning::Mergeability did not resolve after retries; leaving it undetermined" ;;
esac
echo "merge_status=${STATUS}" >> "$GITHUB_OUTPUT"
- name: Manage conflict label
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
HAS_CONFLICTS: ${{ steps.conflict-check.outputs.has_conflicts }}
MERGE_STATUS: ${{ steps.merge-check.outputs.merge_status }}
run: |
LABEL_NAME="has-conflicts"
# Add or remove label based on conflict status
if [ "$HAS_CONFLICTS" = "true" ]; then
if [ "$HAS_CONFLICTS" = "true" ] || [ "$MERGE_STATUS" = "conflict" ]; then
echo "Adding conflict label to PR #${PR_NUMBER}..."
gh pr edit "$PR_NUMBER" --add-label "$LABEL_NAME" --repo ${{ github.repository }} || true
elif [ "$MERGE_STATUS" = "unknown" ]; then
# Don't drop the label on an undetermined merge state; a later run will settle it
echo "Mergeability undetermined; leaving label unchanged"
else
echo "Removing conflict label from PR #${PR_NUMBER}..."
gh pr edit "$PR_NUMBER" --remove-label "$LABEL_NAME" --repo ${{ github.repository }} || true
@@ -121,20 +152,25 @@ jobs:
edit-mode: replace
body: |
<!-- conflict-checker-comment -->
${{ steps.conflict-check.outputs.has_conflicts == 'true' && '⚠️ **Conflict Markers Detected**' || '✅ **Conflict Markers Resolved**' }}
${{ steps.conflict-check.outputs.has_conflicts == 'true' && format('This pull request contains unresolved conflict markers in the following files:
${{ (steps.conflict-check.outputs.has_conflicts == 'true' || steps.merge-check.outputs.merge_status == 'conflict') && '⚠️ **Conflicts Detected**' || (steps.merge-check.outputs.merge_status == 'unknown' && '️ **Conflict Check Incomplete**' || '✅ **No Conflicts**') }}
${{ steps.conflict-check.outputs.has_conflicts == 'true' && format('
**Conflict markers** are present in the following files:
{0}
Please resolve these conflicts by:
1. Locating the conflict markers: `<<<<<<<`, `=======`, and `>>>>>>>`
2. Manually editing the files to resolve the conflicts
3. Removing all conflict markers
4. Committing and pushing the changes', steps.conflict-check.outputs.conflict_files) || 'All conflict markers have been successfully resolved in this pull request.' }}
Resolve them by removing every `<<<<<<<`, `=======`, and `>>>>>>>` marker, then commit and push.', steps.conflict-check.outputs.conflict_files) || '' }}
${{ steps.merge-check.outputs.merge_status == 'conflict' && '
**Merge conflict with the base branch.** This PR cannot be merged cleanly. Update your branch with the latest base (rebase or merge) and resolve the conflicts.' || '' }}
${{ steps.merge-check.outputs.merge_status == 'unknown' && '
GitHub had not finished computing mergeability, so base-branch conflict status could not be verified on this run.' || '' }}
${{ (steps.conflict-check.outputs.has_conflicts != 'true' && steps.merge-check.outputs.merge_status == 'clean') && '
No conflict markers, and the branch merges cleanly into its base.' || '' }}
- name: Fail workflow if conflicts detected
if: steps.conflict-check.outputs.has_conflicts == 'true'
if: steps.conflict-check.outputs.has_conflicts == 'true' || steps.merge-check.outputs.merge_status == 'conflict'
env:
HAS_CONFLICTS: ${{ steps.conflict-check.outputs.has_conflicts }}
MERGE_STATUS: ${{ steps.merge-check.outputs.merge_status }}
run: |
echo "::error::Workflow failed due to conflict markers detected in the PR"
[ "$HAS_CONFLICTS" = "true" ] && echo "::error::Conflict markers detected in changed files"
[ "$MERGE_STATUS" = "conflict" ] && echo "::error::PR branch has merge conflicts with the base branch"
exit 1
+2 -2
View File
@@ -56,6 +56,6 @@ jobs:
"PROWLER_PR_BODY": ${{ toJson(github.event.pull_request.body) }},
"PROWLER_PR_URL": ${{ toJson(github.event.pull_request.html_url) }},
"PROWLER_PR_MERGED_BY": "${{ github.event.pull_request.merged_by.login }}",
"PROWLER_PR_BASE_BRANCH": "${{ github.event.pull_request.base.ref }}",
"PROWLER_PR_HEAD_BRANCH": "${{ github.event.pull_request.head.ref }}"
"PROWLER_PR_BASE_BRANCH": ${{ toJson(github.event.pull_request.base.ref) }},
"PROWLER_PR_HEAD_BRANCH": ${{ toJson(github.event.pull_request.head.ref) }}
}
+25 -15
View File
@@ -138,6 +138,7 @@ jobs:
permissions:
contents: read
packages: write
id-token: write
steps:
- name: Harden Runner
@@ -147,6 +148,8 @@ jobs:
allowed-endpoints: >
api.ecr-public.us-east-1.amazonaws.com:443
public.ecr.aws:443
sts.amazonaws.com:443
sts.us-east-1.amazonaws.com:443
registry-1.docker.io:443
production.cloudflare.docker.com:443
production.cloudfront.docker.com:443
@@ -173,14 +176,16 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Login to Public ECR
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
- name: Configure AWS credentials (OIDC)
uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6.1.1
with:
registry: public.ecr.aws
username: ${{ secrets.PUBLIC_ECR_AWS_ACCESS_KEY_ID }}
password: ${{ secrets.PUBLIC_ECR_AWS_SECRET_ACCESS_KEY }}
env:
AWS_REGION: ${{ env.AWS_REGION }}
aws-region: us-east-1
role-to-assume: ${{ secrets.PUBLIC_ECR_IAM_ROLE_ARN }}
- name: Login to Public ECR
uses: aws-actions/amazon-ecr-login@d539f0932e70871a027e9d5a9d8fc38589180a64 # v2.1.6
with:
registry-type: public
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
@@ -206,6 +211,7 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
id-token: write
steps:
- name: Harden Runner
@@ -221,6 +227,8 @@ jobs:
github.com:443
release-assets.githubusercontent.com:443
api.ecr-public.us-east-1.amazonaws.com:443
sts.amazonaws.com:443
sts.us-east-1.amazonaws.com:443
- name: Login to DockerHub
@@ -229,14 +237,16 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Login to Public ECR
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
- name: Configure AWS credentials (OIDC)
uses: aws-actions/configure-aws-credentials@d979d5b3a71173a29b74b5b88418bfda9437d885 # v6.1.1
with:
registry: public.ecr.aws
username: ${{ secrets.PUBLIC_ECR_AWS_ACCESS_KEY_ID }}
password: ${{ secrets.PUBLIC_ECR_AWS_SECRET_ACCESS_KEY }}
env:
AWS_REGION: ${{ env.AWS_REGION }}
aws-region: us-east-1
role-to-assume: ${{ secrets.PUBLIC_ECR_IAM_ROLE_ARN }}
- name: Login to Public ECR
uses: aws-actions/amazon-ecr-login@d539f0932e70871a027e9d5a9d8fc38589180a64 # v2.1.6
with:
registry-type: public
- name: Create and push manifests for push event
if: github.event_name == 'push'
@@ -299,7 +309,7 @@ jobs:
- name: Install regctl
if: always()
uses: regclient/actions/regctl-installer@da9319db8e44e8b062b3a147e1dfb2f574d41a03 # main
uses: regclient/actions/regctl-installer@9a2d4216180dbb3e2dccfa60d2dd4afd98e42ec5 # main
- name: Cleanup intermediate architecture tags
if: always()
+1 -1
View File
@@ -73,7 +73,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: '3.12'
python-version: '3.12.13'
- name: Install PyYAML
run: pip install pyyaml
+1 -25
View File
@@ -201,7 +201,7 @@ jobs:
- name: Install regctl
if: always()
uses: regclient/actions/regctl-installer@da9319db8e44e8b062b3a147e1dfb2f574d41a03 # main
uses: regclient/actions/regctl-installer@9a2d4216180dbb3e2dccfa60d2dd4afd98e42ec5 # main
- name: Cleanup intermediate architecture tags
if: always()
@@ -258,27 +258,3 @@ jobs:
payload-file-path: "./.github/scripts/slack-messages/container-release-completed.json"
step-outcome: ${{ steps.outcome.outputs.outcome }}
update-ts: ${{ needs.notify-release-started.outputs.message-ts }}
trigger-deployment:
needs: [setup, container-build-push]
if: always() && github.event_name == 'push' && needs.setup.result == 'success' && needs.container-build-push.result == 'success'
runs-on: ubuntu-latest
timeout-minutes: 5
permissions:
contents: read
steps:
- name: Harden Runner
uses: step-security/harden-runner@ab7a9404c0f3da075243ca237b5fac12c98deaa5 # v2.19.3
with:
egress-policy: block
allowed-endpoints: >
api.github.com:443
- name: Trigger UI deployment
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
with:
token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
repository: ${{ secrets.CLOUD_DISPATCH }}
event-type: ui-prowler-deployment
client-payload: '{"sha": "${{ github.sha }}", "short_sha": "${{ needs.setup.outputs.short-sha }}"}'
+4
View File
@@ -37,8 +37,12 @@ jobs:
allowed-endpoints: >
github.com:443
registry.npmjs.org:443
nodejs.org:443
fonts.googleapis.com:443
fonts.gstatic.com:443
api.iconify.design:443
api.simplesvg.com:443
api.unisvg.com:443
api.github.com:443
release-assets.githubusercontent.com:443
cdn.playwright.dev:443
+4
View File
@@ -169,3 +169,7 @@ GEMINI.md
# Claude Code
.claude/*
# Docker
docker-compose.override.yml
docker-compose-dev.override.yml
+13
View File
@@ -52,6 +52,19 @@ CVE-2026-43185 pkg:linux-libc-dev exp:2026-07-15
CVE-2023-45853 pkg:zlib1g exp:2026-07-15
CVE-2023-45853 pkg:zlib1g-dev exp:2026-07-15
# CVE-2026-55200 — libssh2 out-of-bounds write in ssh2_transport_read() due to
# an unchecked packet_length field in transport.c (heap corruption, possible RCE).
# Package: libssh2-1.
# Why ignored: libssh2-1 is pulled in only as a transitive dependency of libcurl4
# (installed in the SDK Dockerfile for the networking/PowerShell stack). The
# vulnerable path is reached exclusively when libssh2 acts as an SSH/SCP/SFTP
# client parsing transport packets from a server. Prowler never uses libcurl's
# SSH/SCP/SFTP transports; it talks to cloud provider HTTPS endpoints only, so the
# affected code is unreachable at runtime. Fixed upstream in libssh2 commit
# 97acf3df (PR #2052); no Debian bookworm fix is available yet.
# Ref: https://security-tracker.debian.org/tracker/CVE-2026-55200
CVE-2026-55200 pkg:libssh2-1 exp:2026-07-15
# --- API container image (api/Dockerfile) ---
# The entries below are specific to the Prowler API image, which ships
# PowerShell and additional build tooling on top of the same bookworm base.
+1
View File
@@ -114,6 +114,7 @@ When performing these actions, ALWAYS invoke the corresponding skill FIRST:
| Review PR requirements: template, title conventions, changelog gate | `prowler-pr` |
| Review changelog format and conventions | `prowler-changelog` |
| Reviewing JSON:API compliance | `jsonapi` |
| Reviewing Prowler UI components | `prowler-ui` |
| Reviewing compliance framework PRs | `prowler-compliance-review` |
| Running makemigrations or pgmakemigrations | `django-migration-psql` |
| Syncing compliance framework with upstream catalog | `prowler-compliance` |
+14 -2
View File
@@ -1,4 +1,4 @@
FROM python:3.12.13-slim-bookworm@sha256:76d4b7b6305788c6b4c6a19d6a22a3921bf802e9af4d5e1e5bd771208dba74bf AS build
FROM python:3.12.13-slim-bookworm@sha256:8a7e7cc04fd3e2bd787f7f24e22d5d119aa590d429b50c95dfe12b3abe52f48b AS build
LABEL maintainer="https://github.com/prowler-cloud/prowler"
LABEL org.opencontainers.image.source="https://github.com/prowler-cloud/prowler"
@@ -6,7 +6,7 @@ LABEL org.opencontainers.image.source="https://github.com/prowler-cloud/prowler"
ARG POWERSHELL_VERSION=7.5.0
ENV POWERSHELL_VERSION=${POWERSHELL_VERSION}
ARG TRIVY_VERSION=0.71.0
ARG TRIVY_VERSION=0.71.2
ENV TRIVY_VERSION=${TRIVY_VERSION}
ARG ZIZMOR_VERSION=1.24.1
@@ -95,6 +95,18 @@ RUN uv sync --locked --compile-bytecode && \
# Install PowerShell modules
RUN .venv/bin/python prowler/providers/m365/lib/powershell/m365_powershell.py
USER root
# Remove build-only packages from the final image after Python dependencies are installed.
RUN apt-get purge -y --auto-remove \
build-essential \
pkg-config \
libzstd-dev \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
USER prowler
# Remove deprecated dash dependencies
RUN pip uninstall dash-html-components -y && \
pip uninstall dash-core-components -y
+45 -26
View File
@@ -83,16 +83,35 @@ prowler dashboard
## Attack Paths
Attack Paths automatically extends every completed AWS scan with a Neo4j graph that combines Cartography's cloud inventory with Prowler findings. The feature runs in the API worker after each scan and therefore requires:
Attack Paths automatically extends every completed AWS scan with a graph that combines Cartography's cloud inventory with Prowler findings. The feature runs in the API worker after each scan.
- An accessible Neo4j instance (the Docker Compose files already ships a `neo4j` service).
- The following environment variables so Django and Celery can connect:
Two graph backends are supported as the long-lived sink:
| Variable | Description | Default |
| --- | --- | --- |
| `NEO4J_HOST` | Hostname used by the API containers. | `neo4j` |
| `NEO4J_PORT` | Bolt port exposed by Neo4j. | `7687` |
| `NEO4J_USER` / `NEO4J_PASSWORD` | Credentials with rights to create per-tenant databases. | `neo4j` / `neo4j_password` |
- **Neo4j** (default; the Docker Compose files already ship a `neo4j` service).
- **Amazon Neptune** (cloud-managed; opt-in).
Select the sink with `ATTACK_PATHS_SINK_DATABASE` (`neo4j` or `neptune`; default `neo4j`).
> Note: Cartography ingestion always uses a temporary Neo4j database, regardless of the configured sink. The `NEO4J_*` variables below must remain set even when `ATTACK_PATHS_SINK_DATABASE=neptune`.
### Neo4j sink
| Variable | Description | Default |
| --- | --- | --- |
| `NEO4J_HOST` | Hostname used by the API containers. | `neo4j` |
| `NEO4J_PORT` | Bolt port exposed by Neo4j. | `7687` |
| `NEO4J_USER` / `NEO4J_PASSWORD` | Credentials with rights to create per-tenant databases. | `neo4j` / `neo4j_password` |
### Neptune sink
| Variable | Description | Default |
| --- | --- | --- |
| `NEPTUNE_WRITER_ENDPOINT` | Bolt host for the Neptune writer instance. Required when sink is `neptune`. | _empty_ |
| `NEPTUNE_READER_ENDPOINT` | Optional reader endpoint for read-only queries. Falls back to the writer when unset. | _empty_ |
| `NEPTUNE_PORT` | Bolt port exposed by Neptune. | `8182` |
| `AWS_REGION` | Region the Neptune cluster lives in. Required when sink is `neptune`. | _empty_ |
Neptune authenticates with SigV4 using the standard boto3 credential chain. The worker's IAM role (or `AWS_ACCESS_KEY_ID` / `AWS_SECRET_ACCESS_KEY`) supplies the credentials. There is no Neptune password variable.
Every AWS provider scan will enqueue an Attack Paths ingestion job automatically. Other cloud providers will be added in future iterations.
@@ -104,27 +123,27 @@ Every AWS provider scan will enqueue an Attack Paths ingestion job automatically
| Provider | Checks | Services | [Compliance Frameworks](https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/compliance/) | [Categories](https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/misc/#categories) | Support | Interface |
|---|---|---|---|---|---|---|
| AWS | 613 | 86 | 46 | 19 | Official | UI, API, CLI |
| Azure | 190 | 22 | 20 | 16 | Official | UI, API, CLI |
| GCP | 109 | 20 | 18 | 12 | Official | UI, API, CLI |
| Kubernetes | 90 | 7 | 7 | 11 | Official | UI, API, CLI |
| GitHub | 24 | 3 | 1 | 5 | Official | UI, API, CLI |
| M365 | 107 | 10 | 4 | 10 | Official | UI, API, CLI |
| OCI | 52 | 14 | 4 | 10 | Official | UI, API, CLI |
| Alibaba Cloud | 63 | 9 | 5 | 9 | Official | UI, API, CLI |
| Cloudflare | 29 | 3 | 1 | 5 | Official | UI, API, CLI |
| AWS | 615 | 86 | 47 | 19 | Official | UI, API, CLI |
| Azure | 190 | 22 | 21 | 16 | Official | UI, API, CLI |
| GCP | 109 | 20 | 19 | 12 | Official | UI, API, CLI |
| Kubernetes | 90 | 7 | 8 | 11 | Official | UI, API, CLI |
| GitHub | 24 | 3 | 2 | 5 | Official | UI, API, CLI |
| M365 | 109 | 10 | 6 | 10 | Official | UI, API, CLI |
| OCI | 52 | 14 | 5 | 10 | Official | UI, API, CLI |
| Alibaba Cloud | 63 | 9 | 6 | 9 | Official | UI, API, CLI |
| Cloudflare | 29 | 3 | 2 | 5 | Official | UI, API, CLI |
| IaC | [See `trivy` docs.](https://trivy.dev/latest/docs/coverage/iac/) | N/A | N/A | N/A | Official | UI, API, CLI |
| MongoDB Atlas | 10 | 3 | 0 | 8 | Official | UI, API, CLI |
| MongoDB Atlas | 10 | 3 | 1 | 8 | Official | UI, API, CLI |
| LLM | [See `promptfoo` docs.](https://www.promptfoo.dev/docs/red-team/plugins/) | N/A | N/A | N/A | Official | CLI |
| Image | N/A | N/A | N/A | N/A | Official | CLI, API |
| Google Workspace | 65 | 11 | 2 | 6 | Official | UI, API, CLI |
| OpenStack | 34 | 5 | 0 | 9 | Official | UI, API, CLI |
| Vercel | 26 | 6 | 0 | 8 | Official | UI, API, CLI |
| Okta | 29 | 8 | 1 | 2 | Official | UI, API, CLI |
| Linode [Contact us](https://prowler.com/contact) | 10 | 3 | 0 | 4 | Unofficial | CLI |
| Scaleway [Contact us](https://prowler.com/contact) | 1 | 1 | 0 | 1 | Unofficial | CLI |
| StackIT [Contact us](https://prowler.com/contact) | 7 | 2 | 0 | 3 | Unofficial | CLI |
| NHN | 6 | 2 | 1 | 0 | Unofficial | CLI |
| Google Workspace | 65 | 11 | 3 | 6 | Official | UI, API, CLI |
| OpenStack | 34 | 5 | 1 | 9 | Official | UI, API, CLI |
| Vercel | 26 | 6 | 1 | 8 | Official | UI, API, CLI |
| Okta | 29 | 8 | 2 | 2 | Official | UI, API, CLI |
| Linode [Contact us](https://prowler.com/contact) | 10 | 3 | 1 | 4 | Unofficial | CLI |
| Scaleway [Contact us](https://prowler.com/contact) | 1 | 1 | 1 | 1 | Unofficial | CLI |
| StackIT [Contact us](https://prowler.com/contact) | 7 | 2 | 1 | 3 | Unofficial | CLI |
| NHN | 6 | 2 | 2 | 0 | Unofficial | CLI |
> [!Note]
> The numbers in the table are updated periodically.
+28
View File
@@ -2,6 +2,34 @@
All notable changes to the **Prowler API** are documented in this file.
## [1.33.0] (Prowler v5.32.0)
### 🚀 Added
- Timestamp precision support for `/api/v1/findings` `inserted_at` and `updated_at` filters [(#11754)](https://github.com/prowler-cloud/prowler/pull/11754)
### 🔄 Changed
- Attack Paths: AWS Neptune is now supported as a persistent sink database, selectable via `ATTACK_PATHS_SINK_DATABASE=neptune` (default `neo4j`), Cartography's (bumped to 0.138.1) per-scan ingest database stays on Neo4j [(#11524)](https://github.com/prowler-cloud/prowler/pull/11524)
- Attack Paths: Scan task now checks the ingest Neo4j database and configured graph sink before starting graph ingestion [(#11743)](https://github.com/prowler-cloud/prowler/pull/11743)
- Disable PowerShell telemetry in the API container image [(#11746)](https://github.com/prowler-cloud/prowler/pull/11746)
### 🐞 Fixed
- Attack Paths: Provider graph cleanup now deletes Neo4j and Neptune relationships in directed batches before deleting nodes [(#11755)](https://github.com/prowler-cloud/prowler/pull/11755)
- `scan-perform` no longer reports an error when a provider is deleted during a running scan [(#11696)](https://github.com/prowler-cloud/prowler/pull/11696)
---
## [1.32.1] (Prowler v5.31.1)
### 🐞 Fixed
- API key auth no longer mutates `TenantAPIKey.objects` during admin DB lookups [(#11686)](https://github.com/prowler-cloud/prowler/pull/11686)
---
## [1.32.0] (Prowler v5.31.0)
### 🚀 Added
+21 -2
View File
@@ -1,11 +1,13 @@
FROM python:3.12.13-slim-bookworm@sha256:76d4b7b6305788c6b4c6a19d6a22a3921bf802e9af4d5e1e5bd771208dba74bf AS build
FROM python:3.12.13-slim-bookworm@sha256:8a7e7cc04fd3e2bd787f7f24e22d5d119aa590d429b50c95dfe12b3abe52f48b AS build
LABEL maintainer="https://github.com/prowler-cloud/api"
ARG POWERSHELL_VERSION=7.5.0
ENV POWERSHELL_VERSION=${POWERSHELL_VERSION}
# Opt out of PowerShell telemetry (Application Insights -> dc.services.visualstudio.com)
ENV POWERSHELL_TELEMETRY_OPTOUT=1
ARG TRIVY_VERSION=0.71.0
ARG TRIVY_VERSION=0.71.2
ENV TRIVY_VERSION=${TRIVY_VERSION}
ARG ZIZMOR_VERSION=1.24.1
@@ -102,6 +104,23 @@ RUN uv sync --locked --no-install-project && \
RUN .venv/bin/python .venv/lib/python3.12/site-packages/prowler/providers/m365/lib/powershell/m365_powershell.py
USER root
# Remove build-only packages from the final image after Python dependencies are installed.
RUN apt-get purge -y --auto-remove \
gcc \
g++ \
make \
libxml2-dev \
libxmlsec1-dev \
pkg-config \
libtool \
libxslt1-dev \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
USER prowler
COPY --chown=prowler:prowler src/backend/ ./backend/
COPY --chown=prowler:prowler docker-entrypoint.sh ./docker-entrypoint.sh
+13 -7
View File
@@ -58,7 +58,7 @@ dependencies = [
"matplotlib (==3.10.8)",
"reportlab (==4.4.10)",
"neo4j (==6.1.0)",
"cartography (==0.135.0)",
"cartography (==0.138.1)",
"gevent (==25.9.1)",
"werkzeug (==3.1.7)",
"sqlparse (==0.5.5)",
@@ -71,7 +71,7 @@ name = "prowler-api"
package-mode = false
# Needed for the SDK compatibility
requires-python = ">=3.11,<3.13"
version = "1.33.0"
version = "1.34.0"
# Shared ruff baseline (kept in sync with mcp_server/pyproject.toml).
# target-version tracks this project's lowest supported Python.
@@ -193,7 +193,7 @@ constraint-dependencies = [
"blinker==1.9.0",
"boto3==1.40.61",
"botocore==1.40.61",
"cartography==0.135.0",
"cartography==0.138.1",
"celery==5.6.2",
"certifi==2026.1.4",
"cffi==2.0.0",
@@ -218,7 +218,6 @@ constraint-dependencies = [
"debugpy==1.8.20",
"decorator==5.2.1",
"defusedxml==0.7.1",
"detect-secrets==1.5.0",
"dill==0.4.1",
"distro==1.9.0",
"dj-rest-auth==7.0.1",
@@ -301,6 +300,7 @@ constraint-dependencies = [
"jsonschema==4.23.0",
"jsonschema-specifications==2025.9.1",
"keystoneauth1==5.13.0",
"kingfisher-bin==1.104.0",
"kiwisolver==1.4.9",
"knack==0.11.0",
"kombu==5.6.2",
@@ -447,7 +447,7 @@ constraint-dependencies = [
"wcwidth==0.5.3",
"websocket-client==1.9.0",
"werkzeug==3.1.7",
"workos==6.0.4",
"workos==6.0.8",
"wrapt==1.17.3",
"xlsxwriter==3.2.9",
"xmlsec==1.3.17",
@@ -458,8 +458,13 @@ constraint-dependencies = [
"zope-interface==8.2",
"zstd==1.5.7.3"
]
# prowler@master needs okta==3.4.2; cartography 0.135.0 declares okta<1.0.0 for an
# integration prowler does not import.
# prowler@master needs okta==3.4.2, but cartography 0.138.1 requires okta<1.0.0.
# Attack Paths does not ingest Okta today, so override the Cartography
# dependency to the Prowler pin.
#
# prowler@master needs azure-mgmt-containerservice==34.1.0, but cartography
# 0.138.1 requires azure-mgmt-containerservice>=41.0.0. Attack Paths does not
# ingest Azure today, so override the Cartography dependency to the Prowler pin.
#
# prowler@master hard-pins microsoft-kiota-abstractions==1.9.2 in [project.dependencies].
# The microsoft-kiota-http security bump to 1.9.9 (GHSA-7j59-v9qr-6fq9) requires
@@ -475,6 +480,7 @@ constraint-dependencies = [
# that request pyjwt[crypto] and leave cryptography (needed for RS256) only transitive.
override-dependencies = [
"okta==3.4.2",
"azure-mgmt-containerservice==34.1.0",
"microsoft-kiota-abstractions==1.9.9",
"dulwich==1.2.5",
"pyjwt[crypto]==2.13.0"
-3
View File
@@ -42,9 +42,6 @@ class ApiConfig(AppConfig):
):
self._ensure_crypto_keys()
# Neo4j driver is created lazily on first use (see api.attack_paths.database).
# App init never contacts Neo4j, so a Neo4j outage cannot block API startup.
def _ensure_crypto_keys(self):
"""
Orchestrator method that ensures all required cryptographic keys are present.
@@ -4,10 +4,10 @@ Cypher sanitizer for custom (user-supplied) Attack Paths queries.
Two responsibilities:
1. **Validation** - reject queries containing SSRF or dangerous procedure
patterns (defense-in-depth; the primary control is ``neo4j.READ_ACCESS``).
patterns (defense-in-depth; the primary control is `neo4j.READ_ACCESS`).
2. **Provider-scoped label injection** - inject a dynamic
``_Provider_{uuid}`` label into every node pattern so the database can
`_Provider_{uuid}` label into every node pattern so the database can
use its native label index for provider isolation.
Label-injection pipeline:
@@ -25,13 +25,13 @@ from rest_framework.exceptions import ValidationError
from tasks.jobs.attack_paths.config import get_provider_label
# Step 1 - String / comment protection
# Single combined regex: strings first, then line comments.
# Single combined regex: strings first, then line comments
# The regex engine finds the leftmost match, so a string like 'https://prowler.com'
# is consumed as a string before the // inside it can match as a comment.
# is consumed as a string before the // inside it can match as a comment
_PROTECTED_RE = re.compile(r"'(?:[^'\\]|\\.)*'|\"(?:[^\"\\]|\\.)*\"|//[^\n]*")
# Step 2 - Clause splitting
# OPTIONAL MATCH must come before MATCH to avoid partial matching.
# `OPTIONAL MATCH` must come before `MATCH` to avoid partial matching
_CLAUSE_RE = re.compile(
r"\b(OPTIONAL\s+MATCH|MATCH|WHERE|RETURN|WITH|ORDER\s+BY"
r"|SKIP|LIMIT|UNION|UNWIND|CALL)\b",
@@ -39,10 +39,10 @@ _CLAUSE_RE = re.compile(
)
# Pass A - Labeled node patterns (all segments)
# Matches node patterns that have at least one :Label.
# (?<!\w)\( - open paren NOT preceded by a word char (excludes function calls).
# Group 1: optional variable + one or more :Label
# Group 2: optional {properties} + closing paren
# Matches node patterns that have at least one `:Label`
# `(?<!\w)\(` - open paren NOT preceded by a word char, excludes function calls
# Group 1: optional variable + one or more `:Label`
# Group 2: optional `{`properties`}` + closing paren
_LABELED_NODE_RE = re.compile(
r"(?<!\w)\("
r"("
@@ -55,9 +55,9 @@ _LABELED_NODE_RE = re.compile(
r")"
)
# Pass B - Bare node patterns (MATCH segments only)
# Matches (identifier) or (identifier {properties}) without any :Label.
# Only applied in MATCH/OPTIONAL MATCH segments.
# Pass B - Bare node patterns (`MATCH` segments only)
# Matches (identifier) or (identifier {properties}) without any `:Label`
# Only applied in `MATCH` / `OPTIONAL MATCH` segments
_BARE_NODE_RE = re.compile(
r"(?<!\w)\(" r"(\s*[a-zA-Z_]\w*)" r"(\s*(?:\{[^}]*\})?)" r"\s*\)"
)
@@ -96,6 +96,11 @@ def inject_provider_label(cypher: str, provider_id: str) -> str:
node pattern.
"""
label = get_provider_label(provider_id)
return inject_label(cypher, label)
def inject_label(cypher: str, label: str) -> str:
"""Rewrite a Cypher query to append a label to every node pattern."""
# Step 1: Protect strings and comments (single pass, leftmost-first)
protected: list[str] = []
@@ -134,9 +139,7 @@ def inject_provider_label(cypher: str, provider_id: str) -> str:
return work
# ---------------------------------------------------------------------------
# Validation
# ---------------------------------------------------------------------------
# Patterns that indicate SSRF or dangerous procedure calls
# Defense-in-depth layer - the primary control is `neo4j.READ_ACCESS`
+195 -251
View File
@@ -1,261 +1,32 @@
import atexit
import logging
import threading
from collections.abc import Iterator
from contextlib import contextmanager
"""Backwards-compatible facade over the ingest and sink modules.
Historically this module owned a single Neo4j driver used for both the
cartography temp database and the per-tenant sink database. The port to AWS
Neptune split those roles: the cartography ingest (temp) database is always
Neo4j and lives in `api.attack_paths.ingest`; the sink is configurable
(Neo4j or Neptune) and lives in `api.attack_paths.sink`. This shim preserves
the public API that `tasks/` and `api/v1/views.py` already depend on, and
dispatches to the right module by database-name prefix.
A database name starting with `db-tmp-scan-` is a cartography temp DB and
routes to ingest. Everything else routes to the configured sink.
"""
from contextlib import AbstractContextManager
from typing import Any
from uuid import UUID
import neo4j
import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
import neo4j # noqa: F401 - kept for tests that patch api.attack_paths.database.neo4j
from api.attack_paths import ingest
from api.attack_paths import sink as sink_module
from config.env import env
from django.conf import settings
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
PROVIDER_RESOURCE_LABEL,
get_provider_label,
from django.conf import (
settings, # noqa: F401 - kept for tests that patch ...database.settings
)
# Without this Celery goes crazy with Neo4j logging
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
SERVICE_UNAVAILABLE_MAX_RETRIES = env.int(
"ATTACK_PATHS_SERVICE_UNAVAILABLE_MAX_RETRIES", default=3
)
READ_QUERY_TIMEOUT_SECONDS = env.int(
"ATTACK_PATHS_READ_QUERY_TIMEOUT_SECONDS", default=30
)
MAX_CUSTOM_QUERY_NODES = env.int("ATTACK_PATHS_MAX_CUSTOM_QUERY_NODES", default=250)
# Shorter than CONN_ACQUISITION_TIMEOUT — the driver requires acquisition to be
# the longer of the two (it may include opening a new connection).
CONNECTION_TIMEOUT = env.int("NEO4J_CONNECTION_TIMEOUT", default=5)
CONN_ACQUISITION_TIMEOUT = env.int("NEO4J_CONN_ACQUISITION_TIMEOUT", default=15)
READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
# Module-level process-wide driver singleton
_driver: neo4j.Driver | None = None
_lock = threading.Lock()
# Base Neo4j functions
def get_uri() -> str:
host = settings.DATABASES["neo4j"]["HOST"]
port = settings.DATABASES["neo4j"]["PORT"]
return f"bolt://{host}:{port}"
def init_driver() -> neo4j.Driver:
global _driver
if _driver is not None:
return _driver
with _lock:
if _driver is None:
uri = get_uri()
config = settings.DATABASES["neo4j"]
driver = neo4j.GraphDatabase.driver(
uri,
auth=(config["USER"], config["PASSWORD"]),
keep_alive=True,
max_connection_lifetime=7200,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONN_ACQUISITION_TIMEOUT,
max_connection_pool_size=50,
)
# Publish the singleton only after connectivity is verified so a
# failed probe does not leave an unverified driver behind. Close the
# driver on failure so a repeatedly-probed outage cannot leak pools.
try:
driver.verify_connectivity()
except Exception:
driver.close()
raise
_driver = driver
# Register cleanup handler (only runs once since we're inside the _driver is None block)
atexit.register(close_driver)
return _driver
def get_driver() -> neo4j.Driver:
return init_driver()
def close_driver() -> None: # TODO: Use it
global _driver
with _lock:
if _driver is not None:
try:
_driver.close()
finally:
_driver = None
@contextmanager
def get_session(
database: str | None = None, default_access_mode: str | None = None
) -> Iterator[RetryableSession]:
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: get_driver().session(
database=database, default_access_mode=default_access_mode
),
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
message = "Read query not allowed"
code = READ_EXCEPTION_CODES[0]
raise WriteQueryNotAllowedException(message=message, code=code)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
def execute_read_query(
database: str,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph:
with get_session(database, default_access_mode=neo4j.READ_ACCESS) as session:
def _run(tx: neo4j.ManagedTransaction) -> neo4j.graph.Graph:
result = tx.run(
cypher, parameters or {}, timeout=READ_QUERY_TIMEOUT_SECONDS
)
return result.graph()
return session.execute_read(_run)
def create_database(database: str) -> None:
query = "CREATE DATABASE $database IF NOT EXISTS"
parameters = {"database": database}
with get_session() as session:
session.run(query, parameters)
def drop_database(database: str) -> None:
query = f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA"
with get_session() as session:
session.run(query)
def drop_subgraph(database: str, provider_id: str) -> int:
"""
Delete all nodes for a provider from the tenant database.
Deletes relationships then nodes in batches (not `DETACH DELETE`) so a dense
provider's graph cannot exceed Neo4j's transaction memory limit.
Silently returns 0 if the database doesn't exist.
"""
provider_label = get_provider_label(provider_id)
deleted_nodes = 0
try:
with get_session(database) as session:
# Phase 1: delete relationships incident to provider nodes in batches.
deleted_count = 1
while deleted_count > 0:
result = session.run(
f"""
MATCH (:`{provider_label}`)-[r]-()
WITH DISTINCT r LIMIT $batch_size
DELETE r
RETURN COUNT(r) AS deleted_rels_count
""",
{"batch_size": BATCH_SIZE},
)
deleted_count = result.single().get("deleted_rels_count", 0)
# Phase 2: delete the now relationship-free nodes in batches.
deleted_count = 1
while deleted_count > 0:
result = session.run(
f"""
MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`)
WITH n LIMIT $batch_size
DELETE n
RETURN COUNT(n) AS deleted_nodes_count
""",
{"batch_size": BATCH_SIZE},
)
deleted_count = result.single().get("deleted_nodes_count", 0)
deleted_nodes += deleted_count
except GraphDatabaseQueryException as exc:
if exc.code == "Neo.ClientError.Database.DatabaseNotFound":
return 0
raise
return deleted_nodes
def has_provider_data(database: str, provider_id: str) -> bool:
"""
Check if any ProviderResource node exists for this provider.
Returns `False` if the database doesn't exist.
"""
provider_label = get_provider_label(provider_id)
query = f"MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`) RETURN 1 LIMIT 1"
try:
with get_session(database, default_access_mode=neo4j.READ_ACCESS) as session:
result = session.run(query)
return result.single() is not None
except GraphDatabaseQueryException as exc:
if exc.code == "Neo.ClientError.Database.DatabaseNotFound":
return False
raise
def clear_cache(database: str) -> None:
query = "CALL db.clearQueryCaches()"
try:
with get_session(database) as session:
session.run(query)
except GraphDatabaseQueryException as exc:
logging.warning(f"Failed to clear query cache for database `{database}`: {exc}")
# Neo4j functions related to Prowler + Cartography
def get_database_name(entity_id: str | UUID, temporary: bool = False) -> str:
prefix = "tmp-scan" if temporary else "tenant"
return f"db-{prefix}-{str(entity_id).lower()}"
TEMP_DB_PREFIX = "db-tmp-scan-"
# Exceptions
@@ -270,7 +41,6 @@ class GraphDatabaseQueryException(Exception):
def __str__(self) -> str:
if self.code:
return f"{self.code}: {self.message}"
return self.message
@@ -280,3 +50,177 @@ class WriteQueryNotAllowedException(GraphDatabaseQueryException):
class ClientStatementException(GraphDatabaseQueryException):
pass
# Routing
def _is_ingest_database(database: str | None) -> bool:
return bool(database) and database.startswith(TEMP_DB_PREFIX)
# Driver lifecycle
def init_driver() -> Any:
"""Initialize the configured sink backend.
The ingest driver (Neo4j for cartography temp DBs) stays lazy: it is
only initialized when a temp-DB operation actually runs, which never
happens on API pods.
"""
return sink_module.init()
def close_driver() -> None:
"""Close every driver held by this process."""
sink_module.close()
ingest.close_driver()
def get_driver() -> neo4j.Driver:
"""Return the sink backend's underlying driver.
Only meaningful for the Neo4j sink (where the backend has a single Neo4j
driver). On Neptune this returns the writer driver. Kept for tests and
legacy call-sites; prefer `get_session` for new code.
"""
backend = sink_module.get_backend()
# Neo4jSink exposes get_driver(); NeptuneSink exposes get_writer()
if hasattr(backend, "get_driver"):
return backend.get_driver()
if hasattr(backend, "get_writer"):
return backend.get_writer()
raise RuntimeError("Active sink backend does not expose a driver handle")
def verify_connectivity() -> None:
"""Raise if the configured graph database is unreachable on the API read path.
Backend-agnostic entry point for the readiness probe: Neo4j verifies its
driver, Neptune verifies the reader endpoint.
"""
sink_module.get_backend().verify_connectivity()
def verify_scan_databases_available() -> None:
"""Raise if either graph database needed by an Attack Paths scan is unavailable."""
errors: list[str] = []
first_error: Exception | None = None
try:
ingest.get_driver().verify_connectivity()
except Exception as exc:
errors.append(f"ingest Neo4j: {exc}")
first_error = exc
try:
get_driver().verify_connectivity()
except Exception as exc:
errors.append(f"sink {settings.ATTACK_PATHS_SINK_DATABASE}: {exc}")
if first_error is None:
first_error = exc
if errors:
raise RuntimeError(
"Attack Paths graph database unavailable before scan start: "
+ "; ".join(errors)
) from first_error
def get_uri() -> str:
"""Return the sink URI. Retained for backwards compatibility."""
if settings.ATTACK_PATHS_SINK_DATABASE == "neptune":
cfg = settings.DATABASES["neptune"]
return f"bolt+s://{cfg['WRITER_ENDPOINT']}:{cfg['PORT']}"
cfg = settings.DATABASES["neo4j"]
return f"bolt://{cfg['HOST']}:{cfg['PORT']}"
def get_ingest_uri() -> str:
"""Neo4j URI for the cartography temp (ingest) database, which is always
Neo4j regardless of the configured sink."""
return ingest.get_uri()
# Session API
def get_session(
database: str | None = None,
default_access_mode: str | None = None,
) -> AbstractContextManager:
"""Return a session against the right backend.
- `database` names starting with `db-tmp-scan-` always go to ingest.
- No database name → ingest (used for CREATE / DROP DATABASE admin ops).
- Any other name → sink.
"""
if _is_ingest_database(database) or database is None:
return ingest.get_session(
database=database, default_access_mode=default_access_mode
)
return sink_module.get_backend().get_session(
database=database, default_access_mode=default_access_mode
)
def execute_read_query(
database: str,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph:
"""Read-only query against the sink."""
return sink_module.get_backend().execute_read_query(database, cypher, parameters)
def create_database(database: str) -> None:
"""Create a database. Temp DBs always land on ingest (Neo4j).
On the Neo4j sink, tenant DBs also route to ingest because both drivers
connect to the same Neo4j cluster. On the Neptune sink, tenant DB creates
are no-ops.
"""
if _is_ingest_database(database):
ingest.create_database(database)
return
sink_module.get_backend().create_database(database)
def drop_database(database: str) -> None:
"""Drop a database. Mirrors `create_database` routing."""
if _is_ingest_database(database):
ingest.drop_database(database)
return
sink_module.get_backend().drop_database(database)
def drop_subgraph(database: str, provider_id: str) -> int:
return sink_module.get_backend().drop_subgraph(database, provider_id)
def has_provider_data(database: str, provider_id: str) -> bool:
return sink_module.get_backend().has_provider_data(database, provider_id)
def clear_cache(database: str) -> None:
if _is_ingest_database(database):
ingest.clear_cache(database)
return
sink_module.get_backend().clear_cache(database)
# Name helper
def get_database_name(entity_id: str | UUID, temporary: bool = False) -> str:
prefix = "tmp-scan" if temporary else "tenant"
return f"db-{prefix}-{str(entity_id).lower()}"
@@ -0,0 +1,29 @@
"""Cartography ingest layer.
Public surface for the per-scan Neo4j temp database driver. Implementation
lives in `api.attack_paths.ingest.driver`.
"""
from api.attack_paths.ingest.driver import (
clear_cache,
close_driver,
create_database,
drop_database,
get_driver,
get_session,
get_uri,
init_driver,
run_cypher,
)
__all__ = [
"clear_cache",
"close_driver",
"create_database",
"drop_database",
"get_driver",
"get_session",
"get_uri",
"init_driver",
"run_cypher",
]
@@ -0,0 +1,187 @@
"""Cartography ingest driver: per-scan throw-away Neo4j database.
Cartography writes each scan's graph into a throw-away Neo4j database named
`db-tmp-scan-{scan_uuid}`. This is always Neo4j, regardless of the configured
sink: Neptune is single-database and cannot host per-scan throw-away
databases. This module owns the Neo4j driver used for those temp DBs and the
admin ops they need (CREATE / DROP DATABASE).
"""
import atexit
import logging
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
import neo4j
import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
from config.env import env
from django.conf import settings
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
SERVICE_UNAVAILABLE_MAX_RETRIES = env.int(
"ATTACK_PATHS_SERVICE_UNAVAILABLE_MAX_RETRIES", default=3
)
CONN_ACQUISITION_TIMEOUT = env.int("NEO4J_CONN_ACQUISITION_TIMEOUT", default=15)
# TCP connect timeout, ordered below the acquisition timeout so an unreachable
# host can't pin a worker on a temp-DB op longer than this.
CONNECTION_TIMEOUT = env.int("NEO4J_CONNECTION_TIMEOUT", default=5)
MAX_CONNECTION_LIFETIME = env.int("NEO4J_MAX_CONNECTION_LIFETIME", default=7200)
MAX_CONNECTION_POOL_SIZE = env.int("NEO4J_MAX_CONNECTION_POOL_SIZE", default=50)
_driver: neo4j.Driver | None = None
_lock = threading.Lock()
def _neo4j_config() -> dict:
return settings.DATABASES["neo4j"]
def get_uri() -> str:
"""Bolt URI for the Neo4j temp (ingest) database. Always Neo4j."""
config = _neo4j_config()
host = config["HOST"]
port = config["PORT"]
if not host or not port:
raise RuntimeError(
"NEO4J_HOST / NEO4J_PORT must be set to use the attack-paths "
"temp database. Workers require Neo4j env even when the sink is Neptune."
)
return f"bolt://{host}:{port}"
def init_driver() -> neo4j.Driver:
"""Initialize the temp-database Neo4j driver. Idempotent."""
global _driver
if _driver is not None:
return _driver
with _lock:
if _driver is None:
config = _neo4j_config()
_driver = neo4j.GraphDatabase.driver(
get_uri(),
auth=(config["USER"], config["PASSWORD"]),
keep_alive=True,
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONN_ACQUISITION_TIMEOUT,
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
)
# Best-effort connectivity check: a Neo4j that is down at boot must
# not crash the worker. The driver reconnects lazily on first use.
try:
_driver.verify_connectivity()
except Exception:
logging.warning(
"Neo4j temp-database unreachable at init; continuing with a "
"lazily-reconnecting driver",
exc_info=True,
)
atexit.register(close_driver)
return _driver
def get_driver() -> neo4j.Driver:
return init_driver()
def close_driver() -> None:
global _driver
with _lock:
if _driver is not None:
try:
_driver.close()
finally:
_driver = None
@contextmanager
def get_session(
database: str | None = None,
default_access_mode: str | None = None,
) -> Iterator[RetryableSession]:
"""Session against the Neo4j temp-database cluster. Used for temp DB sessions
and for admin operations (CREATE / DROP DATABASE) when `database` is None."""
from api.attack_paths.database import (
ClientStatementException,
GraphDatabaseQueryException,
WriteQueryNotAllowedException,
)
READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: get_driver().session(
database=database, default_access_mode=default_access_mode
),
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
raise WriteQueryNotAllowedException(
message="Read query not allowed", code=READ_EXCEPTION_CODES[0]
)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
def create_database(database: str) -> None:
"""Create a database on the Neo4j cluster. Used for temp scan DBs."""
with get_session() as session:
session.run("CREATE DATABASE $database IF NOT EXISTS", {"database": database})
def drop_database(database: str) -> None:
"""Drop a database on the Neo4j cluster. Used for temp scan DBs."""
with get_session() as session:
session.run(f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA")
def clear_cache(database: str) -> None:
"""Best-effort cache clear for a Neo4j database."""
from api.attack_paths.database import GraphDatabaseQueryException
try:
with get_session(database) as session:
session.run("CALL db.clearQueryCaches()")
except GraphDatabaseQueryException as exc:
logging.warning(f"Failed to clear query cache for database `{database}`: {exc}")
def run_cypher(
database: str | None,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> Any:
"""Execute Cypher directly without the context manager. Thin helper."""
with get_session(database) as session:
return session.run(cypher, parameters or {})
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,12 +1,14 @@
from api.attack_paths.queries.aws import AWS_QUERIES
# TODO: drop after Neptune cutover
from api.attack_paths.queries.aws_deprecated import AWS_DEPRECATED_QUERIES
from api.attack_paths.queries.types import AttackPathsQueryDefinition
# Query definitions organized by provider
# Query definitions for scans synced with the current schema.
_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
"aws": AWS_QUERIES,
}
# Flat lookup by query ID for O(1) access
_QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = {
definition.id: definition
for definitions in _QUERY_DEFINITIONS.values()
@@ -14,11 +16,45 @@ _QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = {
}
def get_queries_for_provider(provider: str) -> list[AttackPathsQueryDefinition]:
"""Get all attack path queries for a specific provider."""
return _QUERY_DEFINITIONS.get(provider, [])
# TODO: drop after Neptune cutover
#
# Query definitions for pre-cutover scans (`AttackPathsScan.is_migrated=False`)
# whose graph data was written under the previous schema. Both maps expose the
# same query IDs so the API contract is identical regardless of which set is
# routed to.
_DEPRECATED_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
"aws": AWS_DEPRECATED_QUERIES,
}
_DEPRECATED_QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = {
definition.id: definition
for definitions in _DEPRECATED_QUERY_DEFINITIONS.values()
for definition in definitions
}
def get_query_by_id(query_id: str) -> AttackPathsQueryDefinition | None:
"""Get a specific attack path query by its ID."""
return _QUERIES_BY_ID.get(query_id)
def get_queries_for_provider(
provider: str,
is_migrated: bool = True,
) -> list[AttackPathsQueryDefinition]:
"""Get all attack path queries for a provider.
`is_migrated` selects the catalog: True for scans synced with the current
schema, False for pre-cutover scans still using the legacy graph shape.
# TODO: drop the `is_migrated` parameter after Neptune cutover
"""
catalog = _QUERY_DEFINITIONS if is_migrated else _DEPRECATED_QUERY_DEFINITIONS
return catalog.get(provider, [])
def get_query_by_id(
query_id: str,
is_migrated: bool = True,
) -> AttackPathsQueryDefinition | None:
"""Get a specific attack path query by ID.
`is_migrated` selects the catalog (see `get_queries_for_provider`).
# TODO: drop the `is_migrated` parameter after Neptune cutover
"""
by_id = _QUERIES_BY_ID if is_migrated else _DEPRECATED_QUERIES_BY_ID
return by_id.get(query_id)
@@ -0,0 +1,28 @@
"""Attack-paths sink database layer.
The sink is the persistent store where attack-paths graphs live after a scan
finishes. Currently selectable between Neo4j (OSS / local dev default) and
AWS Neptune (hosted dev/staging/prod). Backend is picked by the
`ATTACK_PATHS_SINK_DATABASE` setting at process init.
This package exposes the public factory API; the implementation lives in
`api.attack_paths.sink.factory`.
"""
from api.attack_paths.sink.factory import (
SinkBackend,
close,
get_backend,
get_backend_for_name,
get_backend_for_scan,
init,
)
__all__ = [
"SinkBackend",
"close",
"get_backend",
"get_backend_for_name",
"get_backend_for_scan",
"init",
]
@@ -0,0 +1,92 @@
"""Protocol every sink backend must implement."""
from contextlib import AbstractContextManager
from typing import Any, Protocol
import neo4j
class SinkDatabase(Protocol):
"""Contract for the persistent attack-paths graph store.
The `database` argument is an opaque identifier passed through from the
legacy `database.py` API surface. On Neo4j it is the per-tenant database
name (e.g. `db-tenant-{uuid}`). On Neptune it is ignored (the cluster
has a single graph, and isolation is label-based).
"""
def init(self) -> None: ...
def close(self) -> None: ...
def verify_connectivity(self) -> None:
"""Raise if the backend the API read path uses is unreachable.
Neo4j verifies its single driver. Neptune verifies the reader
driver (the endpoint the API serves reads from); on single-endpoint
clusters the reader aliases the writer, so that path is covered too.
Used by the readiness probe; must not block longer than the caller's
probe budget.
"""
...
def get_session(
self,
database: str | None = None,
default_access_mode: str | None = None,
) -> AbstractContextManager: ...
def execute_read_query(
self,
database: str,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph: ...
def create_database(self, database: str) -> None: ...
def drop_database(self, database: str) -> None: ...
def drop_subgraph(self, database: str, provider_id: str) -> int: ...
def has_provider_data(self, database: str, provider_id: str) -> bool: ...
def clear_cache(self, database: str) -> None: ...
def ensure_sync_indexes(self, database: str) -> None:
"""Create any index needed for the sync write path.
Called once at the start of each provider sync; must be idempotent.
Neo4j creates a `_provider_element_id` index on `_ProviderResource`;
Neptune is a no-op (its `~id` lookup needs no index).
"""
...
def write_nodes(
self,
database: str,
labels: str,
rows: list[dict[str, Any]],
) -> None:
"""Upsert a batch of nodes into the sink.
`labels` is a pre-rendered Cypher label string ready to drop after
the node variable (e.g. `` `AWSUser`:`_ProviderResource`:`_Tenant_x` ``).
Each row carries `provider_element_id` and `props`.
"""
...
def write_relationships(
self,
database: str,
rel_type: str,
provider_id: str,
rows: list[dict[str, Any]],
) -> None:
"""Upsert a batch of relationships into the sink.
Each row carries `start_element_id`, `end_element_id`,
`provider_element_id` and `props`. `rel_type` is the relationship
type (already a valid Cypher identifier).
"""
...
@@ -0,0 +1,78 @@
"""Shared batched deletion helpers for sink backends."""
import logging
import time
from typing import Any
RELATIONSHIP_DELETE_QUERY_TEMPLATES = {
"outgoing relationship": """
MATCH (n:`{provider_label}`)-[r]->()
WITH r LIMIT $batch_size
DELETE r
RETURN COUNT(r) AS deleted_rels_count
""",
"incoming relationship": """
MATCH (n:`{provider_label}`)<-[r]-()
WITH r LIMIT $batch_size
DELETE r
RETURN COUNT(r) AS deleted_rels_count
""",
}
NODE_DELETE_QUERY_TEMPLATE = """
MATCH (n:{provider_resource_label}:`{provider_label}`)
WITH n LIMIT $batch_size
DELETE n
RETURN COUNT(n) AS deleted_nodes_count
"""
def delete_batches(
*,
session: Any,
logger: logging.Logger,
log_target: str,
provider_id: str,
query: str,
phase: str,
count_key: str,
total_key: str,
deleted_key: str,
initial_total: int,
batch_size: int,
drop_t0: float,
) -> tuple[int, int]:
deleted_total = initial_total
batches = 0
while True:
logger.info(
"Deleting %s batch from %s "
"(provider=%s, batch=%s, total_%s=%s, elapsed=%.3fs)",
phase,
log_target,
provider_id,
batches + 1,
total_key,
deleted_total,
time.perf_counter() - drop_t0,
)
record = session.run(query, {"batch_size": batch_size}).single()
deleted = (record[count_key] if record else 0) or 0
if deleted == 0:
return deleted_total, batches
batches += 1
deleted_total += deleted
logger.info(
"Deleted %s batch from %s "
"(provider=%s, batch=%s, %s=%s, total_%s=%s, elapsed=%.3fs)",
phase,
log_target,
provider_id,
batches,
deleted_key,
deleted,
total_key,
deleted_total,
time.perf_counter() - drop_t0,
)
@@ -0,0 +1,134 @@
"""Sink backend factory and process-wide handle cache.
Picks the active backend from `settings.ATTACK_PATHS_SINK_DATABASE` at first
use, holds the active backend plus any secondary backends needed to serve
scans written under the previous configuration, and tears them all down on
process shutdown. Imported via `from api.attack_paths import sink as
sink_module`.
"""
import threading
from enum import StrEnum, auto
from api.attack_paths.sink.base import SinkDatabase
from api.models import AttackPathsScan
from django.conf import settings
# Backend names
class SinkBackend(StrEnum):
NEO4J = auto()
NEPTUNE = auto()
# Backend cache
_backend: SinkDatabase | None = None
_secondary_backends: dict[SinkBackend, SinkDatabase] = {}
_lock = threading.Lock()
def _resolve_setting() -> SinkBackend:
raw = settings.ATTACK_PATHS_SINK_DATABASE.lower()
try:
return SinkBackend(raw)
except ValueError:
valid = sorted(b.value for b in SinkBackend)
raise RuntimeError(
f"ATTACK_PATHS_SINK_DATABASE must be one of {valid}; got {raw!r}"
)
def _build_backend(name: SinkBackend) -> SinkDatabase:
if name is SinkBackend.NEO4J:
from api.attack_paths.sink.neo4j import Neo4jSink
return Neo4jSink()
if name is SinkBackend.NEPTUNE:
from api.attack_paths.sink.neptune import NeptuneSink
return NeptuneSink()
raise RuntimeError(f"Unknown sink backend {name!r}")
# Lifecycle
def init(name: SinkBackend | str | None = None) -> SinkDatabase:
"""Initialize the configured sink backend. Idempotent."""
global _backend
if _backend is not None:
return _backend
with _lock:
if _backend is None:
resolved = SinkBackend(name) if name else _resolve_setting()
backend = _build_backend(resolved)
backend.init()
_backend = backend
return _backend
def close() -> None:
"""Close the active backend and every cached secondary backend."""
global _backend
with _lock:
backends = [
b for b in (_backend, *_secondary_backends.values()) if b is not None
]
_backend = None
_secondary_backends.clear()
for backend in backends:
try:
backend.close()
except Exception: # pragma: no cover - best-effort
pass
def get_backend() -> SinkDatabase:
"""Return the active sink. Initializes on first call."""
return init()
# Per-scan routing
def get_backend_for_scan(scan: AttackPathsScan) -> SinkDatabase:
"""Route reads by the sink that stores this scan's graph."""
raw_backend = getattr(scan, "sink_backend", SinkBackend.NEO4J.value)
if not isinstance(raw_backend, str):
raw_backend = SinkBackend.NEO4J.value
return get_backend_for_name(raw_backend)
def get_backend_for_name(name: SinkBackend | str) -> SinkDatabase:
"""Return the backend named by persisted scan metadata."""
resolved = SinkBackend(name)
if resolved is _resolve_setting():
return get_backend()
return _build_backend_cached(resolved)
def _build_backend_cached(name: SinkBackend) -> SinkDatabase:
# TODO: drop after Neptune cutover
# Needed only during cutover to serve Neo4j-written scans from a Neptune-
# configured API pod (and vice versa). Once every scan is on Neptune,
# `get_backend_for_scan` becomes a one-liner returning `get_backend()`.
if name in _secondary_backends:
return _secondary_backends[name]
with _lock:
if name not in _secondary_backends:
backend = _build_backend(name)
backend.init()
_secondary_backends[name] = backend
return _secondary_backends[name]
@@ -0,0 +1,417 @@
"""Neo4j sink implementation.
Owns a Neo4j driver independent from the staging driver. On OSS and local dev
this is the only sink; on hosted deployments it runs only as a legacy read
path while phase-1 drains tenant DBs.
"""
import atexit
import logging
import threading
import time
from collections.abc import Iterator
from contextlib import AbstractContextManager, contextmanager
from typing import Any
import neo4j
import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
from api.attack_paths.sink.base import SinkDatabase
from api.attack_paths.sink.drop import (
NODE_DELETE_QUERY_TEMPLATE,
RELATIONSHIP_DELETE_QUERY_TEMPLATES,
delete_batches,
)
from config.env import env
from django.conf import settings
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
logger = logging.getLogger(__name__)
SERVICE_UNAVAILABLE_MAX_RETRIES = env.int(
"ATTACK_PATHS_SERVICE_UNAVAILABLE_MAX_RETRIES", default=3
)
READ_QUERY_TIMEOUT_SECONDS = env.int(
"ATTACK_PATHS_READ_QUERY_TIMEOUT_SECONDS", default=30
)
CONN_ACQUISITION_TIMEOUT = env.int("NEO4J_CONN_ACQUISITION_TIMEOUT", default=15)
# TCP connect timeout, ordered below the acquisition timeout so an unreachable
# host can't pin a request or the readiness probe longer than this.
CONNECTION_TIMEOUT = env.int("NEO4J_CONNECTION_TIMEOUT", default=5)
MAX_CONNECTION_LIFETIME = env.int("NEO4J_MAX_CONNECTION_LIFETIME", default=7200)
MAX_CONNECTION_POOL_SIZE = env.int("NEO4J_MAX_CONNECTION_POOL_SIZE", default=50)
READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
DATABASE_NOT_FOUND_CODE = "Neo.ClientError.Database.DatabaseNotFound"
class Neo4jSink(SinkDatabase):
"""Neo4j-backed sink. Multi-database cluster; tenant isolation is physical."""
def __init__(self) -> None:
self._driver: neo4j.Driver | None = None
self._lock = threading.Lock()
self._atexit_registered = False
# Driver
def _config(self) -> dict:
return settings.DATABASES["neo4j"]
def _uri(self) -> str:
cfg = self._config()
host = cfg["HOST"]
port = cfg["PORT"]
if not host or not port:
raise RuntimeError(
"NEO4J_HOST / NEO4J_PORT must be set when ATTACK_PATHS_SINK_DATABASE=neo4j"
)
return f"bolt://{host}:{port}"
def init(self) -> neo4j.Driver:
if self._driver is not None:
return self._driver
with self._lock:
if self._driver is None:
cfg = self._config()
self._driver = neo4j.GraphDatabase.driver(
self._uri(),
auth=(cfg["USER"], cfg["PASSWORD"]),
keep_alive=True,
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONN_ACQUISITION_TIMEOUT,
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
)
# Eager connectivity check is best-effort:
# A Neo4j that is down at boot must not crash the process, same degradation model as Postgres
# The driver reconnects lazily on first use
# /health/ready surfaces the outage until it recovers
try:
self._driver.verify_connectivity()
except Exception:
logger.warning(
"Neo4j sink unreachable at init; continuing with a lazily-reconnecting driver",
exc_info=True,
)
if not self._atexit_registered:
atexit.register(self.close)
self._atexit_registered = True
return self._driver
def _get_driver(self) -> neo4j.Driver:
return self.init()
def verify_connectivity(self) -> None:
self._get_driver().verify_connectivity()
def close(self) -> None:
with self._lock:
if self._driver is not None:
try:
self._driver.close()
finally:
self._driver = None
# Sessions
@contextmanager
def get_session(
self,
database: str | None = None,
default_access_mode: str | None = None,
) -> Iterator[RetryableSession]:
from api.attack_paths.database import (
ClientStatementException,
GraphDatabaseQueryException,
WriteQueryNotAllowedException,
)
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: self._get_driver().session(
database=database, default_access_mode=default_access_mode
),
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
raise WriteQueryNotAllowedException(
message="Read query not allowed", code=READ_EXCEPTION_CODES[0]
)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
# Operations
def execute_read_query(
self,
database: str,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph:
with self.get_session(
database, default_access_mode=neo4j.READ_ACCESS
) as session:
def _run(tx: neo4j.ManagedTransaction) -> neo4j.graph.Graph:
result = tx.run(
cypher, parameters or {}, timeout=READ_QUERY_TIMEOUT_SECONDS
)
return result.graph()
return session.execute_read(_run)
def create_database(self, database: str) -> None:
with self.get_session() as session:
session.run(
"CREATE DATABASE $database IF NOT EXISTS", {"database": database}
)
def drop_database(self, database: str) -> None:
with self.get_session() as session:
session.run(f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA")
def drop_subgraph(self, database: str, provider_id: str) -> int:
"""Delete all nodes for a provider from a tenant database, batched.
Deletes relationships then nodes in batches (not `DETACH DELETE`) so a
dense provider's graph cannot exceed Neo4j's transaction memory limit.
Silently returns 0 if the database doesn't exist.
"""
from api.attack_paths.database import GraphDatabaseQueryException
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
deleted_nodes = deleted_relationships = 0
relationship_batches = node_batches = 0
drop_t0 = time.perf_counter()
logger.info(
"Dropping provider graph from Neo4j sink database %s "
"(provider=%s, provider_label=%s)",
database,
provider_id,
provider_label,
)
try:
logger.info(
"Opening Neo4j sink session for provider graph drop "
"(database=%s, provider=%s)",
database,
provider_id,
)
with self.get_session(database) as session:
logger.info(
"Opened Neo4j sink session for provider graph drop "
"(database=%s, provider=%s)",
database,
provider_id,
)
log_target = f"Neo4j sink database {database}"
for (
phase,
query_template,
) in RELATIONSHIP_DELETE_QUERY_TEMPLATES.items():
deleted_relationships, phase_batches = delete_batches(
session=session,
logger=logger,
log_target=log_target,
provider_id=provider_id,
query=query_template.format(provider_label=provider_label),
phase=phase,
count_key="deleted_rels_count",
total_key="rels",
deleted_key="deleted_rels",
initial_total=deleted_relationships,
batch_size=BATCH_SIZE,
drop_t0=drop_t0,
)
relationship_batches += phase_batches
deleted_nodes, node_batches = delete_batches(
session=session,
logger=logger,
log_target=log_target,
provider_id=provider_id,
query=NODE_DELETE_QUERY_TEMPLATE.format(
provider_label=provider_label,
provider_resource_label=PROVIDER_RESOURCE_LABEL,
),
phase="node",
count_key="deleted_nodes_count",
total_key="nodes",
deleted_key="deleted_nodes",
initial_total=0,
batch_size=BATCH_SIZE,
drop_t0=drop_t0,
)
except GraphDatabaseQueryException as exc:
if exc.code == DATABASE_NOT_FOUND_CODE:
logger.info(
"Skipped provider graph drop from Neo4j sink database %s "
"(provider=%s, reason=database_not_found, elapsed=%.3fs)",
database,
provider_id,
time.perf_counter() - drop_t0,
)
return 0
raise
logger.info(
"Finished dropping provider graph from Neo4j sink database %s "
"(provider=%s, relationship_batches=%s, deleted_rels=%s, "
"node_batches=%s, deleted_nodes=%s, elapsed=%.3fs)",
database,
provider_id,
relationship_batches,
deleted_relationships,
node_batches,
deleted_nodes,
time.perf_counter() - drop_t0,
)
return deleted_nodes
def has_provider_data(self, database: str, provider_id: str) -> bool:
from api.attack_paths.database import GraphDatabaseQueryException
from tasks.jobs.attack_paths.config import (
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
query = (
f"MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`) RETURN 1 LIMIT 1"
)
try:
with self.get_session(
database, default_access_mode=neo4j.READ_ACCESS
) as session:
result = session.run(query)
return result.single() is not None
except GraphDatabaseQueryException as exc:
if exc.code == DATABASE_NOT_FOUND_CODE:
return False
raise
def clear_cache(self, database: str) -> None:
from api.attack_paths.database import GraphDatabaseQueryException
try:
with self.get_session(database) as session:
session.run("CALL db.clearQueryCaches()")
except GraphDatabaseQueryException as exc:
logger.warning(
f"Failed to clear query cache for database `{database}`: {exc}"
)
# Sync write path
def ensure_sync_indexes(self, database: str) -> None:
"""Create the `_provider_element_id` lookup index on `_ProviderResource`.
Every synced node carries the `_ProviderResource` label, so a single
index covers both node-upserts and relationship endpoint MATCHes.
Without this index the rel sync degrades to a label scan per row and
large provider syncs become unworkable.
"""
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
)
query = (
f"CREATE INDEX provider_element_id_idx IF NOT EXISTS "
f"FOR (n:`{PROVIDER_RESOURCE_LABEL}`) "
f"ON (n.`{PROVIDER_ELEMENT_ID_PROPERTY}`)"
)
with self.get_session(database) as session:
session.run(query).consume()
def write_nodes(
self,
database: str,
labels: str,
rows: list[dict[str, Any]],
) -> None:
if not rows:
return
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
)
query = f"""
UNWIND $rows AS row
MERGE (n:`{PROVIDER_RESOURCE_LABEL}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.provider_element_id}})
SET n:{labels}
SET n += row.props
"""
with self.get_session(database) as session:
session.run(query, {"rows": rows}).consume()
def write_relationships(
self,
database: str,
rel_type: str,
provider_id: str,
rows: list[dict[str, Any]],
) -> None:
if not rows:
return
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
query = f"""
UNWIND $rows AS row
MATCH (s:`{PROVIDER_RESOURCE_LABEL}`:`{provider_label}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.start_element_id}})
MATCH (t:`{PROVIDER_RESOURCE_LABEL}`:`{provider_label}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.end_element_id}})
MERGE (s)-[r:`{rel_type}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.provider_element_id}}]->(t)
SET r += row.props
"""
with self.get_session(database) as session:
session.run(query, {"rows": rows}).consume()
# For compatibility with test harnesses that patch the concrete driver
def get_driver(self) -> neo4j.Driver:
return self._get_driver()
# Helper for tests / external callers that want a writer session specifically
def get_read_session(
sink: Neo4jSink, database: str
) -> AbstractContextManager[RetryableSession]:
return sink.get_session(database, default_access_mode=neo4j.READ_ACCESS)
@@ -0,0 +1,491 @@
"""AWS Neptune sink implementation.
Dual Bolt drivers: one against the writer endpoint for workers, one against
the reader endpoint for the API read path. If `NEPTUNE_READER_ENDPOINT` is
unset the reader falls back to the writer driver so single-node clusters work.
Neptune is single-database. The `database` argument on the SinkDatabase
protocol is ignored; tenant / provider isolation is enforced by labels that
the sync step already writes on every node (see tasks/jobs/attack_paths/sync.py).
SigV4 auth lives at the bottom of this file as `neptune_auth_provider`. The
neo4j driver invokes the returned callable on each token refresh.
"""
import atexit
import datetime
import json
import logging
import threading
import time
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any
from urllib.parse import urlsplit
import neo4j
import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
from api.attack_paths.sink.base import SinkDatabase
from api.attack_paths.sink.drop import (
NODE_DELETE_QUERY_TEMPLATE,
RELATIONSHIP_DELETE_QUERY_TEMPLATES,
delete_batches,
)
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.session import Session as BotoSession
from config.env import env
from django.conf import settings
from neo4j.auth_management import AuthManagers, ExpiringAuth
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
logger = logging.getLogger(__name__)
SERVICE_UNAVAILABLE_MAX_RETRIES = env.int(
"ATTACK_PATHS_SERVICE_UNAVAILABLE_MAX_RETRIES", default=3
)
READ_QUERY_TIMEOUT_SECONDS = env.int(
"ATTACK_PATHS_READ_QUERY_TIMEOUT_SECONDS", default=30
)
# Neptune serverless cold-start can be >30s; give the driver room
CONN_ACQUISITION_TIMEOUT = env.int("NEPTUNE_CONN_ACQUISITION_TIMEOUT", default=60)
# TCP connect timeout, ordered below the acquisition timeout so an unreachable
# endpoint can't pin a request or the readiness probe longer than this. Kept
# generous: cold-start delays query execution, not the socket connect.
CONNECTION_TIMEOUT = env.int("NEPTUNE_CONNECTION_TIMEOUT", default=10)
# Roll connections hourly so SigV4 rotations and cert refreshes don't strand long-lived pool entries
MAX_CONNECTION_LIFETIME = env.int("NEPTUNE_MAX_CONNECTION_LIFETIME", default=3600)
MAX_CONNECTION_POOL_SIZE = env.int("NEPTUNE_MAX_CONNECTION_POOL_SIZE", default=50)
READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
# Refresh 60s before the 5-minute SigV4 window closes
SIGV4_TOKEN_LIFETIME_MINUTES = 4
class NeptuneSink(SinkDatabase):
"""Neptune-backed sink. Single database; isolation is label-based."""
def __init__(self) -> None:
self._writer: neo4j.Driver | None = None
self._reader: neo4j.Driver | None = None
self._lock = threading.Lock()
self._atexit_registered = False
# Config
def _config(self) -> dict:
return settings.DATABASES["neptune"]
def _bolt_uri(self, endpoint: str, port: str) -> str:
return f"bolt+s://{endpoint}:{port}"
def _https_url(self, endpoint: str, port: str) -> str:
return f"https://{endpoint}:{port}"
def _build_driver(self, endpoint: str) -> neo4j.Driver:
cfg = self._config()
port = cfg["PORT"]
region = cfg["REGION"]
if not endpoint or not region:
raise RuntimeError(
"NEPTUNE_WRITER_ENDPOINT and AWS_REGION must be set when "
"ATTACK_PATHS_SINK_DATABASE=neptune"
)
return neo4j.GraphDatabase.driver(
self._bolt_uri(endpoint, port),
auth=AuthManagers.bearer(
neptune_auth_provider(region, self._https_url(endpoint, port))
),
keep_alive=True,
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONN_ACQUISITION_TIMEOUT,
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
max_transaction_retry_time=0,
)
# Lifecycle
def init(self) -> None:
if self._writer is not None:
return
with self._lock:
if self._writer is None:
cfg = self._config()
writer_endpoint = cfg["WRITER_ENDPOINT"]
reader_endpoint = cfg["READER_ENDPOINT"] or writer_endpoint
# Eager connectivity checks are best-effort
# A Neptune that is down at boot must not crash the process, same degradation model as Postgres
# Drivers reconnect lazily on first use
# /health/ready surfaces the outage until it recovers
self._writer = self._build_driver(writer_endpoint)
self._verify_best_effort(self._writer, "writer")
if reader_endpoint == writer_endpoint:
self._reader = self._writer
else:
self._reader = self._build_driver(reader_endpoint)
self._verify_best_effort(self._reader, "reader")
if not self._atexit_registered:
atexit.register(self.close)
self._atexit_registered = True
def close(self) -> None:
with self._lock:
# `Driver.close()` is idempotent, so closing the same driver twice
# (when reader aliases writer on single-endpoint configs) is safe
for driver in (self._reader, self._writer):
if driver is None:
continue
try:
driver.close()
except Exception: # pragma: no cover - best-effort
pass
self._writer = None
self._reader = None
# Sessions
def _get_writer(self) -> neo4j.Driver:
self.init()
assert self._writer is not None
return self._writer
def _get_reader(self) -> neo4j.Driver:
self.init()
assert self._reader is not None
return self._reader
@staticmethod
def _verify_best_effort(driver: neo4j.Driver, role: str) -> None:
try:
driver.verify_connectivity()
except Exception:
logger.warning(
"Neptune %s endpoint unreachable at init; continuing with a lazily-reconnecting driver",
role,
exc_info=True,
)
def verify_connectivity(self) -> None:
# The API read path uses the reader driver
# On single-endpoint clusters it aliases the writer, so this also covers the writer
# A writer-only outage is a workers' concern (no HTTP probe there) and deliberately does not fail API readiness
self._get_reader().verify_connectivity()
@contextmanager
def get_session(
self,
database: str | None = None, # noqa: ARG002 - ignored on Neptune
default_access_mode: str | None = None,
) -> Iterator[RetryableSession]:
from api.attack_paths.database import (
ClientStatementException,
GraphDatabaseQueryException,
WriteQueryNotAllowedException,
)
driver = (
self._get_reader()
if default_access_mode == neo4j.READ_ACCESS
else self._get_writer()
)
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: driver.session(
default_access_mode=default_access_mode
),
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
raise WriteQueryNotAllowedException(
message="Read query not allowed", code=READ_EXCEPTION_CODES[0]
)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
# Operations
def execute_read_query(
self,
database: str, # noqa: ARG002 - ignored on Neptune
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph:
with self.get_session(default_access_mode=neo4j.READ_ACCESS) as session:
def _run(tx: neo4j.ManagedTransaction) -> neo4j.graph.Graph:
result = tx.run(
cypher, parameters or {}, timeout=READ_QUERY_TIMEOUT_SECONDS
)
return result.graph()
return session.execute_read(_run)
def create_database(self, database: str) -> None: # noqa: ARG002
# Neptune clusters are single-database; there is nothing to create.
return None
def drop_database(self, database: str) -> None: # noqa: ARG002
# Neptune clusters are single-database; there is nothing to drop.
return None
def drop_subgraph(self, database: str, provider_id: str) -> int: # noqa: ARG002
"""Delete a provider's subgraph in two bounded phases.
Neptune write transactions are capped at ~2 minutes. A naive
`DETACH DELETE` on a label-scanned batch grows unbounded with graph
density (one node can drag thousands of relationships into the same
transaction). Instead:
1. Delete relationships incident to provider nodes, one fixed-size
batch per transaction.
2. Delete the now-orphaned nodes, one fixed-size batch per transaction.
Each transaction does work proportional to `batch_size`, never to the
graph's branching factor.
"""
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
deleted_relationships = 0
relationship_batches = 0
node_batches = 0
drop_t0 = time.perf_counter()
logger.info(
"Dropping provider graph from Neptune sink "
"(provider=%s, provider_label=%s)",
provider_id,
provider_label,
)
logger.info(
"Opening Neptune writer session for provider graph drop (provider=%s)",
provider_id,
)
with self.get_session() as session:
logger.info(
"Opened Neptune writer session for provider graph drop (provider=%s)",
provider_id,
)
for phase, query_template in RELATIONSHIP_DELETE_QUERY_TEMPLATES.items():
deleted_relationships, phase_batches = delete_batches(
session=session,
logger=logger,
log_target="Neptune sink",
provider_id=provider_id,
query=query_template.format(provider_label=provider_label),
phase=phase,
count_key="deleted_rels_count",
total_key="rels",
deleted_key="deleted_rels",
initial_total=deleted_relationships,
batch_size=BATCH_SIZE,
drop_t0=drop_t0,
)
relationship_batches += phase_batches
deleted_nodes, node_batches = delete_batches(
session=session,
logger=logger,
log_target="Neptune sink",
provider_id=provider_id,
query=NODE_DELETE_QUERY_TEMPLATE.format(
provider_label=provider_label,
provider_resource_label=PROVIDER_RESOURCE_LABEL,
),
phase="node",
count_key="deleted_nodes_count",
total_key="nodes",
deleted_key="deleted_nodes",
initial_total=0,
batch_size=BATCH_SIZE,
drop_t0=drop_t0,
)
logger.info(
"Finished dropping provider graph from Neptune sink "
"(provider=%s, relationship_batches=%s, deleted_rels=%s, "
"node_batches=%s, deleted_nodes=%s, elapsed=%.3fs)",
provider_id,
relationship_batches,
deleted_relationships,
node_batches,
deleted_nodes,
time.perf_counter() - drop_t0,
)
return deleted_nodes
def has_provider_data(self, database: str, provider_id: str) -> bool: # noqa: ARG002
from tasks.jobs.attack_paths.config import (
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
query = (
f"MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`) RETURN 1 LIMIT 1"
)
with self.get_session(default_access_mode=neo4j.READ_ACCESS) as session:
result = session.run(query)
return result.single() is not None
def clear_cache(self, database: str) -> None: # noqa: ARG002
# Neptune has no user-facing cache-clear procedure; no-op.
return None
# Sync write path
def ensure_sync_indexes(self, database: str) -> None: # noqa: ARG002
# Neptune routes node and relationship lookups through `~id`, which is the cluster's primary key
# No additional index is needed or supported
return None
def write_nodes(
self,
database: str, # noqa: ARG002
labels: str,
rows: list[dict[str, Any]],
) -> None:
if not rows:
return
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
)
# MERGE on `~id` is the documented and engine-optimized idempotent
# upsert pattern for Neptune openCypher. The label inside the MERGE
# matters: Neptune assigns a default `vertex` label to any node
# created without an explicit one, so we pin `_ProviderResource`
# (which every synced node carries anyway) at MERGE-time. Additional
# labels are added after
#
# We also write `_provider_element_id` as a regular property so
# non-sync code (drop_subgraph, query helpers) keeps a stable contract
# that doesn't know about `~id`
query = f"""
UNWIND $rows AS row
MERGE (n:`{PROVIDER_RESOURCE_LABEL}` {{`~id`: row.provider_element_id}})
SET n:{labels}
SET n += row.props
SET n.`{PROVIDER_ELEMENT_ID_PROPERTY}` = row.provider_element_id
"""
with self.get_session() as session:
session.run(query, {"rows": rows}).consume()
def write_relationships(
self,
database: str, # noqa: ARG002
rel_type: str,
provider_id: str, # noqa: ARG002 - encoded in start/end `~id` already
rows: list[dict[str, Any]],
) -> None:
if not rows:
return
from tasks.jobs.attack_paths.config import PROVIDER_ELEMENT_ID_PROPERTY
# `id(n) = $value` is Neptune's parameterized fast path; both endpoint
# MATCHes resolve in O(1) via the system `~id`, so per-row work stays
# bounded regardless of batch size
query = f"""
UNWIND $rows AS row
MATCH (s) WHERE id(s) = row.start_element_id
MATCH (e) WHERE id(e) = row.end_element_id
MERGE (s)-[r:`{rel_type}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.provider_element_id}}]->(e)
SET r += row.props
"""
with self.get_session() as session:
session.run(query, {"rows": rows}).consume()
# Test helpers
def get_writer(self) -> neo4j.Driver:
return self._get_writer()
def get_reader(self) -> neo4j.Driver:
return self._get_reader()
# SigV4 auth provider
class _NeptuneAuthToken(neo4j.Auth):
"""Neo4j Auth backed by a SigV4-signed GET to `/opencypher`."""
def __init__(self, region: str, url: str) -> None:
session = BotoSession()
credentials = session.get_credentials()
if credentials is None:
raise RuntimeError(
"No AWS credentials available for Neptune SigV4 signing. "
"Ensure the boto3 credential chain can resolve."
)
credentials = credentials.get_frozen_credentials()
request = AWSRequest(method="GET", url=url + "/opencypher")
# SigV4 canonical Host must carry the real `host:port`
# Neptune runs on a non-default port (8182), so `.hostname` would drop it and break signing
request.headers.add_header("Host", urlsplit(url).netloc)
SigV4Auth(credentials, "neptune-db", region).add_auth(request)
auth_obj = {
header: request.headers[header]
for header in (
"Authorization",
"X-Amz-Date",
"X-Amz-Security-Token",
"Host",
)
if header in request.headers
}
auth_obj["HttpMethod"] = "GET"
super().__init__("basic", "username", json.dumps(auth_obj))
def neptune_auth_provider(region: str, https_url: str) -> Callable[[], ExpiringAuth]:
"""Return a callable the neo4j driver can invoke to refresh credentials."""
def _provider() -> ExpiringAuth:
token = _NeptuneAuthToken(region, https_url)
expires_at = (
datetime.datetime.now(datetime.UTC)
+ datetime.timedelta(minutes=SIGV4_TOKEN_LIFETIME_MINUTES)
).timestamp()
return ExpiringAuth(auth=token, expires_at=expires_at)
return _provider
@@ -5,6 +5,7 @@ from typing import Any
import neo4j
from api.attack_paths import AttackPathsQueryDefinition
from api.attack_paths import database as graph_database
from api.attack_paths import sink as sink_module
from api.attack_paths.cypher_sanitizer import (
inject_provider_label,
validate_custom_query,
@@ -14,7 +15,9 @@ from api.attack_paths.queries.schema import (
RAW_SCHEMA_URL,
get_cartography_schema_query,
)
from api.models import AttackPathsScan
from config.custom_logging import BackendLogger
from config.env import env
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from tasks.jobs.attack_paths.config import (
INTERNAL_LABELS,
@@ -26,6 +29,10 @@ from tasks.jobs.attack_paths.config import (
logger = logging.getLogger(BackendLogger.API)
def _custom_query_timeout_ms() -> int:
return env.int("ATTACK_PATHS_READ_QUERY_TIMEOUT_SECONDS", default=30) * 1000
# Predefined query helpers
@@ -102,13 +109,13 @@ def execute_query(
definition: AttackPathsQueryDefinition,
parameters: dict[str, Any],
provider_id: str,
scan: AttackPathsScan,
) -> dict[str, Any]:
try:
graph = graph_database.execute_read_query(
database=database_name,
cypher=definition.cypher,
parameters=parameters,
)
# TODO: drop after Neptune cutover
# Route reads by the scan row's recorded sink, not by current settings.
backend = sink_module.get_backend_for_scan(scan)
graph = backend.execute_read_query(database_name, definition.cypher, parameters)
return _serialize_graph(graph, provider_id)
except graph_database.WriteQueryNotAllowedException:
@@ -142,22 +149,31 @@ def execute_custom_query(
database_name: str,
cypher: str,
provider_id: str,
scan: AttackPathsScan,
) -> dict[str, Any]:
# Defense-in-depth for custom queries:
# 1. neo4j.READ_ACCESS — prevents mutations at the driver level
# 2. inject_provider_label() — regex-based label injection scopes node patterns
# 3. _serialize_graph() — post-query filter drops nodes without the provider label
# 1. `neo4j.READ_ACCESS` — prevents mutations at the driver level
# 2. `inject_provider_label()` — regex-based label injection scopes node patterns
# 3. `_serialize_graph()` — post-query filter drops nodes without the provider label
# 4. `USING QUERY:TIMEOUTMILLISECONDS` on Neptune — server-side runaway cutoff
#
# Layer 2 is best-effort (regex can't fully parse Cypher);
# layer 3 is the safety net that guarantees provider isolation.
validate_custom_query(cypher)
cypher = inject_provider_label(cypher, provider_id)
# TODO: drop after Neptune cutover
backend = sink_module.get_backend_for_scan(scan)
# Neptune enforces a cluster-level query timeout; prepending the hint
# makes the limit explicit and matches the client-side read timeout.
# Applies only when the scan's graph lives in Neptune.
if getattr(scan, "sink_backend", None) == "neptune":
timeout_ms = _custom_query_timeout_ms()
cypher = f"USING QUERY:TIMEOUTMILLISECONDS {timeout_ms}\n{cypher}"
try:
graph = graph_database.execute_read_query(
database=database_name,
cypher=cypher,
)
graph = backend.execute_read_query(database_name, cypher, None)
serialized = _serialize_graph(graph, provider_id)
return _truncate_graph(serialized)
@@ -180,10 +196,11 @@ def execute_custom_query(
def get_cartography_schema(
database_name: str, provider_id: str
database_name: str, provider_id: str, scan: AttackPathsScan
) -> dict[str, str] | None:
try:
with graph_database.get_session(
backend = sink_module.get_backend_for_scan(scan)
with backend.get_session(
database_name, default_access_mode=neo4j.READ_ACCESS
) as session:
result = session.run(get_cartography_schema_query(provider_id))
+43 -9
View File
@@ -1,11 +1,14 @@
from math import isfinite
from uuid import UUID
from api.db_router import MainRouter
from api.models import TenantAPIKey, TenantAPIKeyManager
from cryptography.fernet import InvalidToken
from django.core.exceptions import ObjectDoesNotExist
from django.utils import timezone
from drf_simple_apikey.backends import APIKeyAuthentication as BaseAPIKeyAuth
from drf_simple_apikey.crypto import get_crypto
from drf_simple_apikey.settings import package_settings
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.request import Request
@@ -21,18 +24,49 @@ class TenantAPIKeyAuthentication(BaseAPIKeyAuth):
def _authenticate_credentials(self, request, key):
"""
Override to use admin connection, bypassing RLS during authentication.
Delegates to parent after temporarily routing model queries to admin DB.
"""
# Temporarily point the model's manager to admin database
original_objects = self.model.objects
self.model.objects = self.model.objects.using(MainRouter.admin_db)
try:
payload = self.key_crypto.decrypt(key)
except ValueError:
raise AuthenticationFailed("Invalid API Key.")
if not isinstance(payload, dict):
raise AuthenticationFailed("Invalid API Key.")
payload_pk = payload.get("_pk")
payload_exp = payload.get("_exp")
if (
not isinstance(payload_pk, str)
or isinstance(payload_exp, bool)
or not isinstance(payload_exp, (int, float))
or not isfinite(payload_exp)
):
raise AuthenticationFailed("Invalid API Key.")
try:
# Call parent method which will now use admin database
return super()._authenticate_credentials(request, key)
finally:
# Restore original manager
self.model.objects = original_objects
api_key_pk = UUID(payload_pk)
except ValueError:
raise AuthenticationFailed("Invalid API Key.")
if payload_exp < timezone.now().timestamp():
raise AuthenticationFailed("API Key has already expired.")
try:
api_key = self.model.objects.using(MainRouter.admin_db).get(id=api_key_pk)
except ObjectDoesNotExist:
raise AuthenticationFailed("No entity matching this api key.")
if api_key.revoked:
raise AuthenticationFailed("This API Key has been revoked.")
client_ip = request.META.get(package_settings.IP_ADDRESS_HEADER)
if api_key.blacklisted_ips and client_ip in api_key.blacklisted_ips:
raise AuthenticationFailed("Access denied from blacklisted IP.")
if api_key.whitelisted_ips and client_ip not in api_key.whitelisted_ips:
raise AuthenticationFailed("Access restricted to specific IP addresses.")
return api_key.entity, key
def authenticate(self, request: Request):
prefixed_key = self.get_key(request)
+206 -59
View File
@@ -67,6 +67,7 @@ from django_filters.rest_framework import (
)
from rest_framework_json_api.django_filters.backends import DjangoFilterBackend
from rest_framework_json_api.serializers import ValidationError
from uuid6 import UUID
class CustomDjangoFilterBackend(DjangoFilterBackend):
@@ -672,35 +673,32 @@ class LatestResourceFilter(ProviderRelationshipFilterSet):
return queryset.filter(tags__text_search=value)
class FindingFilter(CommonFindingFilters):
FINDING_BASE_FILTER_FIELDS = {
"id": ["exact", "in"],
"uid": ["exact", "in"],
"scan": ["exact", "in"],
"delta": ["exact", "in"],
"status": ["exact", "in"],
"severity": ["exact", "in"],
"impact": ["exact", "in"],
"check_id": ["exact", "in", "icontains"],
}
class BaseFindingFilter(CommonFindingFilters):
DATE_FILTER_FIELDS = ()
DATE_FILTER_NAMES = ()
DATE_RANGE_HELP_TEXT = (
f"Maximum date range is {settings.FINDINGS_MAX_DAYS_IN_RANGE} days."
)
DATE_FILTER_REQUIRED_DETAIL = "At least one date filter is required."
scan = UUIDFilter(method="filter_scan_id")
scan__in = UUIDInFilter(method="filter_scan_id_in")
inserted_at = DateFilter(method="filter_inserted_at", lookup_expr="date")
inserted_at__date = DateFilter(method="filter_inserted_at", lookup_expr="date")
inserted_at__gte = DateFilter(
method="filter_inserted_at_gte",
help_text=f"Maximum date range is {settings.FINDINGS_MAX_DAYS_IN_RANGE} days.",
)
inserted_at__lte = DateFilter(
method="filter_inserted_at_lte",
help_text=f"Maximum date range is {settings.FINDINGS_MAX_DAYS_IN_RANGE} days.",
)
class Meta:
model = Finding
fields = {
"id": ["exact", "in"],
"uid": ["exact", "in"],
"scan": ["exact", "in"],
"delta": ["exact", "in"],
"status": ["exact", "in"],
"severity": ["exact", "in"],
"impact": ["exact", "in"],
"check_id": ["exact", "in", "icontains"],
"inserted_at": ["date", "gte", "lte"],
"updated_at": ["gte", "lte"],
}
fields = FINDING_BASE_FILTER_FIELDS
filter_overrides = {
FindingDeltaEnumField: {
"filter_class": CharFilter,
@@ -723,17 +721,13 @@ class FindingFilter(CommonFindingFilters):
return queryset.filter(resource_services__contains=[value])
def filter_queryset(self, queryset):
if not (self.data.get("scan") or self.data.get("scan__in")) and not (
self.data.get("inserted_at")
or self.data.get("inserted_at__date")
or self.data.get("inserted_at__gte")
or self.data.get("inserted_at__lte")
if not (self.data.get("scan") or self.data.get("scan__in")) and not any(
self.data.get(filter_name) for filter_name in self.DATE_FILTER_NAMES
):
raise ValidationError(
[
{
"detail": "At least one date filter is required: filter[inserted_at], filter[inserted_at.gte], "
"or filter[inserted_at.lte].",
"detail": self.DATE_FILTER_REQUIRED_DETAIL,
"status": 400,
"source": {"pointer": "/data/attributes/inserted_at"},
"code": "required",
@@ -742,31 +736,42 @@ class FindingFilter(CommonFindingFilters):
)
cleaned = self.form.cleaned_data
exact_date = cleaned.get("inserted_at") or cleaned.get("inserted_at__date")
gte_date = cleaned.get("inserted_at__gte") or exact_date
lte_date = cleaned.get("inserted_at__lte") or exact_date
if gte_date is None:
gte_date = datetime.now(UTC).date()
if lte_date is None:
lte_date = datetime.now(UTC).date()
if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
):
raise ValidationError(
[
{
"detail": f"The date range cannot exceed {settings.FINDINGS_MAX_DAYS_IN_RANGE} days.",
"status": 400,
"source": {"pointer": "/data/attributes/inserted_at"},
"code": "invalid",
}
]
)
for field_name in self.DATE_FILTER_FIELDS:
self.validate_datetime_filter_range(cleaned, field_name)
return super().filter_queryset(queryset)
def validate_datetime_filter_range(self, cleaned, field_name):
exact_value = cleaned.get(field_name) or cleaned.get(f"{field_name}__date")
gte_value = cleaned.get(f"{field_name}__gte") or exact_value
lte_value = cleaned.get(f"{field_name}__lte") or exact_value
if not (exact_value or gte_value or lte_value):
return
default_value = datetime.now(UTC).date()
gte_value = gte_value or default_value
lte_value = lte_value or default_value
gte_datetime = self.filter_value_to_datetime(gte_value, field_name)
lte_datetime = self.filter_value_to_datetime(lte_value, field_name)
if abs(lte_datetime - gte_datetime) <= timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
):
return
raise ValidationError(
[
{
"detail": f"The date range cannot exceed {settings.FINDINGS_MAX_DAYS_IN_RANGE} days.",
"status": 400,
"source": {"pointer": f"/data/attributes/{field_name}"},
"code": "invalid",
}
]
)
# Convert filter values to UUIDv7 values for use with partitioning
def filter_scan_id(self, queryset, name, value):
try:
@@ -824,27 +829,169 @@ class FindingFilter(CommonFindingFilters):
datetime_value = self.maybe_date_to_datetime(value)
start = uuid7_start(datetime_to_uuid7(datetime_value))
end = uuid7_start(datetime_to_uuid7(datetime_value + timedelta(days=1)))
return queryset.filter(id__gte=start, id__lt=end)
def filter_inserted_at_gte(self, queryset, name, value):
datetime_value = self.maybe_date_to_datetime(value)
start = uuid7_start(datetime_to_uuid7(datetime_value))
return queryset.filter(id__gte=start)
def filter_inserted_at_lte(self, queryset, name, value):
datetime_value = self.maybe_date_to_datetime(value)
end = uuid7_start(datetime_to_uuid7(datetime_value + timedelta(days=1)))
return queryset.filter(id__lt=end)
@staticmethod
def maybe_date_to_datetime(value):
dt = value
if isinstance(value, datetime):
return value
if isinstance(value, date):
dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
return dt
return datetime.combine(value, datetime.min.time(), tzinfo=UTC)
if isinstance(value, str):
return parse(value)
return value
@classmethod
def filter_value_to_datetime(cls, value, field_name):
try:
datetime_value = cls.maybe_date_to_datetime(value)
except (TypeError, ValueError, OverflowError):
raise ValidationError(
[
{
"detail": "Enter a valid date or datetime.",
"status": 400,
"source": {"pointer": f"/data/attributes/{field_name}"},
"code": "invalid",
}
]
)
if datetime_value.tzinfo is None:
return datetime_value.replace(tzinfo=UTC)
return datetime_value.astimezone(UTC)
class FindingFilter(BaseFindingFilter):
DATE_FILTER_FIELDS = ("inserted_at", "updated_at")
DATE_FILTER_NAMES = (
"inserted_at",
"inserted_at__date",
"inserted_at__gte",
"inserted_at__lte",
"updated_at",
"updated_at__date",
"updated_at__gte",
"updated_at__lte",
)
DATE_FILTER_REQUIRED_DETAIL = (
"At least one date filter is required: filter[inserted_at], filter[updated_at], "
"filter[inserted_at.gte], filter[updated_at.gte], filter[inserted_at.lte], "
"or filter[updated_at.lte]."
)
inserted_at = CharFilter(method="filter_inserted_at")
inserted_at__date = DateFilter(method="filter_inserted_at", lookup_expr="date")
inserted_at__gte = CharFilter(
method="filter_inserted_at",
help_text=BaseFindingFilter.DATE_RANGE_HELP_TEXT,
)
inserted_at__lte = CharFilter(
method="filter_inserted_at",
help_text=BaseFindingFilter.DATE_RANGE_HELP_TEXT,
)
updated_at = CharFilter(method="filter_updated_at")
updated_at__date = DateFilter(method="filter_updated_at", lookup_expr="date")
updated_at__gte = CharFilter(
method="filter_updated_at",
help_text=BaseFindingFilter.DATE_RANGE_HELP_TEXT,
)
updated_at__lte = CharFilter(
method="filter_updated_at",
help_text=BaseFindingFilter.DATE_RANGE_HELP_TEXT,
)
class Meta(BaseFindingFilter.Meta):
fields = FINDING_BASE_FILTER_FIELDS | {
"inserted_at": ["date", "gte", "lte"],
"updated_at": ["date", "gte", "lte"],
}
def filter_inserted_at(self, queryset, name, value):
start, end = self.filter_value_to_datetime_bounds(value, "inserted_at")
if name.endswith("__gte"):
return queryset.filter(id__gte=self.datetime_to_uuid7_boundary(start))
if name.endswith("__lte"):
return queryset.filter(id__lt=self.datetime_to_uuid7_boundary(end))
return queryset.filter(
id__gte=self.datetime_to_uuid7_boundary(start),
id__lt=self.datetime_to_uuid7_boundary(end),
)
def filter_updated_at(self, queryset, name, value):
start, end = self.filter_value_to_datetime_bounds(value, "updated_at")
if name.endswith("__gte"):
return queryset.filter(updated_at__gte=start)
if name.endswith("__lte"):
return queryset.filter(updated_at__lt=end)
return queryset.filter(updated_at__gte=start, updated_at__lt=end)
@classmethod
def filter_value_to_datetime_bounds(cls, value, field_name):
start = cls.filter_value_to_datetime(value, field_name)
if cls.is_date_filter_value(value):
return start, start + timedelta(days=1)
return start, start + timedelta(milliseconds=1)
@staticmethod
def datetime_to_uuid7_boundary(datetime_value):
timestamp_ms = int(datetime_value.timestamp() * 1000) & 0xFFFFFFFFFFFF
uuid_int = timestamp_ms << 80
uuid_int |= 0x7 << 76
uuid_int |= 0x2 << 62
return UUID(int=uuid_int)
@staticmethod
def is_date_filter_value(value):
if isinstance(value, datetime):
return False
if isinstance(value, date):
return True
return isinstance(value, str) and len(value.strip()) == 10
class FindingMetadataFilter(BaseFindingFilter):
DATE_FILTER_FIELDS = ("inserted_at",)
DATE_FILTER_NAMES = (
"inserted_at",
"inserted_at__date",
"inserted_at__gte",
"inserted_at__lte",
)
DATE_FILTER_REQUIRED_DETAIL = (
"At least one date filter is required: filter[inserted_at], filter[inserted_at.gte], "
"or filter[inserted_at.lte]."
)
inserted_at = DateFilter(method="filter_inserted_at", lookup_expr="date")
inserted_at__date = DateFilter(method="filter_inserted_at", lookup_expr="date")
inserted_at__gte = DateFilter(
method="filter_inserted_at_gte",
help_text=BaseFindingFilter.DATE_RANGE_HELP_TEXT,
)
inserted_at__lte = DateFilter(
method="filter_inserted_at_lte",
help_text=BaseFindingFilter.DATE_RANGE_HELP_TEXT,
)
class Meta(BaseFindingFilter.Meta):
fields = FINDING_BASE_FILTER_FIELDS | {
"inserted_at": ["date", "gte", "lte"],
}
class LatestFindingFilter(CommonFindingFilters):
+53 -14
View File
@@ -2,8 +2,9 @@
Format (draft-inadarei-api-health-check-06).
Liveness reports only process status. Readiness verifies that PostgreSQL,
Valkey and Neo4j are reachable and returns per-dependency detail when any
of them is unreachable.
Valkey and the attack-paths graph store (Neo4j or Neptune, per
``ATTACK_PATHS_SINK_DATABASE``) are reachable and returns per-dependency
detail when any of them is unreachable.
"""
from __future__ import annotations
@@ -11,6 +12,8 @@ from __future__ import annotations
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from contextlib import suppress
from datetime import UTC, datetime
from typing import Any
@@ -37,9 +40,28 @@ STATUS_FAIL = "fail"
STATUS_WARN = "warn"
# Short socket timeout so a stuck Valkey cannot stall the probe.
# Neo4j inherits its driver-level ``connection_acquisition_timeout``.
VALKEY_PROBE_TIMEOUT_SECONDS = 2
# Probe-scoped budget for the graph database.
# ``Driver.verify_connectivity()`` takes no timeout; its only bound is the
# driver-level ``connection_acquisition_timeout`` (60s on Neptune). The
# probe needs its own budget, independent of the workload driver, so a
# graph-database outage cannot pin a worker thread (and the readiness lock)
# for a minute.
GRAPH_DB_PROBE_TIMEOUT_SECONDS = 5
# Bounded pool that enforces ``GRAPH_DB_PROBE_TIMEOUT_SECONDS``. If the
# graph database is unreachable the probe call blocks until the driver's
# own acquisition timeout fires; we abandon the future after the budget and
# report ``fail``. Orphaned tasks are capped by ``max_workers`` plus the 3s
# readiness cache plus the per-IP throttle, so they cannot pile up: worst
# case during a graph-database outage is every readiness call failing fast
# in ``GRAPH_DB_PROBE_TIMEOUT_SECONDS`` with at most 2 background threads
# stuck for <= the driver acquisition timeout.
_graph_db_probe_executor = ThreadPoolExecutor(
max_workers=2, thread_name_prefix="health-graph-db-probe"
)
# Brief cache window so high-frequency probes (ALB target groups, scrapers)
# do not stampede the actual dependency checks.
CACHE_CONTROL_HEADER = "max-age=3, must-revalidate"
@@ -109,11 +131,24 @@ def _probe_valkey() -> None:
client.close()
def _probe_neo4j() -> None:
# Lazy import: avoids pulling attack_paths into the boot import graph.
from api.attack_paths.database import get_driver
def _graph_db_component_id() -> str:
"""Return the active graph database name for the ``componentId`` field."""
return settings.ATTACK_PATHS_SINK_DATABASE.strip().lower()
get_driver().verify_connectivity()
def _probe_graph_db() -> None:
# Lazy import: avoids pulling attack_paths into the boot import graph
from api.attack_paths.database import verify_connectivity
future = _graph_db_probe_executor.submit(verify_connectivity)
try:
future.result(timeout=GRAPH_DB_PROBE_TIMEOUT_SECONDS)
except FuturesTimeoutError as exc:
# Do not wait for the abandoned task; it ends when the driver's own acquisition timeout fires
future.cancel()
raise TimeoutError(
f"graph-db probe exceeded {GRAPH_DB_PROBE_TIMEOUT_SECONDS}s"
) from exc
def _build_check_entry(
@@ -176,14 +211,18 @@ def _readiness_payload() -> tuple[dict[str, Any], int]:
):
return snapshot[1], snapshot[2]
graph_db_component_id = _graph_db_component_id()
postgres_result, postgres_ms = _measure("postgres", _probe_postgres)
valkey_result, valkey_ms = _measure("valkey", _probe_valkey)
neo4j_result, neo4j_ms = _measure("neo4j", _probe_neo4j)
graph_db_result, graph_db_ms = _measure(graph_db_component_id, _probe_graph_db)
entries = [
_build_check_entry("postgres", "datastore", postgres_result, postgres_ms),
_build_check_entry("valkey", "datastore", valkey_result, valkey_ms),
_build_check_entry("neo4j", "datastore", neo4j_result, neo4j_ms),
_build_check_entry(
graph_db_component_id, "datastore", graph_db_result, graph_db_ms
),
]
overall = _aggregate_status(entries)
@@ -191,7 +230,7 @@ def _readiness_payload() -> tuple[dict[str, Any], int]:
payload["checks"] = {
"postgres:responseTime": [entries[0]],
"valkey:responseTime": [entries[1]],
"neo4j:responseTime": [entries[2]],
"graphdb:responseTime": [entries[2]],
}
http_status = (
@@ -233,10 +272,10 @@ class LivenessView(APIView):
class ReadinessView(APIView):
"""Readiness probe.
Returns 200 when PostgreSQL, Valkey and Neo4j all respond, or 503 with
per-dependency detail when any of them is unreachable. Per-IP throttle
plus the short in-process result cache cap the real dependency hits
regardless of inbound traffic shape.
Returns 200 when PostgreSQL, Valkey and the attack-paths graph store
all respond, or 503 with per-dependency detail when any of them is
unreachable. Per-IP throttle plus the short in-process result cache cap
the real dependency hits regardless of inbound traffic shape.
"""
authentication_classes: list = []
@@ -0,0 +1,24 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0095_reconcile_orphan_tasks_periodic_task"),
]
operations = [
migrations.AddField(
model_name="attackpathsscan",
name="is_migrated",
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name="attackpathsscan",
name="sink_backend",
field=models.CharField(
choices=[("neo4j", "Neo4j"), ("neptune", "Neptune")],
default="neo4j",
max_length=16,
),
),
]
+16
View File
@@ -757,6 +757,10 @@ class Scan(RowLevelSecurityProtectedModel):
class AttackPathsScan(RowLevelSecurityProtectedModel):
class SinkBackendChoices(models.TextChoices):
NEO4J = "neo4j", "Neo4j"
NEPTUNE = "neptune", "Neptune"
objects = ActiveProviderManager()
all_objects = models.Manager()
@@ -805,6 +809,18 @@ class AttackPathsScan(RowLevelSecurityProtectedModel):
)
ingestion_exceptions = models.JSONField(default=dict, null=True, blank=True)
# True when the scan was synced with the current schema (list-typed
# properties materialised as child item nodes). False for pre-cutover scans
# still using the previous graph shape. Query catalog selection uses this
# flag; physical read routing uses sink_backend below.
# TODO: drop after Neptune cutover
is_migrated = models.BooleanField(default=False)
sink_backend = models.CharField(
choices=SinkBackendChoices.choices,
default=SinkBackendChoices.NEO4J,
max_length=16,
)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "attack_paths_scans"
File diff suppressed because it is too large Load Diff
+138 -71
View File
@@ -92,7 +92,9 @@ def test_prepare_parameters_validates_cast(
def test_execute_query_serializes_graph(
attack_paths_query_definition_factory, attack_paths_graph_stub_classes
attack_paths_query_definition_factory,
attack_paths_graph_stub_classes,
sink_backend_stub,
):
definition = attack_paths_query_definition_factory(
id="aws-rds",
@@ -135,18 +137,17 @@ def test_execute_query_serializes_graph(
database_name = "db-tenant-test-tenant-id"
with patch(
"api.attack_paths.views_helpers.graph_database.execute_read_query",
return_value=graph_result,
) as mock_execute_read_query:
result = views_helpers.execute_query(
database_name, definition, parameters, provider_id=provider_id
)
sink_backend_stub.execute_read_query.return_value = graph_result
result = views_helpers.execute_query(
database_name,
definition,
parameters,
provider_id=provider_id,
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
)
mock_execute_read_query.assert_called_once_with(
database=database_name,
cypher=definition.cypher,
parameters=parameters,
sink_backend_stub.execute_read_query.assert_called_once_with(
database_name, definition.cypher, parameters
)
assert result["nodes"][0]["id"] == "node-1"
assert result["nodes"][0]["properties"]["complex"]["items"][0] == "value"
@@ -155,6 +156,7 @@ def test_execute_query_serializes_graph(
def test_execute_query_wraps_graph_errors(
attack_paths_query_definition_factory,
sink_backend_stub,
):
definition = attack_paths_query_definition_factory(
id="aws-rds",
@@ -167,16 +169,17 @@ def test_execute_query_wraps_graph_errors(
database_name = "db-tenant-test-tenant-id"
parameters = {"provider_uid": "123"}
with (
patch(
"api.attack_paths.views_helpers.graph_database.execute_read_query",
side_effect=graph_database.GraphDatabaseQueryException("boom"),
),
patch("api.attack_paths.views_helpers.logger") as mock_logger,
):
sink_backend_stub.execute_read_query.side_effect = (
graph_database.GraphDatabaseQueryException("boom")
)
with patch("api.attack_paths.views_helpers.logger") as mock_logger:
with pytest.raises(APIException):
views_helpers.execute_query(
database_name, definition, parameters, provider_id="test-provider-123"
database_name,
definition,
parameters,
provider_id="test-provider-123",
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
)
mock_logger.error.assert_called_once()
@@ -184,6 +187,7 @@ def test_execute_query_wraps_graph_errors(
def test_execute_query_raises_permission_denied_on_read_only(
attack_paths_query_definition_factory,
sink_backend_stub,
):
definition = attack_paths_query_definition_factory(
id="aws-rds",
@@ -196,17 +200,20 @@ def test_execute_query_raises_permission_denied_on_read_only(
database_name = "db-tenant-test-tenant-id"
parameters = {"provider_uid": "123"}
with patch(
"api.attack_paths.views_helpers.graph_database.execute_read_query",
side_effect=graph_database.WriteQueryNotAllowedException(
sink_backend_stub.execute_read_query.side_effect = (
graph_database.WriteQueryNotAllowedException(
message="Read query not allowed",
code="Neo.ClientError.Statement.AccessMode",
),
):
with pytest.raises(PermissionDenied):
views_helpers.execute_query(
database_name, definition, parameters, provider_id="test-provider-123"
)
)
)
with pytest.raises(PermissionDenied):
views_helpers.execute_query(
database_name,
definition,
parameters,
provider_id="test-provider-123",
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
)
def test_serialize_graph_filters_by_provider_label(attack_paths_graph_stub_classes):
@@ -440,6 +447,7 @@ def test_normalize_custom_query_payload_passthrough_for_flat_dict():
def test_execute_custom_query_serializes_graph(
attack_paths_graph_stub_classes,
sink_backend_stub,
):
provider_id = "test-provider-123"
plabel = get_provider_label(provider_id)
@@ -453,50 +461,73 @@ def test_execute_custom_query_serializes_graph(
graph_result.nodes = [node_1, node_2]
graph_result.relationships = [relationship]
with patch(
"api.attack_paths.views_helpers.graph_database.execute_read_query",
return_value=graph_result,
) as mock_execute:
result = views_helpers.execute_custom_query(
"db-tenant-test", "MATCH (n) RETURN n", provider_id
)
sink_backend_stub.execute_read_query.return_value = graph_result
result = views_helpers.execute_custom_query(
"db-tenant-test",
"MATCH (n) RETURN n",
provider_id,
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
)
mock_execute.assert_called_once()
call_kwargs = mock_execute.call_args[1]
assert call_kwargs["database"] == "db-tenant-test"
sink_backend_stub.execute_read_query.assert_called_once()
call_args = sink_backend_stub.execute_read_query.call_args[0]
assert call_args[0] == "db-tenant-test"
# The cypher is rewritten with the provider label injection
assert plabel in call_kwargs["cypher"]
assert plabel in call_args[1]
assert len(result["nodes"]) == 2
assert result["relationships"][0]["label"] == "OWNS"
assert result["truncated"] is False
assert result["total_nodes"] == 2
def test_execute_custom_query_raises_permission_denied_on_write():
def test_execute_custom_query_adds_timeout_for_neptune_scan(sink_backend_stub):
graph_result = MagicMock()
graph_result.nodes = []
graph_result.relationships = []
sink_backend_stub.execute_read_query.return_value = graph_result
with patch(
"api.attack_paths.views_helpers.graph_database.execute_read_query",
side_effect=graph_database.WriteQueryNotAllowedException(
"api.attack_paths.views_helpers.sink_module.get_backend_for_scan",
return_value=sink_backend_stub,
):
views_helpers.execute_custom_query(
"db-tenant-test",
"MATCH (n) RETURN n",
"provider-1",
scan=MagicMock(is_migrated=True, sink_backend="neptune"),
)
cypher = sink_backend_stub.execute_read_query.call_args[0][1]
assert cypher.startswith("USING QUERY:TIMEOUTMILLISECONDS")
def test_execute_custom_query_raises_permission_denied_on_write(sink_backend_stub):
sink_backend_stub.execute_read_query.side_effect = (
graph_database.WriteQueryNotAllowedException(
message="Read query not allowed",
code="Neo.ClientError.Statement.AccessMode",
),
):
with pytest.raises(PermissionDenied):
views_helpers.execute_custom_query(
"db-tenant-test", "CREATE (n) RETURN n", "provider-1"
)
)
)
with pytest.raises(PermissionDenied):
views_helpers.execute_custom_query(
"db-tenant-test",
"CREATE (n) RETURN n",
"provider-1",
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
)
def test_execute_custom_query_wraps_graph_errors():
with (
patch(
"api.attack_paths.views_helpers.graph_database.execute_read_query",
side_effect=graph_database.GraphDatabaseQueryException("boom"),
),
patch("api.attack_paths.views_helpers.logger") as mock_logger,
):
def test_execute_custom_query_wraps_graph_errors(sink_backend_stub):
sink_backend_stub.execute_read_query.side_effect = (
graph_database.GraphDatabaseQueryException("boom")
)
with patch("api.attack_paths.views_helpers.logger") as mock_logger:
with pytest.raises(APIException):
views_helpers.execute_custom_query(
"db-tenant-test", "MATCH (n) RETURN n", "provider-1"
"db-tenant-test",
"MATCH (n) RETURN n",
"provider-1",
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
)
mock_logger.error.assert_called_once()
@@ -561,13 +592,33 @@ def test_truncate_graph_empty_graph():
@pytest.fixture
def mock_neo4j_session():
"""Mock the Neo4j driver so execute_read_query uses a fake session."""
"""Install a Neo4jSink with a mocked Bolt driver into the sink factory.
The yielded mock is the `neo4j.Session` that the Neo4jSink will obtain via
`driver.session(...)`. Tests configure `mock_neo4j_session.execute_read`
return values / side effects to exercise the read-mode error translation
path on the real `Neo4jSink.execute_read_query` and `get_session` code.
"""
from api.attack_paths.sink import factory
from api.attack_paths.sink.neo4j import Neo4jSink
mock_session = MagicMock(spec=neo4j.Session)
mock_driver = MagicMock(spec=neo4j.Driver)
mock_driver.session.return_value = mock_session
with patch("api.attack_paths.database.get_driver", return_value=mock_driver):
sink = Neo4jSink()
sink._driver = mock_driver
previous_backend = factory._backend
previous_secondary = dict(factory._secondary_backends)
factory._backend = sink
factory._secondary_backends.clear()
try:
yield mock_session
finally:
factory._backend = previous_backend
factory._secondary_backends.clear()
factory._secondary_backends.update(previous_secondary)
def test_execute_read_query_succeeds_with_select(mock_neo4j_session):
@@ -663,16 +714,20 @@ def test_execute_read_query_rejects_apoc_real_create(mock_neo4j_session, cypher)
@pytest.fixture
def mock_schema_session():
"""Mock get_session for cartography schema tests."""
"""Mock the routed sink backend session for cartography schema tests."""
mock_result = MagicMock()
mock_session = MagicMock()
mock_session.run.return_value = mock_result
mock_backend = MagicMock()
with patch(
"api.attack_paths.views_helpers.graph_database.get_session"
) as mock_get_session:
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
"api.attack_paths.views_helpers.sink_module.get_backend_for_scan",
return_value=mock_backend,
):
mock_backend.get_session.return_value.__enter__ = MagicMock(
return_value=mock_session
)
mock_backend.get_session.return_value.__exit__ = MagicMock(return_value=False)
yield mock_session, mock_result
@@ -683,7 +738,9 @@ def test_get_cartography_schema_returns_urls(mock_schema_session):
"module_version": "0.129.0",
}
result = views_helpers.get_cartography_schema("db-tenant-test", "provider-123")
result = views_helpers.get_cartography_schema(
"db-tenant-test", "provider-123", MagicMock(sink_backend="neo4j")
)
mock_session.run.assert_called_once()
assert result["id"] == "aws-0.129.0"
@@ -699,7 +756,9 @@ def test_get_cartography_schema_returns_none_when_no_data(mock_schema_session):
_, mock_result = mock_schema_session
mock_result.single.return_value = None
result = views_helpers.get_cartography_schema("db-tenant-test", "provider-123")
result = views_helpers.get_cartography_schema(
"db-tenant-test", "provider-123", MagicMock(sink_backend="neo4j")
)
assert result is None
@@ -721,21 +780,29 @@ def test_get_cartography_schema_extracts_provider(
"module_version": "1.0.0",
}
result = views_helpers.get_cartography_schema("db-tenant-test", "provider-123")
result = views_helpers.get_cartography_schema(
"db-tenant-test", "provider-123", MagicMock(sink_backend="neo4j")
)
assert result["id"] == f"{expected_provider}-1.0.0"
assert result["provider"] == expected_provider
def test_get_cartography_schema_wraps_database_error():
mock_backend = MagicMock()
mock_backend.get_session.side_effect = graph_database.GraphDatabaseQueryException(
"boom"
)
with (
patch(
"api.attack_paths.views_helpers.graph_database.get_session",
side_effect=graph_database.GraphDatabaseQueryException("boom"),
"api.attack_paths.views_helpers.sink_module.get_backend_for_scan",
return_value=mock_backend,
),
patch("api.attack_paths.views_helpers.logger") as mock_logger,
):
with pytest.raises(APIException):
views_helpers.get_cartography_schema("db-tenant-test", "provider-123")
views_helpers.get_cartography_schema(
"db-tenant-test", "provider-123", MagicMock(sink_backend="neo4j")
)
mock_logger.error.assert_called_once()
@@ -1,623 +1,240 @@
"""
Tests for Neo4j database lazy initialization.
"""Tests for the attack-paths database facade.
The Neo4j driver is created on first use for every process type; app startup
never contacts Neo4j. These tests validate the database module behavior itself.
After the Neptune port, `api.attack_paths.database` is a thin routing shim
over `api.attack_paths.ingest` (cartography temp DB, always Neo4j) and
`api.attack_paths.sink` (configurable Neo4j or Neptune). The facade's
contract is routing by database-name prefix and the public exception
hierarchy; sink-internal behavior is exercised in `test_sink.py`.
"""
import threading
from unittest.mock import MagicMock, patch
import api.attack_paths.database as db_module
import neo4j
import neo4j.exceptions
import pytest
class TestLazyInitialization:
"""Test that Neo4j driver is initialized lazily on first use."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
def test_driver_not_initialized_at_import(self):
"""Driver should be None after module import (no eager connection)."""
assert db_module._driver is None
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_init_driver_creates_connection_on_first_call(
self, mock_driver_factory, mock_settings
):
"""init_driver() should create connection only when called."""
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
assert db_module._driver is None
result = db_module.init_driver()
mock_driver_factory.assert_called_once()
mock_driver.verify_connectivity.assert_called_once()
assert result is mock_driver
assert db_module._driver is mock_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_init_driver_leaves_driver_none_when_verify_fails(
self, mock_driver_factory, mock_settings
):
"""A failed verify_connectivity() must not publish or leak the driver."""
mock_driver = MagicMock()
mock_driver.verify_connectivity.side_effect = (
neo4j.exceptions.ServiceUnavailable("down")
class TestDatabaseNameHelper:
def test_tenant_name_lowercases_uuid(self):
assert (
db_module.get_database_name("ABC-123", temporary=False)
== "db-tenant-abc-123"
)
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
with pytest.raises(neo4j.exceptions.ServiceUnavailable):
db_module.init_driver()
assert db_module._driver is None
mock_driver.close.assert_called_once()
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_init_driver_returns_cached_driver_on_subsequent_calls(
self, mock_driver_factory, mock_settings
):
"""Subsequent calls should return cached driver without reconnecting."""
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
first_result = db_module.init_driver()
second_result = db_module.init_driver()
third_result = db_module.init_driver()
# Only one connection attempt
assert mock_driver_factory.call_count == 1
assert mock_driver.verify_connectivity.call_count == 1
# All calls return same instance
assert first_result is second_result is third_result
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_get_driver_delegates_to_init_driver(
self, mock_driver_factory, mock_settings
):
"""get_driver() should use init_driver() for lazy initialization."""
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
result = db_module.get_driver()
assert result is mock_driver
mock_driver_factory.assert_called_once()
def test_temporary_name_uses_tmp_scan_prefix(self):
assert (
db_module.get_database_name("XYZ-789", temporary=True)
== "db-tmp-scan-xyz-789"
)
class TestConnectionAcquisitionTimeout:
"""Test that the connection acquisition timeout is configurable."""
class TestExceptionHierarchy:
"""`tasks/` and `api/v1/views.py` import these from the facade."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
original_driver = db_module._driver
original_acq_timeout = db_module.CONN_ACQUISITION_TIMEOUT
original_conn_timeout = db_module.CONNECTION_TIMEOUT
def test_write_query_is_graph_database_exception(self):
assert issubclass(
db_module.WriteQueryNotAllowedException,
db_module.GraphDatabaseQueryException,
)
db_module._driver = None
def test_client_statement_is_graph_database_exception(self):
assert issubclass(
db_module.ClientStatementException, db_module.GraphDatabaseQueryException
)
yield
def test_exception_str_includes_code_when_set(self):
exc = db_module.GraphDatabaseQueryException(
message="boom", code="Neo.ClientError.X.Y"
)
assert str(exc) == "Neo.ClientError.X.Y: boom"
db_module._driver = original_driver
db_module.CONN_ACQUISITION_TIMEOUT = original_acq_timeout
db_module.CONNECTION_TIMEOUT = original_conn_timeout
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_driver_receives_configured_timeout(
self, mock_driver_factory, mock_settings
):
"""init_driver() should pass the configured timeouts to the neo4j driver."""
mock_driver_factory.return_value = MagicMock()
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
db_module.CONN_ACQUISITION_TIMEOUT = 42
db_module.CONNECTION_TIMEOUT = 7
db_module.init_driver()
_, kwargs = mock_driver_factory.call_args
assert kwargs["connection_acquisition_timeout"] == 42
assert kwargs["connection_timeout"] == 7
def test_exception_str_falls_back_to_message_without_code(self):
exc = db_module.GraphDatabaseQueryException(message="boom")
assert str(exc) == "boom"
class TestAtexitRegistration:
"""Test that atexit cleanup handler is registered correctly."""
class TestExecuteReadQueryRoutes:
def test_execute_read_query_delegates_to_sink(self, sink_backend_stub):
sink_backend_stub.execute_read_query.return_value = "graph"
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
original_driver = db_module._driver
result = db_module.execute_read_query(
"db-tenant-abc", "MATCH (n) RETURN n", {"provider_uid": "123"}
)
db_module._driver = None
sink_backend_stub.execute_read_query.assert_called_once_with(
"db-tenant-abc", "MATCH (n) RETURN n", {"provider_uid": "123"}
)
assert result == "graph"
yield
def test_execute_read_query_defaults_parameters_to_none(self, sink_backend_stub):
db_module.execute_read_query("db-tenant-abc", "MATCH (n) RETURN n")
db_module._driver = original_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.atexit.register")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_atexit_registered_on_first_init(
self, mock_driver_factory, mock_atexit_register, mock_settings
):
"""atexit.register should be called on first initialization."""
mock_driver_factory.return_value = MagicMock()
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
db_module.init_driver()
mock_atexit_register.assert_called_once_with(db_module.close_driver)
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.atexit.register")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_atexit_registered_only_once(
self, mock_driver_factory, mock_atexit_register, mock_settings
):
"""atexit.register should only be called once across multiple inits.
The double-checked locking on _driver ensures the atexit registration
block only executes once (when _driver is first created).
"""
mock_driver_factory.return_value = MagicMock()
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
db_module.init_driver()
db_module.init_driver()
db_module.init_driver()
# Only registered once because subsequent calls hit the fast path
assert mock_atexit_register.call_count == 1
sink_backend_stub.execute_read_query.assert_called_once_with(
"db-tenant-abc", "MATCH (n) RETURN n", None
)
class TestCloseDriver:
"""Test driver cleanup functionality."""
class TestScanDatabaseAvailability:
def test_verify_scan_databases_available_checks_ingest_and_sink(self):
with (
patch("api.attack_paths.database.ingest") as mock_ingest,
patch("api.attack_paths.database.get_driver") as mock_get_driver,
):
db_module.verify_scan_databases_available()
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
original_driver = db_module._driver
mock_ingest.get_driver.return_value.verify_connectivity.assert_called_once_with()
mock_get_driver.return_value.verify_connectivity.assert_called_once_with()
db_module._driver = None
yield
db_module._driver = original_driver
def test_close_driver_closes_and_clears_driver(self):
"""close_driver() should close the driver and set it to None."""
mock_driver = MagicMock()
db_module._driver = mock_driver
db_module.close_driver()
mock_driver.close.assert_called_once()
assert db_module._driver is None
def test_close_driver_handles_none_driver(self):
"""close_driver() should handle case where driver is None."""
db_module._driver = None
# Should not raise
db_module.close_driver()
assert db_module._driver is None
def test_close_driver_clears_driver_even_on_close_error(self):
"""Driver should be cleared even if close() raises an exception."""
mock_driver = MagicMock()
mock_driver.close.side_effect = Exception("Connection error")
db_module._driver = mock_driver
with pytest.raises(Exception, match="Connection error"):
db_module.close_driver()
# Driver should still be cleared
assert db_module._driver is None
class TestExecuteReadQuery:
"""Test read query execution helper."""
def test_execute_read_query_calls_read_session_and_returns_result(self):
tx = MagicMock()
expected_graph = MagicMock()
run_result = MagicMock()
run_result.graph.return_value = expected_graph
tx.run.return_value = run_result
session = MagicMock()
def execute_read_side_effect(fn):
return fn(tx)
session.execute_read.side_effect = execute_read_side_effect
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
) as mock_get_session:
result = db_module.execute_read_query(
"db-tenant-test-tenant-id",
"MATCH (n) RETURN n",
{"provider_uid": "123"},
def test_verify_scan_databases_available_raises_when_ingest_is_down(self):
with (
patch("api.attack_paths.database.ingest") as mock_ingest,
patch("api.attack_paths.database.get_driver"),
):
mock_ingest.get_driver.return_value.verify_connectivity.side_effect = (
RuntimeError("ingest down")
)
mock_get_session.assert_called_once_with(
"db-tenant-test-tenant-id",
default_access_mode=neo4j.READ_ACCESS,
with pytest.raises(RuntimeError) as exc:
db_module.verify_scan_databases_available()
assert "Attack Paths graph database unavailable before scan start" in str(
exc.value
)
session.execute_read.assert_called_once()
tx.run.assert_called_once_with(
"MATCH (n) RETURN n",
{"provider_uid": "123"},
timeout=db_module.READ_QUERY_TIMEOUT_SECONDS,
)
run_result.graph.assert_called_once_with()
assert result is expected_graph
assert "ingest Neo4j: ingest down" in str(exc.value)
def test_execute_read_query_defaults_parameters_to_empty_dict(self):
tx = MagicMock()
run_result = MagicMock()
run_result.graph.return_value = MagicMock()
tx.run.return_value = run_result
def test_verify_scan_databases_available_raises_when_sink_is_down(self, settings):
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
session = MagicMock()
session.execute_read.side_effect = lambda fn: fn(tx)
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
with (
patch("api.attack_paths.database.ingest"),
patch("api.attack_paths.database.get_driver") as mock_get_driver,
):
db_module.execute_read_query(
"db-tenant-test-tenant-id",
"MATCH (n) RETURN n",
mock_get_driver.return_value.verify_connectivity.side_effect = RuntimeError(
"writer down"
)
tx.run.assert_called_once_with(
"MATCH (n) RETURN n",
{},
timeout=db_module.READ_QUERY_TIMEOUT_SECONDS,
)
run_result.graph.assert_called_once_with()
with pytest.raises(RuntimeError) as exc:
db_module.verify_scan_databases_available()
assert "sink neptune: writer down" in str(exc.value)
class TestGetSessionReadOnly:
"""Test that get_session translates Neo4j read-mode errors."""
def test_verify_scan_databases_available_reports_both_failures(self, settings):
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
@pytest.fixture(autouse=True)
def reset_module_state(self):
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
@pytest.mark.parametrize(
"neo4j_code",
[
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
],
)
def test_get_session_raises_write_query_not_allowed(self, neo4j_code):
"""Read-mode Neo4j errors should raise `WriteQueryNotAllowedException`."""
mock_session = MagicMock()
neo4j_error = neo4j.exceptions.Neo4jError._hydrate_neo4j(
code=neo4j_code,
message="Write operations are not allowed",
)
mock_session.run.side_effect = neo4j_error
mock_driver = MagicMock()
mock_driver.session.return_value = mock_session
db_module._driver = mock_driver
with pytest.raises(db_module.WriteQueryNotAllowedException):
with db_module.get_session(
default_access_mode=neo4j.READ_ACCESS
) as session:
session.run("CREATE (n) RETURN n")
def test_get_session_raises_generic_exception_for_other_errors(self):
"""Non-read-mode Neo4j errors should raise GraphDatabaseQueryException."""
mock_session = MagicMock()
neo4j_error = neo4j.exceptions.Neo4jError._hydrate_neo4j(
code="Neo.ClientError.Statement.SyntaxError",
message="Invalid syntax",
)
mock_session.run.side_effect = neo4j_error
mock_driver = MagicMock()
mock_driver.session.return_value = mock_session
db_module._driver = mock_driver
with pytest.raises(db_module.GraphDatabaseQueryException):
with db_module.get_session(
default_access_mode=neo4j.READ_ACCESS
) as session:
session.run("INVALID CYPHER")
class TestThreadSafety:
"""Test thread-safe initialization."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_concurrent_init_creates_single_driver(
self, mock_driver_factory, mock_settings
):
"""Multiple threads calling init_driver() should create only one driver."""
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
results = []
errors = []
def call_init():
try:
result = db_module.init_driver()
results.append(result)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=call_init) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors, f"Threads raised errors: {errors}"
# Only one driver created
assert mock_driver_factory.call_count == 1
# All threads got the same driver instance
assert all(r is mock_driver for r in results)
assert len(results) == 10
class TestHasProviderData:
"""Test has_provider_data helper for checking provider nodes in Neo4j."""
def test_returns_true_when_nodes_exist(self):
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.single.return_value = MagicMock() # non-None record
mock_session.run.return_value = mock_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = mock_session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
with (
patch("api.attack_paths.database.ingest") as mock_ingest,
patch("api.attack_paths.database.get_driver") as mock_get_driver,
):
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is True
mock_session.run.assert_called_once()
def test_returns_false_when_no_nodes(self):
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.single.return_value = None
mock_session.run.return_value = mock_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = mock_session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is False
def test_returns_false_when_database_not_found(self):
session_ctx = MagicMock()
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
message="Database does not exist",
code="Neo.ClientError.Database.DatabaseNotFound",
)
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
assert (
db_module.has_provider_data("db-tenant-gone", "provider-123") is False
mock_ingest.get_driver.return_value.verify_connectivity.side_effect = (
RuntimeError("ingest down")
)
mock_get_driver.return_value.verify_connectivity.side_effect = RuntimeError(
"sink down"
)
def test_raises_on_other_errors(self):
session_ctx = MagicMock()
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
message="Connection refused",
code="Neo.TransientError.General.UnknownError",
with pytest.raises(RuntimeError) as exc:
db_module.verify_scan_databases_available()
assert "ingest Neo4j: ingest down" in str(exc.value)
assert "sink neo4j: sink down" in str(exc.value)
class TestSinkOperationsDelegation:
def test_has_provider_data_delegates_to_sink(self, sink_backend_stub):
sink_backend_stub.has_provider_data.return_value = True
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is True
sink_backend_stub.has_provider_data.assert_called_once_with(
"db-tenant-abc", "provider-123"
)
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
with pytest.raises(db_module.GraphDatabaseQueryException):
db_module.has_provider_data("db-tenant-abc", "provider-123")
def test_drop_subgraph_delegates_to_sink(self, sink_backend_stub):
sink_backend_stub.drop_subgraph.return_value = 42
class TestDropSubgraph:
"""Test drop_subgraph two-phase batched deletion of a provider's graph."""
@staticmethod
def _result(count):
result = MagicMock()
result.single.return_value.get.return_value = count
return result
@staticmethod
def _session_ctx(session):
ctx = MagicMock()
ctx.__enter__.return_value = session
ctx.__exit__.return_value = False
return ctx
def test_deletes_relationships_then_nodes_in_batches(self):
session = MagicMock()
# Phase 1 (relationships): one full batch then empty.
# Phase 2 (nodes): one full batch then empty.
session.run.side_effect = [
self._result(1000),
self._result(0),
self._result(1000),
self._result(0),
]
with patch(
"api.attack_paths.database.get_session",
return_value=self._session_ctx(session),
):
deleted = db_module.drop_subgraph("db-tenant-abc", "provider-123")
# Only phase-2 node counts contribute to the return value.
assert deleted == 1000
assert session.run.call_count == 4
queries = [call.args[0] for call in session.run.call_args_list]
# Regression guard: the memory blow-up was caused by DETACH DELETE.
assert all("DETACH DELETE" not in query for query in queries)
rel_queries = [query for query in queries if "DELETE r" in query]
node_queries = [query for query in queries if "DELETE n" in query]
assert rel_queries and node_queries
# DISTINCT avoids double-counting relationships matched from both ends.
assert all("DISTINCT r" in query for query in rel_queries)
# Relationships must be fully drained before nodes are deleted.
first_node = next(i for i, q in enumerate(queries) if "DELETE n" in q)
last_rel = max(i for i, q in enumerate(queries) if "DELETE r" in q)
assert last_rel < first_node
def test_returns_zero_when_database_not_found(self):
session_ctx = MagicMock()
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
message="Database does not exist",
code="Neo.ClientError.Database.DatabaseNotFound",
assert db_module.drop_subgraph("db-tenant-abc", "provider-123") == 42
sink_backend_stub.drop_subgraph.assert_called_once_with(
"db-tenant-abc", "provider-123"
)
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
assert db_module.drop_subgraph("db-tenant-gone", "provider-123") == 0
def test_raises_on_other_errors(self):
session_ctx = MagicMock()
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
message="Connection refused",
code="Neo.TransientError.General.UnknownError",
)
class TestRoutingByDatabasePrefix:
"""`db-tmp-scan-*` and `None` route to ingest; everything else to sink."""
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
with pytest.raises(db_module.GraphDatabaseQueryException):
db_module.drop_subgraph("db-tenant-abc", "provider-123")
def test_create_database_routes_temp_to_ingest(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.create_database("db-tmp-scan-uuid-1")
mock_ingest.create_database.assert_called_once_with("db-tmp-scan-uuid-1")
sink_backend_stub.create_database.assert_not_called()
def test_create_database_routes_tenant_to_sink(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.create_database("db-tenant-abc")
sink_backend_stub.create_database.assert_called_once_with("db-tenant-abc")
mock_ingest.create_database.assert_not_called()
def test_drop_database_routes_temp_to_ingest(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.drop_database("db-tmp-scan-uuid-1")
mock_ingest.drop_database.assert_called_once_with("db-tmp-scan-uuid-1")
sink_backend_stub.drop_database.assert_not_called()
def test_drop_database_routes_tenant_to_sink(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.drop_database("db-tenant-abc")
sink_backend_stub.drop_database.assert_called_once_with("db-tenant-abc")
mock_ingest.drop_database.assert_not_called()
def test_clear_cache_routes_temp_to_ingest(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.clear_cache("db-tmp-scan-uuid-1")
mock_ingest.clear_cache.assert_called_once_with("db-tmp-scan-uuid-1")
sink_backend_stub.clear_cache.assert_not_called()
def test_clear_cache_routes_tenant_to_sink(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.clear_cache("db-tenant-abc")
sink_backend_stub.clear_cache.assert_called_once_with("db-tenant-abc")
mock_ingest.clear_cache.assert_not_called()
def test_get_session_routes_temp_to_ingest(self, sink_backend_stub):
sentinel = MagicMock()
with patch("api.attack_paths.database.ingest") as mock_ingest:
mock_ingest.get_session.return_value = sentinel
result = db_module.get_session("db-tmp-scan-uuid-1")
assert result is sentinel
mock_ingest.get_session.assert_called_once()
sink_backend_stub.get_session.assert_not_called()
def test_get_session_routes_none_to_ingest(self, sink_backend_stub):
sentinel = MagicMock()
with patch("api.attack_paths.database.ingest") as mock_ingest:
mock_ingest.get_session.return_value = sentinel
result = db_module.get_session(None)
assert result is sentinel
sink_backend_stub.get_session.assert_not_called()
def test_get_ingest_uri_delegates_to_ingest(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
mock_ingest.get_uri.return_value = "bolt://neo4j:7687"
assert db_module.get_ingest_uri() == "bolt://neo4j:7687"
mock_ingest.get_uri.assert_called_once_with()
def test_get_session_routes_tenant_to_sink(self, sink_backend_stub):
sentinel = MagicMock()
sink_backend_stub.get_session.return_value = sentinel
with patch("api.attack_paths.database.ingest") as mock_ingest:
result = db_module.get_session("db-tenant-abc")
assert result is sentinel
mock_ingest.get_session.assert_not_called()
@@ -7,6 +7,7 @@ import pytest
from api.authentication import SSEAuthentication, TenantAPIKeyAuthentication
from api.db_router import MainRouter
from api.models import TenantAPIKey
from django.db.models.query import QuerySet
from django.test import RequestFactory
from rest_framework.exceptions import AuthenticationFailed
@@ -64,6 +65,54 @@ class TestTenantAPIKeyAuthentication:
# Verify the manager was restored
assert TenantAPIKey.objects == original_manager
def test_authenticate_credentials_keeps_manager_during_lookup(
self, auth_backend, api_keys_fixture, request_factory
):
"""Authentication must not expose a QuerySet as the model manager."""
api_key = api_keys_fixture[0]
raw_key = api_key._raw_key
_, encrypted_key = raw_key.split(TenantAPIKey.objects.separator, 1)
original_get = QuerySet.get
manager_has_create_api_key = []
def observe_manager(queryset, *args, **kwargs):
manager_has_create_api_key.append(
hasattr(TenantAPIKey.objects, "create_api_key")
)
return original_get(queryset, *args, **kwargs)
request = request_factory.get("/")
with patch.object(QuerySet, "get", observe_manager):
auth_backend._authenticate_credentials(request, encrypted_key)
assert manager_has_create_api_key
assert all(manager_has_create_api_key)
@pytest.mark.parametrize(
"payload",
[
{"_pk": str(uuid4()), "_exp": "not-a-timestamp"},
{
"_pk": "not-a-uuid",
"_exp": (datetime.now(UTC) + timedelta(days=1)).timestamp(),
},
{"_pk": str(uuid4()), "_exp": True},
],
)
def test_authenticate_credentials_rejects_malformed_payloads(
self, auth_backend, request_factory, payload
):
"""Malformed decrypted payloads fail as authentication errors."""
request = request_factory.get("/")
encrypted_key = auth_backend.key_crypto.generate(payload)
with pytest.raises(AuthenticationFailed) as exc_info:
auth_backend._authenticate_credentials(request, encrypted_key)
assert str(exc_info.value.detail) == "Invalid API Key."
def test_authenticate_credentials_restores_manager_on_exception(
self, auth_backend, request_factory
):
@@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from api.attack_paths.cypher_sanitizer import (
inject_label,
inject_provider_label,
validate_custom_query,
)
@@ -21,6 +22,13 @@ def _inject(cypher: str) -> str:
return inject_provider_label(cypher, PROVIDER_ID)
def test_generic_inject_label_reuses_provider_injection_pipeline():
result = inject_label("MATCH (n:AWSRole)--(m) RETURN n, m", "_Tenant_test")
assert "(n:AWSRole:_Tenant_test)" in result
assert "(m:_Tenant_test)" in result
# ---------------------------------------------------------------------------
# Pass A - Labeled node patterns (all clauses)
# ---------------------------------------------------------------------------
+71 -31
View File
@@ -67,7 +67,7 @@ class TestLivenessEndpoint:
with (
patch("api.health._probe_postgres") as mock_pg,
patch("api.health._probe_valkey") as mock_vk,
patch("api.health._probe_neo4j") as mock_neo,
patch("api.health._probe_graph_db") as mock_neo,
):
response = api_client.get(reverse("health-live"))
@@ -83,14 +83,14 @@ class TestReadinessEndpoint:
return (
patch("api.health._probe_postgres", return_value=None),
patch("api.health._probe_valkey", return_value=None),
patch("api.health._probe_neo4j", return_value=None),
patch("api.health._probe_graph_db", return_value=None),
)
def test_returns_200_and_pass_when_all_dependencies_healthy(self, api_client):
with (
patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"),
patch("api.health._probe_graph_db"),
):
response = api_client.get(reverse("health-ready"))
@@ -107,7 +107,7 @@ class TestReadinessEndpoint:
assert set(body["checks"].keys()) == {
"postgres:responseTime",
"valkey:responseTime",
"neo4j:responseTime",
"graphdb:responseTime",
}
for key in body["checks"]:
entries = body["checks"][key]
@@ -122,6 +122,23 @@ class TestReadinessEndpoint:
# `output` must not leak when the check passed.
assert "output" not in entry
@pytest.mark.parametrize("sink", ["neo4j", "neptune"])
def test_graphdb_component_id_reflects_active_sink(self, api_client, sink):
from django.test import override_settings
with (
override_settings(ATTACK_PATHS_SINK_DATABASE=sink),
patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"),
patch("api.health._probe_graph_db"),
):
response = api_client.get(reverse("health-ready"))
assert response.status_code == status.HTTP_200_OK
entry = response.json()["checks"]["graphdb:responseTime"][0]
# Stable key, but the concrete store is named in componentId.
assert entry["componentId"] == sink
def test_returns_503_and_fail_when_postgres_is_down(self, api_client):
with (
patch(
@@ -129,7 +146,7 @@ class TestReadinessEndpoint:
side_effect=RuntimeError("connection refused"),
),
patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"),
patch("api.health._probe_graph_db"),
):
response = api_client.get(reverse("health-ready"))
@@ -141,13 +158,13 @@ class TestReadinessEndpoint:
# Exception detail is never echoed in the response, only logged.
assert "output" not in pg_entry
assert body["checks"]["valkey:responseTime"][0]["status"] == "pass"
assert body["checks"]["neo4j:responseTime"][0]["status"] == "pass"
assert body["checks"]["graphdb:responseTime"][0]["status"] == "pass"
def test_returns_503_and_fail_when_valkey_is_down(self, api_client):
with (
patch("api.health._probe_postgres"),
patch("api.health._probe_valkey", side_effect=ConnectionError("timeout")),
patch("api.health._probe_neo4j"),
patch("api.health._probe_graph_db"),
):
response = api_client.get(reverse("health-ready"))
@@ -158,12 +175,12 @@ class TestReadinessEndpoint:
assert vk_entry["status"] == "fail"
assert "output" not in vk_entry
def test_returns_503_and_fail_when_neo4j_is_down(self, api_client):
def test_returns_503_and_fail_when_graph_db_is_down(self, api_client):
with (
patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"),
patch(
"api.health._probe_neo4j",
"api.health._probe_graph_db",
side_effect=RuntimeError("ServiceUnavailable"),
),
):
@@ -172,15 +189,15 @@ class TestReadinessEndpoint:
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
body = response.json()
assert body["status"] == "fail"
neo_entry = body["checks"]["neo4j:responseTime"][0]
assert neo_entry["status"] == "fail"
assert "output" not in neo_entry
graph_db_entry = body["checks"]["graphdb:responseTime"][0]
assert graph_db_entry["status"] == "fail"
assert "output" not in graph_db_entry
def test_reports_all_failures_simultaneously(self, api_client):
with (
patch("api.health._probe_postgres", side_effect=RuntimeError("pg down")),
patch("api.health._probe_valkey", side_effect=RuntimeError("vk down")),
patch("api.health._probe_neo4j", side_effect=RuntimeError("neo down")),
patch("api.health._probe_graph_db", side_effect=RuntimeError("neo down")),
):
response = api_client.get(reverse("health-ready"))
@@ -190,7 +207,7 @@ class TestReadinessEndpoint:
for key in (
"postgres:responseTime",
"valkey:responseTime",
"neo4j:responseTime",
"graphdb:responseTime",
):
entry = body["checks"][key][0]
assert entry["status"] == "fail"
@@ -209,7 +226,7 @@ class TestReadinessEndpoint:
with (
patch("api.health._probe_postgres", side_effect=RuntimeError(sensitive)),
patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"),
patch("api.health._probe_graph_db"),
):
response = api_client.get(reverse("health-ready"))
@@ -229,7 +246,7 @@ class TestReadinessEndpoint:
with (
patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"),
patch("api.health._probe_graph_db"),
):
api_client.credentials()
response = api_client.get(reverse("health-ready"))
@@ -244,7 +261,7 @@ class TestReadinessCache:
with (
patch("api.health._probe_postgres") as pg,
patch("api.health._probe_valkey") as vk,
patch("api.health._probe_neo4j") as neo,
patch("api.health._probe_graph_db") as neo,
):
r1 = api_client.get(reverse("health-ready"))
r2 = api_client.get(reverse("health-ready"))
@@ -262,7 +279,7 @@ class TestReadinessCache:
with (
patch("api.health._probe_postgres") as pg,
patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"),
patch("api.health._probe_graph_db"),
):
api_client.get(reverse("health-ready"))
assert pg.call_count == 1
@@ -286,7 +303,7 @@ class TestReadinessCache:
with (
patch("api.health._probe_postgres", side_effect=RuntimeError("down")) as pg,
patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"),
patch("api.health._probe_graph_db"),
):
r1 = api_client.get(reverse("health-ready"))
r2 = api_client.get(reverse("health-ready"))
@@ -320,7 +337,7 @@ class TestRateLimiting:
with (
patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"),
patch("api.health._probe_graph_db"),
patch.object(ScopedRateThrottle, "parse_rate", return_value=(2, 60)),
):
statuses = [
@@ -414,19 +431,42 @@ class TestProbeImplementations:
with pytest.raises(RuntimeError, match="bug"):
health._probe_valkey()
def test_neo4j_probe_calls_verify_connectivity(self):
with patch("api.attack_paths.database.get_driver") as mock_get_driver:
mock_get_driver.return_value.verify_connectivity.return_value = None
assert health._probe_neo4j() is None
mock_get_driver.return_value.verify_connectivity.assert_called_once_with()
def test_graph_db_probe_calls_verify_connectivity(self):
with patch("api.attack_paths.database.verify_connectivity") as mock_verify:
mock_verify.return_value = None
assert health._probe_graph_db() is None
mock_verify.assert_called_once_with()
def test_neo4j_probe_propagates_driver_errors(self):
with patch("api.attack_paths.database.get_driver") as mock_get_driver:
mock_get_driver.return_value.verify_connectivity.side_effect = RuntimeError(
"unreachable"
)
def test_graph_db_probe_propagates_errors(self):
with patch(
"api.attack_paths.database.verify_connectivity",
side_effect=RuntimeError("unreachable"),
):
with pytest.raises(RuntimeError, match="unreachable"):
health._probe_neo4j()
health._probe_graph_db()
def test_graph_db_probe_times_out_when_check_exceeds_budget(self):
# A sink whose connectivity check blocks past the probe budget must
# surface as a failure fast, not pin the request thread for the
# driver's full acquisition timeout.
import time as _time
def _hang() -> None:
_time.sleep(2)
with (
patch("api.health.GRAPH_DB_PROBE_TIMEOUT_SECONDS", 0.2),
patch(
"api.attack_paths.database.verify_connectivity",
side_effect=_hang,
),
):
started = _time.perf_counter()
with pytest.raises(TimeoutError):
health._probe_graph_db()
elapsed = _time.perf_counter() - started
assert elapsed < health.GRAPH_DB_PROBE_TIMEOUT_SECONDS + 1
class TestStatusAggregation:
+629
View File
@@ -0,0 +1,629 @@
"""Tests for the attack-paths sink factory and Neo4j sink.
The sink module picks a backend per ``settings.ATTACK_PATHS_SINK_DATABASE``.
Neo4j is the default and preserves today's behavior; Neptune is opt-in and
builds dual writer/reader Bolt drivers.
"""
import json
from importlib import import_module
from unittest.mock import MagicMock, patch
import pytest
# Prime patch-target resolution. `api.attack_paths.sink/__init__.py` doesn't
# eagerly import these submodules (they're loaded on demand inside the
# factory), so `mock.patch("api.attack_paths.sink.<sub>.…")` would fail with
# AttributeError on first call. Importing here registers them as attributes
# of the package before any decorator runs.
import_module("api.attack_paths.sink.neo4j")
import_module("api.attack_paths.sink.neptune")
@pytest.fixture(autouse=True)
def reset_sink_state():
"""Reset the module-level backend singletons around each test.
The cache lives in `api.attack_paths.sink.factory`, not on the package.
"""
from api.attack_paths.sink import factory
original_backend = factory._backend
original_secondary = dict(factory._secondary_backends)
factory._backend = None
factory._secondary_backends.clear()
yield
factory._backend = original_backend
factory._secondary_backends.clear()
factory._secondary_backends.update(original_secondary)
class TestSinkFactory:
def test_default_resolves_to_neo4j(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
assert factory._resolve_setting() == "neo4j"
def test_neptune_resolves_correctly(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
assert factory._resolve_setting() == "neptune"
def test_invalid_value_raises(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "foo"
with pytest.raises(RuntimeError, match="ATTACK_PATHS_SINK_DATABASE"):
factory._resolve_setting()
@patch("api.attack_paths.sink.neo4j.neo4j.GraphDatabase.driver")
def test_init_builds_neo4j_backend_by_default(self, mock_driver, settings):
from api.attack_paths import sink as sink_module
from api.attack_paths.sink.neo4j import Neo4jSink
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
settings.DATABASES = {
**settings.DATABASES,
"neo4j": {
"HOST": "localhost",
"PORT": "7687",
"USER": "neo4j",
"PASSWORD": "pw",
},
}
mock_driver.return_value = MagicMock()
backend = sink_module.init()
assert isinstance(backend, Neo4jSink)
mock_driver.assert_called_once()
@patch("api.attack_paths.sink.neptune.neptune_auth_provider")
@patch("api.attack_paths.sink.neptune.neo4j.GraphDatabase.driver")
def test_init_builds_neptune_backend(
self, mock_driver, mock_auth_provider, settings
):
from api.attack_paths import sink as sink_module
from api.attack_paths.sink.neptune import NeptuneSink
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
settings.DATABASES = {
**settings.DATABASES,
"neptune": {
"WRITER_ENDPOINT": "writer.example",
"READER_ENDPOINT": "reader.example",
"PORT": "8182",
"REGION": "eu-west-1",
},
}
mock_driver.return_value = MagicMock()
mock_auth_provider.return_value = lambda: None
backend = sink_module.init()
assert isinstance(backend, NeptuneSink)
# Writer + reader endpoints both trigger driver construction
assert mock_driver.call_count == 2
writer_uri = mock_driver.call_args_list[0][0][0]
reader_uri = mock_driver.call_args_list[1][0][0]
assert writer_uri == "bolt+s://writer.example:8182"
assert reader_uri == "bolt+s://reader.example:8182"
@patch("api.attack_paths.sink.neptune.neptune_auth_provider")
@patch("api.attack_paths.sink.neptune.neo4j.GraphDatabase.driver")
def test_neptune_reader_falls_back_to_writer(
self, mock_driver, mock_auth_provider, settings
):
from api.attack_paths import sink as sink_module
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
settings.DATABASES = {
**settings.DATABASES,
"neptune": {
"WRITER_ENDPOINT": "writer.example",
"READER_ENDPOINT": "",
"PORT": "8182",
"REGION": "eu-west-1",
},
}
mock_driver.return_value = MagicMock()
mock_auth_provider.return_value = lambda: None
sink_module.init()
# Only one driver call — reader aliases writer
assert mock_driver.call_count == 1
class TestGetBackendForScan:
"""``get_backend_for_scan`` routes by the row's recorded sink backend."""
@patch("api.attack_paths.sink.neo4j.neo4j.GraphDatabase.driver")
def test_legacy_scan_in_neo4j_process_uses_active_backend(
self, mock_driver, settings
):
from api.attack_paths import sink as sink_module
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
settings.DATABASES = {
**settings.DATABASES,
"neo4j": {
"HOST": "localhost",
"PORT": "7687",
"USER": "neo4j",
"PASSWORD": "pw",
},
}
mock_driver.return_value = MagicMock()
scan = MagicMock(sink_backend="neo4j")
backend = sink_module.get_backend_for_scan(scan)
assert backend is sink_module.get_backend()
def test_neptune_scan_on_neo4j_process_uses_neptune_secondary(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
active_neo4j = MagicMock(name="neo4j-active")
factory._backend = active_neo4j
secondary_neptune = MagicMock(name="neptune-secondary")
with patch.object(factory, "_build_backend", return_value=secondary_neptune):
scan = MagicMock(sink_backend="neptune")
backend = factory.get_backend_for_scan(scan)
assert backend is secondary_neptune
assert backend is not active_neo4j
def _session_ctx(session: MagicMock) -> MagicMock:
ctx = MagicMock()
ctx.__enter__ = MagicMock(return_value=session)
ctx.__exit__ = MagicMock(return_value=False)
return ctx
def _count_result(key: str, count: int) -> MagicMock:
return MagicMock(single=MagicMock(return_value={key: count}))
def _directed_drop_results(
outgoing_rels: int,
incoming_rels: int,
nodes: int,
) -> list[MagicMock]:
return [
_count_result("deleted_rels_count", outgoing_rels),
_count_result("deleted_rels_count", 0),
_count_result("deleted_rels_count", incoming_rels),
_count_result("deleted_rels_count", 0),
_count_result("deleted_nodes_count", nodes),
_count_result("deleted_nodes_count", 0),
]
class TestNeo4jSinkSyncWrites:
def test_ensure_sync_indexes_runs_create_index_idempotent(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.return_value = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.ensure_sync_indexes("db-tenant-x")
query = session.run.call_args.args[0]
assert "CREATE INDEX" in query
assert "IF NOT EXISTS" in query
assert "`_ProviderResource`" in query
assert "`_provider_element_id`" in query
def test_write_nodes_skips_empty_batch(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
with patch.object(sink, "get_session") as get_session:
sink.write_nodes("db-tenant-x", "`AWSUser`", [])
get_session.assert_not_called()
def test_write_nodes_merges_on_provider_resource_label(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.write_nodes(
"db-tenant-x",
"`AWSUser`:`_ProviderResource`",
[{"provider_element_id": "p:e", "props": {"k": "v"}}],
)
query, params = session.run.call_args.args
assert "MERGE (n:`_ProviderResource`" in query
assert "`_provider_element_id`: row.provider_element_id" in query
assert "SET n:`AWSUser`:`_ProviderResource`" in query
assert params == {"rows": [{"provider_element_id": "p:e", "props": {"k": "v"}}]}
def test_write_relationships_scopes_endpoints_by_provider_label(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
provider_id = "00000000-0000-0000-0000-000000000abc"
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.write_relationships(
"db-tenant-x",
"RESOURCE",
provider_id,
[
{
"start_element_id": "s",
"end_element_id": "e",
"provider_element_id": "pe",
"props": {},
}
],
)
query = session.run.call_args.args[0]
assert ":`_Provider_00000000000000000000000000000abc`" in query
assert ":RESOURCE" in query.replace("`", "")
assert "MERGE (s)-[r:`RESOURCE`" in query
class TestNeptuneSinkSyncWrites:
def test_ensure_sync_indexes_is_noop(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
with patch.object(sink, "get_session") as get_session:
sink.ensure_sync_indexes("ignored")
get_session.assert_not_called()
def test_write_nodes_merges_on_neptune_id_with_provider_resource_label(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
session = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.write_nodes(
"ignored",
"`AWSUser`",
[{"provider_element_id": "p:e", "props": {"k": "v"}}],
)
query = session.run.call_args.args[0]
# Neptune assigns a default `vertex` label to any unlabeled node,
# so the MERGE must pin a real label at creation time.
assert "MERGE (n:`_ProviderResource` {`~id`: row.provider_element_id})" in query
assert "SET n:`AWSUser`" in query
assert "SET n.`_provider_element_id` = row.provider_element_id" in query
def test_write_relationships_matches_endpoints_by_id(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
session = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.write_relationships(
"ignored",
"RESOURCE",
"provider-1",
[
{
"start_element_id": "s",
"end_element_id": "e",
"provider_element_id": "pe",
"props": {},
}
],
)
query = session.run.call_args.args[0]
assert "MATCH (s) WHERE id(s) = row.start_element_id" in query
assert "MATCH (e) WHERE id(e) = row.end_element_id" in query
assert "MERGE (s)-[r:`RESOURCE`" in query
class TestNeptuneSinkDropSubgraph:
def test_drop_subgraph_deletes_directed_rels_before_nodes_in_bounded_batches(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
session = MagicMock()
session.run.side_effect = _directed_drop_results(
outgoing_rels=50,
incoming_rels=30,
nodes=10,
)
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
deleted = sink.drop_subgraph("ignored", "provider-1")
assert deleted == 10
assert session.run.call_count == 6
queries = [call.args[0] for call in session.run.call_args_list]
assert ")-[r]->()" in queries[0]
assert ")<-[r]-()" in queries[2]
assert "DELETE n" in queries[4]
assert all("DETACH DELETE" not in query for query in queries)
assert all("DISTINCT r" not in query for query in queries)
first_node = next(i for i, q in enumerate(queries) if "DELETE n" in q)
last_rel = max(i for i, q in enumerate(queries) if "DELETE r" in q)
assert last_rel < first_node
class TestNeo4jSinkDropSubgraph:
"""Neo4j drop deletes relationships then nodes in batches (no ``DETACH DELETE``)."""
def test_drop_subgraph_deletes_directed_rels_before_nodes_in_bounded_batches(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.side_effect = _directed_drop_results(
outgoing_rels=50,
incoming_rels=30,
nodes=10,
)
provider_id = "00000000-0000-0000-0000-000000000abc"
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
deleted = sink.drop_subgraph("db-tenant-x", provider_id)
# Only phase-2 node counts contribute to the return value.
assert deleted == 10
assert session.run.call_count == 6
queries = [call.args[0] for call in session.run.call_args_list]
# Regression guard: the memory blow-up was caused by DETACH DELETE.
assert all("DETACH DELETE" not in query for query in queries)
assert all("DISTINCT r" not in query for query in queries)
first_query = queries[0]
assert "DELETE r" in first_query
assert ")-[r]->()" in first_query
assert ":`_Provider_00000000000000000000000000000abc`" in first_query
assert ")<-[r]-()" in queries[2]
assert "DELETE n" in queries[4]
# Relationships must be fully drained before nodes are deleted.
first_node = next(i for i, q in enumerate(queries) if "DELETE n" in q)
last_rel = max(i for i, q in enumerate(queries) if "DELETE r" in q)
assert last_rel < first_node
def test_drop_subgraph_returns_zero_when_database_does_not_exist(self):
from api.attack_paths.database import GraphDatabaseQueryException
from api.attack_paths.sink.neo4j import DATABASE_NOT_FOUND_CODE, Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.side_effect = GraphDatabaseQueryException(
message="db missing", code=DATABASE_NOT_FOUND_CODE
)
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
deleted = sink.drop_subgraph("db-tenant-missing", "provider-1")
assert deleted == 0
class TestSinkHasProviderData:
"""``has_provider_data`` is the read-path probe used by API views."""
def test_neo4j_returns_true_when_provider_node_exists(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.return_value.single.return_value = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
present = sink.has_provider_data(
"db-tenant-x", "00000000-0000-0000-0000-000000000abc"
)
assert present is True
query = session.run.call_args.args[0]
assert ":`_Provider_00000000000000000000000000000abc`" in query
def test_neo4j_returns_false_when_database_does_not_exist(self):
from api.attack_paths.database import GraphDatabaseQueryException
from api.attack_paths.sink.neo4j import DATABASE_NOT_FOUND_CODE, Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.side_effect = GraphDatabaseQueryException(
message="db missing", code=DATABASE_NOT_FOUND_CODE
)
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
present = sink.has_provider_data("db-tenant-missing", "provider-1")
assert present is False
def test_neptune_returns_true_when_provider_node_exists(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
session = MagicMock()
session.run.return_value.single.return_value = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
present = sink.has_provider_data("ignored", "provider-1")
assert present is True
class TestGetBackendForScanCutover:
"""``get_backend_for_scan`` keeps old-sink scans queryable after cutover."""
def test_legacy_scan_on_neptune_process_uses_neo4j_secondary(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
active_neptune = MagicMock(name="neptune-active")
factory._backend = active_neptune
secondary_neo4j = MagicMock(name="neo4j-secondary")
with patch.object(factory, "_build_backend", return_value=secondary_neo4j):
scan = MagicMock(sink_backend="neo4j")
backend = factory.get_backend_for_scan(scan)
assert backend is secondary_neo4j
assert backend is not active_neptune
class TestSinkVerifyConnectivity:
"""The readiness probe calls ``verify_connectivity`` through the shim.
Neo4j checks its single driver; Neptune checks the reader (the API read
path), which on single-endpoint clusters aliases the writer.
"""
@patch("api.attack_paths.sink.neo4j.neo4j.GraphDatabase.driver")
def test_neo4j_verifies_its_driver(self, mock_driver, settings):
from api.attack_paths.sink.neo4j import Neo4jSink
settings.DATABASES = {
**settings.DATABASES,
"neo4j": {
"HOST": "localhost",
"PORT": "7687",
"USER": "neo4j",
"PASSWORD": "pw",
},
}
driver = MagicMock()
mock_driver.return_value = driver
sink = Neo4jSink()
sink.init()
driver.verify_connectivity.reset_mock() # ignore the eager init check
sink.verify_connectivity()
driver.verify_connectivity.assert_called_once_with()
@patch("api.attack_paths.sink.neptune.neptune_auth_provider")
@patch("api.attack_paths.sink.neptune.neo4j.GraphDatabase.driver")
def test_neptune_verifies_reader_not_writer(
self, mock_driver, mock_auth_provider, settings
):
from api.attack_paths.sink.neptune import NeptuneSink
settings.DATABASES = {
**settings.DATABASES,
"neptune": {
"WRITER_ENDPOINT": "writer.example",
"READER_ENDPOINT": "reader.example",
"PORT": "8182",
"REGION": "eu-west-1",
},
}
writer, reader = MagicMock(name="writer"), MagicMock(name="reader")
mock_driver.side_effect = [writer, reader]
mock_auth_provider.return_value = lambda: None
sink = NeptuneSink()
sink.init()
writer.verify_connectivity.reset_mock()
reader.verify_connectivity.reset_mock()
sink.verify_connectivity()
reader.verify_connectivity.assert_called_once_with()
writer.verify_connectivity.assert_not_called()
class TestSinkInitToleratesUnreachableSink:
"""Init must not crash the process when the sink is down at boot.
Same degradation model as Postgres: the driver is retained and
reconnects lazily; /health/ready surfaces the outage until it recovers.
"""
@patch("api.attack_paths.sink.neo4j.neo4j.GraphDatabase.driver")
def test_neo4j_init_continues_when_verify_fails(self, mock_driver, settings):
from api.attack_paths.sink.neo4j import Neo4jSink
settings.DATABASES = {
**settings.DATABASES,
"neo4j": {
"HOST": "localhost",
"PORT": "7687",
"USER": "neo4j",
"PASSWORD": "pw",
},
}
driver = MagicMock()
driver.verify_connectivity.side_effect = RuntimeError("unreachable")
mock_driver.return_value = driver
sink = Neo4jSink()
# Must not raise.
assert sink.init() is driver
assert sink._driver is driver
@patch("api.attack_paths.sink.neptune.neptune_auth_provider")
@patch("api.attack_paths.sink.neptune.neo4j.GraphDatabase.driver")
def test_neptune_init_continues_when_verify_fails(
self, mock_driver, mock_auth_provider, settings
):
from api.attack_paths.sink.neptune import NeptuneSink
settings.DATABASES = {
**settings.DATABASES,
"neptune": {
"WRITER_ENDPOINT": "writer.example",
"READER_ENDPOINT": "reader.example",
"PORT": "8182",
"REGION": "eu-west-1",
},
}
driver = MagicMock()
driver.verify_connectivity.side_effect = RuntimeError("unreachable")
mock_driver.return_value = driver
mock_auth_provider.return_value = lambda: None
sink = NeptuneSink()
# Must not raise; both drivers retained.
sink.init()
assert sink._writer is not None
assert sink._reader is not None
class TestNeptuneAdminNoOps:
"""Neptune is single-database; admin DDL has no work to do."""
@pytest.mark.parametrize("method", ["create_database", "drop_database"])
def test_admin_ops_return_none_without_touching_a_session(self, method):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
with patch.object(sink, "get_session") as get_session:
assert getattr(sink, method)("ignored") is None
get_session.assert_not_called()
class TestNeptuneAuthToken:
"""SigV4 signing for the Neptune Bolt endpoint."""
@patch("api.attack_paths.sink.neptune.SigV4Auth")
@patch("api.attack_paths.sink.neptune.BotoSession")
def test_host_header_includes_non_default_port(self, mock_boto, mock_sigv4):
# Neptune runs on 8182; the SigV4 canonical Host must keep the port or
# the signature is rejected.
from api.attack_paths.sink.neptune import _NeptuneAuthToken
credentials = MagicMock()
credentials.get_frozen_credentials.return_value = MagicMock()
mock_boto.return_value.get_credentials.return_value = credentials
token = _NeptuneAuthToken("eu-west-1", "https://writer.example:8182")
auth_obj = json.loads(token.credentials)
assert auth_obj["Host"] == "writer.example:8182"
+270 -5
View File
@@ -57,6 +57,7 @@ from api.models import (
UserRoleRelationship,
)
from api.rls import Tenant
from api.uuid_utils import datetime_to_uuid7
from api.v1.serializers import TokenSerializer
from api.v1.views import ComplianceOverviewViewSet, TenantFinishACSView
from botocore.exceptions import ClientError, NoCredentialsError
@@ -4754,6 +4755,64 @@ class TestAttackPathsScanViewSet:
assert first_attributes["provider_type"] == provider.provider
assert first_attributes["provider_uid"] == provider.uid
def test_attack_paths_scans_list_prefers_active_sink_scan_on_rollback(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
settings,
):
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
provider = providers_fixture[0]
neo4j_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neo4j",
)
neptune_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neptune",
)
response = authenticated_client.get(reverse("attack-paths-scans-list"))
assert response.status_code == status.HTTP_200_OK
ids = {item["id"] for item in response.json()["data"]}
assert str(neo4j_scan.id) in ids
assert str(neptune_scan.id) not in ids
def test_attack_paths_scans_list_falls_back_when_active_sink_has_no_scan(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
settings,
):
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
provider = providers_fixture[0]
legacy_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neo4j",
)
response = authenticated_client.get(reverse("attack-paths-scans-list"))
assert response.status_code == status.HTTP_200_OK
ids = {item["id"] for item in response.json()["data"]}
assert str(legacy_scan.id) in ids
def test_attack_paths_scans_list_respects_provider_group_visibility(
self,
authenticated_client_no_permissions_rbac,
@@ -4874,7 +4933,8 @@ class TestAttackPathsScanViewSet:
)
assert response.status_code == status.HTTP_200_OK
mock_get_queries.assert_called_once_with(provider.provider)
# TODO: drop the is_migrated argument after Neptune cutover
mock_get_queries.assert_called_once_with(provider.provider, is_migrated=False)
payload = response.json()["data"]
assert len(payload) == 1
assert payload[0]["id"] == "aws-rds"
@@ -4974,7 +5034,8 @@ class TestAttackPathsScanViewSet:
)
assert response.status_code == status.HTTP_200_OK
mock_get_query.assert_called_once_with("aws-rds")
# TODO: drop the is_migrated argument after Neptune cutover
mock_get_query.assert_called_once_with("aws-rds", is_migrated=False)
mock_get_db_name.assert_called_once_with(attack_paths_scan.provider.tenant_id)
provider_id = str(attack_paths_scan.provider_id)
mock_prepare.assert_called_once_with(
@@ -4988,6 +5049,7 @@ class TestAttackPathsScanViewSet:
query_definition,
prepared_parameters,
provider_id,
scan=attack_paths_scan,
)
result = response.json()["data"]
attributes = result["attributes"]
@@ -5339,6 +5401,7 @@ class TestAttackPathsScanViewSet:
"db-test",
"MATCH (n) RETURN n",
str(attack_paths_scan.provider_id),
scan=attack_paths_scan,
)
attributes = response.json()["data"]["attributes"]
assert len(attributes["nodes"]) == 1
@@ -5875,9 +5938,10 @@ class TestAttackPathsScanViewSet:
)
assert response.status_code == status.HTTP_200_OK
mock_get_schema.assert_called_once_with(
"db-test", str(attack_paths_scan.provider_id)
)
mock_get_schema.assert_called_once()
schema_args = mock_get_schema.call_args[0]
assert schema_args[:2] == ("db-test", str(attack_paths_scan.provider_id))
assert schema_args[2].id == attack_paths_scan.id
attributes = response.json()["data"]["attributes"]
assert attributes["provider"] == "aws"
assert attributes["cartography_version"] == "0.129.0"
@@ -7155,6 +7219,26 @@ class TestFindingViewSet:
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"][0]["code"] == "invalid"
def test_findings_updated_at_range_too_large_with_inserted_at_filter(
self, authenticated_client
):
response = authenticated_client.get(
reverse("finding-list"),
{
"filter[inserted_at]": TODAY,
"filter[updated_at.gte]": today_after_n_days(
-(settings.FINDINGS_MAX_DAYS_IN_RANGE + 1)
),
"filter[updated_at.lte]": TODAY,
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"][0]["code"] == "invalid"
assert response.json()["errors"][0]["source"]["pointer"] == (
"/data/attributes/updated_at"
)
def test_findings_list(self, authenticated_client, findings_fixture):
response = authenticated_client.get(
reverse("finding-list"), {"filter[inserted_at]": TODAY}
@@ -7166,6 +7250,170 @@ class TestFindingViewSet:
== findings_fixture[0].status
)
def test_findings_list_inserted_at_accepts_timestamp_precision_filters(
self, authenticated_client, scans_fixture
):
scan, *_ = scans_fixture
def create_finding(uid, inserted_at):
finding = Finding.objects.create(
id=datetime_to_uuid7(inserted_at),
tenant_id=scan.tenant_id,
uid=uid,
scan=scan,
status=Status.FAIL,
status_extended="timestamp precision status",
impact=Severity.medium,
severity=Severity.medium,
check_id="timestamp_precision_check",
check_metadata={
"CheckId": "timestamp_precision_check",
"Description": "timestamp precision check",
"servicename": "ec2",
},
first_seen_at=inserted_at,
)
Finding.all_objects.filter(pk=finding.pk).update(
inserted_at=inserted_at,
updated_at=inserted_at,
)
finding.refresh_from_db()
return finding
create_finding(
"timestamp_precision_early",
datetime(2026, 1, 15, 10, 30, 0, 100000, tzinfo=UTC),
)
late_finding = create_finding(
"timestamp_precision_late",
datetime(2026, 1, 15, 10, 30, 0, 200000, tzinfo=UTC),
)
response = authenticated_client.get(
reverse("finding-list"),
{
"filter[inserted_at.gte]": "2026-01-15T10:30:00.150Z",
"filter[inserted_at.lte]": "2026-01-15T10:30:00.250Z",
},
)
assert response.status_code == status.HTTP_200_OK
returned_uids = {
finding["attributes"]["uid"] for finding in response.json()["data"]
}
assert returned_uids == {late_finding.uid}
response = authenticated_client.get(
reverse("finding-list"),
{"filter[inserted_at]": "2026-01-15T10:30:00.200Z"},
)
assert response.status_code == status.HTTP_200_OK
returned_uids = {
finding["attributes"]["uid"] for finding in response.json()["data"]
}
assert returned_uids == {late_finding.uid}
def test_findings_list_updated_at_accepts_timestamp_precision_filters(
self, authenticated_client, findings_fixture
):
early_finding, late_finding, *_ = findings_fixture
early_updated_at = datetime(2026, 1, 15, 10, 30, 0, 100000, tzinfo=UTC)
late_updated_at = datetime(2026, 1, 15, 10, 30, 0, 200000, tzinfo=UTC)
Finding.all_objects.filter(pk=early_finding.pk).update(
updated_at=early_updated_at
)
Finding.all_objects.filter(pk=late_finding.pk).update(
updated_at=late_updated_at
)
response = authenticated_client.get(
reverse("finding-list"),
{
"filter[updated_at.gte]": "2026-01-15T10:30:00.150Z",
"filter[updated_at.lte]": "2026-01-15T10:30:00.250Z",
},
)
assert response.status_code == status.HTTP_200_OK
returned_uids = {
finding["attributes"]["uid"] for finding in response.json()["data"]
}
assert returned_uids == {late_finding.uid}
response = authenticated_client.get(
reverse("finding-list"),
{"filter[updated_at]": "2026-01-15T10:30:00.200Z"},
)
assert response.status_code == status.HTTP_200_OK
returned_uids = {
finding["attributes"]["uid"] for finding in response.json()["data"]
}
assert returned_uids == {late_finding.uid}
def test_findings_list_inserted_at_and_updated_at_filters_are_combined(
self, authenticated_client, scans_fixture
):
scan, *_ = scans_fixture
def create_finding(uid, inserted_at, updated_at):
finding = Finding.objects.create(
id=datetime_to_uuid7(inserted_at),
tenant_id=scan.tenant_id,
uid=uid,
scan=scan,
status=Status.FAIL,
status_extended="timestamp precision status",
impact=Severity.medium,
severity=Severity.medium,
check_id="timestamp_precision_check",
check_metadata={
"CheckId": "timestamp_precision_check",
"Description": "timestamp precision check",
"servicename": "ec2",
},
first_seen_at=inserted_at,
)
Finding.all_objects.filter(pk=finding.pk).update(
inserted_at=inserted_at,
updated_at=updated_at,
)
finding.refresh_from_db()
return finding
matching_finding = create_finding(
"timestamp_precision_combined_match",
datetime(2026, 1, 15, 10, 30, 0, 200000, tzinfo=UTC),
datetime(2026, 1, 15, 11, 30, 0, 200000, tzinfo=UTC),
)
create_finding(
"timestamp_precision_combined_inserted_only",
datetime(2026, 1, 15, 10, 30, 0, 200000, tzinfo=UTC),
datetime(2026, 1, 15, 12, 30, 0, 200000, tzinfo=UTC),
)
create_finding(
"timestamp_precision_combined_updated_only",
datetime(2026, 1, 15, 9, 30, 0, 200000, tzinfo=UTC),
datetime(2026, 1, 15, 11, 30, 0, 200000, tzinfo=UTC),
)
response = authenticated_client.get(
reverse("finding-list"),
{
"filter[inserted_at.gte]": "2026-01-15T10:30:00.150Z",
"filter[inserted_at.lte]": "2026-01-15T10:30:00.250Z",
"filter[updated_at.gte]": "2026-01-15T11:30:00.150Z",
"filter[updated_at.lte]": "2026-01-15T11:30:00.250Z",
},
)
assert response.status_code == status.HTTP_200_OK
returned_uids = {
finding["attributes"]["uid"] for finding in response.json()["data"]
}
assert returned_uids == {matching_finding.uid}
def test_findings_list_resource_tags_no_n_plus_one(
self, authenticated_client, findings_fixture
):
@@ -7631,6 +7879,23 @@ class TestFindingViewSet:
]
}
@pytest.mark.parametrize(
"filter_name",
["inserted_at", "inserted_at.gte", "inserted_at.lte"],
)
def test_findings_metadata_rejects_timestamp_precision_filters(
self, authenticated_client, filter_name
):
response = authenticated_client.get(
reverse("finding-metadata"),
{f"filter[{filter_name}]": "2048-01-01T10:30:00Z"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
error = response.json()["errors"][0]
assert error["detail"] == "Enter a valid date."
assert error["code"] == "invalid"
def test_findings_metadata_backfill(
self, authenticated_client, scans_fixture, findings_fixture
):
+29 -7
View File
@@ -50,6 +50,7 @@ from api.filters import (
FindingGroupAggregatedComputedFilter,
FindingGroupFilter,
FindingGroupSummaryFilter,
FindingMetadataFilter,
IntegrationFilter,
IntegrationJiraFindingsFilter,
InvitationFilter,
@@ -1888,8 +1889,8 @@ class ProviderViewSet(DisablePaginationMixin, BaseRLSViewSet):
description=(
"Download a specific compliance report as an OCSF JSON file. "
"Only universal frameworks that declare an output configuration "
"produce this artifact (currently 'dora_2022_2554' and 'csa_ccm_4.0'); any "
"other framework returns 404."
"produce this artifact (currently 'dora_2022_2554', 'csa_ccm_4.0' "
"and 'cis_controls_8.1'); any other framework returns 404."
),
parameters=[
OpenApiParameter(
@@ -2876,13 +2877,22 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
active_sink_backend = django_settings.ATTACK_PATHS_SINK_DATABASE
latest_per_provider = queryset.annotate(
active_sink_rank=Case(
When(sink_backend=active_sink_backend, then=Value(0)),
default=Value(1),
output_field=IntegerField(),
),
latest_scan_rank=Window(
expression=RowNumber(),
partition_by=[F("provider_id")],
order_by=[F("inserted_at").desc()],
)
order_by=[
F("active_sink_rank").asc(),
F("inserted_at").desc(),
],
),
).filter(latest_scan_rank=1)
page = self.paginate_queryset(latest_per_provider)
@@ -2909,7 +2919,11 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
)
def attack_paths_queries(self, request, pk=None):
attack_paths_scan = self.get_object()
queries = get_queries_for_provider(attack_paths_scan.provider.provider)
# TODO: drop the is_migrated argument after Neptune cutover
queries = get_queries_for_provider(
attack_paths_scan.provider.provider,
is_migrated=attack_paths_scan.is_migrated,
)
if not queries:
return Response(
@@ -2942,7 +2956,11 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
serializer = AttackPathsQueryRunRequestSerializer(data=payload)
serializer.is_valid(raise_exception=True)
query_definition = get_query_by_id(serializer.validated_data["id"])
# TODO: drop the is_migrated argument after Neptune cutover
query_definition = get_query_by_id(
serializer.validated_data["id"],
is_migrated=attack_paths_scan.is_migrated,
)
if (
query_definition is None
or query_definition.provider != attack_paths_scan.provider.provider
@@ -2968,6 +2986,7 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
query_definition,
parameters,
provider_id,
scan=attack_paths_scan,
)
query_duration = time.monotonic() - start
@@ -3035,6 +3054,7 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
database_name,
serializer.validated_data["query"],
provider_id,
scan=attack_paths_scan,
)
query_duration = time.monotonic() - start
@@ -3091,7 +3111,7 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
provider_id = str(attack_paths_scan.provider_id)
schema = attack_paths_views_helpers.get_cartography_schema(
database_name, provider_id
database_name, provider_id, attack_paths_scan
)
if not schema:
return Response(
@@ -3814,6 +3834,8 @@ class FindingViewSet(PaginateByPkMixin, BaseRLSViewSet):
def get_filterset_class(self):
if self.action in ["latest", "metadata_latest"]:
return LatestFindingFilter
if self.action == "metadata":
return FindingMetadataFilter
return FindingFilter
def get_queryset(self):
+5
View File
@@ -311,6 +311,11 @@ ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES = env.int(
"ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES", 2880
) # 48h
# Selects where the persistent attack-paths graph is stored. The scan
# temporary database is always Neo4j; only the sink is configurable.
# Valid values: "neo4j" (default, OSS and local dev), "neptune" (hosted).
ATTACK_PATHS_SINK_DATABASE = env.str("ATTACK_PATHS_SINK_DATABASE", default="neo4j")
# Orphan task recovery feature flags. The master switch is OFF by default, so task
# recovery is opt-in; enable it with DJANGO_TASK_RECOVERY_ENABLED=true. The per-group
# toggles default to enabled, so once the master is on every group recovers unless a
+6
View File
@@ -50,6 +50,12 @@ DATABASES = {
"USER": env.str("NEO4J_USER", "neo4j"),
"PASSWORD": env.str("NEO4J_PASSWORD", "neo4j_password"),
},
"neptune": {
"WRITER_ENDPOINT": env.str("NEPTUNE_WRITER_ENDPOINT", ""),
"READER_ENDPOINT": env.str("NEPTUNE_READER_ENDPOINT", ""),
"PORT": env.str("NEPTUNE_PORT", "8182"),
"REGION": env.str("AWS_REGION", ""),
},
}
DATABASES["default"] = DATABASES["prowler_user"]
@@ -49,12 +49,19 @@ DATABASES = {
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
},
# TODO: drop after Neptune cutover just loosen defaults to `""`
"neo4j": {
"HOST": env.str("NEO4J_HOST"),
"PORT": env.str("NEO4J_PORT"),
"USER": env.str("NEO4J_USER"),
"PASSWORD": env.str("NEO4J_PASSWORD"),
},
"neptune": {
"WRITER_ENDPOINT": env.str("NEPTUNE_WRITER_ENDPOINT", default=""),
"READER_ENDPOINT": env.str("NEPTUNE_READER_ENDPOINT", default=""),
"PORT": env.str("NEPTUNE_PORT", default="8182"),
"REGION": env.str("AWS_REGION", default=""),
},
}
DATABASES["default"] = DATABASES["prowler_user"]
+24 -4
View File
@@ -83,12 +83,32 @@ def _warm_compliance_caches_in_background():
def post_fork(_server, worker):
"""Warm compliance caches after each worker fork.
"""Re-initialize attack-paths drivers and warm compliance caches per worker.
Warm compliance caches in a background thread so the worker becomes ready
immediately. A request for a not-yet-warmed provider lazily loads just that
provider, which stays well under the worker timeout.
Neo4j / Neptune drivers spawn background IO threads that do not survive
``fork()``. When the gunicorn master runs with ``preload_app=True``, the
child inherits driver objects whose pool references dead threads and
hangs on the first ``pool.acquire`` call until the watchdog kills the
worker. Re-initializing per worker guarantees each child owns its own
live threads. See GUNICORN_WORKER_TIMEOUTS_ANALYSIS.md for detail.
Compliance caches are then warmed in a background thread so the worker
becomes ready immediately. A request for a not-yet-warmed provider lazily
loads just that provider, which stays well under the worker timeout.
"""
from api.attack_paths import database as graph_database
try:
graph_database.close_driver()
except Exception: # pragma: no cover - best-effort cleanup
gunicorn_logger.debug(
"Failed to close inherited Neo4j driver in post_fork for worker pid=%s",
worker.pid,
exc_info=True,
)
graph_database.init_driver()
gunicorn_logger.info(f"Attack-paths drivers initialized for worker {worker.pid}")
threading.Thread(
target=_warm_compliance_caches_in_background,
name="warm-compliance-caches",
+30
View File
@@ -1821,6 +1821,36 @@ def attack_paths_query_definition_factory():
return _create
@pytest.fixture
def sink_backend_stub():
"""Install a stub `SinkDatabase` into the sink factory for the test's duration.
The sink factory caches a process-wide backend and lazily initializes it
against `settings.DATABASES["neo4j"]` / `["neptune"]`. Tests that don't
want to stand up a real Bolt driver can yield this fixture's mock and
configure its return values directly:
sink_backend_stub.execute_read_query.return_value = some_graph
Both the active backend and the secondary-backend cache are restored on
teardown so tests stay isolated.
"""
from api.attack_paths.sink import factory
from api.attack_paths.sink.base import SinkDatabase
stub = MagicMock(spec=SinkDatabase)
previous_backend = factory._backend
previous_secondary = dict(factory._secondary_backends)
factory._backend = stub
factory._secondary_backends.clear()
try:
yield stub
finally:
factory._backend = previous_backend
factory._secondary_backends.clear()
factory._secondary_backends.update(previous_secondary)
@pytest.fixture
def attack_paths_graph_stub_classes():
"""Provide lightweight graph element stubs for Attack Paths serialization tests."""
+24 -8
View File
@@ -6,6 +6,7 @@ from typing import Any
import aioboto3
import boto3
import botocore
import neo4j
from api.models import (
AttackPathsScan as ProwlerAPIAttackPathsScan,
@@ -73,13 +74,28 @@ def start_aws_ingestion(
# Adding an extra field
common_job_parameters["AWS_ID"] = prowler_api_provider.uid
cartography_aws._autodiscover_accounts(
neo4j_session,
boto3_session,
prowler_api_provider.uid,
cartography_config.update_tag,
common_job_parameters,
)
# AWS Organizations account autodiscovery. Inlined from Cartography's removed
# `_autodiscover_accounts` (deleted in `0.137.0`), as `load_aws_accounts` is still public.
try:
org_client = boto3_session.client("organizations")
paginator = org_client.get_paginator("list_accounts")
discovered = []
for page in paginator.paginate():
discovered.extend(page["Accounts"])
active_accounts = {
a["Name"]: a["Id"] for a in discovered if a["Status"] == "ACTIVE"
}
cartography_aws.organizations.load_aws_accounts(
neo4j_session,
active_accounts,
cartography_config.update_tag,
common_job_parameters,
)
except botocore.exceptions.ClientError:
logger.warning(
f"Account {prowler_api_provider.uid} lacks permissions for AWS "
"Organizations autodiscovery."
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 4)
failed_syncs = sync_aws_account(
@@ -277,7 +293,7 @@ def sync_aws_account(
sync_args: dict[str, Any],
attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> dict[str, str]:
current_progress = 4 # `cartography_aws._autodiscover_accounts`
current_progress = 4 # AWS Organizations account autodiscovery
max_progress = (
87 # `cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"]` - 1
)
@@ -8,7 +8,7 @@ from celery import states
from celery.utils.log import get_task_logger
from config.django.base import ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES
from tasks.jobs.attack_paths.db_utils import (
_mark_scan_finished,
mark_scan_finished,
recover_graph_data_ready,
)
from tasks.jobs.orphan_recovery import is_worker_alive as _is_worker_alive
@@ -87,7 +87,7 @@ def _cleanup_stale_executing_scans(cutoff: datetime) -> list[str]:
else:
reason = "Worker dead — cleaned up by periodic task"
else:
# No worker recorded time-based heuristic only
# No worker recorded, time-based heuristic only
if scan.started_at and scan.started_at >= cutoff:
continue
reason = (
@@ -160,7 +160,7 @@ def _cleanup_scan(scan, task_result, reason: str) -> bool:
"""
scan_id_str = str(scan.id)
# 1. Drop temp Neo4j database
# Drop temp Neo4j database
tmp_db_name = graph_database.get_database_name(scan.id, temporary=True)
try:
graph_database.drop_database(tmp_db_name)
@@ -225,6 +225,6 @@ def _finalize_failed_scan(scan, expected_state: str, reason: str):
logger.info(f"Scan {scan_id_str} is now {fresh_scan.state}, skipping")
return None
_mark_scan_finished(fresh_scan, StateChoices.FAILED, {"global_error": reason})
mark_scan_finished(fresh_scan, StateChoices.FAILED, {"global_error": reason})
return fresh_scan
@@ -1,9 +1,14 @@
from collections.abc import Callable
from dataclasses import dataclass
from uuid import UUID
from config.env import env
from tasks.jobs.attack_paths import aws
from tasks.jobs.attack_paths import provider_config as _provider_config
# Re-export provider config objects so existing imports keep working.
AWS_CONFIG = _provider_config.AWS_CONFIG
NormalizedList = _provider_config.NormalizedList
PROVIDER_CONFIGS = _provider_config.PROVIDER_CONFIGS
ProviderConfig = _provider_config.ProviderConfig
# Batch size for Neo4j write operations (resource labeling, cleanup)
BATCH_SIZE = env.int("ATTACK_PATHS_BATCH_SIZE", 1000)
@@ -21,42 +26,12 @@ PROWLER_FINDING_LABEL = "ProwlerFinding"
PROVIDER_RESOURCE_LABEL = "_ProviderResource"
# Dynamic isolation labels that contain entity UUIDs and are added to every synced node during sync
# Format: _Tenant_{uuid_no_hyphens}, _Provider_{uuid_no_hyphens}
# Format: `_Tenant_{uuid_no_hyphens}`, `_Provider_{uuid_no_hyphens}`
TENANT_LABEL_PREFIX = "_Tenant_"
PROVIDER_LABEL_PREFIX = "_Provider_"
DYNAMIC_ISOLATION_PREFIXES = [TENANT_LABEL_PREFIX, PROVIDER_LABEL_PREFIX]
@dataclass(frozen=True)
class ProviderConfig:
"""Configuration for a cloud provider's Attack Paths integration."""
name: str
root_node_label: str # e.g., "AWSAccount"
uid_field: str # e.g., "arn"
# Label for resources connected to the account node, enabling indexed finding lookups.
resource_label: str # e.g., "_AWSResource"
ingestion_function: Callable
# Maps a Postgres resource UID (e.g. full ARN) to the short-id form Cartography stores on some node types (e.g. `i-xxx` for EC2Instance).
short_uid_extractor: Callable[[str], str]
# Provider Configurations
# -----------------------
AWS_CONFIG = ProviderConfig(
name="aws",
root_node_label="AWSAccount",
uid_field="arn",
resource_label="_AWSResource",
ingestion_function=aws.start_aws_ingestion,
short_uid_extractor=aws.extract_short_uid,
)
PROVIDER_CONFIGS: dict[str, ProviderConfig] = {
"aws": AWS_CONFIG,
}
# Labels added by Prowler that should be filtered from API responses
# Derived from provider configs + common internal labels
INTERNAL_LABELS: list[str] = [
@@ -87,7 +62,6 @@ INTERNAL_PROPERTIES: list[str] = [
# Provider Config Accessors
# -------------------------
def is_provider_available(provider_type: str) -> bool:
@@ -135,7 +109,6 @@ def get_short_uid_extractor(provider_type: str) -> Callable[[str], str]:
# Dynamic Isolation Label Helpers
# --------------------------------
def _normalize_uuid(value: str | UUID) -> str:
@@ -8,6 +8,8 @@ from api.models import Provider as ProwlerAPIProvider
from api.models import StateChoices
from cartography.config import Config as CartographyConfig
from celery.utils.log import get_task_logger
from django.conf import settings
from django.db.models import Case, IntegerField, Value, When
from tasks.jobs.attack_paths.config import is_provider_available
logger = get_task_logger(__name__)
@@ -29,13 +31,33 @@ def create_attack_paths_scan(
return None
with rls_transaction(tenant_id):
# Inherit graph_data_ready from the previous scan for this provider,
# so queries remain available while the new scan runs.
previous_data_ready = ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
graph_data_ready=True,
).exists()
# Inherit metadata from the previous ready scan for this provider so
# queries remain available while the new scan runs. The new row only
# flips to the target sink after its own graph sync succeeds.
active_sink_backend = settings.ATTACK_PATHS_SINK_DATABASE
previous_ready = (
ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
graph_data_ready=True,
)
.annotate(
active_sink_rank=Case(
When(sink_backend=active_sink_backend, then=Value(0)),
default=Value(1),
output_field=IntegerField(),
)
)
.order_by("active_sink_rank", "-inserted_at")
.first()
)
previous_data_ready = previous_ready is not None
inherited_is_migrated = previous_ready.is_migrated if previous_ready else False
inherited_sink_backend = (
previous_ready.sink_backend
if previous_ready
else ProwlerAPIAttackPathsScan.SinkBackendChoices.NEO4J
)
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
tenant_id=tenant_id,
@@ -44,6 +66,8 @@ def create_attack_paths_scan(
state=StateChoices.SCHEDULED,
started_at=datetime.now(tz=UTC),
graph_data_ready=previous_data_ready,
is_migrated=inherited_is_migrated,
sink_backend=inherited_sink_backend,
)
attack_paths_scan.save()
@@ -114,7 +138,7 @@ def starting_attack_paths_scan(
return True
def _mark_scan_finished(
def mark_scan_finished(
attack_paths_scan: ProwlerAPIAttackPathsScan,
state: StateChoices,
ingestion_exceptions: dict[str, Any],
@@ -148,7 +172,7 @@ def finish_attack_paths_scan(
ingestion_exceptions: dict[str, Any],
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
_mark_scan_finished(attack_paths_scan, state, ingestion_exceptions)
mark_scan_finished(attack_paths_scan, state, ingestion_exceptions)
def update_attack_paths_scan_progress(
@@ -169,19 +193,45 @@ def set_graph_data_ready(
attack_paths_scan.save(update_fields=["graph_data_ready"])
def set_scan_migrated(
attack_paths_scan: ProwlerAPIAttackPathsScan,
migrated: bool,
sink_backend: str | None = None,
) -> None:
"""Mark the scan as written with the current (migrated) schema.
Called after a successful sync so the read catalog and sink backend only
switch once the new graph is actually live.
# TODO: drop after Neptune cutover
"""
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.is_migrated = migrated
update_fields = ["is_migrated"]
if sink_backend is not None:
attack_paths_scan.sink_backend = sink_backend
update_fields.append("sink_backend")
attack_paths_scan.save(update_fields=update_fields)
def set_provider_graph_data_ready(
attack_paths_scan: ProwlerAPIAttackPathsScan,
ready: bool,
sink_backend: str | None = None,
) -> None:
"""
Set `graph_data_ready` for ALL scans of the same provider.
Set `graph_data_ready` for scans of the same provider in one sink.
Used before drop/sync so that older scan IDs cannot bypass the query gate while the graph is being replaced.
Used before drop/sync so that older scan IDs in the target sink cannot
bypass the query gate while that sink's graph is being replaced. Scans
preserved in another sink stay queryable for rollback.
"""
target_sink_backend = sink_backend or attack_paths_scan.sink_backend
with rls_transaction(attack_paths_scan.tenant_id):
ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=attack_paths_scan.tenant_id,
provider_id=attack_paths_scan.provider_id,
sink_backend=target_sink_backend,
).update(graph_data_ready=ready)
attack_paths_scan.refresh_from_db(fields=["graph_data_ready"])
@@ -202,10 +252,15 @@ def recover_graph_data_ready(
next successful scan) is a worse outcome for the user.
"""
try:
from api.attack_paths import sink as sink_module
tenant_db = graph_database.get_database_name(attack_paths_scan.tenant_id)
if graph_database.has_provider_data(
tenant_db, str(attack_paths_scan.provider_id)
):
# TODO: drop after Neptune cutover
# Check the backend that actually holds this scan's data, not the
# currently configured sink, a stale `EXECUTING` scan from before a
# backend switch must still be recoverable
backend = sink_module.get_backend_for_scan(attack_paths_scan)
if backend.has_provider_data(tenant_db, str(attack_paths_scan.provider_id)):
set_provider_graph_data_ready(attack_paths_scan, True)
logger.info(
f"Recovered `graph_data_ready` for provider {attack_paths_scan.provider_id}"
@@ -247,6 +302,6 @@ def fail_attack_paths_scan(
return
if fresh.state in (StateChoices.COMPLETED, StateChoices.FAILED):
return
_mark_scan_finished(fresh, StateChoices.FAILED, {"global_error": error})
mark_scan_finished(fresh, StateChoices.FAILED, {"global_error": error})
recover_graph_data_ready(fresh)
@@ -82,7 +82,6 @@ def _to_neo4j_dict(
# Public API
# ----------
def analysis(
@@ -196,7 +195,6 @@ def load_findings(
# Findings Streaming (Generator-based)
# -------------------------------------
def stream_findings_with_resources(
@@ -275,7 +273,6 @@ def _fetch_findings_batch(
# Batch Enrichment
# -----------------
def _enrich_batch_with_resources(
@@ -1,5 +1,6 @@
import neo4j
from cartography.client.core.tx import run_write_query
from cartography.intel import create_indexes as cartography_create_indexes
from celery.utils.log import get_task_logger
from tasks.jobs.attack_paths.config import (
INTERNET_NODE_LABEL,
@@ -30,14 +31,34 @@ SYNC_INDEX_STATEMENTS = [
def create_findings_indexes(neo4j_session: neo4j.Session) -> None:
"""Create indexes for Prowler findings and resource lookups."""
"""Create indexes for Prowler findings and resource lookups.
Runs `CREATE INDEX`, so the caller must only invoke this against a Neo4j
session (the temp ingest DB or a Neo4j sink). Neptune auto-manages indexes
and rejects `CREATE INDEX`, so callers skip it for the Neptune sink.
"""
logger.info("Creating indexes for Prowler Findings node types")
for statement in FINDINGS_INDEX_STATEMENTS:
run_write_query(neo4j_session, statement)
def create_cartography_indexes(neo4j_session: neo4j.Session, config) -> None:
"""Create Cartography's standard indexes for the session's database.
Runs `CREATE INDEX`, so the caller must only invoke this against a Neo4j
session (the temp ingest DB or a Neo4j sink). Neptune auto-manages indexes
and rejects `CREATE INDEX`, so callers skip it for the Neptune sink.
"""
cartography_create_indexes.run(neo4j_session, config)
def create_sync_indexes(neo4j_session: neo4j.Session) -> None:
"""Create indexes for provider resource sync operations."""
"""Create indexes for provider resource sync operations.
Runs `CREATE INDEX`, so the caller must only invoke this against a Neo4j
session (the temp ingest DB or a Neo4j sink). Neptune auto-manages indexes
and rejects `CREATE INDEX`, so callers skip it for the Neptune sink.
"""
logger.info("Ensuring ProviderResource indexes exist")
for statement in SYNC_INDEX_STATEMENTS:
neo4j_session.run(statement)
@@ -0,0 +1,431 @@
"""
Provider-level Attack Paths configuration.
Each `ProviderConfig` carries the cloud provider's ingestion entry point and
the catalog of list-typed node properties (`normalized_lists`). The sync
layer reads this catalog and materialises each list element as a child node
connected to the parent by a typed edge, so queries traverse the graph
instead of working on serialised list values. Both Neo4j and Neptune sinks
write the same shape and queries are portable across them.
"""
from collections.abc import Callable
from dataclasses import dataclass, field
from tasks.jobs.attack_paths import aws
@dataclass(frozen=True)
class NormalizedList:
"""Catalog entry for a list-typed node property.
Describes how the sync layer materialises a parent node's list-typed
property as a set of child item nodes connected by a typed edge.
Conventions (mechanical, do not invent):
- `child_label`: `<SourceLabel><PropertyPascal>Item`
e.g. AWSPolicyStatement.resource -> AWSPolicyStatementResourceItem
- `rel_type`: `HAS_<PROPERTY_UPPER>`
e.g. resource -> HAS_RESOURCE
- child node property:
* `field_map = []` (scalar list, ~95% case) -> child stores `value: str`
* `field_map = [(src_key, child_field), ...]` (list of dicts, rare)
-> child stores those fields
"""
source_label: str
source_property: str
child_label: str
rel_type: str
field_map: list[tuple[str, str]] = field(default_factory=list)
def __post_init__(self) -> None:
if self.field_map:
child_fields = [dst for _, dst in self.field_map]
if "value" in child_fields:
raise ValueError(
f"NormalizedList {self.source_label}.{self.source_property}: "
"`value` is reserved for scalar mode; do not map a source key to it"
)
src_keys = [src for src, _ in self.field_map]
if len(set(src_keys)) != len(src_keys):
raise ValueError(
f"NormalizedList {self.source_label}.{self.source_property}: "
"duplicate source key in field_map"
)
if len(set(child_fields)) != len(child_fields):
raise ValueError(
f"NormalizedList {self.source_label}.{self.source_property}: "
"duplicate child field in field_map"
)
@dataclass(frozen=True)
class ProviderConfig:
"""Configuration for a cloud provider's Attack Paths integration."""
name: str
root_node_label: str # e.g., "AWSAccount"
uid_field: str # e.g., "arn"
# Label for resources connected to the account node, enabling indexed finding lookups
resource_label: str # e.g., "_AWSResource"
ingestion_function: Callable
# Maps a Postgres resource UID (e.g. full ARN) to the short-id form Cartography stores on some node types (e.g. `i-xxx` for EC2Instance)
short_uid_extractor: Callable[[str], str]
# List-typed properties to materialise as child nodes + edges at sync time.
# Mandatory (may be []). Without an entry here, a list-typed property falls
# back to comma-string flatten and emits a one-time warning.
normalized_lists: list[NormalizedList]
# AWS list-typed property catalog.
# One entry per Cartography node property whose runtime value is a list. The
# sync layer materialises each element as a `<child_label>` node and links it
# to the parent with a `<rel_type>` edge; see the `NormalizedList` docstring
# above for the naming conventions.
AWS_NORMALIZED_LISTS: list[NormalizedList] = [
# AWSPolicyStatement - the hot path driving the 53-query perf fix.
NormalizedList(
"AWSPolicyStatement", "action", "AWSPolicyStatementActionItem", "HAS_ACTION"
),
NormalizedList(
"AWSPolicyStatement",
"notaction",
"AWSPolicyStatementNotactionItem",
"HAS_NOTACTION",
),
NormalizedList(
"AWSPolicyStatement",
"resource",
"AWSPolicyStatementResourceItem",
"HAS_RESOURCE",
),
NormalizedList(
"AWSPolicyStatement",
"notresource",
"AWSPolicyStatementNotresourceItem",
"HAS_NOTRESOURCE",
),
# S3PolicyStatement - same shape as IAM policies; AWS allows list or string.
NormalizedList(
"S3PolicyStatement", "action", "S3PolicyStatementActionItem", "HAS_ACTION"
),
NormalizedList(
"S3PolicyStatement", "resource", "S3PolicyStatementResourceItem", "HAS_RESOURCE"
),
# IAM / Cognito / KMS / Secrets
NormalizedList(
"CognitoIdentityPool", "roles", "CognitoIdentityPoolRolesItem", "HAS_ROLES"
),
NormalizedList(
"KMSKey",
"encryption_algorithms",
"KMSKeyEncryptionAlgorithmsItem",
"HAS_ENCRYPTION_ALGORITHMS",
),
NormalizedList(
"KMSKey",
"signing_algorithms",
"KMSKeySigningAlgorithmsItem",
"HAS_SIGNING_ALGORITHMS",
),
NormalizedList(
"KMSKey",
"anonymous_actions",
"KMSKeyAnonymousActionsItem",
"HAS_ANONYMOUS_ACTIONS",
),
NormalizedList(
"KMSGrant", "operations", "KMSGrantOperationsItem", "HAS_OPERATIONS"
),
NormalizedList(
"SecretsManagerSecretVersion",
"version_stages",
"SecretsManagerSecretVersionVersionStagesItem",
"HAS_VERSION_STAGES",
),
NormalizedList(
"SecretsManagerSecretVersion",
"kms_key_ids",
"SecretsManagerSecretVersionKmsKeyIdsItem",
"HAS_KMS_KEY_IDS",
),
NormalizedList(
"SecretsManagerSecretVersion",
"tags",
"SecretsManagerSecretVersionTagsItem",
"HAS_TAGS",
field_map=[("Key", "key"), ("Value", "value_")],
# `value` is reserved for scalar mode; map `Value` to `value_` to keep dict shape.
),
# Lambda / Compute
NormalizedList(
"AWSLambda", "architectures", "AWSLambdaArchitecturesItem", "HAS_ARCHITECTURES"
),
NormalizedList(
"AWSLambda",
"anonymous_actions",
"AWSLambdaAnonymousActionsItem",
"HAS_ANONYMOUS_ACTIONS",
),
NormalizedList(
"CodeBuildProject",
"environment_variables",
"CodeBuildProjectEnvironmentVariablesItem",
"HAS_ENVIRONMENT_VARIABLES",
),
# ECS family
NormalizedList(
"ECSCluster",
"capacity_providers",
"ECSClusterCapacityProvidersItem",
"HAS_CAPACITY_PROVIDERS",
),
NormalizedList(
"ECSTaskDefinition",
"compatibilities",
"ECSTaskDefinitionCompatibilitiesItem",
"HAS_COMPATIBILITIES",
),
NormalizedList(
"ECSTaskDefinition",
"requires_compatibilities",
"ECSTaskDefinitionRequiresCompatibilitiesItem",
"HAS_REQUIRES_COMPATIBILITIES",
),
NormalizedList(
"ECSContainerDefinition",
"links",
"ECSContainerDefinitionLinksItem",
"HAS_LINKS",
),
NormalizedList(
"ECSContainerDefinition",
"entry_point",
"ECSContainerDefinitionEntryPointItem",
"HAS_ENTRY_POINT",
),
NormalizedList(
"ECSContainerDefinition",
"command",
"ECSContainerDefinitionCommandItem",
"HAS_COMMAND",
),
NormalizedList(
"ECSContainerDefinition",
"dns_servers",
"ECSContainerDefinitionDnsServersItem",
"HAS_DNS_SERVERS",
),
NormalizedList(
"ECSContainerDefinition",
"dns_search_domains",
"ECSContainerDefinitionDnsSearchDomainsItem",
"HAS_DNS_SEARCH_DOMAINS",
),
NormalizedList(
"ECSContainerDefinition",
"docker_security_options",
"ECSContainerDefinitionDockerSecurityOptionsItem",
"HAS_DOCKER_SECURITY_OPTIONS",
),
NormalizedList("ECSContainer", "gpu_ids", "ECSContainerGpuIdsItem", "HAS_GPU_IDS"),
# ECR
NormalizedList(
"ECRImage", "layer_diff_ids", "ECRImageLayerDiffIdsItem", "HAS_LAYER_DIFF_IDS"
),
NormalizedList(
"ECRImage",
"child_image_digests",
"ECRImageChildImageDigestsItem",
"HAS_CHILD_IMAGE_DIGESTS",
),
# EC2 / Networking
NormalizedList(
"EC2Instance",
"exposed_internet_type",
"EC2InstanceExposedInternetTypeItem",
"HAS_EXPOSED_INTERNET_TYPE",
),
NormalizedList(
"AutoScalingGroup",
"exposed_internet_type",
"AutoScalingGroupExposedInternetTypeItem",
"HAS_EXPOSED_INTERNET_TYPE",
),
NormalizedList(
"LaunchConfiguration",
"security_groups",
"LaunchConfigurationSecurityGroupsItem",
"HAS_SECURITY_GROUPS",
),
NormalizedList(
"LaunchTemplateVersion",
"security_group_ids",
"LaunchTemplateVersionSecurityGroupIdsItem",
"HAS_SECURITY_GROUP_IDS",
),
NormalizedList(
"LaunchTemplateVersion",
"security_groups",
"LaunchTemplateVersionSecurityGroupsItem",
"HAS_SECURITY_GROUPS",
),
NormalizedList(
"AWSVpcEndpoint",
"route_table_ids",
"AWSVpcEndpointRouteTableIdsItem",
"HAS_ROUTE_TABLE_IDS",
),
NormalizedList(
"AWSVpcEndpoint",
"network_interface_ids",
"AWSVpcEndpointNetworkInterfaceIdsItem",
"HAS_NETWORK_INTERFACE_IDS",
),
NormalizedList(
"AWSVpcEndpoint",
"subnet_ids",
"AWSVpcEndpointSubnetIdsItem",
"HAS_SUBNET_IDS",
),
NormalizedList(
"ELBListener", "policy_names", "ELBListenerPolicyNamesItem", "HAS_POLICY_NAMES"
),
# CloudFront / Route53 / CloudWatch / CloudTrail
NormalizedList(
"CloudFrontDistribution",
"aliases",
"CloudFrontDistributionAliasesItem",
"HAS_ALIASES",
),
NormalizedList(
"CloudFrontDistribution",
"geo_restriction_locations",
"CloudFrontDistributionGeoRestrictionLocationsItem",
"HAS_GEO_RESTRICTION_LOCATIONS",
),
NormalizedList(
"CloudWatchLogGroup",
"inherited_properties",
"CloudWatchLogGroupInheritedPropertiesItem",
"HAS_INHERITED_PROPERTIES",
),
# RDS / Storage
NormalizedList(
"RDSCluster",
"availability_zones",
"RDSClusterAvailabilityZonesItem",
"HAS_AVAILABILITY_ZONES",
),
NormalizedList(
"RDSEventSubscription",
"event_categories",
"RDSEventSubscriptionEventCategoriesItem",
"HAS_EVENT_CATEGORIES",
),
NormalizedList(
"RDSEventSubscription",
"source_ids",
"RDSEventSubscriptionSourceIdsItem",
"HAS_SOURCE_IDS",
),
NormalizedList(
"S3Bucket",
"anonymous_actions",
"S3BucketAnonymousActionsItem",
"HAS_ANONYMOUS_ACTIONS",
),
# Inspector / Config / SSM / ACM / APIGateway / Glue / SageMaker / Bedrock
NormalizedList(
"AWSInspectorFinding",
"referenceurls",
"AWSInspectorFindingReferenceurlsItem",
"HAS_REFERENCEURLS",
),
NormalizedList(
"AWSInspectorFinding",
"relatedvulnerabilities",
"AWSInspectorFindingRelatedvulnerabilitiesItem",
"HAS_RELATEDVULNERABILITIES",
),
NormalizedList(
"AWSInspectorFinding",
"vulnerablepackageids",
"AWSInspectorFindingVulnerablepackageidsItem",
"HAS_VULNERABLEPACKAGEIDS",
),
NormalizedList(
"AWSConfigurationRecorder",
"recording_group_resource_types",
"AWSConfigurationRecorderRecordingGroupResourceTypesItem",
"HAS_RECORDING_GROUP_RESOURCE_TYPES",
),
NormalizedList(
"AWSConfigRule",
"scope_compliance_resource_types",
"AWSConfigRuleScopeComplianceResourceTypesItem",
"HAS_SCOPE_COMPLIANCE_RESOURCE_TYPES",
),
NormalizedList(
"AWSConfigRule",
"source_details",
"AWSConfigRuleSourceDetailsItem",
"HAS_SOURCE_DETAILS",
),
NormalizedList(
"SSMInstancePatch", "cve_ids", "SSMInstancePatchCveIdsItem", "HAS_CVE_IDS"
),
NormalizedList(
"ACMCertificate", "in_use_by", "ACMCertificateInUseByItem", "HAS_IN_USE_BY"
),
NormalizedList(
"APIGatewayRestAPI",
"anonymous_actions",
"APIGatewayRestAPIAnonymousActionsItem",
"HAS_ANONYMOUS_ACTIONS",
),
NormalizedList(
"GlueJob", "connections", "GlueJobConnectionsItem", "HAS_CONNECTIONS"
),
NormalizedList(
"AWSBedrockFoundationModel",
"input_modalities",
"AWSBedrockFoundationModelInputModalitiesItem",
"HAS_INPUT_MODALITIES",
),
NormalizedList(
"AWSBedrockFoundationModel",
"output_modalities",
"AWSBedrockFoundationModelOutputModalitiesItem",
"HAS_OUTPUT_MODALITIES",
),
NormalizedList(
"AWSBedrockFoundationModel",
"customizations_supported",
"AWSBedrockFoundationModelCustomizationsSupportedItem",
"HAS_CUSTOMIZATIONS_SUPPORTED",
),
NormalizedList(
"AWSBedrockFoundationModel",
"inference_types_supported",
"AWSBedrockFoundationModelInferenceTypesSupportedItem",
"HAS_INFERENCE_TYPES_SUPPORTED",
),
]
AWS_CONFIG = ProviderConfig(
name="aws",
root_node_label="AWSAccount",
uid_field="arn",
resource_label="_AWSResource",
ingestion_function=aws.start_aws_ingestion,
short_uid_extractor=aws.extract_short_uid,
normalized_lists=AWS_NORMALIZED_LISTS,
)
PROVIDER_CONFIGS: dict[str, ProviderConfig] = {
"aws": AWS_CONFIG,
}
@@ -1,8 +1,6 @@
# Cypher query templates for Attack Paths operations
from tasks.jobs.attack_paths.config import (
INTERNET_NODE_LABEL,
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
PROWLER_FINDING_LABEL,
)
@@ -21,7 +19,6 @@ def render_cypher_template(template: str, replacements: dict[str, str]) -> str:
# Findings queries (used by findings.py)
# ---------------------------------------
ADD_RESOURCE_LABEL_TEMPLATE = """
MATCH (account:__ROOT_LABEL__ {id: $provider_uid})-->(r)
@@ -88,7 +85,6 @@ INSERT_FINDING_TEMPLATE = f"""
"""
# Internet queries (used by internet.py)
# ---------------------------------------
CREATE_INTERNET_NODE = f"""
MERGE (internet:{INTERNET_NODE_LABEL} {{id: 'Internet'}})
@@ -118,8 +114,8 @@ CREATE_CAN_ACCESS_RELATIONSHIPS_TEMPLATE = f"""
RETURN COUNT(r) AS relationships_merged
"""
# Sync queries (used by sync.py)
# -------------------------------
# Sync queries (used by sync.py to fetch from the cartography temp Neo4j DB)
# The write side of sync lives in each sink (`api/attack_paths/sink/`).
NODE_FETCH_QUERY = """
MATCH (n)
@@ -143,17 +139,3 @@ RELATIONSHIPS_FETCH_QUERY = """
ORDER BY internal_id
LIMIT $batch_size
"""
NODE_SYNC_TEMPLATE = f"""
UNWIND $rows AS row
MERGE (n:__NODE_LABELS__ {{{PROVIDER_ELEMENT_ID_PROPERTY}: row.provider_element_id}})
SET n += row.props
"""
RELATIONSHIP_SYNC_TEMPLATE = f"""
UNWIND $rows AS row
MATCH (s:{PROVIDER_RESOURCE_LABEL} {{{PROVIDER_ELEMENT_ID_PROPERTY}: row.start_element_id}})
MATCH (t:{PROVIDER_RESOURCE_LABEL} {{{PROVIDER_ELEMENT_ID_PROPERTY}: row.end_element_id}})
MERGE (s)-[r:__REL_TYPE__ {{{PROVIDER_ELEMENT_ID_PROPERTY}: row.provider_element_id}}]->(t)
SET r += row.props
"""
+93 -32
View File
@@ -39,8 +39,8 @@ Pipeline steps:
7. Sync the temp database into the tenant database:
- Drop the old provider subgraph (matched by dynamic _Provider_{uuid} label).
graph_data_ready is set to False for all scans of this provider while
the swap happens so the API doesn't serve partial data.
graph_data_ready is set to False for scans of this provider in the
target sink while the swap happens so the API doesn't serve partial data.
- Copy nodes and relationships in batches. Every synced node gets a
_ProviderResource label and dynamic _Tenant_{uuid} / _Provider_{uuid}
isolation labels, plus a _provider_element_id property for MERGE keys.
@@ -64,10 +64,17 @@ from api.models import StateChoices
from api.utils import initialize_prowler_provider
from cartography.config import Config as CartographyConfig
from cartography.intel import analysis as cartography_analysis
from cartography.intel import create_indexes as cartography_create_indexes
from cartography.intel import ontology as cartography_ontology
from celery.utils.log import get_task_logger
from tasks.jobs.attack_paths import db_utils, findings, indexes, internet, sync, utils
from django.conf import settings
from tasks.jobs.attack_paths import (
db_utils,
findings,
indexes,
internet,
sync,
utils,
)
from tasks.jobs.attack_paths.config import get_cartography_ingestion_function
# Without this Celery goes crazy with Cartography logging
@@ -96,7 +103,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
attack_paths_scan = db_utils.retrieve_attack_paths_scan(tenant_id, scan_id)
# Idempotency guard: cleanup may have flipped this row to a terminal state
# while the message was still in flight. Bail out before touching state.
# while the message was still in flight. Bail out before touching state
if attack_paths_scan and attack_paths_scan.state in (
StateChoices.FAILED,
StateChoices.COMPLETED,
@@ -125,7 +132,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
else:
if not attack_paths_scan:
# Safety net for in-flight messages or direct task invocations; dispatcher normally pre-creates the row.
# Safety net for in-flight messages or direct task invocations; dispatcher normally pre-creates the row
logger.warning(
f"No Attack Paths Scan found for scan {scan_id} and tenant {tenant_id}, let's create it then"
)
@@ -143,10 +150,18 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
tenant_database_name = graph_database.get_database_name(
prowler_api_provider.tenant_id
)
target_sink_backend = settings.ATTACK_PATHS_SINK_DATABASE
target_description = (
f"tenant Neo4j database {tenant_database_name}"
if target_sink_backend == "neo4j"
else f"{target_sink_backend} sink"
)
# While creating the Cartography configuration, attributes `neo4j_user` and `neo4j_password` are not really needed in this config object
tmp_cartography_config = CartographyConfig(
neo4j_uri=graph_database.get_uri(),
# The temp ingest database is always Neo4j, so use the ingest URI here
# rather than the sink URI (which points at Neptune when configured).
neo4j_uri=graph_database.get_ingest_uri(),
neo4j_database=tmp_database_name,
update_tag=int(time.time()),
)
@@ -156,6 +171,8 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
update_tag=tmp_cartography_config.update_tag,
)
graph_database.verify_scan_databases_available()
# Starting the Attack Paths scan
if not db_utils.starting_attack_paths_scan(
attack_paths_scan, tenant_cartography_config
@@ -168,7 +185,8 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
scan_t0 = time.perf_counter()
logger.info(
f"Starting Attack Paths scan ({attack_paths_scan.id}) for "
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id}"
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id} "
f"(staging=Neo4j database {tmp_database_name}, target={target_description})"
)
subgraph_dropped = False
@@ -177,7 +195,8 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
try:
logger.info(
f"Creating Neo4j database {tmp_cartography_config.neo4j_database} for tenant {prowler_api_provider.tenant_id}"
f"Creating staging Neo4j database {tmp_cartography_config.neo4j_database} "
f"for tenant {prowler_api_provider.tenant_id}"
)
graph_database.create_database(tmp_cartography_config.neo4j_database)
@@ -191,7 +210,9 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
tmp_cartography_config.neo4j_database
) as tmp_neo4j_session:
# Indexes creation
cartography_create_indexes.run(tmp_neo4j_session, tmp_cartography_config)
indexes.create_cartography_indexes(
tmp_neo4j_session, tmp_cartography_config
)
indexes.create_findings_indexes(tmp_neo4j_session)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 2)
@@ -223,7 +244,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
cartography_analysis.run(tmp_neo4j_session, tmp_cartography_config)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 95)
# Creating Internet node and CAN_ACCESS relationships
# Creating Internet node and `CAN_ACCESS` relationships
logger.info(
f"Creating Internet graph for AWS account {prowler_api_provider.uid}"
)
@@ -247,23 +268,41 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 97)
logger.info(
f"Clearing Neo4j cache for database {tmp_cartography_config.neo4j_database}"
f"Clearing Neo4j cache for staging database {tmp_cartography_config.neo4j_database}"
)
graph_database.clear_cache(tmp_cartography_config.neo4j_database)
t0 = time.perf_counter()
logger.info(
f"Ensuring tenant database {tenant_database_name}, and its indexes, exists for tenant {prowler_api_provider.tenant_id}"
f"Preparing target {target_description} for tenant {prowler_api_provider.tenant_id}"
)
graph_database.create_database(tenant_database_name)
with graph_database.get_session(tenant_database_name) as tenant_neo4j_session:
cartography_create_indexes.run(
tenant_neo4j_session, tenant_cartography_config
)
indexes.create_findings_indexes(tenant_neo4j_session)
indexes.create_sync_indexes(tenant_neo4j_session)
# Sink-side index creation: Neptune auto-manages indexes and rejects
# `CREATE INDEX`, so only run it when the sink is Neo4j
# The temp ingest DB is always Neo4j and is always indexed above
if target_sink_backend != "neptune":
logger.info(f"Ensuring indexes exist for {target_description}")
with graph_database.get_session(
tenant_database_name
) as tenant_neo4j_session:
indexes.create_cartography_indexes(
tenant_neo4j_session, tenant_cartography_config
)
indexes.create_findings_indexes(tenant_neo4j_session)
indexes.create_sync_indexes(tenant_neo4j_session)
else:
logger.info("Skipping tenant database indexes for neptune sink")
logger.info(
f"Prepared target {target_description} in {time.perf_counter() - t0:.3f}s"
)
logger.info(f"Deleting existing provider graph in {tenant_database_name}")
db_utils.set_provider_graph_data_ready(attack_paths_scan, False)
logger.info(
f"Deleting existing provider graph from {target_description} "
f"(tenant={prowler_api_provider.tenant_id}, provider={prowler_api_provider.id})"
)
db_utils.set_provider_graph_data_ready(
attack_paths_scan, False, target_sink_backend
)
provider_gated = True
t0 = time.perf_counter()
@@ -272,14 +311,17 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
provider_id=str(prowler_api_provider.id),
)
logger.info(
f"Deleted existing provider graph in {time.perf_counter() - t0:.3f}s "
f"(deleted_nodes={deleted_nodes})"
f"Deleted existing provider graph from {target_description} "
f"in {time.perf_counter() - t0:.3f}s (deleted_nodes={deleted_nodes})"
)
subgraph_dropped = True
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 98)
logger.info(
f"Syncing graph from {tmp_database_name} into {tenant_database_name}"
f"Syncing staging graph {tmp_database_name} into {target_description} "
f"for provider {prowler_api_provider.id} "
f"(tenant {prowler_api_provider.tenant_id}, "
f"type {prowler_api_provider.provider})"
)
t0 = time.perf_counter()
sync_result = sync.sync_graph(
@@ -287,17 +329,34 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
target_database=tenant_database_name,
tenant_id=str(prowler_api_provider.tenant_id),
provider_id=str(prowler_api_provider.id),
provider_type=prowler_api_provider.provider,
)
elapsed = time.perf_counter() - t0
total_nodes = sync_result["nodes"] + sync_result["child_nodes"]
elements = total_nodes + sync_result["relationships"]
rate = elements / elapsed if elapsed else 0
logger.info(
f"Synced graph in {time.perf_counter() - t0:.3f}s "
f"(nodes={sync_result['nodes']}, relationships={sync_result['relationships']})"
f"Synced staging graph into {target_description} in {elapsed:.3f}s - "
f"nodes={total_nodes} (source={sync_result['nodes']}, "
f"items={sync_result['child_nodes']}), "
f"relationships={sync_result['relationships']} "
f"(structural={sync_result['structural_relationships']}, "
f"items={sync_result['item_relationships']}), "
f"~{rate:.0f} elem/s"
)
sync_completed = True
# Flip metadata only now: the new schema is live in the target sink, so
# reads can switch to the current catalog/backend. The target-sink gate
# is already closed, so the switch is atomic from the API's view.
db_utils.set_scan_migrated(attack_paths_scan, True, target_sink_backend)
db_utils.set_graph_data_ready(attack_paths_scan, True)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 99)
logger.info(f"Clearing Neo4j cache for database {tenant_database_name}")
graph_database.clear_cache(tenant_database_name)
if target_sink_backend == "neptune":
logger.info("Skipping cache clear for neptune sink")
else:
logger.info(f"Clearing Neo4j cache for target {target_description}")
graph_database.clear_cache(tenant_database_name)
logger.info(f"Dropping temporary Neo4j database {tmp_database_name}")
graph_database.drop_database(tmp_database_name)
@@ -316,14 +375,16 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
logger.exception(exception_message)
ingestion_exceptions["global_error"] = exception_message
# Recover graph_data_ready based on how far the swap got.
# Partial drop (mid-batch failure) may leave `subgraph_dropped=False`
# with data partially deleted, so we prefer that over permanently blocked queries.
# Recover `graph_data_ready` based on how far the swap got
# Partial drop (mid-batch failure) may leave `subgraph_dropped=False` with data partially deleted,
# so we prefer that over permanently blocked queries
try:
if sync_completed:
db_utils.set_graph_data_ready(attack_paths_scan, True)
elif provider_gated and not subgraph_dropped:
db_utils.set_provider_graph_data_ready(attack_paths_scan, True)
db_utils.set_provider_graph_data_ready(
attack_paths_scan, True, target_sink_backend
)
except Exception:
logger.error(
+372 -42
View File
@@ -1,40 +1,58 @@
"""
Graph sync operations for Attack Paths.
This module handles syncing graph data from temporary scan databases
to the tenant database, adding provider isolation labels and properties.
Reads nodes and relationships out of the cartography temp database (always
Neo4j) and hands them to the configured sink (Neo4j or Neptune) in batches.
Backend-specific Cypher (MERGE shape, ID strategy, indexes) lives in each
sink; this module owns the source read loop, per-batch grouping, and the
list-property materialisation policy (see `NormalizedList`).
Each list-typed node property that appears in the provider's
`normalized_lists` catalog becomes a set of child item nodes connected to
the parent by a typed edge. A list-typed property that is not in the
catalog is serialised to a comma-delimited string and emits a one-time
warning per (label, property), surfacing Cartography fields that should be
added to the catalog.
"""
import json
import time
from collections import defaultdict
from collections.abc import Iterator
from typing import Any
import neo4j
from api.attack_paths import database as graph_database
from api.attack_paths import sink as sink_module
from celery.utils.log import get_task_logger
from tasks.jobs.attack_paths.config import (
PROVIDER_CONFIGS,
PROVIDER_ISOLATION_PROPERTIES,
PROVIDER_RESOURCE_LABEL,
SYNC_BATCH_SIZE,
NormalizedList,
get_provider_label,
get_tenant_label,
)
from tasks.jobs.attack_paths.queries import (
NODE_FETCH_QUERY,
NODE_SYNC_TEMPLATE,
RELATIONSHIP_SYNC_TEMPLATE,
RELATIONSHIPS_FETCH_QUERY,
render_cypher_template,
)
logger = get_task_logger(__name__)
# (label, property) tuples for which we've already emitted the
# "unnormalised list" warning. Module-level so the warning fires once per
# process, not once per node.
_WARNED_UNNORMALIZED: set[tuple[str, str]] = set()
def sync_graph(
source_database: str,
target_database: str,
tenant_id: str,
provider_id: str,
provider_type: str,
) -> dict[str, int]:
"""
Sync all nodes and relationships from source to target database.
@@ -44,25 +62,38 @@ def sync_graph(
`target_database`: The tenant database
`tenant_id`: The tenant ID for isolation
`provider_id`: The provider ID for isolation
`provider_type`: Provider type key (e.g. "aws"), used to resolve the
`NormalizedList` catalog from `PROVIDER_CONFIGS`.
Returns:
Dict with counts of synced nodes and relationships
Dict with counts of synced nodes, child item nodes, and relationships.
"""
nodes_synced = sync_nodes(
sink = sink_module.get_backend()
sink.ensure_sync_indexes(target_database)
normalized_lists = _resolve_normalized_lists(provider_type)
node_result = sync_nodes(
source_database,
target_database,
tenant_id,
provider_id,
sink,
normalized_lists,
)
relationships_synced = sync_relationships(
source_database,
target_database,
provider_id,
sink,
)
return {
"nodes": nodes_synced,
"relationships": relationships_synced,
"nodes": node_result["parents"],
"child_nodes": node_result["children"],
"relationships": relationships_synced + node_result["parent_child_rels"],
"structural_relationships": relationships_synced,
"item_relationships": node_result["parent_child_rels"],
}
@@ -71,22 +102,35 @@ def sync_nodes(
target_database: str,
tenant_id: str,
provider_id: str,
) -> int:
sink: Any,
normalized_lists: list[NormalizedList],
) -> dict[str, int]:
"""
Sync nodes from source to target database.
Sync nodes from source to target database, exploding catalogued list
properties into child nodes + parent->child edges.
Adds `_ProviderResource` label and dynamic `_Tenant_{id}` and `_Provider_{id}`
isolation labels to all nodes.
isolation labels to all nodes (parents and children alike).
Source and target sessions are opened sequentially per batch to avoid
holding two Bolt connections simultaneously for the entire sync duration.
"""
t0 = time.perf_counter()
last_id = -1
total_synced = 0
parents_synced = 0
children_synced = 0
parent_child_rels = 0
catalog = _build_catalog_index(normalized_lists)
extra_labels = _build_extra_labels(tenant_id, provider_id)
while True:
grouped: dict[tuple[str, ...], list[dict[str, Any]]] = defaultdict(list)
tb = time.perf_counter()
prev_children = children_synced
prev_rels = parent_child_rels
parent_groups: dict[tuple[str, ...], list[dict[str, Any]]] = defaultdict(list)
child_groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
rel_groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
batch_count = 0
with graph_database.get_session(source_database) as source_session:
@@ -97,43 +141,66 @@ def sync_nodes(
for record in result:
batch_count += 1
last_id = record["internal_id"]
key, value = _node_to_sync_dict(record, provider_id)
grouped[key].append(value)
key, parent_dict, children, rels = _node_to_sync_dict(
record, provider_id, catalog
)
parent_groups[key].append(parent_dict)
for child in children:
child_groups[child["_child_label"]].append(child["row"])
for rel in rels:
rel_groups[rel["rel_type"]].append(rel["row"])
if batch_count == 0:
break
with graph_database.get_session(target_database) as target_session:
for labels, batch in grouped.items():
label_set = set(labels)
label_set.add(PROVIDER_RESOURCE_LABEL)
label_set.add(get_tenant_label(tenant_id))
label_set.add(get_provider_label(provider_id))
node_labels = ":".join(f"`{label}`" for label in sorted(label_set))
for labels, batch in parent_groups.items():
rendered_labels = _render_labels(labels, extra_labels)
for sink_batch in _iter_sink_batches(batch):
sink.write_nodes(target_database, rendered_labels, sink_batch)
query = render_cypher_template(
NODE_SYNC_TEMPLATE, {"__NODE_LABELS__": node_labels}
for child_label, batch in child_groups.items():
rendered_labels = _render_labels((child_label,), extra_labels)
for sink_batch in _iter_sink_batches(batch):
sink.write_nodes(target_database, rendered_labels, sink_batch)
children_synced += len(batch)
for rel_type, batch in rel_groups.items():
for sink_batch in _iter_sink_batches(batch):
sink.write_relationships(
target_database, rel_type, provider_id, sink_batch
)
target_session.run(query, {"rows": batch})
parent_child_rels += len(batch)
total_synced += batch_count
parents_synced += batch_count
batch_dt = time.perf_counter() - tb
batch_elements = (
batch_count
+ (children_synced - prev_children)
+ (parent_child_rels - prev_rels)
)
rate = batch_elements / batch_dt if batch_dt else 0
logger.info(
f"Synced {total_synced} nodes from {source_database} to {target_database} in {time.perf_counter() - t0:.3f}s"
f"[sync nodes] {parents_synced} source (+{children_synced} items, "
f"+{parent_child_rels} item rels) · batch {batch_dt:.1f}s · "
f"elapsed {time.perf_counter() - t0:.1f}s · ~{rate:.0f} elem/s"
)
return total_synced
return {
"parents": parents_synced,
"children": children_synced,
"parent_child_rels": parent_child_rels,
}
def sync_relationships(
source_database: str,
target_database: str,
provider_id: str,
sink: Any,
) -> int:
"""
Sync relationships from source to target database.
Matches source and target nodes by `_provider_element_id` in the tenant database.
Source and target sessions are opened sequentially per batch to avoid
holding two Bolt connections simultaneously for the entire sync duration.
"""
@@ -142,6 +209,7 @@ def sync_relationships(
total_synced = 0
while True:
tb = time.perf_counter()
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
batch_count = 0
@@ -159,32 +227,213 @@ def sync_relationships(
if batch_count == 0:
break
with graph_database.get_session(target_database) as target_session:
for rel_type, batch in grouped.items():
query = render_cypher_template(
RELATIONSHIP_SYNC_TEMPLATE, {"__REL_TYPE__": rel_type}
for rel_type, batch in grouped.items():
for sink_batch in _iter_sink_batches(batch):
sink.write_relationships(
target_database, rel_type, provider_id, sink_batch
)
target_session.run(query, {"rows": batch})
total_synced += batch_count
batch_dt = time.perf_counter() - tb
rate = batch_count / batch_dt if batch_dt else 0
logger.info(
f"Synced {total_synced} relationships from {source_database} to {target_database} in {time.perf_counter() - t0:.3f}s"
f"[sync rels] {total_synced} structural · batch {batch_dt:.1f}s · "
f"elapsed {time.perf_counter() - t0:.1f}s · ~{rate:.0f}/s"
)
return total_synced
def _iter_sink_batches(
rows: list[dict[str, Any]],
batch_size: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
"""Yield final sink write batches after source rows have been transformed."""
batch_size = SYNC_BATCH_SIZE if batch_size is None else batch_size
if batch_size <= 0:
raise ValueError("Sink batch size must be greater than zero")
for index in range(0, len(rows), batch_size):
yield rows[index : index + batch_size]
def _node_to_sync_dict(
record: neo4j.Record, provider_id: str
) -> tuple[tuple[str, ...], dict[str, Any]]:
"""Transform a source node record into a (grouping_key, sync_dict) pair."""
record: neo4j.Record,
provider_id: str,
catalog: dict[tuple[str, str], NormalizedList],
) -> tuple[
tuple[str, ...],
dict[str, Any],
list[dict[str, Any]],
list[dict[str, Any]],
]:
"""Transform a source node record into a (grouping_key, sync_dict, children, rels) tuple.
Catalogued list properties are popped from `props` and emitted as child
nodes + parent->child relationships.
"""
props = dict(record["props"] or {})
_strip_internal_properties(props)
labels = tuple(sorted(set(record["labels"] or [])))
return labels, {
"provider_element_id": f"{provider_id}:{record['element_id']}",
parent_element_id = f"{provider_id}:{record['element_id']}"
children, rels = _explode_catalogued_lists(
labels, props, catalog, provider_id, parent_element_id
)
_normalize_sink_properties(props, labels)
parent = {
"provider_element_id": parent_element_id,
"props": props,
}
return labels, parent, children, rels
def _explode_catalogued_lists(
labels: tuple[str, ...],
props: dict[str, Any],
catalog: dict[tuple[str, str], NormalizedList],
provider_id: str,
parent_element_id: str,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Pop catalogued list properties from `props` and produce child + rel emits.
A node may carry multiple labels (e.g. `AWSPolicyStatement` plus
`_AWSResource`); we check each label for catalog matches independently.
Returns:
- children: list of {"_child_label": str, "row": <node row>} dicts.
- rels: list of {"rel_type": str, "row": <rel row>} dicts.
"""
children: list[dict[str, Any]] = []
rels: list[dict[str, Any]] = []
for label in labels:
for key in list(props.keys()):
spec = catalog.get((label, key))
if spec is None:
continue
value = props.pop(key)
if value is None:
continue
if not isinstance(value, list):
# Catalogued but not actually a list this scan - fall back to
# the generic normaliser so we don't lose the value.
props[key] = value
continue
for item in value:
child_value_key, child_props = _build_child_props(spec, item)
if child_value_key is None:
continue
child_element_id = _build_child_id(
provider_id, spec.child_label, child_value_key
)
children.append(
{
"_child_label": spec.child_label,
"row": {
"provider_element_id": child_element_id,
"props": child_props,
},
}
)
rels.append(
{
"rel_type": spec.rel_type,
"row": {
"start_element_id": parent_element_id,
"end_element_id": child_element_id,
"provider_element_id": (
f"{parent_element_id}::{spec.rel_type}::"
f"{child_element_id}"
),
"props": {},
},
}
)
return children, rels
def _build_child_props(
spec: NormalizedList, item: Any
) -> tuple[str | None, dict[str, Any]]:
"""Translate one list element into a child node's prop dict.
Returns (dedup_key, props). The dedup_key is what makes two child nodes
equal within (tenant, provider) - used to build `_provider_element_id`.
For scalar mode, the dedup key is the value itself. For dict mode it is
a stable concatenation of the mapped fields in `field_map` order.
"""
if not spec.field_map:
if isinstance(item, (dict, list)):
# Defensive: caller marked this list as scalar but elements are
# structured. Convert to a stable string so the value survives.
value_str = json.dumps(item, sort_keys=True, default=str)
else:
value_str = str(item)
return value_str, {"value": value_str}
if not isinstance(item, dict):
# Catalogued as dict-shape but got a scalar. Skip - caller will see
# the value go missing and can fix the field_map.
return None, {}
props: dict[str, Any] = {}
dedup_parts: list[str] = []
for src_key, child_field in spec.field_map:
raw = item.get(src_key)
value_str = _to_sink_property_value(raw) if raw is not None else ""
props[child_field] = value_str
dedup_parts.append(f"{child_field}={value_str}")
return "::".join(dedup_parts), props
def _build_child_id(provider_id: str, child_label: str, value_key: str) -> str:
"""Deterministic `_provider_element_id` for a list-item child node.
Dedupes within (tenant, provider): multiple parents referencing the same
value share one child node via the existing MERGE-on-_provider_element_id
index in both sinks.
"""
return f"{provider_id}::{child_label}::{value_key}"
def _build_catalog_index(
normalized_lists: list[NormalizedList],
) -> dict[tuple[str, str], NormalizedList]:
"""Index the catalog by (source_label, source_property) for O(1) lookup."""
return {
(spec.source_label, spec.source_property): spec for spec in normalized_lists
}
def _build_extra_labels(tenant_id: str, provider_id: str) -> tuple[str, ...]:
return (
PROVIDER_RESOURCE_LABEL,
get_tenant_label(tenant_id),
get_provider_label(provider_id),
)
def _render_labels(base_labels: tuple[str, ...], extra_labels: tuple[str, ...]) -> str:
"""Render the Cypher label string for a node-write batch."""
label_set = set(base_labels) | set(extra_labels)
return ":".join(f"`{label}`" for label in sorted(label_set))
def _resolve_normalized_lists(provider_type: str) -> list[NormalizedList]:
config = PROVIDER_CONFIGS.get(provider_type)
if config is None:
# Unknown provider: empty catalog. Any list-typed property will be
# serialised to a comma-delimited string with one warning per
# (label, property).
logger.warning(
"Provider type %s not in PROVIDER_CONFIGS; no normalized_lists active",
provider_type,
)
return []
return config.normalized_lists
def _rel_to_sync_dict(
@@ -193,7 +442,11 @@ def _rel_to_sync_dict(
"""Transform a source relationship record into a (grouping_key, sync_dict) pair."""
props = dict(record["props"] or {})
_strip_internal_properties(props)
# Relationship properties go through the same primitive coercion as
# nodes; catalog-driven materialisation applies to node properties only.
_normalize_sink_properties(props, labels=None)
rel_type = record["rel_type"]
return rel_type, {
"start_element_id": f"{provider_id}:{record['start_element_id']}",
"end_element_id": f"{provider_id}:{record['end_element_id']}",
@@ -206,3 +459,80 @@ def _strip_internal_properties(props: dict[str, Any]) -> None:
"""Remove provider isolation properties before the += spread in sync templates."""
for key in PROVIDER_ISOLATION_PROPERTIES:
props.pop(key, None)
def _normalize_sink_properties(
props: dict[str, Any], labels: tuple[str, ...] | None
) -> None:
"""Normalize property values to primitive Cypher literals for either sink.
Attack-paths node and relationship properties are written as primitive
scalars regardless of the active sink (Neo4j or Neptune). The convention
is driven by Neptune's openCypher type restrictions, which reject list,
map, temporal and spatial property values, but it is applied uniformly
so that custom and predefined queries are portable across sinks without
runtime rewriting.
Concretely:
- Temporal values (neo4j.time.{DateTime,Date,Time,Duration}) become
their ISO-8601 string representation.
- Spatial values (neo4j.spatial.Point and subclasses) become their
WKT-style string representation.
- Maps / dicts become a JSON-encoded string, read back with `CONTAINS`
substring checks inside queries.
- Lists become a comma-delimited string. Catalogued list properties
are materialised as child item nodes upstream in
`_explode_catalogued_lists` and never reach this point; any list
seen here is uncatalogued, so we log a one-time warning per
(label, property) to surface Cartography fields that should be
added to the catalog.
`labels` is only used for the warning message; pass `None` for
relationship props (no label context).
"""
for key, value in list(props.items()):
if isinstance(value, list) and labels is not None:
_warn_unnormalized_list(labels, key)
props[key] = _to_sink_property_value(value)
def _warn_unnormalized_list(labels: tuple[str, ...], key: str) -> None:
"""Warn once per (label, property), on the real label(s) only.
Every synced node also carries internal isolation labels (`_AWSResource`,
`_ProviderResource`, `_Tenant_*`, `_Provider_*`); warning on those just
doubles the noise, so skip them and point at the actionable Cartography
label. Falls back to all labels if only internal ones are present.
"""
real_labels = [label for label in labels if not label.startswith("_")]
for label in real_labels or labels:
token = (label, key)
if token in _WARNED_UNNORMALIZED:
continue
_WARNED_UNNORMALIZED.add(token)
logger.warning(
"Unnormalized list property %s.%s reached sink as comma-string; "
"add a NormalizedList entry to the provider catalog to explode it",
label,
key,
)
def _to_sink_property_value(value: Any) -> Any:
if hasattr(value, "iso_format") and callable(value.iso_format):
return value.iso_format()
if type(value).__module__.startswith("neo4j.spatial"):
return str(value)
if isinstance(value, dict):
# openCypher `SET` rejects map property values: encode as JSON so the structured payload
# survives the round-trip and is queryable with `CONTAINS` substring checks
return json.dumps(value, sort_keys=True, default=str)
if isinstance(value, list):
# openCypher `SET` rejects list/array property values: encode as a
# delimited string read back with split() inside queries
return ",".join(str(_to_sink_property_value(v)) for v in value)
return value
+14 -1
View File
@@ -1,4 +1,5 @@
from api.attack_paths import database as graph_database
from api.attack_paths import sink as sink_module
from api.db_router import MainRouter
from api.db_utils import batch_delete, rls_transaction
from api.models import (
@@ -76,6 +77,12 @@ def delete_provider(tenant_id: str, pk: str):
"id", flat=True
)
)
attack_paths_sink_backends = list(
AttackPathsScan.all_objects.filter(provider=instance)
.values_list("sink_backend", flat=True)
.distinct()
.order_by("sink_backend")
)
deletion_steps = [
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
@@ -97,7 +104,13 @@ def delete_provider(tenant_id: str, pk: str):
# Delete the Attack Paths' graph data related to the provider from the tenant database
tenant_database_name = graph_database.get_database_name(tenant_id)
try:
graph_database.drop_subgraph(tenant_database_name, str(pk))
if attack_paths_sink_backends:
for sink_backend in attack_paths_sink_backends:
sink_module.get_backend_for_name(sink_backend).drop_subgraph(
tenant_database_name, str(pk)
)
else:
graph_database.drop_subgraph(tenant_database_name, str(pk))
except graph_database.GraphDatabaseQueryException as gdb_error:
logger.error(f"Error deleting Provider graph data: {gdb_error}")
+55 -19
View File
@@ -19,7 +19,7 @@ from api.db_utils import (
psycopg_connection,
rls_transaction,
)
from api.exceptions import ProviderConnectionError
from api.exceptions import ProviderConnectionError, ProviderDeletedException
from api.models import (
AttackSurfaceOverview,
ComplianceOverviewSummary,
@@ -48,7 +48,7 @@ from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
from config.env import env
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
from django.db import IntegrityError, OperationalError
from django.db import DatabaseError, IntegrityError, OperationalError, transaction
from django.db.models import (
Case,
Count,
@@ -117,6 +117,20 @@ ATTACK_SURFACE_PROVIDER_COMPATIBILITY = {
_ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {}
def _save_scan_instance(
scan_instance: Scan, provider_id: str, update_fields: list[str]
) -> None:
try:
with transaction.atomic(): # Savepoint for not killing the `rls_transaction`
scan_instance.save(update_fields=update_fields)
except DatabaseError:
if Scan.objects.filter(pk=scan_instance.id).exists():
raise
raise ProviderDeletedException(
f"Provider '{provider_id}' for scan '{scan_instance.id}' was deleted during the scan"
) from None
def aggregate_category_counts(
categories: list[str],
severity: str,
@@ -1029,13 +1043,18 @@ def perform_prowler_scan(
group_resources_cache: dict[str, set] = {}
start_time = time.time()
exc = None
skip_final_scan_update = False
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
scan_instance = Scan.objects.get(pk=scan_id)
scan_instance.state = StateChoices.EXECUTING
scan_instance.started_at = datetime.now(tz=UTC)
scan_instance.save(update_fields=["state", "started_at", "updated_at"])
_save_scan_instance(
scan_instance,
provider_id,
["state", "started_at", "updated_at"],
)
# Find the mutelist processor if it exists
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
@@ -1101,7 +1120,7 @@ def perform_prowler_scan(
# Throttle scan_instance progress writes to avoid hammering the writer:
# only persist when progress moves by at least `PROGRESS_THROTTLE_DELTA`
# OR `PROGRESS_THROTTLE_SECONDS` have elapsed. The final progress (1.0)
# OR `PROGRESS_THROTTLE_SECONDS` have elapsed. The final progress (100)
# always persists in the `finally` block below.
last_persisted_progress = -1.0
last_persisted_progress_at = 0.0
@@ -1143,7 +1162,11 @@ def perform_prowler_scan(
):
with rls_transaction(tenant_id):
scan_instance.progress = progress
scan_instance.save(update_fields=["progress", "updated_at"])
_save_scan_instance(
scan_instance,
provider_id,
["progress", "updated_at"],
)
last_persisted_progress = progress
last_persisted_progress_at = now
@@ -1170,26 +1193,39 @@ def perform_prowler_scan(
batch_size=SCAN_DB_BATCH_SIZE,
)
except ProviderDeletedException as e:
logger.warning(str(e))
exception = e
skip_final_scan_update = True
except Exception as e:
logger.error(f"Error performing scan {scan_id}: {e}")
exception = e
scan_instance.state = StateChoices.FAILED
finally:
with rls_transaction(tenant_id):
scan_instance.duration = time.time() - start_time
scan_instance.completed_at = datetime.now(tz=UTC)
scan_instance.unique_resource_count = len(unique_resources)
scan_instance.save(
update_fields=[
"state",
"duration",
"completed_at",
"unique_resource_count",
"progress",
"updated_at",
]
)
if not skip_final_scan_update:
try:
with rls_transaction(tenant_id):
scan_instance.duration = time.time() - start_time
scan_instance.completed_at = datetime.now(tz=UTC)
scan_instance.unique_resource_count = len(unique_resources)
if exception is None:
scan_instance.progress = 100
_save_scan_instance(
scan_instance,
provider_id,
[
"state",
"duration",
"completed_at",
"unique_resource_count",
"progress",
"updated_at",
],
)
except ProviderDeletedException as e:
logger.warning(str(e))
exception = e
if exception is not None:
raise exception
@@ -0,0 +1,30 @@
from tasks.jobs.attack_paths.provider_config import AWS_NORMALIZED_LISTS
from tasks.jobs.attack_paths.sync import _build_catalog_index, _node_to_sync_dict
def test_aws_vpc_endpoint_id_lists_are_normalized():
catalog = _build_catalog_index(AWS_NORMALIZED_LISTS)
record = {
"element_id": "node-1",
"labels": ["AWSVpcEndpoint"],
"props": {
"id": "vpce-123",
"route_table_ids": ["rtb-1"],
"network_interface_ids": ["eni-1"],
"subnet_ids": ["subnet-1"],
},
}
_, parent, children, rels = _node_to_sync_dict(record, "provider-id", catalog)
assert parent["props"] == {"id": "vpce-123"}
assert {child["_child_label"] for child in children} == {
"AWSVpcEndpointRouteTableIdsItem",
"AWSVpcEndpointNetworkInterfaceIdsItem",
"AWSVpcEndpointSubnetIdsItem",
}
assert {rel["rel_type"] for rel in rels} == {
"HAS_ROUTE_TABLE_IDS",
"HAS_NETWORK_INTERFACE_IDS",
"HAS_SUBNET_IDS",
}
@@ -23,15 +23,31 @@ from tasks.jobs.attack_paths import internet as internet_module
from tasks.jobs.attack_paths import sync as sync_module
from tasks.jobs.attack_paths.scan import run as attack_paths_run
SYNC_RESULT_EMPTY = {
"nodes": 0,
"child_nodes": 0,
"relationships": 0,
"structural_relationships": 0,
"item_relationships": 0,
}
@pytest.mark.django_db
class TestAttackPathsRun:
@pytest.fixture(autouse=True)
def mock_graph_database_preflight(self):
with patch(
"tasks.jobs.attack_paths.scan.graph_database.verify_scan_databases_available"
) as mock_preflight:
yield mock_preflight
# Patching with decorators as we got a `SyntaxError: too many statically nested blocks` error if we use context managers
@patch("tasks.jobs.attack_paths.scan.graph_database.drop_database")
@patch(
"tasks.jobs.attack_paths.scan.utils.call_within_event_loop",
side_effect=lambda fn, *a, **kw: fn(*a, **kw),
)
@patch("tasks.jobs.attack_paths.scan.db_utils.set_scan_migrated")
@patch("tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready")
@patch("tasks.jobs.attack_paths.scan.db_utils.set_provider_graph_data_ready")
@patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan")
@@ -39,7 +55,7 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@patch(
"tasks.jobs.attack_paths.scan.sync.sync_graph",
return_value={"nodes": 0, "relationships": 0},
return_value=SYNC_RESULT_EMPTY,
)
@patch("tasks.jobs.attack_paths.scan.graph_database.drop_subgraph", return_value=0)
@patch("tasks.jobs.attack_paths.scan.indexes.create_sync_indexes")
@@ -48,11 +64,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri",
"tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j",
)
@patch(
@@ -66,7 +82,7 @@ class TestAttackPathsRun:
def test_run_success_flow(
self,
mock_init_provider,
mock_get_uri,
mock_get_ingest_uri,
mock_create_db,
mock_clear_cache,
mock_cartography_indexes,
@@ -83,6 +99,7 @@ class TestAttackPathsRun:
mock_finish,
mock_set_provider_graph_data_ready,
mock_set_graph_data_ready,
mock_set_scan_migrated,
mock_event_loop,
mock_drop_db,
tenants_fixture,
@@ -159,6 +176,7 @@ class TestAttackPathsRun:
target_database="tenant-db",
tenant_id=str(provider.tenant_id),
provider_id=str(provider.id),
provider_type="aws",
)
mock_get_ingestion.assert_called_once_with(provider.provider)
mock_event_loop.assert_called_once()
@@ -172,9 +190,70 @@ class TestAttackPathsRun:
attack_paths_scan, StateChoices.COMPLETED, ingestion_result
)
mock_set_provider_graph_data_ready.assert_called_once_with(
attack_paths_scan, False
attack_paths_scan, False, "neo4j"
)
mock_set_graph_data_ready.assert_called_once_with(attack_paths_scan, True)
# is_migrated is flipped to True only after the sync succeeds, so reads
# don't switch to the new catalog/sink before the graph is live.
mock_set_scan_migrated.assert_called_once_with(attack_paths_scan, True, "neo4j")
def test_run_preflight_failure_does_not_start_scan(
self,
mock_graph_database_preflight,
tenants_fixture,
providers_fixture,
scans_fixture,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
attack_paths_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.SCHEDULED,
)
mock_graph_database_preflight.side_effect = RuntimeError("graph unavailable")
with (
patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
),
patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]),
),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j",
),
patch(
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
),
patch(
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
return_value=MagicMock(return_value={}),
),
patch(
"tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan"
) as mock_starting,
patch(
"tasks.jobs.attack_paths.scan.graph_database.create_database"
) as mock_create_db,
):
with pytest.raises(RuntimeError, match="graph unavailable"):
attack_paths_run(str(tenant.id), str(scan.id), "task-123")
mock_graph_database_preflight.assert_called_once_with()
mock_starting.assert_not_called()
mock_create_db.assert_not_called()
@patch(
"tasks.jobs.attack_paths.scan.utils.stringify_exception",
@@ -194,13 +273,13 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.internet.analysis")
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id",
)
@patch("tasks.jobs.attack_paths.scan.graph_database.get_uri")
@patch("tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri")
@patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]),
@@ -212,7 +291,7 @@ class TestAttackPathsRun:
def test_run_failure_marks_scan_failed(
self,
mock_init_provider,
mock_get_uri,
mock_get_ingest_uri,
mock_get_db_name,
mock_create_db,
mock_cartography_indexes,
@@ -293,13 +372,13 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.internet.analysis")
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id",
)
@patch("tasks.jobs.attack_paths.scan.graph_database.get_uri")
@patch("tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri")
@patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]),
@@ -311,7 +390,7 @@ class TestAttackPathsRun:
def test_failure_before_gate_does_not_flip_graph_data_ready_true(
self,
mock_init_provider,
mock_get_uri,
mock_get_ingest_uri,
mock_get_db_name,
mock_create_db,
mock_cartography_indexes,
@@ -396,13 +475,13 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.internet.analysis")
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id",
)
@patch("tasks.jobs.attack_paths.scan.graph_database.get_uri")
@patch("tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri")
@patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]),
@@ -414,7 +493,7 @@ class TestAttackPathsRun:
def test_run_failure_marks_scan_failed_even_when_drop_database_fails(
self,
mock_init_provider,
mock_get_uri,
mock_get_ingest_uri,
mock_get_db_name,
mock_create_db,
mock_cartography_indexes,
@@ -493,7 +572,7 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@patch(
"tasks.jobs.attack_paths.scan.sync.sync_graph",
return_value={"nodes": 0, "relationships": 0},
return_value=SYNC_RESULT_EMPTY,
)
@patch(
"tasks.jobs.attack_paths.scan.graph_database.drop_subgraph",
@@ -505,11 +584,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri",
"tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j",
)
@patch(
@@ -523,7 +602,7 @@ class TestAttackPathsRun:
def test_failure_after_gate_before_drop_restores_graph_data_ready(
self,
mock_init_provider,
mock_get_uri,
mock_get_ingest_uri,
mock_create_db,
mock_clear_cache,
mock_cartography_indexes,
@@ -589,8 +668,8 @@ class TestAttackPathsRun:
attack_paths_run(str(tenant.id), str(scan.id), "task-456")
assert mock_set_provider_graph_data_ready.call_args_list == [
call(attack_paths_scan, False),
call(attack_paths_scan, True),
call(attack_paths_scan, False, "neo4j"),
call(attack_paths_scan, True, "neo4j"),
]
@patch(
@@ -618,11 +697,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri",
"tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j",
)
@patch(
@@ -636,7 +715,7 @@ class TestAttackPathsRun:
def test_failure_after_drop_before_sync_leaves_graph_data_ready_false(
self,
mock_init_provider,
mock_get_uri,
mock_get_ingest_uri,
mock_create_db,
mock_clear_cache,
mock_cartography_indexes,
@@ -703,7 +782,7 @@ class TestAttackPathsRun:
# Only called with False (gate), never with True (no recovery for partial data)
mock_set_provider_graph_data_ready.assert_called_once_with(
attack_paths_scan, False
attack_paths_scan, False, "neo4j"
)
@patch(
@@ -716,6 +795,7 @@ class TestAttackPathsRun:
)
@patch("tasks.jobs.attack_paths.scan.graph_database.drop_database")
@patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan")
@patch("tasks.jobs.attack_paths.scan.db_utils.set_scan_migrated")
@patch(
"tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready",
side_effect=[RuntimeError("flag failed"), None],
@@ -725,7 +805,7 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@patch(
"tasks.jobs.attack_paths.scan.sync.sync_graph",
return_value={"nodes": 0, "relationships": 0},
return_value=SYNC_RESULT_EMPTY,
)
@patch("tasks.jobs.attack_paths.scan.graph_database.drop_subgraph")
@patch("tasks.jobs.attack_paths.scan.indexes.create_sync_indexes")
@@ -734,11 +814,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri",
"tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j",
)
@patch(
@@ -752,7 +832,7 @@ class TestAttackPathsRun:
def test_failure_after_sync_restores_graph_data_ready(
self,
mock_init_provider,
mock_get_uri,
mock_get_ingest_uri,
mock_create_db,
mock_clear_cache,
mock_cartography_indexes,
@@ -768,6 +848,7 @@ class TestAttackPathsRun:
mock_update_progress,
mock_set_provider_graph_data_ready,
mock_set_graph_data_ready,
mock_set_scan_migrated,
mock_finish,
mock_drop_db,
mock_event_loop,
@@ -824,8 +905,11 @@ class TestAttackPathsRun:
]
# set_provider_graph_data_ready only called once with False (the gate)
mock_set_provider_graph_data_ready.assert_called_once_with(
attack_paths_scan, False
attack_paths_scan, False, "neo4j"
)
# is_migrated is flipped once after the sync and is not touched again by
# the failure-recovery branch
mock_set_scan_migrated.assert_called_once_with(attack_paths_scan, True, "neo4j")
@patch(
"tasks.jobs.attack_paths.scan.utils.stringify_exception",
@@ -843,7 +927,7 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@patch(
"tasks.jobs.attack_paths.scan.sync.sync_graph",
return_value={"nodes": 0, "relationships": 0},
return_value=SYNC_RESULT_EMPTY,
)
@patch(
"tasks.jobs.attack_paths.scan.graph_database.drop_subgraph",
@@ -855,11 +939,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri",
"tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j",
)
@patch(
@@ -873,7 +957,7 @@ class TestAttackPathsRun:
def test_recovery_failure_does_not_suppress_original_exception(
self,
mock_init_provider,
mock_get_uri,
mock_get_ingest_uri,
mock_create_db,
mock_clear_cache,
mock_cartography_indexes,
@@ -1116,7 +1200,7 @@ class TestFailAttackPathsScan:
fail_attack_paths_scan(str(tenant.id), "nonexistent", "setup exploded")
def test_fail_recovers_graph_data_ready_when_data_exists(
self, tenants_fixture, providers_fixture, scans_fixture
self, tenants_fixture, providers_fixture, scans_fixture, sink_backend_stub
):
from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan
@@ -1135,16 +1219,18 @@ class TestFailAttackPathsScan:
state=StateChoices.EXECUTING,
)
# `recover_graph_data_ready` routes `has_provider_data` through
# `sink_module.get_backend_for_scan(scan)`. With `is_migrated=False`
# and the default `ATTACK_PATHS_SINK_DATABASE=neo4j`, the factory
# returns the active backend, which `sink_backend_stub` replaces.
sink_backend_stub.has_provider_data.return_value = True
with (
patch(
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
),
patch("tasks.jobs.attack_paths.db_utils.graph_database.drop_database"),
patch(
"tasks.jobs.attack_paths.db_utils.graph_database.has_provider_data",
return_value=True,
),
patch(
"tasks.jobs.attack_paths.db_utils.set_provider_graph_data_ready"
) as mock_set_ready,
@@ -1154,7 +1240,7 @@ class TestFailAttackPathsScan:
mock_set_ready.assert_called_once_with(attack_paths_scan, True)
def test_fail_leaves_graph_data_ready_false_when_no_data(
self, tenants_fixture, providers_fixture, scans_fixture
self, tenants_fixture, providers_fixture, scans_fixture, sink_backend_stub
):
from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan
@@ -1173,16 +1259,14 @@ class TestFailAttackPathsScan:
state=StateChoices.EXECUTING,
)
sink_backend_stub.has_provider_data.return_value = False
with (
patch(
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
),
patch("tasks.jobs.attack_paths.db_utils.graph_database.drop_database"),
patch(
"tasks.jobs.attack_paths.db_utils.graph_database.has_provider_data",
return_value=False,
),
patch(
"tasks.jobs.attack_paths.db_utils.set_provider_graph_data_ready"
) as mock_set_ready,
@@ -1271,6 +1355,20 @@ class TestAttackPathsFindingsHelpers:
[call(mock_session, stmt) for stmt in FINDINGS_INDEX_STATEMENTS]
)
def test_create_findings_indexes_runs_even_when_sink_is_neptune(self, settings):
# The index helpers run against the temp ingest DB, which is always
# Neo4j regardless of the configured sink. A Neptune sink must not
# suppress index creation on that DB (regression for the dropped
# in-helper sink gate).
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
mock_session = MagicMock()
with patch("tasks.jobs.attack_paths.indexes.run_write_query") as mock_run_write:
indexes_module.create_findings_indexes(mock_session)
from tasks.jobs.attack_paths.indexes import FINDINGS_INDEX_STATEMENTS
assert mock_run_write.call_count == len(FINDINGS_INDEX_STATEMENTS)
def test_load_findings_batches_requests(self, providers_fixture):
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
@@ -1802,7 +1900,13 @@ def _make_session_ctx(session, call_order=None, name=None):
class TestSyncNodes:
def test_sync_nodes_adds_private_label(self):
def test_iter_sink_batches_rejects_zero_batch_size(self):
with pytest.raises(
ValueError, match="Sink batch size must be greater than zero"
):
list(sync_module._iter_sink_batches([], batch_size=0))
def test_sync_nodes_passes_isolation_labels_to_sink(self):
row = {
"internal_id": 1,
"element_id": "elem-1",
@@ -1812,29 +1916,32 @@ class TestSyncNodes:
mock_source_1 = MagicMock()
mock_source_1.run.return_value = [row]
mock_target = MagicMock()
mock_source_2 = MagicMock()
mock_source_2.run.return_value = []
sink = MagicMock()
with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(mock_source_1),
_make_session_ctx(mock_target),
_make_session_ctx(mock_source_2),
],
):
total = sync_module.sync_nodes(
"source-db", "target-db", "tenant-1", "prov-1"
result = sync_module.sync_nodes(
"source-db", "target-db", "tenant-1", "prov-1", sink, []
)
assert total == 1
query = mock_target.run.call_args.args[0]
assert "_ProviderResource" in query
assert "_Tenant_tenant1" in query
assert "_Provider_prov1" in query
assert result["parents"] == 1
sink.write_nodes.assert_called_once()
target_db, labels, batch = sink.write_nodes.call_args.args
assert target_db == "target-db"
assert "_ProviderResource" in labels
assert "_Tenant_tenant1" in labels
assert "_Provider_prov1" in labels
assert batch[0]["provider_element_id"] == "prov-1:elem-1"
assert batch[0]["props"] == {"key": "value"}
def test_sync_nodes_source_closes_before_target_opens(self):
def test_sync_nodes_writes_after_source_session_closes(self):
row = {
"internal_id": 1,
"element_id": "elem-1",
@@ -1846,21 +1953,23 @@ class TestSyncNodes:
src_1 = MagicMock()
src_1.run.return_value = [row]
tgt = MagicMock()
src_2 = MagicMock()
src_2.run.return_value = []
sink = MagicMock()
sink.write_nodes.side_effect = lambda *_a, **_kw: call_order.append(
"sink:write"
)
with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(src_1, call_order, "source1"),
_make_session_ctx(tgt, call_order, "target"),
_make_session_ctx(src_2, call_order, "source2"),
],
):
sync_module.sync_nodes("src-db", "tgt-db", "t-1", "p-1")
sync_module.sync_nodes("src-db", "tgt-db", "t-1", "p-1", sink, [])
assert call_order.index("source1:exit") < call_order.index("target:enter")
assert call_order.index("source1:exit") < call_order.index("sink:write")
def test_sync_nodes_pagination_with_batch_size_1(self):
row_a = {
@@ -1882,44 +1991,89 @@ class TestSyncNodes:
src_2.run.return_value = [row_b]
src_3 = MagicMock()
src_3.run.return_value = []
tgt_1 = MagicMock()
tgt_2 = MagicMock()
sink = MagicMock()
with (
patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(src_1),
_make_session_ctx(tgt_1),
_make_session_ctx(src_2),
_make_session_ctx(tgt_2),
_make_session_ctx(src_3),
],
),
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1),
):
total = sync_module.sync_nodes("src", "tgt", "t-1", "p-1")
result = sync_module.sync_nodes("src", "tgt", "t-1", "p-1", sink, [])
assert total == 2
assert result["parents"] == 2
assert sink.write_nodes.call_count == 2
assert src_1.run.call_args.args[1]["last_id"] == -1
assert src_2.run.call_args.args[1]["last_id"] == 1
def test_sync_nodes_chunks_expanded_list_rows_before_sink_write(self):
row = {
"internal_id": 1,
"element_id": "elem-1",
"labels": ["SomeLabel"],
"props": {"values": ["a", "b", "c", "d", "e"]},
}
normalized_lists = [
sync_module.NormalizedList(
"SomeLabel",
"values",
"SomeLabelValuesItem",
"HAS_VALUES",
)
]
src_1 = MagicMock()
src_1.run.return_value = [row]
src_2 = MagicMock()
src_2.run.return_value = []
sink = MagicMock()
with (
patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(src_1),
_make_session_ctx(src_2),
],
),
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 2),
):
result = sync_module.sync_nodes(
"src", "tgt", "t-1", "p-1", sink, normalized_lists
)
assert result == {"parents": 1, "children": 5, "parent_child_rels": 5}
assert [
len(call_args.args[2]) for call_args in sink.write_nodes.call_args_list[1:]
] == [2, 2, 1]
assert [
len(call_args.args[3])
for call_args in sink.write_relationships.call_args_list
] == [2, 2, 1]
def test_sync_nodes_empty_source_returns_zero(self):
src = MagicMock()
src.run.return_value = []
sink = MagicMock()
with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[_make_session_ctx(src)],
) as mock_get_session:
total = sync_module.sync_nodes("src", "tgt", "t-1", "p-1")
result = sync_module.sync_nodes("src", "tgt", "t-1", "p-1", sink, [])
assert total == 0
assert result["parents"] == 0
assert mock_get_session.call_count == 1
sink.write_nodes.assert_not_called()
class TestSyncRelationships:
def test_sync_relationships_source_closes_before_target_opens(self):
def test_sync_relationships_writes_after_source_session_closes(self):
row = {
"internal_id": 1,
"rel_type": "HAS",
@@ -1932,21 +2086,23 @@ class TestSyncRelationships:
src_1 = MagicMock()
src_1.run.return_value = [row]
tgt = MagicMock()
src_2 = MagicMock()
src_2.run.return_value = []
sink = MagicMock()
sink.write_relationships.side_effect = lambda *_a, **_kw: call_order.append(
"sink:write"
)
with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(src_1, call_order, "source1"),
_make_session_ctx(tgt, call_order, "target"),
_make_session_ctx(src_2, call_order, "source2"),
],
):
sync_module.sync_relationships("src", "tgt", "p-1")
sync_module.sync_relationships("src", "tgt", "p-1", sink)
assert call_order.index("source1:exit") < call_order.index("target:enter")
assert call_order.index("source1:exit") < call_order.index("sink:write")
def test_sync_relationships_pagination_with_batch_size_1(self):
row_a = {
@@ -1970,40 +2126,76 @@ class TestSyncRelationships:
src_2.run.return_value = [row_b]
src_3 = MagicMock()
src_3.run.return_value = []
tgt_1 = MagicMock()
tgt_2 = MagicMock()
sink = MagicMock()
with (
patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(src_1),
_make_session_ctx(tgt_1),
_make_session_ctx(src_2),
_make_session_ctx(tgt_2),
_make_session_ctx(src_3),
],
),
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1),
):
total = sync_module.sync_relationships("src", "tgt", "p-1")
total = sync_module.sync_relationships("src", "tgt", "p-1", sink)
assert total == 2
assert sink.write_relationships.call_count == 2
assert src_1.run.call_args.args[1]["last_id"] == -1
assert src_2.run.call_args.args[1]["last_id"] == 1
def test_sync_relationships_chunks_grouped_rows_before_sink_write(self):
rows = [
{
"internal_id": idx,
"rel_type": "HAS",
"start_element_id": f"s-{idx}",
"end_element_id": f"e-{idx}",
"props": {},
}
for idx in range(1, 6)
]
src_1 = MagicMock()
src_1.run.return_value = rows
src_2 = MagicMock()
src_2.run.return_value = []
sink = MagicMock()
with (
patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(src_1),
_make_session_ctx(src_2),
],
),
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 2),
):
total = sync_module.sync_relationships("src", "tgt", "p-1", sink)
assert total == 5
assert [
len(call_args.args[3])
for call_args in sink.write_relationships.call_args_list
] == [2, 2, 1]
def test_sync_relationships_empty_source_returns_zero(self):
src = MagicMock()
src.run.return_value = []
sink = MagicMock()
with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[_make_session_ctx(src)],
) as mock_get_session:
total = sync_module.sync_relationships("src", "tgt", "p-1")
total = sync_module.sync_relationships("src", "tgt", "p-1", sink)
assert total == 0
assert mock_get_session.call_count == 1
sink.write_relationships.assert_not_called()
class TestInternetAnalysis:
@@ -2075,6 +2267,8 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is False
assert attack_paths_scan.is_migrated is False
assert attack_paths_scan.sink_backend == "neo4j"
def test_create_attack_paths_scan_inherits_true_from_previous(
self, tenants_fixture, providers_fixture, scans_fixture
@@ -2095,6 +2289,8 @@ class TestAttackPathsDbUtilsGraphDataReady:
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
is_migrated=True,
sink_backend="neptune",
)
new_scan = Scan.objects.create(
@@ -2115,6 +2311,109 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is True
# is_migrated tracks the data being served: inherited from the ready scan
assert attack_paths_scan.is_migrated is True
assert attack_paths_scan.sink_backend == "neptune"
def test_create_attack_paths_scan_prefers_active_sink_ready_scan(
self, tenants_fixture, providers_fixture, scans_fixture, settings
):
from tasks.jobs.attack_paths.db_utils import create_attack_paths_scan
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
is_migrated=False,
sink_backend="neo4j",
)
AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
is_migrated=True,
sink_backend="neptune",
)
new_scan = Scan.objects.create(
name="New Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
tenant_id=tenant.id,
)
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
):
attack_paths_scan = create_attack_paths_scan(
str(tenant.id), str(new_scan.id), provider.id
)
assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is True
assert attack_paths_scan.is_migrated is False
assert attack_paths_scan.sink_backend == "neo4j"
def test_create_attack_paths_scan_inherits_is_migrated_false_from_legacy_ready(
self, tenants_fixture, providers_fixture, scans_fixture
):
from tasks.jobs.attack_paths.db_utils import create_attack_paths_scan
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
# Previous scan is ready but pre-cutover (legacy Neo4j graph shape)
AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
is_migrated=False,
sink_backend="neo4j",
)
new_scan = Scan.objects.create(
name="New Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
tenant_id=tenant.id,
)
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
):
attack_paths_scan = create_attack_paths_scan(
str(tenant.id), str(new_scan.id), provider.id
)
assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is True
# Reads stay on the legacy catalog/backend until this scan's own sync
assert attack_paths_scan.is_migrated is False
assert attack_paths_scan.sink_backend == "neo4j"
def test_create_attack_paths_scan_inherits_false_when_no_previous_ready(
self, tenants_fixture, providers_fixture, scans_fixture
@@ -2135,6 +2434,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
scan=scan,
state=StateChoices.FAILED,
graph_data_ready=False,
sink_backend="neptune",
)
new_scan = Scan.objects.create(
@@ -2155,6 +2455,8 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is False
assert attack_paths_scan.is_migrated is False
assert attack_paths_scan.sink_backend == "neo4j"
def test_set_graph_data_ready_updates_field(
self, tenants_fixture, providers_fixture, scans_fixture
@@ -2261,7 +2563,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert attack_paths_scan.state == StateChoices.FAILED
assert attack_paths_scan.graph_data_ready is True
def test_set_provider_graph_data_ready_updates_all_scans_for_provider(
def test_set_provider_graph_data_ready_updates_all_scans_for_provider_sink(
self, tenants_fixture, providers_fixture, scans_fixture
):
from tasks.jobs.attack_paths.db_utils import set_provider_graph_data_ready
@@ -2289,6 +2591,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
scan=scan_a,
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neptune",
)
new_ap_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
@@ -2296,6 +2599,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
scan=scan_b,
state=StateChoices.EXECUTING,
graph_data_ready=True,
sink_backend="neptune",
)
with patch(
@@ -2309,6 +2613,48 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert old_ap_scan.graph_data_ready is False
assert new_ap_scan.graph_data_ready is False
def test_set_provider_graph_data_ready_preserves_other_sink_scans(
self, tenants_fixture, providers_fixture, scans_fixture
):
from tasks.jobs.attack_paths.db_utils import set_provider_graph_data_ready
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
legacy_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neo4j",
)
neptune_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.EXECUTING,
graph_data_ready=True,
sink_backend="neptune",
)
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
):
set_provider_graph_data_ready(neptune_scan, False)
legacy_scan.refresh_from_db()
neptune_scan.refresh_from_db()
assert legacy_scan.graph_data_ready is True
assert neptune_scan.graph_data_ready is False
def test_set_provider_graph_data_ready_does_not_affect_other_providers(
self, tenants_fixture, providers_fixture, scans_fixture
):
@@ -2871,3 +3217,57 @@ class TestCleanupStaleAttackPathsScans:
ap_scan.refresh_from_db()
assert ap_scan.state == StateChoices.SCHEDULED
mock_revoke.assert_not_called()
class TestNormalizeSinkProperties:
"""Coerce Cartography-emitted property values into sink-portable primitives.
Lists become comma-strings, dicts become JSON strings, temporals become
ISO strings, spatials become their stringified form. The same coercion
runs regardless of the active sink so queries are portable.
"""
@pytest.mark.parametrize(
"raw, expected",
[
(
{"a": "x", "b": 1, "c": 1.5, "d": True, "e": None},
{"a": "x", "b": 1, "c": 1.5, "d": True, "e": None},
),
(
{"actions": ["s3:GetObject", "s3:PutObject"], "tags": []},
{"actions": "s3:GetObject,s3:PutObject", "tags": ""},
),
(
{"condition": {"StringEquals": {"aws:SourceAccount": "123456789012"}}},
{
"condition": '{"StringEquals": {"aws:SourceAccount": "123456789012"}}'
},
),
],
)
def test_primitive_list_and_dict_branches(self, raw, expected):
sync_module._normalize_sink_properties(raw, labels=None)
assert raw == expected
def test_temporal_and_spatial_become_strings(self):
class FakeDateTime:
def iso_format(self) -> str:
return "2026-05-13T10:00:00+00:00"
class FakeSpatialPoint:
def __str__(self) -> str:
return "POINT(1.0 2.0)"
# The spatial branch is detected by module prefix, not by base class.
FakeSpatialPoint.__module__ = "neo4j.spatial.fake"
props = {
"created_at": FakeDateTime(),
"location": FakeSpatialPoint(),
}
sync_module._normalize_sink_properties(props, labels=None)
assert props == {
"created_at": "2026-05-13T10:00:00+00:00",
"location": "POINT(1.0 2.0)",
}
+50 -3
View File
@@ -1,4 +1,4 @@
from unittest.mock import call, patch
from unittest.mock import MagicMock, call, patch
import pytest
from api.attack_paths import database as graph_database
@@ -60,10 +60,12 @@ class TestDeleteProvider:
aps1 = create_attack_paths_scan(instance)
aps2 = create_attack_paths_scan(instance)
backend = MagicMock()
with (
patch(
"tasks.jobs.deletion.graph_database.drop_subgraph",
"tasks.jobs.deletion.sink_module.get_backend_for_name",
return_value=backend,
),
patch(
"tasks.jobs.deletion.graph_database.drop_database",
@@ -72,12 +74,55 @@ class TestDeleteProvider:
result = delete_provider(tenant_id, instance.id)
assert result
backend.drop_subgraph.assert_called_once_with(
graph_database.get_database_name(tenant_id), str(instance.id)
)
expected_tmp_calls = [
call(f"db-tmp-scan-{str(aps1.id).lower()}"),
call(f"db-tmp-scan-{str(aps2.id).lower()}"),
]
mock_drop_database.assert_has_calls(expected_tmp_calls, any_order=True)
def test_delete_provider_drops_graph_data_from_all_recorded_sinks(
self, providers_fixture, create_attack_paths_scan
):
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
create_attack_paths_scan(instance, sink_backend="neo4j")
create_attack_paths_scan(instance, sink_backend="neptune")
neo4j_backend = MagicMock()
neptune_backend = MagicMock()
def get_backend_for_name(name):
return {
"neo4j": neo4j_backend,
"neptune": neptune_backend,
}[name]
with (
patch(
"tasks.jobs.deletion.graph_database.get_database_name",
return_value="tenant-db",
),
patch(
"tasks.jobs.deletion.sink_module.get_backend_for_name",
side_effect=get_backend_for_name,
) as mock_get_backend_for_name,
patch("tasks.jobs.deletion.graph_database.drop_database"),
):
result = delete_provider(tenant_id, instance.id)
assert result
mock_get_backend_for_name.assert_has_calls(
[call("neo4j"), call("neptune")], any_order=True
)
neo4j_backend.drop_subgraph.assert_called_once_with(
"tenant-db", str(instance.id)
)
neptune_backend.drop_subgraph.assert_called_once_with(
"tenant-db", str(instance.id)
)
def test_delete_provider_continues_when_temp_db_drop_fails(
self, providers_fixture, create_attack_paths_scan
):
@@ -85,10 +130,12 @@ class TestDeleteProvider:
tenant_id = str(instance.tenant_id)
create_attack_paths_scan(instance)
backend = MagicMock()
with (
patch(
"tasks.jobs.deletion.graph_database.drop_subgraph",
"tasks.jobs.deletion.sink_module.get_backend_for_name",
return_value=backend,
),
patch(
"tasks.jobs.deletion.graph_database.drop_database",
+70 -1
View File
@@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
import pytest
from api.db_router import MainRouter
from api.exceptions import ProviderConnectionError
from api.exceptions import ProviderConnectionError, ProviderDeletedException
from api.models import (
Finding,
MuteRule,
@@ -262,6 +262,75 @@ class TestPerformScan:
assert provider.connected is False
assert isinstance(provider.connection_last_checked_at, datetime)
def test_perform_prowler_scan_provider_deleted_during_progress_update(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
tenant_id = str(tenant.id)
scan_id = str(scan.id)
provider_id = str(provider.id)
def scan_results():
Provider.objects.filter(pk=provider_id).delete()
yield 50, []
with (
patch(
"tasks.jobs.scan.initialize_prowler_provider",
return_value=MagicMock(),
),
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
patch("tasks.jobs.scan.logger.error") as mock_logger_error,
):
mock_prowler_scan_instance = MagicMock()
mock_prowler_scan_instance.scan.return_value = scan_results()
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
with pytest.raises(ProviderDeletedException):
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
mock_logger_error.assert_not_called()
assert not Scan.objects.filter(pk=scan_id).exists()
def test_perform_prowler_scan_sets_final_progress_when_progress_updates_are_throttled(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
tenant_id = str(tenant.id)
scan_id = str(scan.id)
provider_id = str(provider.id)
with (
patch(
"tasks.jobs.scan.initialize_prowler_provider",
return_value=MagicMock(),
),
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
patch("tasks.jobs.scan.PROGRESS_THROTTLE_DELTA", 200),
patch("tasks.jobs.scan.PROGRESS_THROTTLE_SECONDS", 3600),
):
mock_prowler_scan_instance = MagicMock()
mock_prowler_scan_instance.scan.return_value = [(99, []), (100, [])]
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
scan.refresh_from_db()
assert scan.state == StateChoices.COMPLETED
assert scan.progress == 100
@pytest.mark.parametrize(
"last_status, new_status, expected_delta",
[
Generated
+278 -20
View File
@@ -110,7 +110,7 @@ constraints = [
{ name = "blinker", specifier = "==1.9.0" },
{ name = "boto3", specifier = "==1.40.61" },
{ name = "botocore", specifier = "==1.40.61" },
{ name = "cartography", specifier = "==0.135.0" },
{ name = "cartography", specifier = "==0.138.1" },
{ name = "celery", specifier = "==5.6.2" },
{ name = "certifi", specifier = "==2026.1.4" },
{ name = "cffi", specifier = "==2.0.0" },
@@ -135,7 +135,6 @@ constraints = [
{ name = "debugpy", specifier = "==1.8.20" },
{ name = "decorator", specifier = "==5.2.1" },
{ name = "defusedxml", specifier = "==0.7.1" },
{ name = "detect-secrets", specifier = "==1.5.0" },
{ name = "dill", specifier = "==0.4.1" },
{ name = "distro", specifier = "==1.9.0" },
{ name = "dj-rest-auth", specifier = "==7.0.1" },
@@ -218,6 +217,7 @@ constraints = [
{ name = "jsonschema", specifier = "==4.23.0" },
{ name = "jsonschema-specifications", specifier = "==2025.9.1" },
{ name = "keystoneauth1", specifier = "==5.13.0" },
{ name = "kingfisher-bin", specifier = "==1.104.0" },
{ name = "kiwisolver", specifier = "==1.4.9" },
{ name = "knack", specifier = "==0.11.0" },
{ name = "kombu", specifier = "==5.6.2" },
@@ -364,7 +364,7 @@ constraints = [
{ name = "wcwidth", specifier = "==0.5.3" },
{ name = "websocket-client", specifier = "==1.9.0" },
{ name = "werkzeug", specifier = "==3.1.7" },
{ name = "workos", specifier = "==6.0.4" },
{ name = "workos", specifier = "==6.0.8" },
{ name = "wrapt", specifier = "==1.17.3" },
{ name = "xlsxwriter", specifier = "==3.2.9" },
{ name = "xmlsec", specifier = "==1.3.17" },
@@ -376,6 +376,7 @@ constraints = [
{ name = "zstd", specifier = "==1.5.7.3" },
]
overrides = [
{ name = "azure-mgmt-containerservice", specifier = "==34.1.0" },
{ name = "dulwich", specifier = "==1.2.5" },
{ name = "microsoft-kiota-abstractions", specifier = "==1.9.9" },
{ name = "okta", specifier = "==3.4.2" },
@@ -1407,6 +1408,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3d/66/0d8ae9ca4d75e57746026a1f9a10a7e25029511c128cf20166fce516bda9/azure_mgmt_logic-10.0.0-py3-none-any.whl", hash = "sha256:525c78afedf3edb35eb0a16152c8beba89769ee1bc6af01bcdc42842a551e443", size = 235433, upload-time = "2022-06-13T01:38:27.333Z" },
]
[[package]]
name = "azure-mgmt-managementgroups"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "azure-mgmt-core" },
{ name = "isodate" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fd/73/ac5e064ed7343e1b3172f32f09be3efca906087218d3046b5038f2f394ed/azure_mgmt_managementgroups-1.1.0.tar.gz", hash = "sha256:e6199baf118890ba2bda35dda83a88861c0b1bbef126311b20ec12eed9681951", size = 60101, upload-time = "2026-02-13T03:45:45.439Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/92/bc/993158de03cc0a49f2cf8192615ffedbc508c417cb3522e88f6652b714cc/azure_mgmt_managementgroups-1.1.0-py3-none-any.whl", hash = "sha256:140934589559ef6afcac6f1d24f995588a1965aaa89d47851c1cc639fafb1942", size = 83586, upload-time = "2026-02-13T03:45:46.836Z" },
]
[[package]]
name = "azure-mgmt-monitor"
version = "6.0.2"
@@ -1726,7 +1741,7 @@ wheels = [
[[package]]
name = "cartography"
version = "0.135.0"
version = "0.138.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "adal" },
@@ -1746,6 +1761,7 @@ dependencies = [
{ name = "azure-mgmt-eventhub" },
{ name = "azure-mgmt-keyvault" },
{ name = "azure-mgmt-logic" },
{ name = "azure-mgmt-managementgroups" },
{ name = "azure-mgmt-monitor" },
{ name = "azure-mgmt-network" },
{ name = "azure-mgmt-resource" },
@@ -1754,6 +1770,7 @@ dependencies = [
{ name = "azure-mgmt-storage" },
{ name = "azure-mgmt-synapse" },
{ name = "azure-mgmt-web" },
{ name = "azure-storage-blob" },
{ name = "azure-synapse-artifacts" },
{ name = "backoff" },
{ name = "boto3" },
@@ -1765,8 +1782,12 @@ dependencies = [
{ name = "duo-client" },
{ name = "google-api-python-client" },
{ name = "google-auth" },
{ name = "google-cloud-aiplatform" },
{ name = "google-cloud-artifact-registry" },
{ name = "google-cloud-asset" },
{ name = "google-cloud-resource-manager" },
{ name = "google-cloud-run" },
{ name = "google-cloud-storage" },
{ name = "httpx" },
{ name = "kubernetes" },
{ name = "marshmallow" },
@@ -1792,9 +1813,9 @@ dependencies = [
{ name = "workos" },
{ name = "xmltodict" },
]
sdist = { url = "https://files.pythonhosted.org/packages/39/47/606851d2403a983b63813b9e95427a5dd896e49bc5a501868c041262e9a5/cartography-0.135.0.tar.gz", hash = "sha256:3f500cd22c3b392d00e8b49f62acc95fd4dcd559ce514aafe2eb8101133c7a49", size = 9106458, upload-time = "2026-04-10T16:25:34.898Z" }
sdist = { url = "https://files.pythonhosted.org/packages/51/cd/0eb6a5a3c89cc179801d902ade9719af1a583c516c00f50d72b8207db1eb/cartography-0.138.1.tar.gz", hash = "sha256:356e946a0bcac899cba293d57803c71bd35fdeabe623f5f67d9405d7a643af9f", size = 9756966, upload-time = "2026-06-19T22:11:32.411Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/e1/99a26b3e662202be77961aba73338e1448623490710b81783e53a4bbef15/cartography-0.135.0-py3-none-any.whl", hash = "sha256:c62c32a6917b8f23a8b98fe2b6c7c4a918b50f55918482966c4dae1cf5f538e1", size = 1590545, upload-time = "2026-04-10T16:25:37.669Z" },
{ url = "https://files.pythonhosted.org/packages/a8/15/4447ec968825b2a19cba26ecb74964208aa3f941d9181a7782572e30b43d/cartography-0.138.1-py3-none-any.whl", hash = "sha256:88ec0898ea1a1b3f4653be9a3e7e61144f5cee20384b9040e92039617d39f029", size = 2014725, upload-time = "2026-06-19T22:11:29.886Z" },
]
[[package]]
@@ -2224,16 +2245,15 @@ wheels = [
]
[[package]]
name = "detect-secrets"
version = "1.5.0"
name = "deprecated"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyyaml" },
{ name = "requests" },
{ name = "wrapt" },
]
sdist = { url = "https://files.pythonhosted.org/packages/69/67/382a863fff94eae5a0cf05542179169a1c49a4c8784a9480621e2066ca7d/detect_secrets-1.5.0.tar.gz", hash = "sha256:6bb46dcc553c10df51475641bb30fd69d25645cc12339e46c824c1e0c388898a", size = 97351, upload-time = "2024-05-06T17:46:19.721Z" }
sdist = { url = "https://files.pythonhosted.org/packages/49/85/12f0a49a7c4ffb70572b6c2ef13c90c88fd190debda93b23f026b25f9634/deprecated-1.3.1.tar.gz", hash = "sha256:b1b50e0ff0c1fddaa5708a2c6b0a6588bb09b892825ab2b214ac9ea9d92a5223", size = 2932523, upload-time = "2025-10-30T08:19:02.757Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4e/5e/4f5fe4b89fde1dc3ed0eb51bd4ce4c0bca406246673d370ea2ad0c58d747/detect_secrets-1.5.0-py3-none-any.whl", hash = "sha256:e24e7b9b5a35048c313e983f76c4bd09dad89f045ff059e354f9943bf45aa060", size = 120341, upload-time = "2024-05-06T17:46:16.628Z" },
{ url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" },
]
[[package]]
@@ -2511,6 +2531,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
]
[[package]]
name = "docstring-parser"
version = "0.18.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e0/4d/f332313098c1de1b2d2ff91cf2674415cc7cddab2ca1b01ae29774bd5fdf/docstring_parser-0.18.0.tar.gz", hash = "sha256:292510982205c12b1248696f44959db3cdd1740237a968ea1e2e7a900eeb2015", size = 29341, upload-time = "2026-04-14T04:09:19.867Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/5f/ed01f9a3cdffbd5a008556fc7b2a08ddb1cc6ace7effa7340604b1d16699/docstring_parser-0.18.0-py3-none-any.whl", hash = "sha256:b3fcbed555c47d8479be0796ef7e19c2670d428d72e96da63f3a40122860374b", size = 22484, upload-time = "2026-04-14T04:09:18.638Z" },
]
[[package]]
name = "dogpile-cache"
version = "1.5.0"
@@ -2851,6 +2880,11 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/83/1d/d6466de3a5249d35e832a52834115ca9d1d0de6abc22065f049707516d47/google_auth-2.48.0-py3-none-any.whl", hash = "sha256:2e2a537873d449434252a9632c28bfc268b0adb1e53f9fb62afc5333a975903f", size = 236499, upload-time = "2026-01-26T19:22:45.099Z" },
]
[package.optional-dependencies]
requests = [
{ name = "requests" },
]
[[package]]
name = "google-auth-httplib2"
version = "0.2.0"
@@ -2877,6 +2911,46 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ca/94/24b010493660dd55e2d9769ae7ef44164aebd7e1f4a9266cf9459affd687/google_cloud_access_context_manager-0.3.0-py3-none-any.whl", hash = "sha256:5d15ad51547f06c281e35f16b4ffcb3e98bb2d898b01470f88b94edfb2eeb0a3", size = 58852, upload-time = "2025-10-17T02:30:33.768Z" },
]
[[package]]
name = "google-cloud-aiplatform"
version = "1.153.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "docstring-parser" },
{ name = "google-api-core", extra = ["grpc"] },
{ name = "google-auth" },
{ name = "google-cloud-bigquery" },
{ name = "google-cloud-resource-manager" },
{ name = "google-cloud-storage" },
{ name = "google-genai" },
{ name = "packaging" },
{ name = "proto-plus" },
{ name = "protobuf" },
{ name = "pydantic" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d5/97/1779e66ab845550bc602364311ea093ba156cb805a1c31b7c4d6f25b5863/google_cloud_aiplatform-1.153.1.tar.gz", hash = "sha256:445b6c683d5c630f174d81ae1f69f7da9e27e4d4ec5b70c5fe96de5c1247cfbc", size = 11011349, upload-time = "2026-05-15T06:34:14.851Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/16/01/8a1900e7a742ed480e6037ac4f6541466cb981d81bd4cbd34a9d46204ea1/google_cloud_aiplatform-1.153.1-py2.py3-none-any.whl", hash = "sha256:033fa1595a7e8ed1d97066e261e630f38fbc60e10c98c6487cf228fe9c7ec151", size = 9170782, upload-time = "2026-05-15T06:34:10.887Z" },
]
[[package]]
name = "google-cloud-artifact-registry"
version = "1.21.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core", extra = ["grpc"] },
{ name = "google-auth" },
{ name = "grpc-google-iam-v1" },
{ name = "grpcio" },
{ name = "proto-plus" },
{ name = "protobuf" },
]
sdist = { url = "https://files.pythonhosted.org/packages/13/2b/24e6956789bc1244efb18143aa4f124e03d870228e5bfd065c04d38a4d6b/google_cloud_artifact_registry-1.21.0.tar.gz", hash = "sha256:546e51eb5d463a6e5c668be6727d14f8ec82bc798031398006b2213d703e184c", size = 315219, upload-time = "2026-03-30T22:50:38.875Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e1/8c/a5c68031728f38d3306bad5ac10c0ca670cbdf414db308ddefa2c47f2b34/google_cloud_artifact_registry-1.21.0-py3-none-any.whl", hash = "sha256:a07079035438fd0f2e7264d4318b388650495f011db575405c18c9881449025c", size = 250544, upload-time = "2026-03-30T22:48:49.345Z" },
]
[[package]]
name = "google-cloud-asset"
version = "4.2.0"
@@ -2897,6 +2971,37 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/88/9a43fae1d2fed94d7f5f46b6f4c44bd15e5ea0e8657632108b5ec5f53d9d/google_cloud_asset-4.2.0-py3-none-any.whl", hash = "sha256:fd7ea04c64948a4779790343204cd5b41d4772d6ab1d05a9125e28a637ac0862", size = 282707, upload-time = "2026-01-09T14:53:03.081Z" },
]
[[package]]
name = "google-cloud-bigquery"
version = "3.41.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core", extra = ["grpc"] },
{ name = "google-auth" },
{ name = "google-cloud-core" },
{ name = "google-resumable-media" },
{ name = "packaging" },
{ name = "python-dateutil" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ce/13/6515c7aab55a4a0cf708ffd309fb9af5bab54c13e32dc22c5acd6497193c/google_cloud_bigquery-3.41.0.tar.gz", hash = "sha256:2217e488b47ed576360c9b2cc07d59d883a54b83167c0ef37f915c26b01a06fe", size = 513434, upload-time = "2026-03-30T22:50:55.347Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/40/33/1d3902efadef9194566d499d61507e1f038454e0b55499d2d7f8ab2a4fee/google_cloud_bigquery-3.41.0-py3-none-any.whl", hash = "sha256:2a5b5a737b401cbd824a6e5eac7554100b878668d908e6548836b5d8aaa4dcaa", size = 262343, upload-time = "2026-03-30T22:48:45.444Z" },
]
[[package]]
name = "google-cloud-core"
version = "2.6.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core" },
{ name = "google-auth" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/dd/1eef226e470369b26824a505c34482c0b493bc35fe8e0c6b003b5feca21a/google_cloud_core-2.6.0.tar.gz", hash = "sha256:e76149739f90fac1fc6757c09f47eaccb3145b54adbd7759b0f7c4b235f46c83", size = 36001, upload-time = "2026-05-07T08:04:04.124Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/84/4a/98da8930ab109c73d9a5d13782a9ebb81ea8c111f6d534a567b71d23e52b/google_cloud_core-2.6.0-py3-none-any.whl", hash = "sha256:6d63ac8e5eca6d9e4319d0a1e2265fadcd7f1049904378caecfa01cf52dd869e", size = 29390, upload-time = "2026-05-07T08:02:34.672Z" },
]
[[package]]
name = "google-cloud-org-policy"
version = "1.16.0"
@@ -2946,6 +3051,93 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/94/ff/4b28bcc791d9d7e4ac8fea00fbd90ccb236afda56746a3b4564d2ae45df3/google_cloud_resource_manager-1.16.0-py3-none-any.whl", hash = "sha256:fb9a2ad2b5053c508e1c407ac31abfd1a22e91c32876c1892830724195819a28", size = 400218, upload-time = "2026-01-15T13:02:47.378Z" },
]
[[package]]
name = "google-cloud-run"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core", extra = ["grpc"] },
{ name = "google-auth" },
{ name = "grpc-google-iam-v1" },
{ name = "grpcio" },
{ name = "proto-plus" },
{ name = "protobuf" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b7/89/dcaf0dc97e39b41e446456ceb60657ab025de79cfccd39cbd739d1a9849e/google_cloud_run-0.16.0.tar.gz", hash = "sha256:d52cf4e6ad3702ae48caccf6abcab543afee6f61c2a6ec753cc62a31e5b629f1", size = 514452, upload-time = "2026-03-26T22:17:05.589Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/c7/46153dc13713b5e4276d86f28ff4563332f9e4bae5ebc83abc5bfd994801/google_cloud_run-0.16.0-py3-none-any.whl", hash = "sha256:d7d2dd7307130fde2a0ce27e96d580dd23b7b2d973b6484b94d902e6b2618860", size = 459112, upload-time = "2026-03-26T22:16:00.018Z" },
]
[[package]]
name = "google-cloud-storage"
version = "3.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core" },
{ name = "google-auth" },
{ name = "google-cloud-core" },
{ name = "google-crc32c" },
{ name = "google-resumable-media" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/4c/47/205eb8e9a1739b5345843e5a425775cbdc472cc38e7eda082ba5b8d02450/google_cloud_storage-3.10.1.tar.gz", hash = "sha256:97db9aa4460727982040edd2bd13ff3d5e2260b5331ad22895802da1fc2a5286", size = 17309950, upload-time = "2026-03-23T09:35:23.409Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ad/ff/ca9ab2417fa913d75aae38bf40bf856bb2749a604b2e0f701b37cfcd23cc/google_cloud_storage-3.10.1-py3-none-any.whl", hash = "sha256:a72f656759b7b99bda700f901adcb3425a828d4a29f911bc26b3ea79c5b1217f", size = 324453, upload-time = "2026-03-23T09:35:21.368Z" },
]
[[package]]
name = "google-crc32c"
version = "1.8.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" },
{ url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" },
{ url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" },
{ url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" },
{ url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" },
{ url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" },
{ url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" },
{ url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" },
{ url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" },
{ url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" },
{ url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" },
{ url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" },
]
[[package]]
name = "google-genai"
version = "1.68.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "google-auth", extra = ["requests"] },
{ name = "httpx" },
{ name = "pydantic" },
{ name = "requests" },
{ name = "sniffio" },
{ name = "tenacity" },
{ name = "typing-extensions" },
{ name = "websockets" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9c/2c/f059982dbcb658cc535c81bbcbe7e2c040d675f4b563b03cdb01018a4bc3/google_genai-1.68.0.tar.gz", hash = "sha256:ac30c0b8bc630f9372993a97e4a11dae0e36f2e10d7c55eacdca95a9fa14ca96", size = 511285, upload-time = "2026-03-18T01:03:18.243Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/84/de/7d3ee9c94b74c3578ea4f88d45e8de9405902f857932334d81e89bce3dfa/google_genai-1.68.0-py3-none-any.whl", hash = "sha256:a1bc9919c0e2ea2907d1e319b65471d3d6d58c54822039a249fe1323e4178d15", size = 750912, upload-time = "2026-03-18T01:03:15.983Z" },
]
[[package]]
name = "google-resumable-media"
version = "2.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-crc32c" },
]
sdist = { url = "https://files.pythonhosted.org/packages/00/4b/0b235beccc310d0a48adbc7246b719d173cca6c88c572dfa4b090e39143c/google_resumable_media-2.9.0.tar.gz", hash = "sha256:f7cfb224846a9dd444d125115dfbe8ef02a2b893e78f087762fe716a255a734b", size = 2164534, upload-time = "2026-05-07T08:04:44.236Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/73/3518e63deb1667c5409a4579e28daf5e84479a87a72c547e0487f7883dcd/google_resumable_media-2.9.0-py3-none-any.whl", hash = "sha256:c8901e88e389af8bed64d9696c74d8bad961865eb2236e13e0bfca9bb0a65ca3", size = 81507, upload-time = "2026-05-07T08:03:23.809Z" },
]
[[package]]
name = "googleapis-common-protos"
version = "1.72.0"
@@ -3420,6 +3612,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/dd/99/76476a1057b349c860bae72e45d6ef438feb877c84ee7d565faf464e54c3/keystoneauth1-5.13.0-py3-none-any.whl", hash = "sha256:5ab81412eb0923ceb9c602cc3decce514b399523cb83d16b409ed3b0f9b03d41", size = 343585, upload-time = "2026-01-19T10:47:00.762Z" },
]
[[package]]
name = "kingfisher-bin"
version = "1.104.0"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/2b/324212f1baf482a7d4b66a2edf33073336735b67bb6b04a38d18fd9e67fb/kingfisher_bin-1.104.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:8e3840e67004a971fef80aba240ee5c3c5f7a3a343a6d1083a2751aaf866d5d3", size = 14057606, upload-time = "2026-06-22T03:03:01.419Z" },
{ url = "https://files.pythonhosted.org/packages/21/0a/cbf964da5102657cb9be4a59db7c9f7807ef88f9419673b7486daba785d3/kingfisher_bin-1.104.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b838313411fa2166a318a45aec2cfcc238e2f30f5292e309ca1129a73180c851", size = 12468386, upload-time = "2026-06-22T03:03:03.951Z" },
{ url = "https://files.pythonhosted.org/packages/0b/a0/cc7ef0ac28f147cdfc9d80e4239fff11c1329831c6f57510c929e848753c/kingfisher_bin-1.104.0-py3-none-manylinux_2_17_aarch64.musllinux_1_2_aarch64.whl", hash = "sha256:0a94abbf2154ef8a3b4845cc0240e2321cdc19e0f5c7f585ea5252e76b242f68", size = 13943188, upload-time = "2026-06-22T03:03:06.378Z" },
{ url = "https://files.pythonhosted.org/packages/17/79/827cfd7787885798a00b5ab905bdc866ef6f8deeff0f708679b06bc9baaa/kingfisher_bin-1.104.0-py3-none-manylinux_2_17_x86_64.musllinux_1_2_x86_64.whl", hash = "sha256:f381274b946f7f68ed72911770fff72024f2192c6e2e2158f2a7fbfda8c482fb", size = 14757594, upload-time = "2026-06-22T03:03:08.66Z" },
{ url = "https://files.pythonhosted.org/packages/da/93/b0061fc69cd10382f647f9266823f213fd0b3f168f8b5bd9151a2370abb1/kingfisher_bin-1.104.0-py3-none-win_amd64.whl", hash = "sha256:f228d0dd61a738673b1c536e965a5661a83b1ee6ca64186a46ba6ea81ab4fd0b", size = 27697957, upload-time = "2026-06-22T03:03:11.268Z" },
{ url = "https://files.pythonhosted.org/packages/a5/fb/f062665b4eb3f77e799cb6335e56bc2945aea83787888a6c1ab329858d0a/kingfisher_bin-1.104.0-py3-none-win_arm64.whl", hash = "sha256:a7774d9d11815ca946bd80b8c9df0f1d39c36cb5a21def3323b99d148dc63065", size = 26063704, upload-time = "2026-06-22T03:03:14.08Z" },
]
[[package]]
name = "kiwisolver"
version = "1.4.9"
@@ -3513,6 +3718,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/08/10/9f8af3e6f569685ce3af7faab51c8dd9d93b9c38eba339ca31c746119447/kubernetes-32.0.1-py2.py3-none-any.whl", hash = "sha256:35282ab8493b938b08ab5526c7ce66588232df00ef5e1dbe88a419107dc10998", size = 1988070, upload-time = "2025-02-18T21:06:31.391Z" },
]
[[package]]
name = "linode-api4"
version = "5.45.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "deprecated" },
{ name = "polling" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b2/b5/fce03d9b81008dcc0fe4961ce10e140ac3ae5ab17f2cdd659763e4964c0d/linode_api4-5.45.0.tar.gz", hash = "sha256:af8a0a5638345ad467447112dcf5d58ec47e7dd192b89ce0c8537a1e5c435d04", size = 283375, upload-time = "2026-06-11T18:05:13.671Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/18/38/19e3c8f7b7a9dbeea2aa5af61f70162bff5131b3d39acbe73e8d0dd12972/linode_api4-5.45.0-py3-none-any.whl", hash = "sha256:3cc2650b13d8d3bc7735fa8e92a639669618f320471dc8e519db778c6020eacd", size = 158336, upload-time = "2026-06-11T18:05:11.799Z" },
]
[[package]]
name = "lxml"
version = "6.1.0"
@@ -4332,6 +4551,12 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/41/f5/65b66420c275e9b26513fdd6d84687403d11ac8be4650b67d1e5572b8f48/policyuniverse-1.5.1.20231109-py2.py3-none-any.whl", hash = "sha256:0b0ece0ee8285af31fc39ce09c82a551ca62e62bc2842e23952503bccb973321", size = 484251, upload-time = "2023-11-30T19:12:43.463Z" },
]
[[package]]
name = "polling"
version = "0.3.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8f/c5/4249317962180d97ec7a60fe38aa91f86216533bd478a427a5468945c5c9/polling-0.3.2.tar.gz", hash = "sha256:3afd62320c99b725c70f379964bf548b302fc7f04d4604e6c315d9012309cc9a", size = 5189, upload-time = "2021-05-22T19:48:41.466Z" }
[[package]]
name = "portalocker"
version = "2.10.1"
@@ -4448,8 +4673,8 @@ wheels = [
[[package]]
name = "prowler"
version = "5.31.0"
source = { git = "https://github.com/prowler-cloud/prowler.git?rev=master#b5bb85c9564f6ca6a7f66c851bb56bde719205ee" }
version = "5.32.0"
source = { git = "https://github.com/prowler-cloud/prowler.git?rev=master#5dac8a0a53272e4db68c476fb969dc03e88beb68" }
dependencies = [
{ name = "alibabacloud-actiontrail20200706" },
{ name = "alibabacloud-credentials" },
@@ -4500,13 +4725,14 @@ dependencies = [
{ name = "dash" },
{ name = "dash-bootstrap-components" },
{ name = "defusedxml" },
{ name = "detect-secrets" },
{ name = "dulwich" },
{ name = "google-api-python-client" },
{ name = "google-auth-httplib2" },
{ name = "h2" },
{ name = "jsonschema" },
{ name = "kingfisher-bin" },
{ name = "kubernetes" },
{ name = "linode-api4" },
{ name = "markdown" },
{ name = "microsoft-kiota-abstractions" },
{ name = "msgraph-sdk" },
@@ -4536,7 +4762,7 @@ dependencies = [
[[package]]
name = "prowler-api"
version = "1.33.0"
version = "1.34.0"
source = { virtual = "." }
dependencies = [
{ name = "cartography" },
@@ -4606,7 +4832,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "cartography", specifier = "==0.135.0" },
{ name = "cartography", specifier = "==0.138.1" },
{ name = "celery", specifier = "==5.6.2" },
{ name = "defusedxml", specifier = "==0.7.1" },
{ name = "dj-rest-auth", extras = ["with-social", "jwt"], specifier = "==7.0.1" },
@@ -5931,6 +6157,38 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" },
]
[[package]]
name = "websockets"
version = "16.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f2/db/de907251b4ff46ae804ad0409809504153b3f30984daf82a1d84a9875830/websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8", size = 177340, upload-time = "2026-01-10T09:22:34.539Z" },
{ url = "https://files.pythonhosted.org/packages/f3/fa/abe89019d8d8815c8781e90d697dec52523fb8ebe308bf11664e8de1877e/websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad", size = 175022, upload-time = "2026-01-10T09:22:36.332Z" },
{ url = "https://files.pythonhosted.org/packages/58/5d/88ea17ed1ded2079358b40d31d48abe90a73c9e5819dbcde1606e991e2ad/websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d", size = 175319, upload-time = "2026-01-10T09:22:37.602Z" },
{ url = "https://files.pythonhosted.org/packages/d2/ae/0ee92b33087a33632f37a635e11e1d99d429d3d323329675a6022312aac2/websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe", size = 184631, upload-time = "2026-01-10T09:22:38.789Z" },
{ url = "https://files.pythonhosted.org/packages/c8/c5/27178df583b6c5b31b29f526ba2da5e2f864ecc79c99dae630a85d68c304/websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b", size = 185870, upload-time = "2026-01-10T09:22:39.893Z" },
{ url = "https://files.pythonhosted.org/packages/87/05/536652aa84ddc1c018dbb7e2c4cbcd0db884580bf8e95aece7593fde526f/websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5", size = 185361, upload-time = "2026-01-10T09:22:41.016Z" },
{ url = "https://files.pythonhosted.org/packages/6d/e2/d5332c90da12b1e01f06fb1b85c50cfc489783076547415bf9f0a659ec19/websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64", size = 184615, upload-time = "2026-01-10T09:22:42.442Z" },
{ url = "https://files.pythonhosted.org/packages/77/fb/d3f9576691cae9253b51555f841bc6600bf0a983a461c79500ace5a5b364/websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6", size = 178246, upload-time = "2026-01-10T09:22:43.654Z" },
{ url = "https://files.pythonhosted.org/packages/54/67/eaff76b3dbaf18dcddabc3b8c1dba50b483761cccff67793897945b37408/websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac", size = 178684, upload-time = "2026-01-10T09:22:44.941Z" },
{ url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" },
{ url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" },
{ url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" },
{ url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" },
{ url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" },
{ url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" },
{ url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" },
{ url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" },
{ url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" },
{ url = "https://files.pythonhosted.org/packages/72/07/c98a68571dcf256e74f1f816b8cc5eae6eb2d3d5cfa44d37f801619d9166/websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d", size = 174947, upload-time = "2026-01-10T09:23:36.166Z" },
{ url = "https://files.pythonhosted.org/packages/7e/52/93e166a81e0305b33fe416338be92ae863563fe7bce446b0f687b9df5aea/websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03", size = 175260, upload-time = "2026-01-10T09:23:37.409Z" },
{ url = "https://files.pythonhosted.org/packages/56/0c/2dbf513bafd24889d33de2ff0368190a0e69f37bcfa19009ef819fe4d507/websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da", size = 176071, upload-time = "2026-01-10T09:23:39.158Z" },
{ url = "https://files.pythonhosted.org/packages/a5/8f/aea9c71cc92bf9b6cc0f7f70df8f0b420636b6c96ef4feee1e16f80f75dd/websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c", size = 176968, upload-time = "2026-01-10T09:23:41.031Z" },
{ url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" },
{ url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" },
]
[[package]]
name = "werkzeug"
version = "3.1.7"
@@ -5945,16 +6203,16 @@ wheels = [
[[package]]
name = "workos"
version = "6.0.4"
version = "6.0.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
{ name = "httpx" },
{ name = "pyjwt", extra = ["crypto"] },
]
sdist = { url = "https://files.pythonhosted.org/packages/3c/2f/99fb8718274116c5c146c745755620fd5c5943f78ca52ca9b17e94348286/workos-6.0.4.tar.gz", hash = "sha256:b0bfe8fd212b8567422c4ea3732eb33608794033eb3a69900c6b04db183c32d6", size = 172217, upload-time = "2026-04-16T03:09:28.583Z" }
sdist = { url = "https://files.pythonhosted.org/packages/ca/0d/0a7f78912657f99412c788932ea1f3f4089916e77bdef7d2463842febe08/workos-6.0.8.tar.gz", hash = "sha256:43aa3f1992a0a4ca8933d9b6e5ada846dd3b1fe0ee10e64c876ee2000fc6090d", size = 178137, upload-time = "2026-04-24T18:48:03.203Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/f1/d2ab661e6dc2828a4c73e38f12630c3b109cfe2bc664ab70631c04f0db4b/workos-6.0.4-py3-none-any.whl", hash = "sha256:548668b3702673536f853ba72a7b5bbbc269e467aaf9ac4f477b6e0177df5e21", size = 511418, upload-time = "2026-04-16T03:09:27.098Z" },
{ url = "https://files.pythonhosted.org/packages/b2/3f/3d96da80d650b2f97d58af626053354584f619dbb769051e118bd9cd1ca5/workos-6.0.8-py3-none-any.whl", hash = "sha256:a00dd4930333aded2babbba824f8032eea05c5ca8c44d04a3fa068cf6be6e21a", size = 524505, upload-time = "2026-04-24T18:48:01.389Z" },
]
[[package]]
@@ -1,7 +1,7 @@
# Build command
# docker build --platform=linux/amd64 --no-cache -t prowler:latest .
ARG PROWLER_VERSION=latest@sha256:4b796c6df40a3350c7947747b59bdda230d0da6222287500e13b0a8e1574aad4
ARG PROWLER_VERSION=latest@sha256:ebb4ab999f10cb7e7c256226c2873de9b3bf2f3d855f385e0164bcf34104bfba
FROM toniblyx/prowler:${PROWLER_VERSION}
@@ -16,7 +16,7 @@
services:
nginx:
image: nginx:alpine@sha256:8b1e78743a03dbb2c95171cc58639fef29abc8816598e27fb910ed2e621e589a
image: nginx:alpine@sha256:54f2a904c251d5a34adf545a72d32515a15e08418dae0266e23be2e18c66fefa
container_name: prowler-nginx
restart: unless-stopped
ports:
@@ -0,0 +1,24 @@
import warnings
from dashboard.common_methods import get_section_containers_cis
warnings.filterwarnings("ignore")
def get_table(data):
aux = data[
[
"REQUIREMENTS_ID",
"REQUIREMENTS_DESCRIPTION",
"REQUIREMENTS_ATTRIBUTES_SECTION",
"CHECKID",
"STATUS",
"REGION",
"ACCOUNTID",
"RESOURCEID",
]
].copy()
return get_section_containers_cis(
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
)
+24
View File
@@ -0,0 +1,24 @@
import warnings
from dashboard.common_methods import get_section_containers_cis
warnings.filterwarnings("ignore")
def get_table(data):
aux = data[
[
"REQUIREMENTS_ID",
"REQUIREMENTS_DESCRIPTION",
"REQUIREMENTS_ATTRIBUTES_SECTION",
"CHECKID",
"STATUS",
"REGION",
"ACCOUNTID",
"RESOURCEID",
]
].copy()
return get_section_containers_cis(
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
)
+25
View File
@@ -0,0 +1,25 @@
import warnings
from dashboard.common_methods import get_section_containers_cis
warnings.filterwarnings("ignore")
def get_table(data):
aux = data[
[
"REQUIREMENTS_ID",
"REQUIREMENTS_DESCRIPTION",
"REQUIREMENTS_ATTRIBUTES_SECTION",
"CHECKID",
"STATUS",
"REGION",
"ACCOUNTID",
"RESOURCEID",
]
].copy()
return get_section_containers_cis(
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
)
+24
View File
@@ -0,0 +1,24 @@
import warnings
from dashboard.common_methods import get_section_containers_cis
warnings.filterwarnings("ignore")
def get_table(data):
aux = data[
[
"REQUIREMENTS_ID",
"REQUIREMENTS_DESCRIPTION",
"REQUIREMENTS_ATTRIBUTES_SECTION",
"CHECKID",
"STATUS",
"REGION",
"ACCOUNTID",
"RESOURCEID",
]
].copy()
return get_section_containers_cis(
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
)
+24
View File
@@ -0,0 +1,24 @@
import warnings
from dashboard.common_methods import get_section_containers_cis
warnings.filterwarnings("ignore")
def get_table(data):
aux = data[
[
"REQUIREMENTS_ID",
"REQUIREMENTS_DESCRIPTION",
"REQUIREMENTS_ATTRIBUTES_SECTION",
"CHECKID",
"STATUS",
"REGION",
"ACCOUNTID",
"RESOURCEID",
]
].copy()
return get_section_containers_cis(
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
)
+2
View File
@@ -445,3 +445,5 @@ The metadata structure is enforced in code using a Pydantic model. For reference
## Specific Check Patterns
Details for specific providers can be found in documentation pages named using the pattern `<provider_name>-details`.
Checks that scan resources for plaintext secrets follow a dedicated batched structure. Refer to [Secret-Scanning Checks](/developer-guide/secret-scanning-checks) before creating or updating one.
+5 -2
View File
@@ -42,10 +42,14 @@ When adding a new configurable check to Prowler, update the following files:
```
- **Provider Schema:** Add the typed field to the provider's Pydantic schema in `prowler/config/schema/<provider>.py`. This is required: the loader validates user configs against these schemas and the shipped `config.yaml` must round-trip with zero warnings. See [Adding a Parameter to the Provider Schema](#adding-a-parameter-to-the-provider-schema) below.
- **Test Fixtures:** If tests depend on this configuration, add the variable to `tests/config/fixtures/config.yaml`.
- **Documentation:** Document the new variable in the list of configurable checks in `docs/tutorials/configuration_file.md`.
- **Documentation:** Document the new variable in the list of configurable checks in [Configuration File](/user-guide/cli/tutorials/configuration_file) (`docs/user-guide/cli/tutorials/configuration_file.mdx`).
For a complete list of checks that already support configuration, see the [Configuration File Tutorial](/user-guide/cli/tutorials/configuration_file).
<Note>
Because a configurable check's verdict depends on the `audit_config` value it reads, a compliance requirement can lose meaning if the scan ran with a looser threshold than the control demands. Compliance frameworks can guard against this with **configuration guardrails**: a requirement declares the strictest configuration it tolerates and is forced to FAIL when the scan's config falls short. See [Configuration Guardrails for Requirements](/developer-guide/security-compliance-framework#configuration-guardrails-for-requirements).
</Note>
## Adding a Parameter to the Provider Schema
Most providers have a typed Pydantic schema in `prowler/config/schema/`, registered in `prowler/config/schema/registry.py`. When a config is loaded and the provider has a registered schema, `validate_provider_config` checks each user-supplied key against it, logs a warning, and drops any field that fails validation. The consumer's `.get(key, default)` then falls back to the built-in default. Providers without a registered schema are passed through unchanged.
@@ -149,7 +153,6 @@ Only fields with a numeric range, a fixed value set, or a length cap are listed.
| `max_days_secret_unused` | `7..365` days | |
| `max_days_secret_unrotated` | `1..180` days | NIST IA-5: rotate quarterly; CIS ≤90 |
| `min_kinesis_stream_retention_hours` | `24..8760` h | 1 day .. 1 year |
| `detect_secrets_plugins[].limit` | `0.0..10.0` | Shannon entropy threshold |
| `shodan_api_key` | ≤512 chars | |
### Azure
@@ -0,0 +1,119 @@
---
title: 'Secret-Scanning Checks'
---
import { VersionBadge } from "/snippets/version-badge.mdx"
<VersionBadge version="5.32.0" />
Prowler scans audited resources for plaintext secrets using [Kingfisher](https://github.com/mongodb/kingfisher), an open-source secret-scanning engine that Prowler invokes as a subprocess. This guide explains the structure every secret-scanning check must follow to keep scanning correct and efficient on large accounts.
<Note>
Since Prowler 5.32.0 the secret-scanning checks scan with Kingfisher. Earlier versions used the `detect-secrets` library.
</Note>
## Overview
Secret detection runs through a single helper in `prowler/lib/utils/utils.py`:
- **`detect_secrets_scan_batch(payloads, excluded_secrets=..., validate=...)`** scans many payloads in chunked subprocess invocations and returns a `{key: [findings]}` dictionary. To scan a single payload, pass a one-entry mapping (for example, `{0: data}`).
Every Kingfisher invocation carries a fixed process-startup cost (around 100 ms). Scanning once per resource would spawn thousands of subprocesses on large accounts (for example, thousands of CloudWatch log groups). `detect_secrets_scan_batch` amortizes that cost: it writes each payload to a temporary file as it consumes them, runs one subprocess per chunk (500 payloads by default), and maps the findings back to each payload by key.
## The Batched Structure
Every secret-scanning check follows three phases.
### Phase 1: Collect
Define a generator that yields `(key, payload)` for each scannable unit. The generator builds payload strings only — it does not call Kingfisher. Lazy yielding keeps memory and temporary-disk usage bounded to a single chunk, which matters when an account holds thousands of resources.
### Phase 2: Batch
Call `detect_secrets_scan_batch` once with the generator. The helper consumes it in chunks, runs Kingfisher per chunk, and returns the keys that produced findings mapped to their finding lists.
### Phase 3: Report
Iterate the resources, look up the findings by key, and build one report per resource. Emit a finding for **every** iterated resource — never drop one silently. When a resource's payload cannot be prepared for scanning (for example, user data that fails to base64-decode or decompress), report it as `MANUAL` with a status explaining the scan could not inspect it, rather than omitting it or claiming `PASS`.
```python
from prowler.lib.check.models import Check, Check_Report_AWS
from prowler.lib.utils.utils import (
annotate_verified_secrets,
detect_secrets_scan_batch,
)
from prowler.providers.aws.services.example.example_client import example_client
class example_resource_no_secrets(Check):
def execute(self):
findings = []
excluded = example_client.audit_config.get("secrets_ignore_patterns", [])
validate = example_client.audit_config.get("secrets_validate", False)
resources = list(example_client.resources)
# Phase 1: collect — builds strings only, no scan.
def payloads():
for index, resource in enumerate(resources):
if resource.scannable_data:
yield index, serialize(resource)
# Phase 2: batch — one call, chunked subprocesses.
batch_results = detect_secrets_scan_batch(
payloads(), excluded_secrets=excluded, validate=validate
)
# Phase 3: report — look up findings by key.
for index, resource in enumerate(resources):
report = Check_Report_AWS(metadata=self.metadata(), resource=resource)
report.status = "PASS"
report.status_extended = f"No secrets found in {resource.name}."
detect_secrets_output = batch_results.get(index)
if detect_secrets_output:
report.status = "FAIL"
report.status_extended = (
f"Potential secret found in {resource.name} -> ..."
)
annotate_verified_secrets(report, detect_secrets_output)
findings.append(report)
return findings
```
## Choosing the Key
The key maps each finding back to its source. Two shapes cover every check:
- **One payload per resource:** use the resource index. This fits checks that serialize a single payload per resource, such as launch configurations, CloudFormation outputs, SSM documents, Step Functions definitions, and OpenStack metadata.
- **Several payloads per resource:** use a `(resource_index, fragment)` tuple, where the fragment identifies the variable, log stream, container, file, or version. Phase 3 groups the per-fragment findings to build the resource report. This fits CloudWatch log streams, ECS containers, CodeBuild variables, Glue arguments, and Lambda code files.
Derive the indices from the same `list(...)` of resources in both Phase 1 and Phase 3 so the order stays stable and the keys align.
## Preserving Per-Payload Results
`detect_secrets_scan_batch` runs Kingfisher with `--no-dedup`, so a secret that appears in more than one payload is reported for each one. This reproduces the result of scanning each payload individually. Build payload strings exactly as a single scan would: serialize the same data and keep line ordering, because messages often map a finding's `line_number` back to a variable name or metadata key.
## Validation and Severity
`detect_secrets_scan_batch` accepts `validate`, read from `secrets_validate` in the provider configuration or the `--scan-secrets-validate` flag. When enabled, Kingfisher confirms whether each secret is live, and confirmed secrets carry `is_verified: True`.
After marking a report as `FAIL`, pass the findings to `annotate_verified_secrets(report, findings)`. When any secret is verified, the helper escalates the finding to critical severity and appends a note that the secret was confirmed live. Validation stays off by default because it sends the discovered secret to the provider API.
## Excluded Secrets
`detect_secrets_scan_batch` applies `secrets_ignore_patterns` — regular expressions from the provider configuration — against each finding's source line and drops the matches, mirroring single-scan behavior.
## Testing
To assert on the verified-secret path, mock `detect_secrets_scan_batch` in the check module and return the keyed dictionary. For a single resource scanned at index `0`:
```python
mock.patch(
"prowler.providers.aws.services.example.example_resource_no_secrets.example_resource_no_secrets.detect_secrets_scan_batch",
return_value={
0: [{"type": "...", "line_number": 1, "is_verified": True}]
},
)
```
Most tests need no mock at all: they seed resources that contain example secrets and assert on the `FAIL` status and message, which exercises the real batched path. Refer to the [Testing](/developer-guide/unit-testing) documentation for the general structure.
@@ -2,6 +2,8 @@
title: 'Creating a New Security Compliance Framework in Prowler'
---
import { VersionBadge } from "/snippets/version-badge.mdx"
This guide explains how to add a new security compliance framework to Prowler, end to end. It covers directory layout, the two supported JSON schemas (universal and legacy), the Pydantic models that validate each framework, check mapping conventions, output formatting, local validation, testing, and the pull request process.
## Introduction
@@ -23,7 +25,7 @@ Requirement coverage feeds the compliance percentage calculations and the metada
| **Universal (recommended for new frameworks)** | Multi-provider frameworks, or single-provider frameworks that benefit from declarative table/PDF rendering | `prowler/compliance/<framework>.json` (top-level) | Available for **every** provider whose key appears in any `requirement.checks` dict |
| **Legacy provider-specific** | Single-provider frameworks with framework-specific attribute classes already declared in the codebase (CIS, ENS, ISO 27001, etc.) | `prowler/compliance/<provider>/<framework>_<version>_<provider>.json` | Available only under that provider |
Auto-discovery happens in `get_bulk_compliance_frameworks_universal(provider)` (`prowler/lib/check/compliance_models.py:915`), which scans **both** the top-level `prowler/compliance/` directory and every per-provider sub-directory. Legacy frameworks are transparently converted to the universal `ComplianceFramework` model via `adapt_legacy_to_universal()` before being returned, so the rest of Prowler — CLI table rendering, CSV/OCSF outputs, PDF generation — works the same regardless of the source schema.
Auto-discovery happens in `get_bulk_compliance_frameworks_universal(provider)` (`prowler/lib/check/compliance_models.py`), which scans **both** the top-level `prowler/compliance/` directory and every per-provider sub-directory. Legacy frameworks are transparently converted to the universal `ComplianceFramework` model via `adapt_legacy_to_universal()` before being returned, so the rest of Prowler — CLI table rendering, CSV/OCSF outputs, PDF generation — works the same regardless of the source schema.
> The legacy entry-point `Compliance.get_bulk(provider)` (used by older code paths) only scans per-provider sub-directories. Universal top-level files are picked up exclusively via the universal loader; this matters if you are wiring a new code path against the legacy API.
@@ -70,13 +72,13 @@ The file is auto-discovered — there is **no** need to register it in any `__in
}
```
A `provider` field at the top level is **optional**. The framework's effective provider list is derived by `ComplianceFramework.get_providers()` (`compliance_models.py:739`) from the union of all keys appearing in `requirement.checks` across all requirements; the explicit `provider` field is used **only as a fallback** when no requirement carries any `checks` key. This is what enables a single file (e.g. `dora_2022_2554.json`) to cover AWS today and add Azure / GCP / etc. tomorrow without restructuring.
A `provider` field at the top level is **optional**. The framework's effective provider list is derived by `ComplianceFramework.get_providers()` (`compliance_models.py`) from the union of all keys appearing in `requirement.checks` across all requirements; the explicit `provider` field is used **only as a fallback** when no requirement carries any `checks` key. This is what enables a single file (e.g. `dora_2022_2554.json`) to cover AWS today and add Azure / GCP / etc. tomorrow without restructuring.
Provider keys inside `requirement.checks` must match the directory names under `prowler/providers/`. The valid keys at present are: `aws`, `azure`, `gcp`, `m365`, `kubernetes`, `iac`, `github`, `googleworkspace`, `alibabacloud`, `cloudflare`, `mongodbatlas`, `nhn`, `openstack`, `oraclecloud`, `llm`. Comparison in `supports_provider()` is case-insensitive, but lowercase is the convention used everywhere in the repository.
### `attributes_metadata`
Declares the shape of the per-requirement `attributes` dict. When this field is present, the root validator `validate_attributes_against_metadata` (`compliance_models.py:669`) enforces the schema at load time and rejects:
Declares the shape of the per-requirement `attributes` dict. When this field is present, the root validator `validate_attributes_against_metadata` (`compliance_models.py`) enforces the schema at load time and rejects:
- Missing keys marked `required: true`.
- Keys present in `attributes` but not declared in `attributes_metadata` (typo / drift guard).
@@ -192,6 +194,7 @@ Per requirement:
- `name`: short title shown alongside the id.
- `attributes`: flat dict; keys must conform to `attributes_metadata`.
- `checks`: dict keyed by provider name (the same lowercase keys listed in the previous section). Each value is a list of Prowler check names that evidence this requirement for that provider. The list **may be empty** and the dict itself defaults to `{}` if omitted; either way the requirement is still loaded and listed by `--list-compliance-requirements`, it just has zero checks to execute. Note: there is **no automatic check-existence validation** at load time — referencing a non-existent check name will silently produce a requirement with no findings. Validate this yourself (see "Validating Your Framework" below).
- `config_requirements`: optional list of configuration guardrails. Each entry asserts that a configurable check referenced by this requirement ran with a configuration strict enough to actually satisfy the requirement; otherwise the requirement is forced to FAIL. See [Configuration Guardrails for Requirements](#configuration-guardrails-for-requirements) for the full schema and semantics. In the universal schema the field name is lowercase (`config_requirements`); legacy files use `ConfigRequirements`.
For MITRE-style frameworks, additional optional fields are available on the requirement: `tactics`, `sub_techniques`, `platforms`, `technique_url` (these are populated automatically when adapting a legacy MITRE JSON to the universal model).
@@ -258,7 +261,7 @@ prowler/lib/outputs/compliance/<framework>/
### JSON schema reference
Every legacy compliance file is a JSON document with the following top-level keys. `Framework`, `Name` and `Provider` are validated non-empty by the root validator `framework_and_provider_must_not_be_empty` (`compliance_models.py:329`).
Every legacy compliance file is a JSON document with the following top-level keys. `Framework`, `Name` and `Provider` are validated non-empty by the root validator `framework_and_provider_must_not_be_empty` (`compliance_models.py`).
| Field | Type | Required | Description |
|---|---|---|---|
@@ -280,10 +283,11 @@ Each entry in `Requirements` describes one control or requirement.
| `Description` | string | Yes | Verbatim description from the source framework. |
| `Attributes` | array | Yes | List of [attribute objects](#attribute-objects). The shape depends on the framework. |
| `Checks` | array of strings | Yes | Prowler check identifiers that automate the requirement. Leave the list empty when the control cannot be automated. |
| `ConfigRequirements` | array of objects | No | Optional [configuration guardrails](#configuration-guardrails-for-requirements). Each entry asserts that a configurable check ran with a configuration strict enough to satisfy the requirement; when it did not, the requirement is forced to FAIL. |
#### Attribute Objects
`Attributes` is parsed against the union declared in `Compliance_Requirement.Attributes` (`compliance_models.py:293`). Pydantic v1 tries each member of the union in declaration order and falls back to `Generic_Compliance_Requirement_Attribute` (the last entry) when nothing else matches — so a brand-new shape that doesn't match any existing class will silently be accepted as Generic, losing its specific fields.
`Attributes` is parsed against the union declared in `Compliance_Requirement.Attributes` (`compliance_models.py`). Pydantic v1 tries each member of the union in declaration order and falls back to `Generic_Compliance_Requirement_Attribute` (the last entry) when nothing else matches — so a brand-new shape that doesn't match any existing class will silently be accepted as Generic, losing its specific fields.
As of today, the registered attribute classes are: `CIS_Requirement_Attribute`, `ENS_Requirement_Attribute`, `ASDEssentialEight_Requirement_Attribute`, `ISO27001_2013_Requirement_Attribute`, `AWS_Well_Architected_Requirement_Attribute`, `KISA_ISMSP_Requirement_Attribute`, `Prowler_ThreatScore_Requirement_Attribute`, `CCC_Requirement_Attribute`, `C5Germany_Requirement_Attribute`, `CSA_CCM_Requirement_Attribute`, and `Generic_Compliance_Requirement_Attribute` (fallback). MITRE-style frameworks use the separate `Mitre_Requirement` model with `Tactics` / `SubTechniques` / `Platforms` / `TechniqueURL` at the requirement top level. The most common shapes are summarized below.
@@ -472,13 +476,188 @@ For NIST-style catalogs that use `Generic_Compliance_Requirement_Attribute`, no
### Legacy-to-universal adapter
At load time, every legacy file is transparently adapted to a `ComplianceFramework` via `adapt_legacy_to_universal()` (`compliance_models.py:819`), which: (a) flattens the first element of `Attributes` into a flat `attributes` dict, (b) wraps `Checks` as `{provider_lower: [...]}`, (c) infers `attributes_metadata` from the matched Pydantic class via `_infer_attribute_metadata()`. The rest of Prowler (CSV/OCSF/PDF output, CLI table) then treats both formats identically.
At load time, every legacy file is transparently adapted to a `ComplianceFramework` via `adapt_legacy_to_universal()` (`compliance_models.py`), which: (a) flattens the first element of `Attributes` into a flat `attributes` dict, (b) wraps `Checks` as `{provider_lower: [...]}`, (c) infers `attributes_metadata` from the matched Pydantic class via `_infer_attribute_metadata()`. The rest of Prowler (CSV/OCSF/PDF output, CLI table) then treats both formats identically.
Loader-error behaviour differs between the two entry points:
- `load_compliance_framework()` (legacy) is **fail-fast**: it calls `sys.exit(1)` on any `ValidationError` (`compliance_models.py:464`).
- `load_compliance_framework()` (legacy) is **fail-fast**: it calls `sys.exit(1)` on any `ValidationError` (`compliance_models.py`).
- `load_compliance_framework_universal()` is more lenient — it logs the error and returns `None`, so `get_bulk_compliance_frameworks_universal()` simply skips the broken file and keeps loading the rest.
## Configuration Guardrails for Requirements
<VersionBadge version="5.32.0" />
Some requirements are only truly satisfied when the configurable checks behind them ran with a configuration strict enough to meet the control. A [configurable check](/developer-guide/configurable-checks) reads thresholds from the scan's `audit_config`, so loosening a value can make the check PASS while the requirement it backs is, in fact, not satisfied.
A worked example: CIS AWS 6.0 requirement 2.11 ("credentials unused for 45 days or more are disabled") maps to `iam_user_accesskey_unused`, which is driven by the `max_unused_access_keys_days` config key. If a user raises that value to `120`, the check passes for a key unused for 90 days — yet the requirement explicitly demands a 45-day threshold, so the PASS is misleading.
Configuration guardrails close that gap. A requirement declares the configuration it expects, and when the scan ran with a configuration too loose to honor it, the requirement is forced to **FAIL** in every compliance output, with the reason surfaced in the finding's extended status.
<Note>
Guardrails are an **optional** safety net for configurable checks. A requirement that maps only to non-configurable checks does not need them. When the field is absent, behavior is unchanged.
</Note>
### Where guardrails are declared
The field is attached to each requirement and exists in both schemas:
- **Legacy** (`prowler/compliance/<provider>/...`): `ConfigRequirements`, a list of objects, validated against the `Compliance_Requirement_ConfigConstraint` Pydantic model (`prowler/lib/check/compliance_models.py`).
- **Universal** (`prowler/compliance/...`): `config_requirements`, the same list of objects as plain dicts on `UniversalComplianceRequirement`.
When a legacy file is adapted to the universal model, `adapt_legacy_to_universal()` copies `ConfigRequirements` into `config_requirements` (`compliance_models.py`), so downstream code only ever reads one shape.
### Constraint schema
Each entry in the list is a single constraint with the following fields:
| Field | Type | Required | Description |
|---|---|---|---|
| `Check` | string | Yes | The configurable check this constraint guards. Should be one of the requirement's `Checks`. Used only to build a human-readable reason. |
| `ConfigKey` | string | Yes | The `audit_config` key the check reads (for example `max_unused_access_keys_days`). |
| `Operator` | enum | Yes | How to compare the applied value against `Value`. One of `lte`, `gte`, `eq`, `in`, `subset`, `superset`. |
| `Value` | bool, int, float, string, or list | Yes | The strictest configuration the requirement tolerates. The accepted Python type depends on the operator (see below). |
| `Provider` | string | No | The provider this constraint applies to (e.g. `aws`). **Required for universal (multi-provider) frameworks**, where the same requirement maps checks across providers — the constraint is only evaluated when the scanned provider matches. Single-provider (legacy) frameworks omit it. |
### Operators
| Operator | Applied value satisfies the guardrail when… | Typical use |
|---|---|---|
| `lte` | `applied <= Value` | Maximum-age / maximum-count thresholds (e.g. `max_unused_access_keys_days <= 45`). |
| `gte` | `applied >= Value` | Minimum-retention / minimum-count thresholds. |
| `eq` | `applied == Value` | Boolean toggles or an exact required value (e.g. `mute_non_default_regions == false`). |
| `in` | `applied` is one of `Value` (a list) | The applied scalar must belong to an allowed set. |
| `subset` | `set(applied) <= set(Value)` | **Allowlist** configs — every applied value must already be permitted. Widening the allowlist with a weaker value (e.g. adding TLS `1.0` to `recommended_minimal_tls_versions`) breaks the guardrail. |
| `superset` | `set(applied) >= set(Value)` | **Denylist** configs — every forbidden value must remain forbidden. Removing an entry from a denylist (e.g. dropping a weak algorithm from `insecure_key_algorithms`) breaks the guardrail. |
<Note>
`subset` / `superset` require both the applied value and `Value` to be lists; any other type is treated as not satisfied. For `eq` against a boolean, declare `Value` as a JSON boolean (`false`, not `0`) — the model keeps booleans distinct from integers.
</Note>
### How guardrails are evaluated
All evaluation lives in one shared module, `prowler/lib/check/compliance_config_eval.py`, consumed by every compliance output (CSV, OCSF, and the CLI tables) and reused by the Prowler App backend so the rule is defined exactly once.
1. The applied configuration is the scan-global `audit_config` (the same mapping for every resource and region), resolved via `get_scan_audit_config()`.
2. For each requirement that declares constraints, `evaluate_config_constraints()` walks the list and returns `(is_compliant, reason)`. The requirement is compliant when **every** explicitly-set key satisfies its constraint.
3. A constraint tagged with a `Provider` that does **not** match the provider being scanned (resolved via `get_scan_provider_type()`) is **skipped**. This scopes a universal framework's constraints to the right provider, so a guardrail authored for an AWS check never affects a GCP or Azure scan of the same requirement. Untagged constraints (legacy single-provider frameworks) always apply.
4. A constraint whose `ConfigKey` is **not present** in `audit_config` is **skipped** — the check's built-in default is assumed to already match what the requirement expects. This is why nothing changes for the default configuration.
5. When a constraint is violated, the finding's status is overridden to `FAIL` and a plain-language explanation is prepended to `status_extended` (via `apply_config_status()`). The message opens with `Configuration not valid for this requirement.` and names the check, the value the scan applied, what the requirement needs and how to fix it. For the table generators, `get_effective_status()` applies the same FAIL roll-up so per-section counts stay consistent.
<Warning>
Guardrails only ever make a result **stricter** (they can turn PASS into FAIL); they never relax a real FAIL into PASS. A requirement with no constraints, or whose keys all use defaults, is reported exactly as before.
</Warning>
### Example: legacy framework
From `prowler/compliance/aws/cis_6.0_aws.json`, requirement 2.11 declares two guardrails — one per configurable check it maps to:
```json title="prowler/compliance/aws/cis_6.0_aws.json"
{
"Id": "2.11",
"Description": "Ensure credentials unused for 45 days or more are disabled.",
"Checks": [
"iam_user_accesskey_unused",
"iam_user_console_access_unused"
],
"ConfigRequirements": [
{
"Check": "iam_user_accesskey_unused",
"ConfigKey": "max_unused_access_keys_days",
"Operator": "lte",
"Value": 45
},
{
"Check": "iam_user_console_access_unused",
"ConfigKey": "max_console_access_days",
"Operator": "lte",
"Value": 45
}
],
"Attributes": [ /* ... */ ]
}
```
A boolean guardrail from the same file: requirement 2.5 (IAM Access Analyzer) only holds when regions are not muted, so a scan with `mute_non_default_regions: true` cannot be trusted for it:
```json
"ConfigRequirements": [
{
"Check": "accessanalyzer_enabled",
"ConfigKey": "mute_non_default_regions",
"Operator": "eq",
"Value": false
}
]
```
### Example: universal framework
The universal schema uses the lowercase `config_requirements` key with the identical object shape:
```json
{
"id": "MF-2.1",
"name": "Restrict TLS to modern versions",
"description": "Endpoints must negotiate only TLS 1.2 or higher.",
"checks": {
"aws": ["elbv2_listener_ssl_listeners"]
},
"config_requirements": [
{
"Check": "elbv2_listener_ssl_listeners",
"Provider": "aws",
"ConfigKey": "recommended_minimal_tls_versions",
"Operator": "subset",
"Value": ["TLS 1.2", "TLS 1.3"]
}
]
}
```
Each constraint declares the `Provider` it targets so the guardrail is only evaluated on scans of that provider — essential for universal frameworks like CSA CCM and DORA, where one requirement maps checks across `aws`, `azure`, `gcp` and more. Because the operator is `subset`, adding `"TLS 1.0"` to `recommended_minimal_tls_versions` widens the allowlist beyond `["TLS 1.2", "TLS 1.3"]` and the requirement is forced to FAIL.
### What the user sees
With a loosened config, the affected requirement's findings report:
```text
Status: FAIL
StatusExtended: Configuration not valid for this requirement. The check
iam_user_accesskey_unused has max_unused_access_keys_days set
to 120, but the requirement needs a value of 45 or lower.
Update it to 45 or lower. <original status_extended>
```
The same `Configuration not valid for this requirement.` message appears identically across the CSV, OCSF, and console-table outputs.
### Authoring guidelines
- Declare a guardrail only for keys whose value actually changes whether the requirement is met. Most configurable checks do not need one.
- Set `Value` to the **strictest** configuration the control tolerates — the same number the control text cites (CIS 45 days, NIST ≤90, and so on).
- Keep `ConfigKey` spelled exactly as the check reads it from `audit_config`; an unknown key is never present in the config and the constraint is silently skipped.
- In a **universal (multi-provider) framework**, always set `Provider` to the provider that owns `Check` — otherwise the guardrail would leak onto scans of the other providers the requirement maps. Legacy single-provider files omit it.
- Pick the operator from the value's role: a max threshold is `lte`, a min threshold is `gte`, a toggle is `eq`, an allowlist is `subset`, a denylist is `superset`.
- An unrecognized operator does **not** block the requirement — a malformed constraint is treated as satisfied rather than failing the whole framework. Validate your JSON with the tests below.
### Testing guardrails
The shared evaluator and the per-output integration are covered by:
- `tests/lib/check/compliance_config_eval_test.py` — operator semantics, skipped-key behavior, and the FAIL override.
- `tests/lib/check/compliance_config_constraint_model_test.py` — model validation (types, operator enum, bool-vs-int).
- `tests/lib/check/compliance_config_requirements_data_test.py` — sanity-checks the guardrails shipped in the JSON catalog.
- Per-output tests under `tests/lib/outputs/compliance/` (CIS AWS/Azure, ENS AWS, OCSF, universal table) confirm the override reaches each format.
Run them with:
```bash
uv run pytest -n auto \
tests/lib/check/compliance_config_eval_test.py \
tests/lib/check/compliance_config_constraint_model_test.py \
tests/lib/check/compliance_config_requirements_data_test.py \
tests/lib/outputs/compliance/
```
## Version handling
Prowler matches frameworks by concatenating `Framework` and `Version`. A missing or empty `Version` collapses several frameworks to the same key and breaks CLI filtering with `--compliance`.
@@ -609,7 +788,7 @@ The following issues are the most common when contributing a compliance framewor
- **`ValidationError: field required` during scan (legacy).** The JSON is missing a required attribute field. Re-check the matching Pydantic model in `prowler/lib/check/compliance_models.py`.
- **All attributes collapse to `Generic_Compliance_Requirement_Attribute` values (legacy).** The Pydantic `Union` is ordered incorrectly, or the JSON matches only the generic shape. Keep the generic model in the last Union position and ensure every required field is present in the JSON.
- **`attributes_metadata validation failed` (universal).** The root validator in `compliance_models.py:669` rejected the file. The error message lists each offending requirement; common causes are unknown attribute keys (typo or missing entry in `attributes_metadata`), enum violations, or missing required keys.
- **`attributes_metadata validation failed` (universal).** The root validator in `compliance_models.py` rejected the file. The error message lists each offending requirement; common causes are unknown attribute keys (typo or missing entry in `attributes_metadata`), enum violations, or missing required keys.
- **`--compliance` filter does not find the framework.** For legacy: the filename does not match `<framework>_<version>_<provider>.json`, the version is empty, or the file lives outside `prowler/compliance/<provider>/`. For universal: the file is not at the top level of `prowler/compliance/` or it loaded as `None` (check logs for the validation error).
- **CLI summary table is empty but the CSV is populated (legacy).** The dispatcher branch in `prowler/lib/outputs/compliance/compliance.py` is missing or its substring match does not catch the framework key.
- **CSV file is missing after the scan (legacy).** The transformer class is not registered in `prowler/lib/outputs/compliance/compliance_output.py`, or `transform()` raises silently. Run the scan with `--log-level DEBUG`.
+7 -1
View File
@@ -125,7 +125,10 @@
"user-guide/tutorials/prowler-app-multi-tenant",
"user-guide/tutorials/prowler-app-api-keys",
"user-guide/tutorials/prowler-import-findings",
"user-guide/tutorials/prowler-scan-scheduling",
"user-guide/tutorials/prowler-alerts",
"user-guide/tutorials/prowler-app-scan-configuration",
"user-guide/tutorials/prowler-app-findings-triage",
{
"group": "Mutelist",
"expanded": true,
@@ -235,6 +238,7 @@
"user-guide/providers/azure/authentication",
"user-guide/providers/azure/use-non-default-cloud",
"user-guide/providers/azure/subscriptions",
"user-guide/providers/azure/resource-groups",
"user-guide/providers/azure/create-prowler-service-principal"
]
},
@@ -357,7 +361,8 @@
"group": "Okta",
"pages": [
"user-guide/providers/okta/getting-started-okta",
"user-guide/providers/okta/authentication"
"user-guide/providers/okta/authentication",
"user-guide/providers/okta/retry-configuration"
]
},
{
@@ -396,6 +401,7 @@
"developer-guide/provider",
"developer-guide/services",
"developer-guide/checks",
"developer-guide/secret-scanning-checks",
"developer-guide/outputs",
"developer-guide/integrations",
"developer-guide/security-compliance-framework",
@@ -128,8 +128,8 @@ To update the environment file:
Edit the `.env` file and change version values:
```env
PROWLER_UI_VERSION="5.31.0"
PROWLER_API_VERSION="5.31.0"
PROWLER_UI_VERSION="5.32.0"
PROWLER_API_VERSION="5.32.0"
```
<Note>
Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 481 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 456 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

+8
View File
@@ -0,0 +1,8 @@
export const SubscriptionBanner = ({ children }) => {
return (
<Note>
This feature is available exclusively in <b>Prowler Cloud</b> and <b>Prowler Enterprise</b> with a <a href="https://prowler.com/pricing">subscription</a>.
{children}
</Note>
);
};
@@ -2,6 +2,8 @@
title: "Configuration File"
---
import { VersionBadge } from "/snippets/version-badge.mdx"
Several Prowler's checks have user configurable variables that can be modified in a common **configuration file**. This file can be found in the following [path](https://github.com/prowler-cloud/prowler/blob/master/prowler/config/config.yaml):
```
@@ -24,6 +26,7 @@ The following list includes all the AWS checks with configurable variables that
|---------------------------------------------------------------|--------------------------------------------------|-----------------|
| `acm_certificates_expiration_check` | `days_to_expire_threshold` | Integer |
| `acmpca_certificate_authority_pqc_key_algorithm` | `acmpca_pqc_key_algorithms` | List of Strings |
| `apigateway_restapi_no_secrets_in_stage_variables` | `secrets_ignore_patterns` | List of Strings |
| `appstream_fleet_maximum_session_duration` | `max_session_duration_seconds` | Integer |
| `appstream_fleet_session_disconnect_timeout` | `max_disconnect_timeout_in_seconds` | Integer |
| `appstream_fleet_session_idle_disconnect_timeout` | `max_idle_disconnect_timeout_in_seconds` | Integer |
@@ -86,6 +89,91 @@ The following list includes all the AWS checks with configurable variables that
| `vpc_endpoint_services_allowed_principals_trust_boundaries` | `trusted_account_ids` | List of Strings |
| `opensearch_service_domains_not_publicly_accessible` | `trusted_ips` | List of Strings |
### Resource Scan Limit
<VersionBadge version="5.32.0" />
Some AWS services accumulate large numbers of resources (EBS snapshots, backup recovery points, CloudWatch log groups, Lambda functions, ECS task definitions, and CodeArtifact packages). Scanning every resource increases scan time, cost, API throttling, and finding volume. By default, Prowler scans every resource. Configure a positive resource scan limit to cap how many resources Prowler analyzes for these high-volume AWS resource paths.
The global default applies to the supported resources below and is overridable per resource. The default global value is `0`, which disables the limit and scans every resource. A global `null` value is also unlimited. For per-resource values, `null` means inherit the global default; set `0` or a negative value to disable that resource limit explicitly. Positive values enable limits.
<Warning>
When positive resource scan limits are configured, compliance results are based only on the selected resources, not on the full set of matching resources in the account. Treat compliance summaries and percentages as partial evidence, because unselected resources are not analyzed and can change the real compliance posture.
</Warning>
#### Global Behavior
Resource scan limits select resources for analysis. They do not cap, prioritize, or reorder findings.
* **`0`, negative, or global `null` values:** Disable the limit and keep the legacy behavior for that resource path. Prowler analyzes every discovered matching resource.
* **Positive values:** Select at most that many resources for the affected resource path. A selected resource can produce zero, one, or many findings.
* **No PASS/FAIL prioritization:** Prowler does not inspect the compliance result before selecting resources. Limits do not prefer failed resources, passed resources, or resources with more findings.
* **Latest-first where possible:** When AWS exposes timestamps or useful ordering, Prowler selects the newest resources first. When AWS only exposes API order, Prowler preserves that API order and documents the behavior as best effort.
* **Findings are downstream:** Checks only evaluate the resources exposed by the service client after selection. Findings from unselected resources are not produced because those resources are not analyzed.
Exact list API call reduction depends on each AWS API's ordering and pagination capabilities. When Prowler must enumerate candidates locally to select the latest resources, list calls may still read candidates, but expensive per-resource enrichment calls are bounded to the selected resources for the supported paths below.
#### Full Collections Versus Limited Analysis Sets
Some checks need lightweight evidence from a complete resource collection to avoid incorrect cross-service conclusions, while other checks perform primary analysis on a limited resource set.
Prowler keeps full lightweight collections where they are needed for cross-service evidence. For example:
* **Lambda security groups and regions:** Prowler records security groups used by all discovered Lambda functions and the regions where functions exist before it limits Lambda functions for primary Lambda checks. This helps Amazon EC2 and Amazon Inspector checks avoid false positives such as treating Lambda security groups as unused or assuming a region has no Lambda functions.
* **CloudWatch `all_log_groups`:** Prowler records all discovered CloudWatch log groups in `all_log_groups` before limiting the primary `log_groups` analysis set. Other services can still resolve log group evidence, while CloudWatch log group checks only analyze the selected log groups.
This split is intentional. It reduces expensive per-resource analysis calls without discarding lightweight context that other services need for accurate results.
#### Supported AWS Resource Limits
| Value | Scope | Type |
|-------|-------|------|
| `max_scanned_resources_per_service` | Global default for all supported high-volume AWS resources (default `0`, disabled/unlimited) | Integer |
| `max_ebs_snapshots` | EBS snapshots (`ec2_ebs_*` checks) | Integer |
| `max_backup_recovery_points` | Backup recovery points (`backup_recovery_point_*`) | Integer |
| `max_cloudwatch_log_groups` | CloudWatch log groups (`cloudwatch_log_group_*`) | Integer |
| `max_lambda_functions` | Lambda functions (`awslambda_function_*`) | Integer |
| `max_ecs_task_definitions` | ECS task definitions (`ecs_task_definitions_*`) | Integer |
| `max_codeartifact_packages` | CodeArtifact packages (`codeartifact_packages_*`) | Integer |
#### Resource Limit Behavior By Resource Path
| Resource Path | What Prowler Discovers | What A Positive Limit Selects For Analysis | Ordering And Latest Behavior | AWS Calls Reduced | Drawbacks And Consequences |
|---------------|------------------------|--------------------------------------------|------------------------------|-------------------|----------------------------|
| EBS snapshots (`max_ebs_snapshots`) | Prowler lists self-owned snapshots and keeps lightweight evidence that volumes and regions have snapshots. | The selected EBS snapshots exposed to `ec2_ebs_*` checks. | Prowler sorts discovered snapshots by `StartTime` newest first, then applies the limit. Snapshots without a timestamp sort last. | Bounds expensive per-snapshot public attribute checks to selected snapshots. Snapshot listing still runs so Prowler can choose the newest snapshots and keep volume/region evidence. | Older unselected snapshots are not analyzed by snapshot checks. A public, unencrypted, or otherwise noncompliant older snapshot can be missed when the limit is lower than the number of snapshots. |
| Backup recovery points (`max_backup_recovery_points`) | Prowler lists backup vaults, plans, selections, and recovery point candidates in discovered vaults. | The selected recovery points exposed to `backup_recovery_point_*` checks and tag hydration. | Prowler sorts discovered recovery points by `CreationDate` newest first across vaults, then applies the limit. Recovery points without a timestamp sort last. | Bounds recovery point tag calls to selected recovery points. Vault and recovery point list calls still run so Prowler can choose the newest points. | Older unselected recovery points are not analyzed. A nonencrypted or otherwise noncompliant older recovery point can be missed. |
| CloudWatch log groups (`max_cloudwatch_log_groups`) | Prowler lists log groups into both `all_log_groups` and the primary `log_groups` collection. `all_log_groups` remains available as lightweight cross-service evidence. | The selected log groups exposed to `cloudwatch_log_group_*` checks, tag hydration, and log event retrieval for checks that need log contents. | Prowler sorts discovered log groups by `creationTime` newest first, then applies the limit. Log groups without a creation time sort last. | Bounds tag calls and log event retrieval to selected log groups. Log group listing still runs to build `all_log_groups` and choose newest log groups. | Older unselected log groups are not analyzed by CloudWatch log group checks. Retention, encryption, or secrets-in-logs issues in older log groups can be missed, although cross-service evidence can still use `all_log_groups`. |
| Lambda functions (`max_lambda_functions`) | Prowler lists Lambda functions and records lightweight security group and region evidence for all discovered functions. | The selected Lambda functions exposed to `awslambda_function_*` checks and per-function enrichment such as tags, policies, function URLs, and event source mappings. | Prowler sorts discovered functions by `LastModified` newest first, then applies the limit. Functions without `LastModified` sort last. | Bounds per-function enrichment calls to selected functions. Function listing still runs to choose newest functions and keep security group/region evidence. | Older unselected functions are not analyzed by Lambda checks. Runtime, policy, URL, environment secret, or dead-letter queue issues in older functions can be missed. Cross-service checks can still use full Lambda security group and region evidence to avoid false positives. |
| ECS task definitions (`max_ecs_task_definitions`) | Prowler lists ECS task definition ARN candidates in each region. Candidate ARNs can remain visible and discoverable through AWS list operations, even when not all are described. | The selected task definitions that Prowler describes and exposes to `ecs_task_definitions_*` checks. | Selection is not random. Prowler calls `ListTaskDefinitions` with `sort=DESC`, which asks AWS to return task definition ARNs in descending family and revision order. Prowler then interleaves regional candidate lists to avoid starving later regions before applying the limit. This selects the latest task definition revisions according to the ARN order AWS provides, while preserving regional fairness. | Bounds `DescribeTaskDefinition` calls to selected task definitions. Prowler may still list candidates so it can select the bounded set and keep discovery deterministic. | Unselected task definitions are not described or analyzed. Issues in older task definition revisions, or in lower-priority families outside the selected AWS `sort=DESC` order, can be missed. Because ECS ordering is family/revision based rather than a registration timestamp sort across every family, this is latest-first according to AWS task definition ARN ordering, not a global newest-by-time guarantee. |
| CodeArtifact packages (`max_codeartifact_packages`) | Prowler lists CodeArtifact repositories and lazily lists packages inside them. | The selected packages exposed to `codeartifact_packages_*` checks, including latest-version metadata for those packages. | AWS `ListPackages` does not provide a newest-package timestamp ordering in this path. Prowler preserves repository order and package API order, then applies the limit. Latest package version metadata is retrieved for selected packages with `sortBy=PUBLISHED_TIME` and `maxResults=1`. | Bounds `ListPackageVersions` calls to selected packages and can stop package listing once the limit is reached. Repository listing still runs. | Package selection is best effort by API order, not newest package order. Packages outside the selected repository/API order are not analyzed, so origin restriction or latest-version issues can be missed. |
Use limits when scan duration, API throttling, or cost are more important than exhaustive coverage for these high-volume resources. Keep limits disabled when you need complete evidence for every resource in the affected checks.
### Validating Discovered Secrets
<VersionBadge version="5.32.0" />
By default, the secret-scanning checks run fully offline: secrets are detected but never sent anywhere. Setting `secrets_validate` to `True` additionally confirms whether each discovered secret is live by authenticating with it against the corresponding provider API. The discovered secret itself serves as the credential, so Prowler requires no additional permissions to validate it.
`secrets_validate` applies to every AWS secret-scanning check listed above (those that accept `secrets_ignore_patterns`). The `--scan-secrets-validate` CLI flag is provider-wide: it also enables validation for the secret-scanning checks of other providers, such as the OpenStack metadata checks.
To enable validation through the configuration file, set the value under the `aws` section:
```yaml
aws:
secrets_validate: True
```
To enable validation for a single scan (any provider), use Prowler CLI:
```
prowler aws --scan-secrets-validate
```
<Warning>
Secret validation makes outbound network calls that authenticate with each discovered secret. The credential is exercised against the provider, so the call appears in the audited account's logs and can trigger its monitoring (for example, AWS CloudTrail records the validation request). Validation stays disabled by default so that scans remain fully offline.
</Warning>
## Azure
@@ -191,6 +279,19 @@ aws:
# AWS Global Configuration
# aws.mute_non_default_regions --> Set to True to muted failed findings in non-default regions for AccessAnalyzer, GuardDuty, SecurityHub, DRS and Config
mute_non_default_regions: False
# AWS Resource Scan Limit Configuration
# Disabled by default: scan every resource unless a positive limit is configured.
# Findings are not capped. Set to 0 (or a negative value) to disable the limit.
# aws.max_scanned_resources_per_service --> global default for all services below
max_scanned_resources_per_service: 0
# Per-service overrides. Leave as null to fall back to the global default.
max_ebs_snapshots: null
max_backup_recovery_points: null
max_cloudwatch_log_groups: null
max_lambda_functions: null
max_ecs_task_definitions: null
max_codeartifact_packages: null
# If you want to mute failed findings only in specific regions, create a file with the following syntax and run it with `prowler aws -w mutelist.yaml`:
# Mutelist:
# Accounts:
+17 -3
View File
@@ -6,20 +6,34 @@ Prowler has some checks that analyse pentesting risks (Secrets, Internet Exposed
## Detect Secrets
Prowler uses `detect-secrets` library to search for any secrets that are stores in plaintext within your environment.
Prowler scans for secrets stored in plaintext within the audited environment using [Kingfisher](https://github.com/mongodb/kingfisher), an open-source secret-scanning engine. By default these scans run fully offline, so no data leaves the audited environment. Discovered secrets can optionally be validated against the provider APIs to confirm whether they are live — see [Validating Discovered Secrets](/user-guide/cli/tutorials/configuration_file#validating-discovered-secrets).
The actual checks that have this functionality are the following:
The checks with this functionality are the following.
AWS:
- apigateway\_restapi\_no\_secrets\_in\_stage\_variables
- autoscaling\_find\_secrets\_ec2\_launch\_configuration
- awslambda\_function\_no\_secrets\_in\_code
- awslambda\_function\_no\_secrets\_in\_variables
- cloudformation\_stack\_outputs\_find\_secrets
- cloudwatch\_log\_group\_no\_secrets\_in\_logs
- codebuild\_project\_no\_secrets\_in\_variables
- ec2\_instance\_secrets\_user\_data
- ec2\_launch\_template\_no\_secrets
- ecs\_task\_definitions\_no\_environment\_secrets
- glue\_etl\_jobs\_no\_secrets\_in\_arguments
- ssm\_document\_secrets
- stepfunctions\_statemachine\_no\_secrets\_in\_definition
To execute detect-secrets related checks, you can run the following command:
OpenStack:
- compute\_instance\_metadata\_sensitive\_data
- blockstorage\_volume\_metadata\_sensitive\_data
- blockstorage\_snapshot\_metadata\_sensitive\_data
- objectstorage\_container\_metadata\_sensitive\_data
To execute the secret-scanning checks, run the following command:
```console
prowler <provider> --categories secrets
@@ -0,0 +1,47 @@
---
title: 'Azure Resource Group Scope'
---
Prowler supports narrowing security scans to specific resource groups within Azure subscriptions. This is useful when you want to audit only a subset of resources rather than scanning an entire subscription.
By default, Prowler scans all resource groups it has permission to access. Passing `--azure-resource-group` limits the scan to only the specified resource groups across all accessible subscriptions.
## Configuring Resource Group Scoped Scans
To restrict a scan to one or more resource groups, pass them as arguments using the `--azure-resource-group` flag:
```console
prowler azure --az-cli-auth --azure-resource-group <resource-group-1> <resource-group-2> ... <resource-group-N>
```
For example, to scan only `rg-production` and `rg-staging`:
```console
prowler azure --az-cli-auth --azure-resource-group rg-prod1 rg-prod2
```
This works with all supported authentication methods:
```console
# Service Principal
prowler azure --sp-env-auth --azure-resource-group rg-production
# Browser
prowler azure --browser-auth --tenant-id <tenant-id> --azure-resource-group rg-production
# Managed Identity
prowler azure --managed-identity-auth --azure-resource-group rg-production
```
## How It Works
When `--azure-resource-group` is provided, Prowler validates each specified resource group against all accessible subscriptions. A resource group is included in the scan if it exists in **at least one** subscription.
- If a resource group is found in one or more subscriptions, it will be scanned in those subscriptions only.
- If a resource group is **not found in any** subscription, Prowler logs a warning and skips it.
- If **none** of the provided resource groups are found across any subscription, Prowler logs a warning and no resource group scoped checks will run.
- Resource group names are matched case-insensitively, so `MyGroup` and `mygroup` are treated as the same group, mirroring Azure's own behavior.
<Warning>
If `--azure-resource-group` is used, checks that apply to specific resources are limited to the relevant resource groups. But if checks that apply to tenant or subscription scope (identity, policy, or subscription-level configuration checks) are involved, then these checks will run in their natural scope.
</Warning>
@@ -0,0 +1,123 @@
---
title: "Okta Rate Limit Configuration in Prowler"
---
import { VersionBadge } from "/snippets/version-badge.mdx"
<VersionBadge version="5.32.0" />
Prowler's Okta Provider manages API rate limits with two complementary controls:
- **Request throttling (proactive):** Prowler paces outbound requests through a shared limiter so scans stay under Okta's rate limits and rarely trigger a rate-limit response in the first place.
- **Retries (reactive):** When Okta still returns a rate-limit response (HTTP 429), the official Okta Python SDK reads the `X-Rate-Limit-Reset` header and waits until the window resets before retrying. This acts as a safety net for occasional bursts.
Both controls are configurable through the configuration file or command line flags.
## Request Throttling (Requests per Second)
Throttling is the primary control for avoiding rate limits. Prowler limits the aggregate number of Okta API requests per second across every service in a scan.
### Using the Command Line Flag
```bash
prowler okta --okta-requests-per-second 4
```
Set the value to `0` to disable throttling.
### Using the Configuration File
```yaml
okta:
# Maximum aggregate Okta API requests per second. Default: 4. Set to 0 to disable.
okta_requests_per_second: 4
```
Okta enforces rate limits per endpoint, so this single global cap is a deliberately simple control. Lower the value if scans still hit limits on large organizations; raise it to scan faster when the organization has generous limits.
## Retries
Retries cover the cases throttling does not prevent, such as short bursts or per-endpoint limits lower than the global cap.
### Using the Command Line Flag
```bash
prowler okta --okta-retries-max-attempts 8
```
### Using the Configuration File
```yaml
okta:
# Maximum retries on HTTP 429. Default: 5.
okta_max_retries: 8
# Per-request timeout in seconds. Default: 300.
okta_request_timeout: 300
```
The command line flags override the configuration file values.
## How It Works
- **Automatic detection:** The Okta SDK retries the retryable statuses 429, 503, and 504.
- **Reset-aware backoff:** On a 429 response the SDK sleeps until the `X-Rate-Limit-Reset` window before each retry, rather than using a fixed delay.
- **Bounded attempts:** `okta_max_retries` caps how many times a single request is retried. The Okta SDK default is 2, which is often too low for large organizations, so Prowler defaults to 5.
## Request Timeout
The `okta_request_timeout` setting plays a dual role in the Okta SDK:
- It is the per-request socket timeout, bounding how long a single HTTP call can hang.
- It is also the total wall-clock budget for the whole retry-and-backoff loop of one request.
For this reason, the value defaults to 300 seconds rather than 0 (no timeout). A value of 0 leaves hung connections unbounded, while a value that is too low cuts the rate-limit waits short and reintroduces the errors. As a guideline, keep `okta_request_timeout` greater than or equal to `okta_max_retries` multiplied by 60 when raising the retry count, because Okta reset windows are typically up to one minute.
## Error Example Handled
```
Okta HTTP 429: Too Many Requests. Hit rate limit. Retry request in 42 seconds.
```
## Validation
### Debug Logging
To confirm that throttling and retries are active, run a scan with debug logging:
```bash
prowler okta --okta-requests-per-second 4 --log-level DEBUG --log-file debuglogs.txt
```
### Check the Messages
```bash
grep -i "throttling\|rate limit\|retry" debuglogs.txt
```
### Expected Output
When throttling is enabled, Prowler logs the configured rate at startup:
```
Okta request throttling enabled at 4 req/s
```
If a rate limit is still hit, the SDK logs the backoff:
```
Hit rate limit. Retry request in 42 seconds.
```
## Troubleshooting
If scans continue to hit rate limits:
1. Lower `--okta-requests-per-second` so requests are paced more conservatively.
2. Raise `--okta-retries-max-attempts` (and keep `okta_request_timeout` proportionally large) so the safety net absorbs more bursts.
3. Review the rate-limit allocation for the Okta organization and request an increase if needed.
4. Verify throttling and retry behavior with debug logging.
## Official References
- [Okta Rate Limits](https://developer.okta.com/docs/reference/rate-limits/)
- [Okta SDK for Python](https://github.com/okta/okta-sdk-python)
+2 -3
View File
@@ -4,14 +4,13 @@ description: 'Create email alerts from Prowler Cloud findings to monitor relevan
---
import { VersionBadge } from "/snippets/version-badge.mdx"
import { SubscriptionBanner } from "/snippets/subscription-banner.mdx"
<VersionBadge version="5.26.0" />
Alerts notify recipients by email when security findings match saved filter conditions. Use Alerts to track high-priority findings, monitor specific providers or services, and keep teams informed about scan results that match defined criteria.
<Note>
This feature is available exclusively in **Prowler Cloud** and **Prowler Enterprise** with a [paid subscription](https://prowler.com/pricing).
</Note>
<SubscriptionBanner />
## Prerequisites

Some files were not shown because too many files have changed in this diff Show More