mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-06-10 21:42:29 +00:00
Compare commits
93 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 68e11086e9 | |||
| 6bb8dc6168 | |||
| 9e7ecb39fa | |||
| 255ce0e866 | |||
| dce406b39b | |||
| 28c36cc5fc | |||
| 8242b21f34 | |||
| 1897e38c6b | |||
| 19e6daeac3 | |||
| 3d6aa6c650 | |||
| ee93ad6cbc | |||
| 7f4c02c738 | |||
| d386730770 | |||
| 598035b381 | |||
| 5784592437 | |||
| 35f263dea6 | |||
| a1637ec46b | |||
| 6c6a6c55cf | |||
| 31b53f091b | |||
| f7a16fff99 | |||
| cb5c9ea1c5 | |||
| cb367da97d | |||
| be2a58dc82 | |||
| 29133f2d7e | |||
| babf18ffea | |||
| b6a34d2220 | |||
| 77dc79df32 | |||
| 91e3c01f51 | |||
| 6cb0edf3e1 | |||
| 7dfafb9337 | |||
| dce05295ef | |||
| 03d4c19ed5 | |||
| 963ece9a0b | |||
| a32eff6946 | |||
| 3bb326133a | |||
| 799826758e | |||
| 1208005a94 | |||
| ecdece9f1e | |||
| 9c2c555628 | |||
| e4640a0497 | |||
| ca2f3ccc1c | |||
| 9ffa0043ab | |||
| e76ecfdd4d | |||
| f11f71bc42 | |||
| 607cfd61ef | |||
| 9c76dafaa4 | |||
| 7b839d9f9e | |||
| f39a82fdf4 | |||
| d1a7eed5fa | |||
| 5be4ec511f | |||
| a0166aede7 | |||
| 1a2a2ea3cc | |||
| e61d1401b9 | |||
| a2789b7fc6 | |||
| 34217492d0 | |||
| ed50ed1e6d | |||
| 186977f81c | |||
| c33f20ad72 | |||
| d0b0c66ef0 | |||
| e849959fd5 | |||
| 7c090a6a07 | |||
| bc4484f269 | |||
| 7601142e42 | |||
| f47310bceb | |||
| 032499c29a | |||
| d7af97b30a | |||
| aa24034ca7 | |||
| ec4eb70539 | |||
| 76a8610121 | |||
| d5e2c930a9 | |||
| 2c4f866e42 | |||
| 31845df1a7 | |||
| d8c1273a57 | |||
| 3317c0a5e0 | |||
| 847645543a | |||
| 76aa65cb61 | |||
| 484a1d1fef | |||
| c8bc0576ea | |||
| 76cda6d777 | |||
| 28978f6db6 | |||
| d4bc6d7531 | |||
| 1bf49747ad | |||
| 2cde4c939d | |||
| 9844379d30 | |||
| 211b1b67f9 | |||
| 864b2099c3 | |||
| 270266c906 | |||
| c8fab497fd | |||
| b0eea61468 | |||
| 463fc32fca | |||
| 17f5633a8d | |||
| 48274f1d54 | |||
| 9719f9ee86 |
@@ -48,6 +48,26 @@ POSTGRES_DB=prowler_db
|
||||
# POSTGRES_REPLICA_MAX_ATTEMPTS=3
|
||||
# POSTGRES_REPLICA_RETRY_BASE_DELAY=0.5
|
||||
|
||||
# Neo4j auth
|
||||
NEO4J_HOST=neo4j
|
||||
NEO4J_PORT=7687
|
||||
NEO4J_USER=neo4j
|
||||
NEO4J_PASSWORD=neo4j_password
|
||||
# Neo4j settings
|
||||
NEO4J_DBMS_MAX__DATABASES=1000
|
||||
NEO4J_SERVER_MEMORY_PAGECACHE_SIZE=1G
|
||||
NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE=1G
|
||||
NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE=1G
|
||||
NEO4J_POC_EXPORT_FILE_ENABLED=true
|
||||
NEO4J_APOC_IMPORT_FILE_ENABLED=true
|
||||
NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG=true
|
||||
NEO4J_PLUGINS=["apoc"]
|
||||
NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST=apoc.*
|
||||
NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED=apoc.*
|
||||
NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS=0.0.0.0:7687
|
||||
# Neo4j Prowler settings
|
||||
ATTACK_PATHS_FINDINGS_BATCH_SIZE=1000
|
||||
|
||||
# Celery-Prowler task settings
|
||||
TASK_RETRY_DELAY_SECONDS=0.1
|
||||
TASK_RETRY_ATTEMPTS=5
|
||||
@@ -117,7 +137,6 @@ SENTRY_ENVIRONMENT=local
|
||||
SENTRY_RELEASE=local
|
||||
NEXT_PUBLIC_SENTRY_ENVIRONMENT=${SENTRY_ENVIRONMENT}
|
||||
|
||||
|
||||
#### Prowler release version ####
|
||||
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.16.0
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ runs:
|
||||
run: |
|
||||
BRANCH_NAME="${GITHUB_HEAD_REF:-${GITHUB_REF_NAME}}"
|
||||
echo "Using branch: $BRANCH_NAME"
|
||||
sed -i "s|@master|@$BRANCH_NAME|g" pyproject.toml
|
||||
sed -i "s|\(git+https://github.com/prowler-cloud/prowler[^@]*\)@master|\1@$BRANCH_NAME|g" pyproject.toml
|
||||
|
||||
- name: Install poetry
|
||||
shell: bash
|
||||
|
||||
+13
-2
@@ -46,12 +46,17 @@ provider/oci:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: "prowler/providers/oraclecloud/**"
|
||||
- any-glob-to-any-file: "tests/providers/oraclecloud/**"
|
||||
|
||||
|
||||
provider/alibabacloud:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: "prowler/providers/alibabacloud/**"
|
||||
- any-glob-to-any-file: "tests/providers/alibabacloud/**"
|
||||
|
||||
provider/cloudflare:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: "prowler/providers/cloudflare/**"
|
||||
- any-glob-to-any-file: "tests/providers/cloudflare/**"
|
||||
|
||||
github_actions:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: ".github/workflows/*"
|
||||
@@ -67,15 +72,21 @@ mutelist:
|
||||
- any-glob-to-any-file: "prowler/providers/azure/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "prowler/providers/gcp/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "prowler/providers/kubernetes/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "prowler/providers/m365/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "prowler/providers/mongodbatlas/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "prowler/providers/oraclecloud/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "prowler/providers/alibabacloud/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "prowler/providers/cloudflare/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/aws/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/azure/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/gcp/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/kubernetes/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/m365/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/mongodbatlas/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/oci/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/oraclecloud/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/alibabacloud/lib/mutelist/**"
|
||||
- any-glob-to-any-file: "tests/providers/cloudflare/lib/mutelist/**"
|
||||
|
||||
integration/s3:
|
||||
- changed-files:
|
||||
|
||||
@@ -46,6 +46,7 @@ jobs:
|
||||
api/docs/**
|
||||
api/README.md
|
||||
api/CHANGELOG.md
|
||||
api/AGENTS.md
|
||||
|
||||
- name: Setup Python with Poetry
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -74,6 +74,7 @@ jobs:
|
||||
api/docs/**
|
||||
api/README.md
|
||||
api/CHANGELOG.md
|
||||
api/AGENTS.md
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -46,6 +46,7 @@ jobs:
|
||||
api/docs/**
|
||||
api/README.md
|
||||
api/CHANGELOG.md
|
||||
api/AGENTS.md
|
||||
|
||||
- name: Setup Python with Poetry
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
@@ -60,7 +61,8 @@ jobs:
|
||||
|
||||
- name: Safety
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
run: poetry run safety check
|
||||
run: poetry run safety check --ignore 79023,79027
|
||||
# TODO: 79023 & 79027 knack ReDoS until `azure-cli-core` (via `cartography`) allows `knack` >=0.13.0
|
||||
|
||||
- name: Vulture
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -86,6 +86,7 @@ jobs:
|
||||
api/docs/**
|
||||
api/README.md
|
||||
api/CHANGELOG.md
|
||||
api/AGENTS.md
|
||||
|
||||
- name: Setup Python with Poetry
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -42,14 +42,16 @@ jobs:
|
||||
ui/**
|
||||
prowler/**
|
||||
mcp_server/**
|
||||
poetry.lock
|
||||
pyproject.toml
|
||||
|
||||
- name: Check for folder changes and changelog presence
|
||||
id: check-folders
|
||||
run: |
|
||||
missing_changelogs=""
|
||||
|
||||
# Check api folder
|
||||
if [[ "${{ steps.changed-files.outputs.any_changed }}" == "true" ]]; then
|
||||
# Check monitored folders
|
||||
for folder in $MONITORED_FOLDERS; do
|
||||
# Get files changed in this folder
|
||||
changed_in_folder=$(echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr ' ' '\n' | grep "^${folder}/" || true)
|
||||
@@ -64,6 +66,22 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Check root-level dependency files (poetry.lock, pyproject.toml)
|
||||
# These are associated with the prowler folder changelog
|
||||
root_deps_changed=$(echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr ' ' '\n' | grep -E "^(poetry\.lock|pyproject\.toml)$" || true)
|
||||
if [ -n "$root_deps_changed" ]; then
|
||||
echo "Detected changes in root dependency files: $root_deps_changed"
|
||||
# Check if prowler/CHANGELOG.md was already updated (might have been caught above)
|
||||
prowler_changelog_updated=$(echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr ' ' '\n' | grep "^prowler/CHANGELOG.md$" || true)
|
||||
if [ -z "$prowler_changelog_updated" ]; then
|
||||
# Only add if prowler wasn't already flagged
|
||||
if ! echo "$missing_changelogs" | grep -q "prowler"; then
|
||||
echo "No changelog update found for root dependency changes"
|
||||
missing_changelogs="${missing_changelogs}- \`prowler\` (root dependency files changed)"$'\n'
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
{
|
||||
|
||||
@@ -47,6 +47,7 @@ jobs:
|
||||
ui/**
|
||||
dashboard/**
|
||||
mcp_server/**
|
||||
skills/**
|
||||
README.md
|
||||
mkdocs.yml
|
||||
.backportrc.json
|
||||
@@ -55,6 +56,7 @@ jobs:
|
||||
examples/**
|
||||
.gitignore
|
||||
contrib/**
|
||||
**/AGENTS.md
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
@@ -83,7 +85,7 @@ jobs:
|
||||
|
||||
- name: Check format with black
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
run: poetry run black --exclude api ui skills --check .
|
||||
run: poetry run black --exclude "api|ui|skills" --check .
|
||||
|
||||
- name: Lint with pylint
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -78,6 +78,7 @@ jobs:
|
||||
ui/**
|
||||
dashboard/**
|
||||
mcp_server/**
|
||||
skills/**
|
||||
README.md
|
||||
mkdocs.yml
|
||||
.backportrc.json
|
||||
@@ -86,6 +87,7 @@ jobs:
|
||||
examples/**
|
||||
.gitignore
|
||||
contrib/**
|
||||
**/AGENTS.md
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -42,6 +42,7 @@ jobs:
|
||||
ui/**
|
||||
dashboard/**
|
||||
mcp_server/**
|
||||
skills/**
|
||||
README.md
|
||||
mkdocs.yml
|
||||
.backportrc.json
|
||||
@@ -50,6 +51,7 @@ jobs:
|
||||
examples/**
|
||||
.gitignore
|
||||
contrib/**
|
||||
**/AGENTS.md
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -47,6 +47,7 @@ jobs:
|
||||
ui/**
|
||||
dashboard/**
|
||||
mcp_server/**
|
||||
skills/**
|
||||
README.md
|
||||
mkdocs.yml
|
||||
.backportrc.json
|
||||
@@ -55,6 +56,7 @@ jobs:
|
||||
examples/**
|
||||
.gitignore
|
||||
contrib/**
|
||||
**/AGENTS.md
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -73,6 +73,7 @@ jobs:
|
||||
files_ignore: |
|
||||
ui/CHANGELOG.md
|
||||
ui/README.md
|
||||
ui/AGENTS.md
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
@@ -116,7 +116,7 @@ jobs:
|
||||
- name: Setup Node.js environment
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0
|
||||
with:
|
||||
node-version: '20.x'
|
||||
node-version: '24.13.0'
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
@@ -125,16 +125,20 @@ jobs:
|
||||
- name: Get pnpm store directory
|
||||
shell: bash
|
||||
run: echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
|
||||
- name: Setup pnpm cache
|
||||
- name: Setup pnpm and Next.js cache
|
||||
uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1
|
||||
with:
|
||||
path: ${{ env.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('ui/pnpm-lock.yaml') }}
|
||||
path: |
|
||||
${{ env.STORE_PATH }}
|
||||
./ui/node_modules
|
||||
./ui/.next/cache
|
||||
key: ${{ runner.os }}-pnpm-nextjs-${{ hashFiles('ui/pnpm-lock.yaml') }}-${{ hashFiles('ui/**/*.ts', 'ui/**/*.tsx', 'ui/**/*.js', 'ui/**/*.jsx') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
${{ runner.os }}-pnpm-nextjs-${{ hashFiles('ui/pnpm-lock.yaml') }}-
|
||||
${{ runner.os }}-pnpm-nextjs-
|
||||
- name: Install UI dependencies
|
||||
working-directory: ./ui
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile --prefer-offline
|
||||
- name: Build UI application
|
||||
working-directory: ./ui
|
||||
run: pnpm run build
|
||||
|
||||
@@ -16,7 +16,7 @@ concurrency:
|
||||
|
||||
env:
|
||||
UI_WORKING_DIR: ./ui
|
||||
NODE_VERSION: '20.x'
|
||||
NODE_VERSION: '24.13.0'
|
||||
|
||||
jobs:
|
||||
ui-tests:
|
||||
@@ -42,6 +42,7 @@ jobs:
|
||||
files_ignore: |
|
||||
ui/CHANGELOG.md
|
||||
ui/README.md
|
||||
ui/AGENTS.md
|
||||
|
||||
- name: Setup Node.js ${{ env.NODE_VERSION }}
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
@@ -61,18 +62,22 @@ jobs:
|
||||
shell: bash
|
||||
run: echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
|
||||
|
||||
- name: Setup pnpm cache
|
||||
- name: Setup pnpm and Next.js cache
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1
|
||||
with:
|
||||
path: ${{ env.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('ui/pnpm-lock.yaml') }}
|
||||
path: |
|
||||
${{ env.STORE_PATH }}
|
||||
${{ env.UI_WORKING_DIR }}/node_modules
|
||||
${{ env.UI_WORKING_DIR }}/.next/cache
|
||||
key: ${{ runner.os }}-pnpm-nextjs-${{ hashFiles('ui/pnpm-lock.yaml') }}-${{ hashFiles('ui/**/*.ts', 'ui/**/*.tsx', 'ui/**/*.js', 'ui/**/*.jsx') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
${{ runner.os }}-pnpm-nextjs-${{ hashFiles('ui/pnpm-lock.yaml') }}-
|
||||
${{ runner.os }}-pnpm-nextjs-
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile --prefer-offline
|
||||
|
||||
- name: Run healthcheck
|
||||
if: steps.check-changes.outputs.any_changed == 'true'
|
||||
|
||||
+3
-1
@@ -150,8 +150,10 @@ node_modules
|
||||
# Persistent data
|
||||
_data/
|
||||
|
||||
# Claude
|
||||
# AI Instructions (generated by skills/setup.sh from AGENTS.md)
|
||||
CLAUDE.md
|
||||
GEMINI.md
|
||||
.github/copilot-instructions.md
|
||||
|
||||
# Compliance report
|
||||
*.pdf
|
||||
|
||||
@@ -42,7 +42,7 @@ repos:
|
||||
"--remove-unused-variable",
|
||||
]
|
||||
|
||||
- repo: https://github.com/timothycrosley/isort
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
@@ -120,7 +120,8 @@ repos:
|
||||
name: safety
|
||||
description: "Safety is a tool that checks your installed dependencies for known security vulnerabilities"
|
||||
# TODO: Botocore needs urllib3 1.X so we need to ignore these vulnerabilities 77744,77745. Remove this once we upgrade to urllib3 2.X
|
||||
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353,77744,77745'
|
||||
# TODO: 79023 & 79027 knack ReDoS until `azure-cli-core` (via `cartography`) allows `knack` >=0.13.0
|
||||
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353,77744,77745,79023,79027'
|
||||
language: system
|
||||
|
||||
- id: vulture
|
||||
|
||||
@@ -20,6 +20,7 @@ Use these skills for detailed patterns on-demand:
|
||||
| `playwright` | Page Object Model, MCP workflow, selectors | [SKILL.md](skills/playwright/SKILL.md) |
|
||||
| `pytest` | Fixtures, mocking, markers, parametrize | [SKILL.md](skills/pytest/SKILL.md) |
|
||||
| `django-drf` | ViewSets, Serializers, Filters | [SKILL.md](skills/django-drf/SKILL.md) |
|
||||
| `jsonapi` | Strict JSON:API v1.1 spec compliance | [SKILL.md](skills/jsonapi/SKILL.md) |
|
||||
| `zod-4` | New API (z.email(), z.uuid()) | [SKILL.md](skills/zod-4/SKILL.md) |
|
||||
| `zustand-5` | Persist, selectors, slices | [SKILL.md](skills/zustand-5/SKILL.md) |
|
||||
| `ai-sdk-5` | UIMessage, streaming, LangChain | [SKILL.md](skills/ai-sdk-5/SKILL.md) |
|
||||
@@ -38,10 +39,74 @@ Use these skills for detailed patterns on-demand:
|
||||
| `prowler-compliance` | Compliance framework structure | [SKILL.md](skills/prowler-compliance/SKILL.md) |
|
||||
| `prowler-compliance-review` | Review compliance framework PRs | [SKILL.md](skills/prowler-compliance-review/SKILL.md) |
|
||||
| `prowler-provider` | Add new cloud providers | [SKILL.md](skills/prowler-provider/SKILL.md) |
|
||||
| `prowler-changelog` | Changelog entries (keepachangelog.com) | [SKILL.md](skills/prowler-changelog/SKILL.md) |
|
||||
| `prowler-ci` | CI checks and PR gates (GitHub Actions) | [SKILL.md](skills/prowler-ci/SKILL.md) |
|
||||
| `prowler-commit` | Professional commits (conventional-commits) | [SKILL.md](skills/prowler-commit/SKILL.md) |
|
||||
| `prowler-pr` | Pull request conventions | [SKILL.md](skills/prowler-pr/SKILL.md) |
|
||||
| `prowler-docs` | Documentation style guide | [SKILL.md](skills/prowler-docs/SKILL.md) |
|
||||
| `skill-creator` | Create new AI agent skills | [SKILL.md](skills/skill-creator/SKILL.md) |
|
||||
|
||||
### Auto-invoke Skills
|
||||
|
||||
When performing these actions, ALWAYS invoke the corresponding skill FIRST:
|
||||
|
||||
| Action | Skill |
|
||||
|--------|-------|
|
||||
| Add changelog entry for a PR or feature | `prowler-changelog` |
|
||||
| Adding DRF pagination or permissions | `django-drf` |
|
||||
| Adding new providers | `prowler-provider` |
|
||||
| Adding services to existing providers | `prowler-provider` |
|
||||
| After creating/modifying a skill | `skill-sync` |
|
||||
| App Router / Server Actions | `nextjs-15` |
|
||||
| Building AI chat features | `ai-sdk-5` |
|
||||
| Committing changes | `prowler-commit` |
|
||||
| Create PR that requires changelog entry | `prowler-changelog` |
|
||||
| Create a PR with gh pr create | `prowler-pr` |
|
||||
| Creating API endpoints | `jsonapi` |
|
||||
| Creating ViewSets, serializers, or filters in api/ | `django-drf` |
|
||||
| Creating Zod schemas | `zod-4` |
|
||||
| Creating a git commit | `prowler-commit` |
|
||||
| Creating new checks | `prowler-sdk-check` |
|
||||
| Creating new skills | `skill-creator` |
|
||||
| Creating/modifying Prowler UI components | `prowler-ui` |
|
||||
| Creating/modifying models, views, serializers | `prowler-api` |
|
||||
| Creating/updating compliance frameworks | `prowler-compliance` |
|
||||
| Debug why a GitHub Actions job is failing | `prowler-ci` |
|
||||
| Fill .github/pull_request_template.md (Context/Description/Steps to review/Checklist) | `prowler-pr` |
|
||||
| General Prowler development questions | `prowler` |
|
||||
| Implementing JSON:API endpoints | `django-drf` |
|
||||
| Inspect PR CI checks and gates (.github/workflows/*) | `prowler-ci` |
|
||||
| Inspect PR CI workflows (.github/workflows/*): conventional-commit, pr-check-changelog, pr-conflict-checker, labeler | `prowler-pr` |
|
||||
| Mapping checks to compliance controls | `prowler-compliance` |
|
||||
| Mocking AWS with moto in tests | `prowler-test-sdk` |
|
||||
| Modifying API responses | `jsonapi` |
|
||||
| Regenerate AGENTS.md Auto-invoke tables (sync.sh) | `skill-sync` |
|
||||
| Review PR requirements: template, title conventions, changelog gate | `prowler-pr` |
|
||||
| Review changelog format and conventions | `prowler-changelog` |
|
||||
| Reviewing JSON:API compliance | `jsonapi` |
|
||||
| Reviewing compliance framework PRs | `prowler-compliance-review` |
|
||||
| Testing RLS tenant isolation | `prowler-test-api` |
|
||||
| Troubleshoot why a skill is missing from AGENTS.md auto-invoke | `skill-sync` |
|
||||
| Understand CODEOWNERS/labeler-based automation | `prowler-ci` |
|
||||
| Understand PR title conventional-commit validation | `prowler-ci` |
|
||||
| Understand changelog gate and no-changelog label behavior | `prowler-ci` |
|
||||
| Understand review ownership with CODEOWNERS | `prowler-pr` |
|
||||
| Update CHANGELOG.md in any component | `prowler-changelog` |
|
||||
| Updating existing checks and metadata | `prowler-sdk-check` |
|
||||
| Using Zustand stores | `zustand-5` |
|
||||
| Working on MCP server tools | `prowler-mcp` |
|
||||
| Working on Prowler UI structure (actions/adapters/types/hooks) | `prowler-ui` |
|
||||
| Working with Prowler UI test helpers/pages | `prowler-test-ui` |
|
||||
| Working with Tailwind classes | `tailwind-4` |
|
||||
| Writing Playwright E2E tests | `playwright` |
|
||||
| Writing Prowler API tests | `prowler-test-api` |
|
||||
| Writing Prowler SDK tests | `prowler-test-sdk` |
|
||||
| Writing Prowler UI E2E tests | `prowler-test-ui` |
|
||||
| Writing Python tests with pytest | `pytest` |
|
||||
| Writing React components | `react-19` |
|
||||
| Writing TypeScript types/interfaces | `typescript` |
|
||||
| Writing documentation | `prowler-docs` |
|
||||
|
||||
---
|
||||
|
||||
## Project Overview
|
||||
|
||||
@@ -80,6 +80,23 @@ prowler dashboard
|
||||
```
|
||||

|
||||
|
||||
|
||||
## Attack Paths
|
||||
|
||||
Attack Paths automatically extends every completed AWS scan with a Neo4j graph that combines Cartography's cloud inventory with Prowler findings. The feature runs in the API worker after each scan and therefore requires:
|
||||
|
||||
- An accessible Neo4j instance (the Docker Compose files already ships a `neo4j` service).
|
||||
- The following environment variables so Django and Celery can connect:
|
||||
|
||||
| Variable | Description | Default |
|
||||
| --- | --- | --- |
|
||||
| `NEO4J_HOST` | Hostname used by the API containers. | `neo4j` |
|
||||
| `NEO4J_PORT` | Bolt port exposed by Neo4j. | `7687` |
|
||||
| `NEO4J_USER` / `NEO4J_PASSWORD` | Credentials with rights to create per-tenant databases. | `neo4j` / `neo4j_password` |
|
||||
|
||||
Every AWS provider scan will enqueue an Attack Paths ingestion job automatically. Other cloud providers will be added in future iterations.
|
||||
|
||||
|
||||
# Prowler at a Glance
|
||||
> [!Tip]
|
||||
> For the most accurate and up-to-date information about checks, services, frameworks, and categories, visit [**Prowler Hub**](https://hub.prowler.com).
|
||||
@@ -87,16 +104,17 @@ prowler dashboard
|
||||
|
||||
| Provider | Checks | Services | [Compliance Frameworks](https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/compliance/) | [Categories](https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/misc/#categories) | Support | Interface |
|
||||
|---|---|---|---|---|---|---|
|
||||
| AWS | 584 | 85 | 40 | 17 | Official | UI, API, CLI |
|
||||
| GCP | 89 | 17 | 14 | 5 | Official | UI, API, CLI |
|
||||
| Azure | 169 | 22 | 15 | 8 | Official | UI, API, CLI |
|
||||
| Kubernetes | 84 | 7 | 6 | 9 | Official | UI, API, CLI |
|
||||
| AWS | 584 | 84 | 40 | 17 | Official | UI, API, CLI |
|
||||
| Azure | 169 | 22 | 16 | 12 | Official | UI, API, CLI |
|
||||
| GCP | 100 | 17 | 14 | 7 | Official | UI, API, CLI |
|
||||
| Kubernetes | 84 | 7 | 7 | 9 | Official | UI, API, CLI |
|
||||
| GitHub | 20 | 2 | 1 | 2 | Official | UI, API, CLI |
|
||||
| M365 | 70 | 7 | 3 | 2 | Official | UI, API, CLI |
|
||||
| OCI | 52 | 15 | 1 | 12 | Official | UI, API, CLI |
|
||||
| Alibaba Cloud | 63 | 10 | 1 | 9 | Official | CLI |
|
||||
| M365 | 71 | 7 | 4 | 3 | Official | UI, API, CLI |
|
||||
| OCI | 52 | 14 | 1 | 12 | Official | UI, API, CLI |
|
||||
| Alibaba Cloud | 64 | 9 | 2 | 9 | Official | UI, API, CLI |
|
||||
| Cloudflare | 23 | 2 | 0 | 5 | Official | CLI |
|
||||
| IaC | [See `trivy` docs.](https://trivy.dev/latest/docs/coverage/iac/) | N/A | N/A | N/A | Official | UI, API, CLI |
|
||||
| MongoDB Atlas | 10 | 4 | 0 | 3 | Official | UI, API, CLI |
|
||||
| MongoDB Atlas | 10 | 3 | 0 | 3 | Official | UI, API, CLI |
|
||||
| LLM | [See `promptfoo` docs.](https://www.promptfoo.dev/docs/red-team/plugins/) | N/A | N/A | N/A | Official | CLI |
|
||||
| NHN | 6 | 2 | 1 | 0 | Unofficial | CLI |
|
||||
|
||||
|
||||
@@ -4,8 +4,34 @@
|
||||
> - [`prowler-api`](../skills/prowler-api/SKILL.md) - Models, Serializers, Views, RLS patterns
|
||||
> - [`prowler-test-api`](../skills/prowler-test-api/SKILL.md) - Testing patterns (pytest-django)
|
||||
> - [`django-drf`](../skills/django-drf/SKILL.md) - Generic DRF patterns
|
||||
> - [`jsonapi`](../skills/jsonapi/SKILL.md) - Strict JSON:API v1.1 spec compliance
|
||||
> - [`pytest`](../skills/pytest/SKILL.md) - Generic pytest patterns
|
||||
|
||||
### Auto-invoke Skills
|
||||
|
||||
When performing these actions, ALWAYS invoke the corresponding skill FIRST:
|
||||
|
||||
| Action | Skill |
|
||||
|--------|-------|
|
||||
| Add changelog entry for a PR or feature | `prowler-changelog` |
|
||||
| Adding DRF pagination or permissions | `django-drf` |
|
||||
| Committing changes | `prowler-commit` |
|
||||
| Create PR that requires changelog entry | `prowler-changelog` |
|
||||
| Creating API endpoints | `jsonapi` |
|
||||
| Creating ViewSets, serializers, or filters in api/ | `django-drf` |
|
||||
| Creating a git commit | `prowler-commit` |
|
||||
| Creating/modifying models, views, serializers | `prowler-api` |
|
||||
| Implementing JSON:API endpoints | `django-drf` |
|
||||
| Modifying API responses | `jsonapi` |
|
||||
| Review changelog format and conventions | `prowler-changelog` |
|
||||
| Reviewing JSON:API compliance | `jsonapi` |
|
||||
| Testing RLS tenant isolation | `prowler-test-api` |
|
||||
| Update CHANGELOG.md in any component | `prowler-changelog` |
|
||||
| Writing Prowler API tests | `prowler-test-api` |
|
||||
| Writing Python tests with pytest | `pytest` |
|
||||
|
||||
---
|
||||
|
||||
## CRITICAL RULES - NON-NEGOTIABLE
|
||||
|
||||
### Models
|
||||
|
||||
+155
-65
@@ -2,45 +2,85 @@
|
||||
|
||||
All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.18.0] (Prowler UNRELEASED)
|
||||
## [1.19.0] (Prowler UNRELEASED)
|
||||
|
||||
### Added
|
||||
- `/api/v1/overviews/compliance-watchlist` to retrieve the compliance watchlist [(#9596)](https://github.com/prowler-cloud/prowler/pull/9596)
|
||||
- Support AlibabaCloud provider [(#9485)](https://github.com/prowler-cloud/prowler/pull/9485)
|
||||
- `provider_id` and `provider_id__in` filter aliases for findings endpoints to enable consistent frontend parameter naming [(#9701)](https://github.com/prowler-cloud/prowler/pull/9701)
|
||||
### 🚀 Added
|
||||
|
||||
- Attack Paths: Bedrock Code Interpreter and AttachRolePolicy privilege escalation queries [(#9885)](https://github.com/prowler-cloud/prowler/pull/9885)
|
||||
- Added memory optimizations for large compliance report generation [(#9444)](https://github.com/prowler-cloud/prowler/pull/9444)
|
||||
- `GET /api/v1/resources/{id}/events` endpoint to retrieve AWS resource modification history from CloudTrail [(#9101)](https://github.com/prowler-cloud/prowler/pull/9101)
|
||||
|
||||
### 🔄 Changed
|
||||
|
||||
- Lazy-load providers and compliance data to reduce API/worker startup memory and time [(#9857)](https://github.com/prowler-cloud/prowler/pull/9857)
|
||||
|
||||
---
|
||||
|
||||
## [1.17.2] (Prowler v5.16.2)
|
||||
## [1.18.1] (Prowler v5.17.1)
|
||||
|
||||
### Security
|
||||
- Updated dependencies to patch security vulnerabilities: Django 5.1.15 (CVE-2025-64460, CVE-2025-13372), Werkzeug 3.1.4 (CVE-2025-66221), sqlparse 0.5.5 (PVE-2025-82038), fonttools 4.60.2 (CVE-2025-66034) [(#9730)](https://github.com/prowler-cloud/prowler/pull/9730)
|
||||
### 🐞 Fixed
|
||||
|
||||
- Improve API startup process by `manage.py` argument detection [(#9856)](https://github.com/prowler-cloud/prowler/pull/9856)
|
||||
- Deleting providers don't try to delete a `None` Neo4j database when an Attack Paths scan is scheduled [(#9858)](https://github.com/prowler-cloud/prowler/pull/9858)
|
||||
- Use replica database for reading Findings to add them to the Attack Paths graph [(#9861)](https://github.com/prowler-cloud/prowler/pull/9861)
|
||||
- Attack paths findings loading query to use streaming generator for O(batch_size) memory instead of O(total_findings) [(#9862)](https://github.com/prowler-cloud/prowler/pull/9862)
|
||||
- Lazy load Neo4j driver [(#9868)](https://github.com/prowler-cloud/prowler/pull/9868)
|
||||
- Use `Findings.all_objects` to avoid the `ActiveProviderPartitionedManager` [(#9869)](https://github.com/prowler-cloud/prowler/pull/9869)
|
||||
- Lazy load Neo4j driver for workers only [(#9872)](https://github.com/prowler-cloud/prowler/pull/9872)
|
||||
- Improve Cypher query for inserting Findings into Attack Paths scan graphs [(#9874)](https://github.com/prowler-cloud/prowler/pull/9874)
|
||||
- Clear Neo4j database cache after Attack Paths scan and each API query [(#9877)](https://github.com/prowler-cloud/prowler/pull/9877)
|
||||
- Deduplicated scheduled scans for long-running providers [(#9829)](https://github.com/prowler-cloud/prowler/pull/9829)
|
||||
|
||||
---
|
||||
|
||||
## [1.18.0] (Prowler v5.17.0)
|
||||
|
||||
### 🚀 Added
|
||||
|
||||
- `/api/v1/overviews/compliance-watchlist` endpoint to retrieve the compliance watchlist [(#9596)](https://github.com/prowler-cloud/prowler/pull/9596)
|
||||
- AlibabaCloud provider support [(#9485)](https://github.com/prowler-cloud/prowler/pull/9485)
|
||||
- `/api/v1/overviews/resource-groups` endpoint to retrieve an overview of resource groups based on finding severities [(#9694)](https://github.com/prowler-cloud/prowler/pull/9694)
|
||||
- `group` filter for `GET /findings` and `GET /findings/metadata/latest` endpoints [(#9694)](https://github.com/prowler-cloud/prowler/pull/9694)
|
||||
- `provider_id` and `provider_id__in` filter aliases for findings endpoints to enable consistent frontend parameter naming [(#9701)](https://github.com/prowler-cloud/prowler/pull/9701)
|
||||
- Attack Paths: `/api/v1/attack-paths-scans` for AWS providers backed by Neo4j [(#9805)](https://github.com/prowler-cloud/prowler/pull/9805)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
- Django 5.1.15 (CVE-2025-64460, CVE-2025-13372), Werkzeug 3.1.4 (CVE-2025-66221), sqlparse 0.5.5 (PVE-2025-82038), fonttools 4.60.2 (CVE-2025-66034) [(#9730)](https://github.com/prowler-cloud/prowler/pull/9730)
|
||||
- `safety` to `3.7.0` and `filelock` to `3.20.3` due to [Safety vulnerability 82754 (CVE-2025-68146)](https://data.safetycli.com/v/82754/97c/) [(#9816)](https://github.com/prowler-cloud/prowler/pull/9816)
|
||||
- `pyasn1` to v0.6.2 to address [CVE-2026-23490](https://nvd.nist.gov/vuln/detail/CVE-2026-23490) [(#9818)](https://github.com/prowler-cloud/prowler/pull/9818)
|
||||
- `django-allauth[saml]` to v65.13.0 to address [CVE-2025-65431](https://nvd.nist.gov/vuln/detail/CVE-2025-65431) [(#9575)](https://github.com/prowler-cloud/prowler/pull/9575)
|
||||
|
||||
---
|
||||
|
||||
## [1.17.1] (Prowler v5.16.1)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Security Hub integration error when no regions [(#9635)](https://github.com/prowler-cloud/prowler/pull/9635)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Orphan scheduled scans caused by transaction isolation during provider creation [(#9633)](https://github.com/prowler-cloud/prowler/pull/9633)
|
||||
|
||||
---
|
||||
|
||||
## [1.17.0] (Prowler v5.16.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- New endpoint to retrieve and overview of the categories based on finding severities [(#9529)](https://github.com/prowler-cloud/prowler/pull/9529)
|
||||
- Endpoints `GET /findings` and `GET /findings/latests` can now use the category filter [(#9529)](https://github.com/prowler-cloud/prowler/pull/9529)
|
||||
- Account id, alias and provider name to PDF reporting table [(#9574)](https://github.com/prowler-cloud/prowler/pull/9574)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Endpoint `GET /overviews/attack-surfaces` no longer returns the related check IDs [(#9529)](https://github.com/prowler-cloud/prowler/pull/9529)
|
||||
- OpenAI provider to only load chat-compatible models with tool calling support [(#9523)](https://github.com/prowler-cloud/prowler/pull/9523)
|
||||
- Increased execution delay for the first scheduled scan tasks to 5 seconds[(#9558)](https://github.com/prowler-cloud/prowler/pull/9558)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Made `scan_id` a required filter in the compliance overview endpoint [(#9560)](https://github.com/prowler-cloud/prowler/pull/9560)
|
||||
- Reduced unnecessary UPDATE resources operations by only saving when tag mappings change, lowering write load during scans [(#9569)](https://github.com/prowler-cloud/prowler/pull/9569)
|
||||
|
||||
@@ -48,19 +88,22 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.16.1] (Prowler v5.15.1)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Race condition in scheduled scan creation by adding countdown to task [(#9516)](https://github.com/prowler-cloud/prowler/pull/9516)
|
||||
|
||||
## [1.16.0] (Prowler v5.15.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- New endpoint to retrieve an overview of the attack surfaces [(#9309)](https://github.com/prowler-cloud/prowler/pull/9309)
|
||||
- New endpoint `GET /api/v1/overviews/findings_severity/timeseries` to retrieve daily aggregated findings by severity level [(#9363)](https://github.com/prowler-cloud/prowler/pull/9363)
|
||||
- Lighthouse AI support for Amazon Bedrock API key [(#9343)](https://github.com/prowler-cloud/prowler/pull/9343)
|
||||
- Exception handler for provider deletions during scans [(#9414)](https://github.com/prowler-cloud/prowler/pull/9414)
|
||||
- Support to use admin credentials through the read replica database [(#9440)](https://github.com/prowler-cloud/prowler/pull/9440)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Error messages from Lighthouse celery tasks [(#9165)](https://github.com/prowler-cloud/prowler/pull/9165)
|
||||
- Restore the compliance overview endpoint's mandatory filters [(#9338)](https://github.com/prowler-cloud/prowler/pull/9338)
|
||||
|
||||
@@ -68,7 +111,8 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.15.2] (Prowler v5.14.2)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Unique constraint violation during compliance overviews task [(#9436)](https://github.com/prowler-cloud/prowler/pull/9436)
|
||||
- Division by zero error in ENS PDF report when all requirements are manual [(#9443)](https://github.com/prowler-cloud/prowler/pull/9443)
|
||||
|
||||
@@ -76,7 +120,8 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.15.1] (Prowler v5.14.1)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Fix typo in PDF reporting [(#9345)](https://github.com/prowler-cloud/prowler/pull/9345)
|
||||
- Fix IaC provider initialization failure when mutelist processor is configured [(#9331)](https://github.com/prowler-cloud/prowler/pull/9331)
|
||||
- Match logic for ThreatScore when counting findings [(#9348)](https://github.com/prowler-cloud/prowler/pull/9348)
|
||||
@@ -85,7 +130,8 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.15.0] (Prowler v5.14.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- IaC (Infrastructure as Code) provider support for remote repositories [(#8751)](https://github.com/prowler-cloud/prowler/pull/8751)
|
||||
- Extend `GET /api/v1/providers` with provider-type filters and optional pagination disable to support the new Overview filters [(#8975)](https://github.com/prowler-cloud/prowler/pull/8975)
|
||||
- New endpoint to retrieve the number of providers grouped by provider type [(#8975)](https://github.com/prowler-cloud/prowler/pull/8975)
|
||||
@@ -104,11 +150,13 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- Enhanced compliance overview endpoint with provider filtering and latest scan aggregation [(#9244)](https://github.com/prowler-cloud/prowler/pull/9244)
|
||||
- New endpoint `GET /api/v1/overview/regions` to retrieve aggregated findings data by region [(#9273)](https://github.com/prowler-cloud/prowler/pull/9273)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Optimized database write queries for scan related tasks [(#9190)](https://github.com/prowler-cloud/prowler/pull/9190)
|
||||
- Date filters are now optional for `GET /api/v1/overviews/services` endpoint; returns latest scan data by default [(#9248)](https://github.com/prowler-cloud/prowler/pull/9248)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Scans no longer fail when findings have UIDs exceeding 300 characters; such findings are now skipped with detailed logging [(#9246)](https://github.com/prowler-cloud/prowler/pull/9246)
|
||||
- Updated unique constraint for `Provider` model to exclude soft-deleted entries, resolving duplicate errors when re-deleting providers [(#9054)](https://github.com/prowler-cloud/prowler/pull/9054)
|
||||
- Removed compliance generation for providers without compliance frameworks [(#9208)](https://github.com/prowler-cloud/prowler/pull/9208)
|
||||
@@ -116,14 +164,16 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- Severity overview endpoint now ignores muted findings as expected [(#9283)](https://github.com/prowler-cloud/prowler/pull/9283)
|
||||
- Fixed discrepancy between ThreatScore PDF report values and database calculations [(#9296)](https://github.com/prowler-cloud/prowler/pull/9296)
|
||||
|
||||
### Security
|
||||
### 🔐 Security
|
||||
|
||||
- Django updated to the latest 5.1 security release, 5.1.14, due to problems with potential [SQL injection](https://github.com/prowler-cloud/prowler/security/dependabot/113) and [denial-of-service vulnerability](https://github.com/prowler-cloud/prowler/security/dependabot/114) [(#9176)](https://github.com/prowler-cloud/prowler/pull/9176)
|
||||
|
||||
---
|
||||
|
||||
## [1.14.1] (Prowler v5.13.1)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- `/api/v1/overviews/providers` collapses data by provider type so the UI receives a single aggregated record per cloud family even when multiple accounts exist [(#9053)](https://github.com/prowler-cloud/prowler/pull/9053)
|
||||
- Added retry logic to database transactions to handle Aurora read replica connection failures during scale-down events [(#9064)](https://github.com/prowler-cloud/prowler/pull/9064)
|
||||
- Security Hub integrations stop failing when they read relationships via the replica by allowing replica relations and saving updates through the primary [(#9080)](https://github.com/prowler-cloud/prowler/pull/9080)
|
||||
@@ -132,7 +182,8 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.14.0] (Prowler v5.13.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Default JWT keys are generated and stored if they are missing from configuration [(#8655)](https://github.com/prowler-cloud/prowler/pull/8655)
|
||||
- `compliance_name` for each compliance [(#7920)](https://github.com/prowler-cloud/prowler/pull/7920)
|
||||
- Support C5 compliance framework for the AWS provider [(#8830)](https://github.com/prowler-cloud/prowler/pull/8830)
|
||||
@@ -145,35 +196,41 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- Support Common Cloud Controls for AWS, Azure and GCP [(#8000)](https://github.com/prowler-cloud/prowler/pull/8000)
|
||||
- Add `provider_id__in` filter support to findings and findings severity overview endpoints [(#8951)](https://github.com/prowler-cloud/prowler/pull/8951)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Now the MANAGE_ACCOUNT permission is required to modify or read user permissions instead of MANAGE_USERS [(#8281)](https://github.com/prowler-cloud/prowler/pull/8281)
|
||||
- Now at least one user with MANAGE_ACCOUNT permission is required in the tenant [(#8729)](https://github.com/prowler-cloud/prowler/pull/8729)
|
||||
|
||||
### Security
|
||||
### 🔐 Security
|
||||
|
||||
- Django updated to the latest 5.1 security release, 5.1.13, due to problems with potential [SQL injection](https://github.com/prowler-cloud/prowler/security/dependabot/104) and [directory traversals](https://github.com/prowler-cloud/prowler/security/dependabot/103) [(#8842)](https://github.com/prowler-cloud/prowler/pull/8842)
|
||||
|
||||
---
|
||||
|
||||
## [1.13.2] (Prowler v5.12.3)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- 500 error when deleting user [(#8731)](https://github.com/prowler-cloud/prowler/pull/8731)
|
||||
|
||||
---
|
||||
|
||||
## [1.13.1] (Prowler v5.12.2)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Renamed compliance overview task queue to `compliance` [(#8755)](https://github.com/prowler-cloud/prowler/pull/8755)
|
||||
|
||||
### Security
|
||||
### 🔐 Security
|
||||
|
||||
- Django updated to the latest 5.1 security release, 5.1.12, due to [problems](https://www.djangoproject.com/weblog/2025/sep/03/security-releases/) with potential SQL injection in FilteredRelation column aliases [(#8693)](https://github.com/prowler-cloud/prowler/pull/8693)
|
||||
|
||||
---
|
||||
|
||||
## [1.13.0] (Prowler v5.12.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Integration with JIRA, enabling sending findings to a JIRA project [(#8622)](https://github.com/prowler-cloud/prowler/pull/8622), [(#8637)](https://github.com/prowler-cloud/prowler/pull/8637)
|
||||
- `GET /overviews/findings_severity` now supports `filter[status]` and `filter[status__in]` to aggregate by specific statuses (`FAIL`, `PASS`)[(#8186)](https://github.com/prowler-cloud/prowler/pull/8186)
|
||||
- Throttling options for `/api/v1/tokens` using the `DJANGO_THROTTLE_TOKEN_OBTAIN` environment variable [(#8647)](https://github.com/prowler-cloud/prowler/pull/8647)
|
||||
@@ -182,101 +239,120 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.12.0] (Prowler v5.11.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Lighthouse support for OpenAI GPT-5 [(#8527)](https://github.com/prowler-cloud/prowler/pull/8527)
|
||||
- Integration with Amazon Security Hub, enabling sending findings to Security Hub [(#8365)](https://github.com/prowler-cloud/prowler/pull/8365)
|
||||
- Generate ASFF output for AWS providers with SecurityHub integration enabled [(#8569)](https://github.com/prowler-cloud/prowler/pull/8569)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- GitHub provider always scans user instead of organization when using provider UID [(#8587)](https://github.com/prowler-cloud/prowler/pull/8587)
|
||||
|
||||
---
|
||||
|
||||
## [1.11.0] (Prowler v5.10.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Github provider support [(#8271)](https://github.com/prowler-cloud/prowler/pull/8271)
|
||||
- Integration with Amazon S3, enabling storage and retrieval of scan data via S3 buckets [(#8056)](https://github.com/prowler-cloud/prowler/pull/8056)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Avoid sending errors to Sentry in M365 provider when user authentication fails [(#8420)](https://github.com/prowler-cloud/prowler/pull/8420)
|
||||
|
||||
---
|
||||
|
||||
## [1.10.2] (Prowler v5.9.2)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Optimized queries for resources views [(#8336)](https://github.com/prowler-cloud/prowler/pull/8336)
|
||||
|
||||
---
|
||||
|
||||
## [v1.10.1] (Prowler v5.9.1)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Calculate failed findings during scans to prevent heavy database queries [(#8322)](https://github.com/prowler-cloud/prowler/pull/8322)
|
||||
|
||||
---
|
||||
|
||||
## [v1.10.0] (Prowler v5.9.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- SSO with SAML support [(#8175)](https://github.com/prowler-cloud/prowler/pull/8175)
|
||||
- `GET /resources/metadata`, `GET /resources/metadata/latest` and `GET /resources/latest` to expose resource metadata and latest scan results [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- `/processors` endpoints to post-process findings. Currently, only the Mutelist processor is supported to allow to mute findings.
|
||||
- Optimized the underlying queries for resources endpoints [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
|
||||
- Optimized include parameters for resources view [(#8229)](https://github.com/prowler-cloud/prowler/pull/8229)
|
||||
- Optimized overview background tasks [(#8300)](https://github.com/prowler-cloud/prowler/pull/8300)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Search filter for findings and resources [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
|
||||
- RBAC is now applied to `GET /overviews/providers` [(#8277)](https://github.com/prowler-cloud/prowler/pull/8277)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- `POST /schedules/daily` returns a `409 CONFLICT` if already created [(#8258)](https://github.com/prowler-cloud/prowler/pull/8258)
|
||||
|
||||
### Security
|
||||
### 🔐 Security
|
||||
|
||||
- Enhanced password validation to enforce 12+ character passwords with special characters, uppercase, lowercase, and numbers [(#8225)](https://github.com/prowler-cloud/prowler/pull/8225)
|
||||
|
||||
---
|
||||
|
||||
## [v1.9.1] (Prowler v5.8.1)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Custom exception for provider connection errors during scans [(#8234)](https://github.com/prowler-cloud/prowler/pull/8234)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Summary and overview tasks now use a dedicated queue and no longer propagate errors to compliance tasks [(#8214)](https://github.com/prowler-cloud/prowler/pull/8214)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Scan with no resources will not trigger legacy code for findings metadata [(#8183)](https://github.com/prowler-cloud/prowler/pull/8183)
|
||||
- Invitation email comparison case-insensitive [(#8206)](https://github.com/prowler-cloud/prowler/pull/8206)
|
||||
|
||||
### Removed
|
||||
### ❌ Removed
|
||||
|
||||
- Validation of the provider's secret type during updates [(#8197)](https://github.com/prowler-cloud/prowler/pull/8197)
|
||||
|
||||
---
|
||||
|
||||
## [v1.9.0] (Prowler v5.8.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Support GCP Service Account key [(#7824)](https://github.com/prowler-cloud/prowler/pull/7824)
|
||||
- `GET /compliance-overviews` endpoints to retrieve compliance metadata and specific requirements statuses [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877)
|
||||
- Lighthouse configuration support [(#7848)](https://github.com/prowler-cloud/prowler/pull/7848)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Reworked `GET /compliance-overviews` to return proper requirement metrics [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877)
|
||||
- Optional `user` and `password` for M365 provider [(#7992)](https://github.com/prowler-cloud/prowler/pull/7992)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Scheduled scans are no longer deleted when their daily schedule run is disabled [(#8082)](https://github.com/prowler-cloud/prowler/pull/8082)
|
||||
|
||||
---
|
||||
|
||||
## [v1.8.5] (Prowler v5.7.5)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Normalize provider UID to ensure safe and unique export directory paths [(#8007)](https://github.com/prowler-cloud/prowler/pull/8007).
|
||||
- Blank resource types in `/metadata` endpoints [(#8027)](https://github.com/prowler-cloud/prowler/pull/8027)
|
||||
|
||||
@@ -284,20 +360,24 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [v1.8.4] (Prowler v5.7.4)
|
||||
|
||||
### Removed
|
||||
### ❌ Removed
|
||||
|
||||
- Reverted RLS transaction handling and DB custom backend [(#7994)](https://github.com/prowler-cloud/prowler/pull/7994)
|
||||
|
||||
---
|
||||
|
||||
## [v1.8.3] (Prowler v5.7.3)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Database backend to handle already closed connections [(#7935)](https://github.com/prowler-cloud/prowler/pull/7935)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Renamed field encrypted_password to password for M365 provider [(#7784)](https://github.com/prowler-cloud/prowler/pull/7784)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Transaction persistence with RLS operations [(#7916)](https://github.com/prowler-cloud/prowler/pull/7916)
|
||||
- Reverted the change `get_with_retry` to use the original `get` method for retrieving tasks [(#7932)](https://github.com/prowler-cloud/prowler/pull/7932)
|
||||
|
||||
@@ -305,7 +385,8 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [v1.8.2] (Prowler v5.7.2)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Task lookup to use task_kwargs instead of task_args for scan report resolution [(#7830)](https://github.com/prowler-cloud/prowler/pull/7830)
|
||||
- Kubernetes UID validation to allow valid context names [(#7871)](https://github.com/prowler-cloud/prowler/pull/7871)
|
||||
- Connection status verification before launching a scan [(#7831)](https://github.com/prowler-cloud/prowler/pull/7831)
|
||||
@@ -316,14 +397,16 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [v1.8.1] (Prowler v5.7.1)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Added database index to improve performance on finding lookup [(#7800)](https://github.com/prowler-cloud/prowler/pull/7800)
|
||||
|
||||
---
|
||||
|
||||
## [v1.8.0] (Prowler v5.7.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Huge improvements to `/findings/metadata` and resource related filters for findings [(#7690)](https://github.com/prowler-cloud/prowler/pull/7690)
|
||||
- Improvements to `/overviews` endpoints [(#7690)](https://github.com/prowler-cloud/prowler/pull/7690)
|
||||
- Queue to perform backfill background tasks [(#7690)](https://github.com/prowler-cloud/prowler/pull/7690)
|
||||
@@ -334,7 +417,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [v1.7.0] (Prowler v5.6.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- M365 as a new provider [(#7563)](https://github.com/prowler-cloud/prowler/pull/7563)
|
||||
- `compliance/` folder and ZIP‐export functionality for all compliance reports [(#7653)](https://github.com/prowler-cloud/prowler/pull/7653)
|
||||
@@ -344,7 +427,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [v1.6.0] (Prowler v5.5.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Support for developing new integrations [(#7167)](https://github.com/prowler-cloud/prowler/pull/7167)
|
||||
- HTTP Security Headers [(#7289)](https://github.com/prowler-cloud/prowler/pull/7289)
|
||||
@@ -356,14 +439,16 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [v1.5.4] (Prowler v5.4.4)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Bug with periodic tasks when trying to delete a provider [(#7466)](https://github.com/prowler-cloud/prowler/pull/7466)
|
||||
|
||||
---
|
||||
|
||||
## [v1.5.3] (Prowler v5.4.3)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Duplicated scheduled scans handling [(#7401)](https://github.com/prowler-cloud/prowler/pull/7401)
|
||||
- Environment variable to configure the deletion task batch size [(#7423)](https://github.com/prowler-cloud/prowler/pull/7423)
|
||||
|
||||
@@ -371,14 +456,16 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [v1.5.2] (Prowler v5.4.2)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Refactored deletion logic and implemented retry mechanism for deletion tasks [(#7349)](https://github.com/prowler-cloud/prowler/pull/7349)
|
||||
|
||||
---
|
||||
|
||||
## [v1.5.1] (Prowler v5.4.1)
|
||||
|
||||
### Fixed
|
||||
### 🐞 Fixed
|
||||
|
||||
- Handle response in case local files are missing [(#7183)](https://github.com/prowler-cloud/prowler/pull/7183)
|
||||
- Race condition when deleting export files after the S3 upload [(#7172)](https://github.com/prowler-cloud/prowler/pull/7172)
|
||||
- Handle exception when a provider has no secret in test connection [(#7283)](https://github.com/prowler-cloud/prowler/pull/7283)
|
||||
@@ -387,19 +474,22 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [v1.5.0] (Prowler v5.4.0)
|
||||
|
||||
### Added
|
||||
### 🚀 Added
|
||||
|
||||
- Social login integration with Google and GitHub [(#6906)](https://github.com/prowler-cloud/prowler/pull/6906)
|
||||
- API scan report system, now all scans launched from the API will generate a compressed file with the report in OCSF, CSV and HTML formats [(#6878)](https://github.com/prowler-cloud/prowler/pull/6878)
|
||||
- Configurable Sentry integration [(#6874)](https://github.com/prowler-cloud/prowler/pull/6874)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Optimized `GET /findings` endpoint to improve response time and size [(#7019)](https://github.com/prowler-cloud/prowler/pull/7019)
|
||||
|
||||
---
|
||||
|
||||
## [v1.4.0] (Prowler v5.3.0)
|
||||
|
||||
### Changed
|
||||
### 🔄 Changed
|
||||
|
||||
- Daily scheduled scan instances are now created beforehand with `SCHEDULED` state [(#6700)](https://github.com/prowler-cloud/prowler/pull/6700)
|
||||
- Findings endpoints now require at least one date filter [(#6800)](https://github.com/prowler-cloud/prowler/pull/6800)
|
||||
- Findings metadata endpoint received a performance improvement [(#6863)](https://github.com/prowler-cloud/prowler/pull/6863)
|
||||
|
||||
@@ -32,7 +32,7 @@ start_prod_server() {
|
||||
|
||||
start_worker() {
|
||||
echo "Starting the worker..."
|
||||
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview,integrations,compliance -E --max-tasks-per-child 1
|
||||
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview,integrations,compliance,attack-paths-scans -E --max-tasks-per-child 1
|
||||
}
|
||||
|
||||
start_worker_beat() {
|
||||
|
||||
Generated
+1485
-360
File diff suppressed because it is too large
Load Diff
+7
-4
@@ -8,7 +8,7 @@ dependencies = [
|
||||
"celery[pytest] (>=5.4.0,<6.0.0)",
|
||||
"dj-rest-auth[with_social,jwt] (==7.0.1)",
|
||||
"django (==5.1.15)",
|
||||
"django-allauth[saml] (>=65.8.0,<66.0.0)",
|
||||
"django-allauth[saml] (>=65.13.0,<66.0.0)",
|
||||
"django-celery-beat (>=2.7.0,<3.0.0)",
|
||||
"django-celery-results (>=2.5.1,<3.0.0)",
|
||||
"django-cors-headers==4.4.0",
|
||||
@@ -36,6 +36,8 @@ dependencies = [
|
||||
"drf-simple-apikey (==2.2.1)",
|
||||
"matplotlib (>=3.10.6,<4.0.0)",
|
||||
"reportlab (>=4.4.4,<5.0.0)",
|
||||
"neo4j (<6.0.0)",
|
||||
"cartography @ git+https://github.com/prowler-cloud/cartography@master",
|
||||
"gevent (>=25.9.1,<26.0.0)",
|
||||
"werkzeug (>=3.1.4)",
|
||||
"sqlparse (>=0.5.4)",
|
||||
@@ -47,7 +49,7 @@ name = "prowler-api"
|
||||
package-mode = false
|
||||
# Needed for the SDK compatibility
|
||||
requires-python = ">=3.11,<3.13"
|
||||
version = "1.18.0"
|
||||
version = "1.19.0"
|
||||
|
||||
[project.scripts]
|
||||
celery = "src.backend.config.settings.celery"
|
||||
@@ -68,6 +70,7 @@ pytest-env = "1.1.3"
|
||||
pytest-randomly = "3.15.0"
|
||||
pytest-xdist = "3.6.1"
|
||||
ruff = "0.5.0"
|
||||
safety = "3.2.9"
|
||||
tqdm = "4.67.1"
|
||||
safety = "3.7.0"
|
||||
filelock = "3.20.3"
|
||||
vulture = "2.14"
|
||||
tqdm = "4.67.1"
|
||||
|
||||
@@ -30,16 +30,48 @@ class ApiConfig(AppConfig):
|
||||
def ready(self):
|
||||
from api import schema_extensions # noqa: F401
|
||||
from api import signals # noqa: F401
|
||||
from api.compliance import load_prowler_compliance
|
||||
from api.attack_paths import database as graph_database
|
||||
|
||||
# Generate required cryptographic keys if not present, but only if:
|
||||
# `"manage.py" not in sys.argv`: If an external server (e.g., Gunicorn) is running the app
|
||||
# `"manage.py" not in sys.argv[0]`: If an external server (e.g., Gunicorn) is running the app
|
||||
# `os.environ.get("RUN_MAIN")`: If it's not a Django command or using `runserver`,
|
||||
# only the main process will do it
|
||||
if "manage.py" not in sys.argv or os.environ.get("RUN_MAIN"):
|
||||
if (len(sys.argv) >= 1 and "manage.py" not in sys.argv[0]) or os.environ.get(
|
||||
"RUN_MAIN"
|
||||
):
|
||||
self._ensure_crypto_keys()
|
||||
|
||||
load_prowler_compliance()
|
||||
# Commands that don't need Neo4j
|
||||
SKIP_NEO4J_DJANGO_COMMANDS = [
|
||||
"makemigrations",
|
||||
"migrate",
|
||||
"pgpartition",
|
||||
"check",
|
||||
"help",
|
||||
"showmigrations",
|
||||
"check_and_fix_socialaccount_sites_migration",
|
||||
]
|
||||
|
||||
# Skip Neo4j initialization during tests, some Django commands, and Celery
|
||||
if getattr(settings, "TESTING", False) or (
|
||||
len(sys.argv) > 1
|
||||
and (
|
||||
(
|
||||
"manage.py" in sys.argv[0]
|
||||
and sys.argv[1] in SKIP_NEO4J_DJANGO_COMMANDS
|
||||
)
|
||||
or "celery" in sys.argv[0]
|
||||
)
|
||||
):
|
||||
logger.info(
|
||||
"Skipping Neo4j initialization because tests, some Django commands or Celery"
|
||||
)
|
||||
|
||||
else:
|
||||
graph_database.init_driver()
|
||||
|
||||
# Neo4j driver is initialized at API startup (see api.attack_paths.database)
|
||||
# It remains lazy for Celery workers and selected Django commands
|
||||
|
||||
def _ensure_crypto_keys(self):
|
||||
"""
|
||||
@@ -54,7 +86,7 @@ class ApiConfig(AppConfig):
|
||||
global _keys_initialized
|
||||
|
||||
# Skip key generation if running tests
|
||||
if hasattr(settings, "TESTING") and settings.TESTING:
|
||||
if getattr(settings, "TESTING", False):
|
||||
return
|
||||
|
||||
# Skip if already initialized in this process
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from api.attack_paths.query_definitions import (
|
||||
AttackPathsQueryDefinition,
|
||||
AttackPathsQueryParameterDefinition,
|
||||
get_queries_for_provider,
|
||||
get_query_by_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AttackPathsQueryDefinition",
|
||||
"AttackPathsQueryParameterDefinition",
|
||||
"get_queries_for_provider",
|
||||
"get_query_by_id",
|
||||
]
|
||||
@@ -0,0 +1,161 @@
|
||||
import atexit
|
||||
import logging
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
import neo4j
|
||||
import neo4j.exceptions
|
||||
from django.conf import settings
|
||||
|
||||
from api.attack_paths.retryable_session import RetryableSession
|
||||
|
||||
# Without this Celery goes crazy with Neo4j logging
|
||||
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
||||
logging.getLogger("neo4j").propagate = False
|
||||
|
||||
SERVICE_UNAVAILABLE_MAX_RETRIES = 3
|
||||
|
||||
# Module-level process-wide driver singleton
|
||||
_driver: neo4j.Driver | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
# Base Neo4j functions
|
||||
|
||||
|
||||
def get_uri() -> str:
|
||||
host = settings.DATABASES["neo4j"]["HOST"]
|
||||
port = settings.DATABASES["neo4j"]["PORT"]
|
||||
return f"bolt://{host}:{port}"
|
||||
|
||||
|
||||
def init_driver() -> neo4j.Driver:
|
||||
global _driver
|
||||
if _driver is not None:
|
||||
return _driver
|
||||
|
||||
with _lock:
|
||||
if _driver is None:
|
||||
uri = get_uri()
|
||||
config = settings.DATABASES["neo4j"]
|
||||
|
||||
_driver = neo4j.GraphDatabase.driver(
|
||||
uri,
|
||||
auth=(config["USER"], config["PASSWORD"]),
|
||||
keep_alive=True,
|
||||
max_connection_lifetime=7200,
|
||||
connection_acquisition_timeout=120,
|
||||
max_connection_pool_size=50,
|
||||
)
|
||||
_driver.verify_connectivity()
|
||||
|
||||
# Register cleanup handler (only runs once since we're inside the _driver is None block)
|
||||
atexit.register(close_driver)
|
||||
|
||||
return _driver
|
||||
|
||||
|
||||
def get_driver() -> neo4j.Driver:
|
||||
return init_driver()
|
||||
|
||||
|
||||
def close_driver() -> None: # TODO: Use it
|
||||
global _driver
|
||||
with _lock:
|
||||
if _driver is not None:
|
||||
try:
|
||||
_driver.close()
|
||||
|
||||
finally:
|
||||
_driver = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session(database: str | None = None) -> Iterator[RetryableSession]:
|
||||
session_wrapper: RetryableSession | None = None
|
||||
|
||||
try:
|
||||
session_wrapper = RetryableSession(
|
||||
session_factory=lambda: get_driver().session(database=database),
|
||||
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
|
||||
)
|
||||
yield session_wrapper
|
||||
|
||||
except neo4j.exceptions.Neo4jError as exc:
|
||||
raise GraphDatabaseQueryException(message=exc.message, code=exc.code)
|
||||
|
||||
finally:
|
||||
if session_wrapper is not None:
|
||||
session_wrapper.close()
|
||||
|
||||
|
||||
def create_database(database: str) -> None:
|
||||
query = "CREATE DATABASE $database IF NOT EXISTS"
|
||||
parameters = {"database": database}
|
||||
|
||||
with get_session() as session:
|
||||
session.run(query, parameters)
|
||||
|
||||
|
||||
def drop_database(database: str) -> None:
|
||||
query = f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA"
|
||||
|
||||
with get_session() as session:
|
||||
session.run(query)
|
||||
|
||||
|
||||
def drop_subgraph(database: str, root_node_label: str, root_node_id: str) -> int:
|
||||
query = """
|
||||
MATCH (a:__ROOT_NODE_LABEL__ {id: $root_node_id})
|
||||
CALL apoc.path.subgraphNodes(a, {})
|
||||
YIELD node
|
||||
DETACH DELETE node
|
||||
RETURN COUNT(node) AS deleted_nodes_count
|
||||
""".replace("__ROOT_NODE_LABEL__", root_node_label)
|
||||
parameters = {"root_node_id": root_node_id}
|
||||
|
||||
with get_session(database) as session:
|
||||
result = session.run(query, parameters)
|
||||
|
||||
try:
|
||||
return result.single()["deleted_nodes_count"]
|
||||
|
||||
except neo4j.exceptions.ResultConsumedError:
|
||||
return 0 # As there are no nodes to delete, the result is empty
|
||||
|
||||
|
||||
def clear_cache(database: str) -> None:
|
||||
query = "CALL db.clearQueryCaches()"
|
||||
|
||||
try:
|
||||
with get_session(database) as session:
|
||||
session.run(query)
|
||||
|
||||
except GraphDatabaseQueryException as exc:
|
||||
logging.warning(f"Failed to clear query cache for database `{database}`: {exc}")
|
||||
|
||||
|
||||
# Neo4j functions related to Prowler + Cartography
|
||||
DATABASE_NAME_TEMPLATE = "db-{attack_paths_scan_id}"
|
||||
|
||||
|
||||
def get_database_name(attack_paths_scan_id: UUID) -> str:
|
||||
attack_paths_scan_id_str = str(attack_paths_scan_id).lower()
|
||||
return DATABASE_NAME_TEMPLATE.format(attack_paths_scan_id=attack_paths_scan_id_str)
|
||||
|
||||
|
||||
# Exceptions
|
||||
|
||||
|
||||
class GraphDatabaseQueryException(Exception):
|
||||
def __init__(self, message: str, code: str | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.code:
|
||||
return f"{self.code}: {self.message}"
|
||||
|
||||
return self.message
|
||||
@@ -0,0 +1,693 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
# Dataclases for handling API's Attack Path query definitions and their parameters
|
||||
@dataclass
|
||||
class AttackPathsQueryParameterDefinition:
|
||||
"""
|
||||
Metadata describing a parameter that must be provided to an Attack Paths query.
|
||||
"""
|
||||
|
||||
name: str
|
||||
label: str
|
||||
data_type: str = "string"
|
||||
cast: type = str
|
||||
description: str | None = None
|
||||
placeholder: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackPathsQueryDefinition:
|
||||
"""
|
||||
Immutable representation of an Attack Path query.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
provider: str
|
||||
cypher: str
|
||||
parameters: list[AttackPathsQueryParameterDefinition] = field(default_factory=list)
|
||||
|
||||
|
||||
# Accessor functions for API's Attack Paths query definitions
|
||||
def get_queries_for_provider(provider: str) -> list[AttackPathsQueryDefinition]:
|
||||
return _QUERY_DEFINITIONS.get(provider, [])
|
||||
|
||||
|
||||
def get_query_by_id(query_id: str) -> AttackPathsQueryDefinition | None:
|
||||
return _QUERIES_BY_ID.get(query_id)
|
||||
|
||||
|
||||
# API's Attack Paths query definitions
|
||||
_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
|
||||
"aws": [
|
||||
# Custom query for detecting internet-exposed EC2 instances with sensitive S3 access
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-internet-exposed-ec2-sensitive-s3-access",
|
||||
name="Identify internet-exposed EC2 instances with sensitive S3 access",
|
||||
description="Detect EC2 instances with SSH exposed to the internet that can assume higher-privileged roles to read tagged sensitive S3 buckets despite bucket-level public access blocks.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
|
||||
YIELD node AS internet
|
||||
|
||||
MATCH path_s3 = (aws:AWSAccount {id: $provider_uid})--(s3:S3Bucket)--(t:AWSTag)
|
||||
WHERE toLower(t.key) = toLower($tag_key) AND toLower(t.value) = toLower($tag_value)
|
||||
|
||||
MATCH path_ec2 = (aws)--(ec2:EC2Instance)--(sg:EC2SecurityGroup)--(ipi:IpPermissionInbound)
|
||||
WHERE ec2.exposed_internet = true
|
||||
AND ipi.toport = 22
|
||||
|
||||
MATCH path_role = (r:AWSRole)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
|
||||
WHERE ANY(x IN stmt.resource WHERE x CONTAINS s3.name)
|
||||
AND ANY(x IN stmt.action WHERE toLower(x) =~ 's3:(listbucket|getobject).*')
|
||||
|
||||
MATCH path_assume_role = (ec2)-[p:STS_ASSUMEROLE_ALLOW*1..9]-(r:AWSRole)
|
||||
|
||||
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, ec2)
|
||||
YIELD rel AS can_access
|
||||
|
||||
UNWIND nodes(path_s3) + nodes(path_ec2) + nodes(path_role) + nodes(path_assume_role) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path_s3, path_ec2, path_role, path_assume_role, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
|
||||
""",
|
||||
parameters=[
|
||||
AttackPathsQueryParameterDefinition(
|
||||
name="tag_key",
|
||||
label="Tag key",
|
||||
description="Tag key to filter the S3 bucket, e.g. DataClassification.",
|
||||
placeholder="DataClassification",
|
||||
),
|
||||
AttackPathsQueryParameterDefinition(
|
||||
name="tag_value",
|
||||
label="Tag value",
|
||||
description="Tag value to filter the S3 bucket, e.g. Sensitive.",
|
||||
placeholder="Sensitive",
|
||||
),
|
||||
],
|
||||
),
|
||||
# Regular Cartography Attack Paths queries
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-rds-instances",
|
||||
name="Identify provisioned RDS instances",
|
||||
description="List the selected AWS account alongside the RDS instances it owns.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(rds:RDSInstance)
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-rds-unencrypted-storage",
|
||||
name="Identify RDS instances without storage encryption",
|
||||
description="Find RDS instances with storage encryption disabled within the selected account.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(rds:RDSInstance)
|
||||
WHERE rds.storage_encrypted = false
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-s3-anonymous-access-buckets",
|
||||
name="Identify S3 buckets with anonymous access",
|
||||
description="Find S3 buckets that allow anonymous access within the selected account.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(s3:S3Bucket)
|
||||
WHERE s3.anonymous_access = true
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-iam-statements-allow-all-actions",
|
||||
name="Identify IAM statements that allow all actions",
|
||||
description="Find IAM policy statements that allow all actions via '*' within the selected account.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
|
||||
WHERE stmt.effect = 'Allow'
|
||||
AND any(x IN stmt.action WHERE x = '*')
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-iam-statements-allow-delete-policy",
|
||||
name="Identify IAM statements that allow iam:DeletePolicy",
|
||||
description="Find IAM policy statements that allow the iam:DeletePolicy action within the selected account.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
|
||||
WHERE stmt.effect = 'Allow'
|
||||
AND any(x IN stmt.action WHERE x = "iam:DeletePolicy")
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-iam-statements-allow-create-actions",
|
||||
name="Identify IAM statements that allow create actions",
|
||||
description="Find IAM policy statements that allow actions containing 'create' within the selected account.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
|
||||
WHERE stmt.effect = "Allow"
|
||||
AND any(x IN stmt.action WHERE toLower(x) CONTAINS "create")
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-ec2-instances-internet-exposed",
|
||||
name="Identify internet-exposed EC2 instances",
|
||||
description="Find EC2 instances flagged as exposed to the internet within the selected account.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
|
||||
YIELD node AS internet
|
||||
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(ec2:EC2Instance)
|
||||
WHERE ec2.exposed_internet = true
|
||||
|
||||
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, ec2)
|
||||
YIELD rel AS can_access
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-security-groups-open-internet-facing",
|
||||
name="Identify internet-facing resources with open security groups",
|
||||
description="Find internet-facing resources associated with security groups that allow inbound access from '0.0.0.0/0'.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
|
||||
YIELD node AS internet
|
||||
|
||||
MATCH path_open = (aws:AWSAccount {id: $provider_uid})-[r0]-(open)
|
||||
MATCH path_sg = (open)-[r1:MEMBER_OF_EC2_SECURITY_GROUP]-(sg:EC2SecurityGroup)
|
||||
MATCH path_ip = (sg)-[r2:MEMBER_OF_EC2_SECURITY_GROUP]-(ipi:IpPermissionInbound)
|
||||
MATCH path_ipi = (ipi)-[r3]-(ir:IpRange)
|
||||
WHERE ir.range = "0.0.0.0/0"
|
||||
OPTIONAL MATCH path_dns = (dns:AWSDNSRecord)-[:DNS_POINTS_TO]->(lb)
|
||||
WHERE open.scheme = 'internet-facing'
|
||||
|
||||
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, open)
|
||||
YIELD rel AS can_access
|
||||
|
||||
UNWIND nodes(path_open) + nodes(path_sg) + nodes(path_ip) + nodes(path_ipi) + nodes(path_dns) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path_open, path_sg, path_ip, path_ipi, path_dns, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-classic-elb-internet-exposed",
|
||||
name="Identify internet-exposed Classic Load Balancers",
|
||||
description="Find Classic Load Balancers exposed to the internet along with their listeners.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
|
||||
YIELD node AS internet
|
||||
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(elb:LoadBalancer)--(listener:ELBListener)
|
||||
WHERE elb.exposed_internet = true
|
||||
|
||||
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, elb)
|
||||
YIELD rel AS can_access
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-elbv2-internet-exposed",
|
||||
name="Identify internet-exposed ELBv2 load balancers",
|
||||
description="Find ELBv2 load balancers exposed to the internet along with their listeners.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
|
||||
YIELD node AS internet
|
||||
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})--(elbv2:LoadBalancerV2)--(listener:ELBV2Listener)
|
||||
WHERE elbv2.exposed_internet = true
|
||||
|
||||
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, elbv2)
|
||||
YIELD rel AS can_access
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-public-ip-resource-lookup",
|
||||
name="Identify resources by public IP address",
|
||||
description="Given a public IP address, find the related AWS resource and its adjacent node within the selected account.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
|
||||
YIELD node AS internet
|
||||
|
||||
CALL () {
|
||||
MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x:EC2PrivateIp)-[q]-(y)
|
||||
WHERE x.public_ip = $ip
|
||||
RETURN path, x
|
||||
|
||||
UNION MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x:EC2Instance)-[q]-(y)
|
||||
WHERE x.publicipaddress = $ip
|
||||
RETURN path, x
|
||||
|
||||
UNION MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x:NetworkInterface)-[q]-(y)
|
||||
WHERE x.public_ip = $ip
|
||||
RETURN path, x
|
||||
|
||||
UNION MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x:ElasticIPAddress)-[q]-(y)
|
||||
WHERE x.public_ip = $ip
|
||||
RETURN path, x
|
||||
}
|
||||
|
||||
WITH path, x, internet
|
||||
|
||||
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, x)
|
||||
YIELD rel AS can_access
|
||||
|
||||
UNWIND nodes(path) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
|
||||
""",
|
||||
parameters=[
|
||||
AttackPathsQueryParameterDefinition(
|
||||
name="ip",
|
||||
label="IP address",
|
||||
description="Public IP address, e.g. 192.0.2.0.",
|
||||
placeholder="192.0.2.0",
|
||||
),
|
||||
],
|
||||
),
|
||||
# Privilege Escalation Queries (based on pathfinding.cloud research): https://github.com/DataDog/pathfinding.cloud
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-iam-privesc-passrole-ec2",
|
||||
name="Privilege Escalation: iam:PassRole + ec2:RunInstances",
|
||||
description="Detect principals who can launch EC2 instances with privileged IAM roles attached. This allows gaining the permissions of the passed role by accessing the EC2 instance metadata service. This is a new-passrole escalation path (pathfinding.cloud: ec2-001).",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
// Create a single shared virtual EC2 instance node
|
||||
CALL apoc.create.vNode(['EC2Instance'], {
|
||||
id: 'potential-ec2-passrole',
|
||||
name: 'New EC2 Instance',
|
||||
description: 'Attacker-controlled EC2 with privileged role'
|
||||
})
|
||||
YIELD node AS ec2_node
|
||||
|
||||
// Create a single shared virtual escalation outcome node (styled like a finding)
|
||||
CALL apoc.create.vNode(['PrivilegeEscalation'], {
|
||||
id: 'effective-administrator-passrole-ec2',
|
||||
check_title: 'Privilege Escalation',
|
||||
name: 'Effective Administrator',
|
||||
status: 'FAIL',
|
||||
severity: 'critical'
|
||||
})
|
||||
YIELD node AS escalation_outcome
|
||||
|
||||
WITH ec2_node, escalation_outcome
|
||||
|
||||
// Find principals in the account
|
||||
MATCH path_principal = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)
|
||||
|
||||
// Find statements granting iam:PassRole
|
||||
MATCH path_passrole = (principal)--(passrole_policy:AWSPolicy)--(stmt_passrole:AWSPolicyStatement)
|
||||
WHERE stmt_passrole.effect = 'Allow'
|
||||
AND any(action IN stmt_passrole.action WHERE
|
||||
toLower(action) = 'iam:passrole'
|
||||
OR toLower(action) = 'iam:*'
|
||||
OR action = '*'
|
||||
)
|
||||
|
||||
// Find statements granting ec2:RunInstances
|
||||
MATCH path_ec2 = (principal)--(ec2_policy:AWSPolicy)--(stmt_ec2:AWSPolicyStatement)
|
||||
WHERE stmt_ec2.effect = 'Allow'
|
||||
AND any(action IN stmt_ec2.action WHERE
|
||||
toLower(action) = 'ec2:runinstances'
|
||||
OR toLower(action) = 'ec2:*'
|
||||
OR action = '*'
|
||||
)
|
||||
|
||||
// Find roles that trust EC2 service (can be passed to EC2)
|
||||
MATCH path_target = (aws)--(target_role:AWSRole)
|
||||
WHERE target_role.arn CONTAINS $provider_uid
|
||||
// Check if principal can pass this role
|
||||
AND any(resource IN stmt_passrole.resource WHERE
|
||||
resource = '*'
|
||||
OR target_role.arn CONTAINS resource
|
||||
OR resource CONTAINS target_role.name
|
||||
)
|
||||
|
||||
// Check if target role has elevated permissions (optional, for severity assessment)
|
||||
OPTIONAL MATCH (target_role)--(role_policy:AWSPolicy)--(role_stmt:AWSPolicyStatement)
|
||||
WHERE role_stmt.effect = 'Allow'
|
||||
AND (
|
||||
any(action IN role_stmt.action WHERE action = '*')
|
||||
OR any(action IN role_stmt.action WHERE toLower(action) = 'iam:*')
|
||||
)
|
||||
|
||||
CALL apoc.create.vRelationship(principal, 'CAN_LAUNCH', {
|
||||
via: 'ec2:RunInstances + iam:PassRole'
|
||||
}, ec2_node)
|
||||
YIELD rel AS launch_rel
|
||||
|
||||
CALL apoc.create.vRelationship(ec2_node, 'ASSUMES_ROLE', {}, target_role)
|
||||
YIELD rel AS assumes_rel
|
||||
|
||||
CALL apoc.create.vRelationship(target_role, 'GRANTS_ACCESS', {
|
||||
reference: 'https://pathfinding.cloud/paths/ec2-001'
|
||||
}, escalation_outcome)
|
||||
YIELD rel AS grants_rel
|
||||
|
||||
UNWIND nodes(path_principal) + nodes(path_passrole) + nodes(path_ec2) + nodes(path_target) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path_principal, path_passrole, path_ec2, path_target,
|
||||
ec2_node, escalation_outcome, launch_rel, assumes_rel, grants_rel,
|
||||
collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-glue-privesc-passrole-dev-endpoint",
|
||||
name="Privilege Escalation: Glue Dev Endpoint with PassRole",
|
||||
description="Detect principals that can escalate privileges by passing a role to a Glue development endpoint. The attacker creates a dev endpoint with an arbitrary role attached, then accesses those credentials through the endpoint.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
CALL apoc.create.vNode(['PrivilegeEscalation'], {
|
||||
id: 'effective-administrator-glue',
|
||||
check_title: 'Privilege Escalation',
|
||||
name: 'Effective Administrator (Glue)',
|
||||
status: 'FAIL',
|
||||
severity: 'critical'
|
||||
})
|
||||
YIELD node AS escalation_outcome
|
||||
|
||||
WITH escalation_outcome
|
||||
|
||||
// Find principals in the account
|
||||
MATCH path_principal = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)
|
||||
|
||||
// Principal can assume roles (up to 2 hops)
|
||||
OPTIONAL MATCH path_assume = (principal)-[:STS_ASSUMEROLE_ALLOW*0..2]->(acting_as:AWSRole)
|
||||
WITH escalation_outcome, principal, path_principal, path_assume,
|
||||
CASE WHEN path_assume IS NULL THEN principal ELSE acting_as END AS effective_principal
|
||||
|
||||
// Find iam:PassRole permission
|
||||
MATCH path_passrole = (effective_principal)--(passrole_policy:AWSPolicy)--(passrole_stmt:AWSPolicyStatement)
|
||||
WHERE passrole_stmt.effect = 'Allow'
|
||||
AND any(action IN passrole_stmt.action WHERE toLower(action) = 'iam:passrole' OR action = '*')
|
||||
|
||||
// Find Glue CreateDevEndpoint permission
|
||||
MATCH (effective_principal)--(glue_policy:AWSPolicy)--(glue_stmt:AWSPolicyStatement)
|
||||
WHERE glue_stmt.effect = 'Allow'
|
||||
AND any(action IN glue_stmt.action WHERE toLower(action) = 'glue:createdevendpoint' OR action = '*' OR toLower(action) = 'glue:*')
|
||||
|
||||
// Find target role with elevated permissions
|
||||
MATCH (aws)--(target_role:AWSRole)--(target_policy:AWSPolicy)--(target_stmt:AWSPolicyStatement)
|
||||
WHERE target_stmt.effect = 'Allow'
|
||||
AND (
|
||||
any(action IN target_stmt.action WHERE action = '*')
|
||||
OR any(action IN target_stmt.action WHERE toLower(action) = 'iam:*')
|
||||
)
|
||||
|
||||
// Deduplicate before creating virtual nodes
|
||||
WITH DISTINCT escalation_outcome, aws, principal, effective_principal, target_role
|
||||
|
||||
// Create virtual Glue endpoint node (one per unique principal->target pair)
|
||||
CALL apoc.create.vNode(['GlueDevEndpoint'], {
|
||||
name: 'New Dev Endpoint',
|
||||
description: 'Glue endpoint with target role attached',
|
||||
id: effective_principal.arn + '->' + target_role.arn
|
||||
})
|
||||
YIELD node AS glue_endpoint
|
||||
|
||||
CALL apoc.create.vRelationship(effective_principal, 'CREATES_ENDPOINT', {
|
||||
permissions: ['iam:PassRole', 'glue:CreateDevEndpoint'],
|
||||
technique: 'new-passrole'
|
||||
}, glue_endpoint)
|
||||
YIELD rel AS create_rel
|
||||
|
||||
CALL apoc.create.vRelationship(glue_endpoint, 'RUNS_AS', {}, target_role)
|
||||
YIELD rel AS runs_rel
|
||||
|
||||
CALL apoc.create.vRelationship(target_role, 'GRANTS_ACCESS', {
|
||||
reference: 'https://pathfinding.cloud/paths/glue-001'
|
||||
}, escalation_outcome)
|
||||
YIELD rel AS grants_rel
|
||||
|
||||
// Re-match paths for visualization
|
||||
MATCH path_principal = (aws)--(principal)
|
||||
MATCH path_target = (aws)--(target_role)
|
||||
|
||||
RETURN path_principal, path_target,
|
||||
glue_endpoint, escalation_outcome, create_rel, runs_rel, grants_rel
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-iam-privesc-attach-role-policy-assume-role",
|
||||
name="Privilege Escalation: iam:AttachRolePolicy + sts:AssumeRole",
|
||||
description="Detect principals who can both attach policies to roles AND assume those roles. This two-step attack allows modifying a role's permissions then assuming it to gain elevated access. This is a principal-access escalation path (pathfinding.cloud: iam-014).",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
// Create a virtual escalation outcome node (styled like a finding)
|
||||
CALL apoc.create.vNode(['PrivilegeEscalation'], {
|
||||
id: 'effective-administrator',
|
||||
check_title: 'Privilege Escalation',
|
||||
name: 'Effective Administrator',
|
||||
status: 'FAIL',
|
||||
severity: 'critical'
|
||||
})
|
||||
YIELD node AS admin_outcome
|
||||
|
||||
WITH admin_outcome
|
||||
|
||||
// Find principals in the account
|
||||
MATCH path_principal = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)
|
||||
|
||||
// Find statements granting iam:AttachRolePolicy
|
||||
MATCH path_attach = (principal)--(attach_policy:AWSPolicy)--(stmt_attach:AWSPolicyStatement)
|
||||
WHERE stmt_attach.effect = 'Allow'
|
||||
AND any(action IN stmt_attach.action WHERE
|
||||
toLower(action) = 'iam:attachrolepolicy'
|
||||
OR toLower(action) = 'iam:*'
|
||||
OR action = '*'
|
||||
)
|
||||
|
||||
// Find statements granting sts:AssumeRole
|
||||
MATCH path_assume = (principal)--(assume_policy:AWSPolicy)--(stmt_assume:AWSPolicyStatement)
|
||||
WHERE stmt_assume.effect = 'Allow'
|
||||
AND any(action IN stmt_assume.action WHERE
|
||||
toLower(action) = 'sts:assumerole'
|
||||
OR toLower(action) = 'sts:*'
|
||||
OR action = '*'
|
||||
)
|
||||
|
||||
// Find target roles that the principal can both modify AND assume
|
||||
MATCH path_target = (aws)--(target_role:AWSRole)
|
||||
WHERE target_role.arn CONTAINS $provider_uid
|
||||
// Can attach policy to this role
|
||||
AND any(resource IN stmt_attach.resource WHERE
|
||||
resource = '*'
|
||||
OR target_role.arn CONTAINS resource
|
||||
OR resource CONTAINS target_role.name
|
||||
)
|
||||
// Can assume this role
|
||||
AND any(resource IN stmt_assume.resource WHERE
|
||||
resource = '*'
|
||||
OR target_role.arn CONTAINS resource
|
||||
OR resource CONTAINS target_role.name
|
||||
)
|
||||
|
||||
// Deduplicate before creating virtual relationships
|
||||
WITH DISTINCT admin_outcome, aws, principal, target_role
|
||||
|
||||
// Create virtual relationships showing the attack path
|
||||
CALL apoc.create.vRelationship(principal, 'CAN_MODIFY', {
|
||||
via: 'iam:AttachRolePolicy'
|
||||
}, target_role)
|
||||
YIELD rel AS modify_rel
|
||||
|
||||
CALL apoc.create.vRelationship(target_role, 'LEADS_TO', {
|
||||
technique: 'iam:AttachRolePolicy + sts:AssumeRole',
|
||||
via: 'sts:AssumeRole',
|
||||
reference: 'https://pathfinding.cloud/paths/iam-014'
|
||||
}, admin_outcome)
|
||||
YIELD rel AS escalation_rel
|
||||
|
||||
// Re-match paths for visualization
|
||||
MATCH path_principal = (aws)--(principal)
|
||||
MATCH path_target = (aws)--(target_role)
|
||||
|
||||
UNWIND nodes(path_principal) + nodes(path_target) as n
|
||||
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
|
||||
WHERE pf.status = 'FAIL'
|
||||
|
||||
RETURN path_principal, path_target,
|
||||
admin_outcome, modify_rel, escalation_rel,
|
||||
collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
AttackPathsQueryDefinition(
|
||||
id="aws-bedrock-privesc-passrole-code-interpreter",
|
||||
name="Privilege Escalation: Bedrock Code Interpreter with PassRole",
|
||||
description="Detect principals that can escalate privileges by passing a role to a Bedrock AgentCore Code Interpreter. The attacker creates a code interpreter with an arbitrary role, then invokes it to execute code with those credentials.",
|
||||
provider="aws",
|
||||
cypher="""
|
||||
CALL apoc.create.vNode(['PrivilegeEscalation'], {
|
||||
id: 'effective-administrator-bedrock',
|
||||
check_title: 'Privilege Escalation',
|
||||
name: 'Effective Administrator (Bedrock)',
|
||||
status: 'FAIL',
|
||||
severity: 'critical'
|
||||
})
|
||||
YIELD node AS escalation_outcome
|
||||
|
||||
WITH escalation_outcome
|
||||
|
||||
// Find principals in the account
|
||||
MATCH path_principal = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)
|
||||
|
||||
// Principal can assume roles (up to 2 hops)
|
||||
OPTIONAL MATCH path_assume = (principal)-[:STS_ASSUMEROLE_ALLOW*0..2]->(acting_as:AWSRole)
|
||||
WITH escalation_outcome, aws, principal, path_principal, path_assume,
|
||||
CASE WHEN path_assume IS NULL THEN principal ELSE acting_as END AS effective_principal
|
||||
|
||||
// Find iam:PassRole permission
|
||||
MATCH path_passrole = (effective_principal)--(passrole_policy:AWSPolicy)--(passrole_stmt:AWSPolicyStatement)
|
||||
WHERE passrole_stmt.effect = 'Allow'
|
||||
AND any(action IN passrole_stmt.action WHERE toLower(action) = 'iam:passrole' OR action = '*')
|
||||
|
||||
// Find Bedrock AgentCore permissions
|
||||
MATCH (effective_principal)--(bedrock_policy:AWSPolicy)--(bedrock_stmt:AWSPolicyStatement)
|
||||
WHERE bedrock_stmt.effect = 'Allow'
|
||||
AND (
|
||||
any(action IN bedrock_stmt.action WHERE toLower(action) = 'bedrock-agentcore:createcodeinterpreter' OR action = '*' OR toLower(action) = 'bedrock-agentcore:*')
|
||||
)
|
||||
AND (
|
||||
any(action IN bedrock_stmt.action WHERE toLower(action) = 'bedrock-agentcore:startsession' OR action = '*' OR toLower(action) = 'bedrock-agentcore:*')
|
||||
)
|
||||
AND (
|
||||
any(action IN bedrock_stmt.action WHERE toLower(action) = 'bedrock-agentcore:invoke' OR action = '*' OR toLower(action) = 'bedrock-agentcore:*')
|
||||
)
|
||||
|
||||
// Find target roles with elevated permissions that could be passed
|
||||
MATCH (aws)--(target_role:AWSRole)--(target_policy:AWSPolicy)--(target_stmt:AWSPolicyStatement)
|
||||
WHERE target_stmt.effect = 'Allow'
|
||||
AND (
|
||||
any(action IN target_stmt.action WHERE action = '*')
|
||||
OR any(action IN target_stmt.action WHERE toLower(action) = 'iam:*')
|
||||
)
|
||||
|
||||
// Deduplicate per (principal, target_role) pair
|
||||
WITH DISTINCT escalation_outcome, aws, principal, target_role
|
||||
|
||||
// Group by principal, collect target_roles
|
||||
WITH escalation_outcome, aws, principal,
|
||||
collect(DISTINCT target_role) AS target_roles,
|
||||
count(DISTINCT target_role) AS target_count
|
||||
|
||||
// Create single virtual Bedrock node per principal
|
||||
CALL apoc.create.vNode(['BedrockCodeInterpreter'], {
|
||||
name: 'New Code Interpreter',
|
||||
description: toString(target_count) + ' admin role(s) can be passed',
|
||||
id: principal.arn,
|
||||
target_role_count: target_count
|
||||
})
|
||||
YIELD node AS bedrock_agent
|
||||
|
||||
// Connect from principal (not effective_principal) to keep graph connected
|
||||
CALL apoc.create.vRelationship(principal, 'CREATES_INTERPRETER', {
|
||||
permissions: ['iam:PassRole', 'bedrock-agentcore:CreateCodeInterpreter', 'bedrock-agentcore:StartSession', 'bedrock-agentcore:Invoke'],
|
||||
technique: 'new-passrole'
|
||||
}, bedrock_agent)
|
||||
YIELD rel AS create_rel
|
||||
|
||||
// UNWIND target_roles to show which roles can be passed
|
||||
UNWIND target_roles AS target_role
|
||||
|
||||
CALL apoc.create.vRelationship(bedrock_agent, 'PASSES_ROLE', {}, target_role)
|
||||
YIELD rel AS pass_rel
|
||||
|
||||
CALL apoc.create.vRelationship(target_role, 'GRANTS_ACCESS', {
|
||||
reference: 'https://pathfinding.cloud/paths/bedrock-001'
|
||||
}, escalation_outcome)
|
||||
YIELD rel AS grants_rel
|
||||
|
||||
// Re-match path for visualization
|
||||
MATCH path_principal = (aws)--(principal)
|
||||
|
||||
RETURN path_principal,
|
||||
bedrock_agent, target_role, escalation_outcome, create_rel, pass_rel, grants_rel, target_count
|
||||
""",
|
||||
parameters=[],
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
_QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = {
|
||||
definition.id: definition
|
||||
for definitions in _QUERY_DEFINITIONS.values()
|
||||
for definition in definitions
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import neo4j
|
||||
import neo4j.exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetryableSession:
|
||||
"""
|
||||
Wrapper around `neo4j.Session` that retries `neo4j.exceptions.ServiceUnavailable` errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Callable[[], neo4j.Session],
|
||||
max_retries: int,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._max_retries = max(0, max_retries)
|
||||
self._session = self._session_factory()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._session is not None:
|
||||
self._session.close()
|
||||
self._session = None
|
||||
|
||||
def __enter__(self) -> "RetryableSession":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, _: Any, __: Any, ___: Any
|
||||
) -> None: # Unused args: exc_type, exc, exc_tb
|
||||
self.close()
|
||||
|
||||
def run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_with_retry("run", *args, **kwargs)
|
||||
|
||||
def write_transaction(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_with_retry("write_transaction", *args, **kwargs)
|
||||
|
||||
def read_transaction(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_with_retry("read_transaction", *args, **kwargs)
|
||||
|
||||
def execute_write(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_with_retry("execute_write", *args, **kwargs)
|
||||
|
||||
def execute_read(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_with_retry("execute_read", *args, **kwargs)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
return getattr(self._session, item)
|
||||
|
||||
def _call_with_retry(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
attempt = 0
|
||||
last_exc: Exception | None = None
|
||||
|
||||
while attempt <= self._max_retries:
|
||||
try:
|
||||
method = getattr(self._session, method_name)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
except (
|
||||
BrokenPipeError,
|
||||
ConnectionResetError,
|
||||
neo4j.exceptions.ServiceUnavailable,
|
||||
) as exc: # pragma: no cover - depends on infra
|
||||
last_exc = exc
|
||||
attempt += 1
|
||||
|
||||
if attempt > self._max_retries:
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
f"Neo4j session {method_name} failed with {type(exc).__name__} ({attempt}/{self._max_retries} attempts). Retrying..."
|
||||
)
|
||||
self._refresh_session()
|
||||
|
||||
raise last_exc if last_exc else RuntimeError("Unexpected retry loop exit")
|
||||
|
||||
def _refresh_session(self) -> None:
|
||||
if self._session is not None:
|
||||
try:
|
||||
self._session.close()
|
||||
except Exception:
|
||||
# Best-effort close; failures just mean we open a new session below
|
||||
pass
|
||||
|
||||
self._session = self._session_factory()
|
||||
@@ -0,0 +1,143 @@
|
||||
import logging
|
||||
|
||||
from typing import Any
|
||||
|
||||
from rest_framework.exceptions import APIException, ValidationError
|
||||
|
||||
from api.attack_paths import database as graph_database, AttackPathsQueryDefinition
|
||||
from api.models import AttackPathsScan
|
||||
from config.custom_logging import BackendLogger
|
||||
|
||||
logger = logging.getLogger(BackendLogger.API)
|
||||
|
||||
|
||||
def normalize_run_payload(raw_data):
|
||||
if not isinstance(raw_data, dict): # Let the serializer handle this
|
||||
return raw_data
|
||||
|
||||
if "data" in raw_data and isinstance(raw_data.get("data"), dict):
|
||||
data_section = raw_data.get("data") or {}
|
||||
attributes = data_section.get("attributes") or {}
|
||||
payload = {
|
||||
"id": attributes.get("id", data_section.get("id")),
|
||||
"parameters": attributes.get("parameters"),
|
||||
}
|
||||
|
||||
# Remove `None` parameters to allow defaults downstream
|
||||
if payload.get("parameters") is None:
|
||||
payload.pop("parameters")
|
||||
return payload
|
||||
|
||||
return raw_data
|
||||
|
||||
|
||||
def prepare_query_parameters(
|
||||
definition: AttackPathsQueryDefinition,
|
||||
provided_parameters: dict[str, Any],
|
||||
provider_uid: str,
|
||||
) -> dict[str, Any]:
|
||||
parameters = dict(provided_parameters or {})
|
||||
expected_names = {parameter.name for parameter in definition.parameters}
|
||||
provided_names = set(parameters.keys())
|
||||
|
||||
unexpected = provided_names - expected_names
|
||||
if unexpected:
|
||||
raise ValidationError(
|
||||
{"parameters": f"Unknown parameter(s): {', '.join(sorted(unexpected))}"}
|
||||
)
|
||||
|
||||
missing = expected_names - provided_names
|
||||
if missing:
|
||||
raise ValidationError(
|
||||
{
|
||||
"parameters": f"Missing required parameter(s): {', '.join(sorted(missing))}"
|
||||
}
|
||||
)
|
||||
|
||||
clean_parameters = {
|
||||
"provider_uid": str(provider_uid),
|
||||
}
|
||||
|
||||
for definition_parameter in definition.parameters:
|
||||
raw_value = provided_parameters[definition_parameter.name]
|
||||
|
||||
try:
|
||||
casted_value = definition_parameter.cast(raw_value)
|
||||
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValidationError(
|
||||
{
|
||||
"parameters": (
|
||||
f"Invalid value for parameter `{definition_parameter.name}`: {str(exc)}"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
clean_parameters[definition_parameter.name] = casted_value
|
||||
|
||||
return clean_parameters
|
||||
|
||||
|
||||
def execute_attack_paths_query(
|
||||
attack_paths_scan: AttackPathsScan,
|
||||
definition: AttackPathsQueryDefinition,
|
||||
parameters: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
with graph_database.get_session(attack_paths_scan.graph_database) as session:
|
||||
result = session.run(definition.cypher, parameters)
|
||||
return _serialize_graph(result.graph())
|
||||
|
||||
except graph_database.GraphDatabaseQueryException as exc:
|
||||
logger.error(f"Query failed for Attack Paths query `{definition.id}`: {exc}")
|
||||
raise APIException(
|
||||
"Attack Paths query execution failed due to a database error"
|
||||
)
|
||||
|
||||
|
||||
def _serialize_graph(graph):
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
{
|
||||
"id": node.element_id,
|
||||
"labels": list(node.labels),
|
||||
"properties": _serialize_properties(node._properties),
|
||||
},
|
||||
)
|
||||
|
||||
relationships = []
|
||||
for relationship in graph.relationships:
|
||||
relationships.append(
|
||||
{
|
||||
"id": relationship.element_id,
|
||||
"label": relationship.type,
|
||||
"source": relationship.start_node.element_id,
|
||||
"target": relationship.end_node.element_id,
|
||||
"properties": _serialize_properties(relationship._properties),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"relationships": relationships,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_properties(properties: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert Neo4j property values into JSON-serializable primitives."""
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
# Neo4j temporal and spatial values expose `to_native` returning Python primitives
|
||||
if hasattr(value, "to_native") and callable(value.to_native):
|
||||
return _serialize_value(value.to_native())
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_serialize_value(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {key: _serialize_value(val) for key, val in value.items()}
|
||||
|
||||
return value
|
||||
|
||||
return {key: _serialize_value(val) for key, val in properties.items()}
|
||||
@@ -1,15 +1,99 @@
|
||||
from types import MappingProxyType
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
from api.models import Provider
|
||||
from prowler.config.config import get_available_compliance_frameworks
|
||||
from prowler.lib.check.compliance_models import Compliance
|
||||
from prowler.lib.check.models import CheckMetadata
|
||||
|
||||
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE = {}
|
||||
PROWLER_CHECKS = {}
|
||||
AVAILABLE_COMPLIANCE_FRAMEWORKS = {}
|
||||
|
||||
|
||||
class LazyComplianceTemplate(Mapping):
|
||||
"""Lazy-load compliance templates per provider on first access."""
|
||||
|
||||
def __init__(self, provider_types: Iterable[str] | None = None) -> None:
|
||||
if provider_types is None:
|
||||
provider_types = Provider.ProviderChoices.values
|
||||
self._provider_types = tuple(provider_types)
|
||||
self._provider_types_set = set(self._provider_types)
|
||||
self._cache: dict[str, dict] = {}
|
||||
|
||||
def _load_provider(self, provider_type: str) -> dict:
|
||||
if provider_type not in self._provider_types_set:
|
||||
raise KeyError(provider_type)
|
||||
cached = self._cache.get(provider_type)
|
||||
if cached is not None:
|
||||
return cached
|
||||
_ensure_provider_loaded(provider_type)
|
||||
return self._cache[provider_type]
|
||||
|
||||
def __getitem__(self, key: str) -> dict:
|
||||
return self._load_provider(key)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._provider_types)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._provider_types)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self._provider_types_set
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
if key not in self._provider_types_set:
|
||||
return default
|
||||
return self._load_provider(key)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - debugging helper
|
||||
loaded = ", ".join(sorted(self._cache))
|
||||
return f"{self.__class__.__name__}(loaded=[{loaded}])"
|
||||
|
||||
|
||||
class LazyChecksMapping(Mapping):
|
||||
"""Lazy-load checks mapping per provider on first access."""
|
||||
|
||||
def __init__(self, provider_types: Iterable[str] | None = None) -> None:
|
||||
if provider_types is None:
|
||||
provider_types = Provider.ProviderChoices.values
|
||||
self._provider_types = tuple(provider_types)
|
||||
self._provider_types_set = set(self._provider_types)
|
||||
self._cache: dict[str, dict] = {}
|
||||
|
||||
def _load_provider(self, provider_type: str) -> dict:
|
||||
if provider_type not in self._provider_types_set:
|
||||
raise KeyError(provider_type)
|
||||
cached = self._cache.get(provider_type)
|
||||
if cached is not None:
|
||||
return cached
|
||||
_ensure_provider_loaded(provider_type)
|
||||
return self._cache[provider_type]
|
||||
|
||||
def __getitem__(self, key: str) -> dict:
|
||||
return self._load_provider(key)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._provider_types)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._provider_types)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self._provider_types_set
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
if key not in self._provider_types_set:
|
||||
return default
|
||||
return self._load_provider(key)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - debugging helper
|
||||
loaded = ", ".join(sorted(self._cache))
|
||||
return f"{self.__class__.__name__}(loaded=[{loaded}])"
|
||||
|
||||
|
||||
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE = LazyComplianceTemplate()
|
||||
PROWLER_CHECKS = LazyChecksMapping()
|
||||
|
||||
|
||||
def get_compliance_frameworks(provider_type: Provider.ProviderChoices) -> list[str]:
|
||||
"""
|
||||
Retrieve and cache the list of available compliance frameworks for a specific cloud provider.
|
||||
@@ -70,28 +154,35 @@ def get_prowler_provider_compliance(provider_type: Provider.ProviderChoices) ->
|
||||
return Compliance.get_bulk(provider_type)
|
||||
|
||||
|
||||
def load_prowler_compliance():
|
||||
"""
|
||||
Load and initialize the Prowler compliance data and checks for all provider types.
|
||||
|
||||
This function retrieves compliance data for all supported provider types,
|
||||
generates a compliance overview template, and populates the global variables
|
||||
`PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE` and `PROWLER_CHECKS` with read-only mappings
|
||||
of the compliance templates and checks, respectively.
|
||||
"""
|
||||
global PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE
|
||||
global PROWLER_CHECKS
|
||||
|
||||
prowler_compliance = {
|
||||
provider_type: get_prowler_provider_compliance(provider_type)
|
||||
for provider_type in Provider.ProviderChoices.values
|
||||
}
|
||||
template = generate_compliance_overview_template(prowler_compliance)
|
||||
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE = MappingProxyType(template)
|
||||
PROWLER_CHECKS = MappingProxyType(load_prowler_checks(prowler_compliance))
|
||||
def _load_provider_assets(provider_type: Provider.ProviderChoices) -> tuple[dict, dict]:
|
||||
prowler_compliance = {provider_type: get_prowler_provider_compliance(provider_type)}
|
||||
template = generate_compliance_overview_template(
|
||||
prowler_compliance, provider_types=[provider_type]
|
||||
)
|
||||
checks = load_prowler_checks(prowler_compliance, provider_types=[provider_type])
|
||||
return template.get(provider_type, {}), checks.get(provider_type, {})
|
||||
|
||||
|
||||
def load_prowler_checks(prowler_compliance):
|
||||
def _ensure_provider_loaded(provider_type: Provider.ProviderChoices) -> None:
|
||||
if (
|
||||
provider_type in PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE._cache
|
||||
and provider_type in PROWLER_CHECKS._cache
|
||||
):
|
||||
return
|
||||
template_cached = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE._cache.get(provider_type)
|
||||
checks_cached = PROWLER_CHECKS._cache.get(provider_type)
|
||||
if template_cached is not None and checks_cached is not None:
|
||||
return
|
||||
template, checks = _load_provider_assets(provider_type)
|
||||
if template_cached is None:
|
||||
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE._cache[provider_type] = template
|
||||
if checks_cached is None:
|
||||
PROWLER_CHECKS._cache[provider_type] = checks
|
||||
|
||||
|
||||
def load_prowler_checks(
|
||||
prowler_compliance, provider_types: Iterable[str] | None = None
|
||||
):
|
||||
"""
|
||||
Generate a mapping of checks to the compliance frameworks that include them.
|
||||
|
||||
@@ -100,21 +191,25 @@ def load_prowler_checks(prowler_compliance):
|
||||
of compliance names that include that check.
|
||||
|
||||
Args:
|
||||
prowler_compliance (dict): The compliance data for all provider types,
|
||||
prowler_compliance (dict): The compliance data for provider types,
|
||||
as returned by `get_prowler_provider_compliance`.
|
||||
provider_types (Iterable[str] | None): Optional subset of provider types to
|
||||
process. Defaults to all providers.
|
||||
|
||||
Returns:
|
||||
dict: A nested dictionary where the first-level keys are provider types,
|
||||
and the values are dictionaries mapping check IDs to sets of compliance names.
|
||||
"""
|
||||
checks = {}
|
||||
for provider_type in Provider.ProviderChoices.values:
|
||||
if provider_types is None:
|
||||
provider_types = Provider.ProviderChoices.values
|
||||
for provider_type in provider_types:
|
||||
checks[provider_type] = {
|
||||
check_id: set() for check_id in get_prowler_provider_checks(provider_type)
|
||||
}
|
||||
for compliance_name, compliance_data in prowler_compliance[
|
||||
provider_type
|
||||
].items():
|
||||
for compliance_name, compliance_data in prowler_compliance.get(
|
||||
provider_type, {}
|
||||
).items():
|
||||
for requirement in compliance_data.Requirements:
|
||||
for check in requirement.Checks:
|
||||
try:
|
||||
@@ -163,7 +258,9 @@ def generate_scan_compliance(
|
||||
] += 1
|
||||
|
||||
|
||||
def generate_compliance_overview_template(prowler_compliance: dict):
|
||||
def generate_compliance_overview_template(
|
||||
prowler_compliance: dict, provider_types: Iterable[str] | None = None
|
||||
):
|
||||
"""
|
||||
Generate a compliance overview template for all provider types.
|
||||
|
||||
@@ -173,17 +270,21 @@ def generate_compliance_overview_template(prowler_compliance: dict):
|
||||
counts for requirements status.
|
||||
|
||||
Args:
|
||||
prowler_compliance (dict): The compliance data for all provider types,
|
||||
prowler_compliance (dict): The compliance data for provider types,
|
||||
as returned by `get_prowler_provider_compliance`.
|
||||
provider_types (Iterable[str] | None): Optional subset of provider types to
|
||||
process. Defaults to all providers.
|
||||
|
||||
Returns:
|
||||
dict: A nested dictionary representing the compliance overview template,
|
||||
structured by provider type and compliance framework.
|
||||
"""
|
||||
template = {}
|
||||
for provider_type in Provider.ProviderChoices.values:
|
||||
if provider_types is None:
|
||||
provider_types = Provider.ProviderChoices.values
|
||||
for provider_type in provider_types:
|
||||
provider_compliance = template.setdefault(provider_type, {})
|
||||
compliance_data_dict = prowler_compliance[provider_type]
|
||||
compliance_data_dict = prowler_compliance.get(provider_type, {})
|
||||
|
||||
for compliance_name, compliance_data in compliance_data_dict.items():
|
||||
compliance_requirements = {}
|
||||
|
||||
@@ -107,3 +107,105 @@ class ConflictException(APIException):
|
||||
error_detail["source"] = {"pointer": pointer}
|
||||
|
||||
super().__init__(detail=[error_detail])
|
||||
|
||||
|
||||
# Upstream Provider Errors (for external API calls like CloudTrail)
|
||||
# These indicate issues with the provider, not with the user's API authentication
|
||||
|
||||
|
||||
class UpstreamAuthenticationError(APIException):
|
||||
"""Provider credentials are invalid or expired (502 Bad Gateway).
|
||||
|
||||
Used when AWS/Azure/GCP credentials fail to authenticate with the upstream
|
||||
provider. This is NOT the user's API authentication failing.
|
||||
"""
|
||||
|
||||
status_code = status.HTTP_502_BAD_GATEWAY
|
||||
default_detail = (
|
||||
"Provider credentials are invalid or expired. Please reconnect the provider."
|
||||
)
|
||||
default_code = "upstream_auth_failed"
|
||||
|
||||
def __init__(self, detail=None):
|
||||
super().__init__(
|
||||
detail=[
|
||||
{
|
||||
"detail": detail or self.default_detail,
|
||||
"status": str(self.status_code),
|
||||
"code": self.default_code,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class UpstreamAccessDeniedError(APIException):
|
||||
"""Provider credentials lack required permissions (502 Bad Gateway).
|
||||
|
||||
Used when credentials are valid but don't have the IAM permissions
|
||||
needed for the requested operation (e.g., cloudtrail:LookupEvents).
|
||||
This is 502 (not 403) because it's an upstream/gateway error - the USER
|
||||
authenticated fine, but the PROVIDER's credentials are misconfigured.
|
||||
"""
|
||||
|
||||
status_code = status.HTTP_502_BAD_GATEWAY
|
||||
default_detail = (
|
||||
"Access denied. The provider credentials do not have the required permissions."
|
||||
)
|
||||
default_code = "upstream_access_denied"
|
||||
|
||||
def __init__(self, detail=None):
|
||||
super().__init__(
|
||||
detail=[
|
||||
{
|
||||
"detail": detail or self.default_detail,
|
||||
"status": str(self.status_code),
|
||||
"code": self.default_code,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class UpstreamServiceUnavailableError(APIException):
|
||||
"""Provider service is unavailable (503 Service Unavailable).
|
||||
|
||||
Used when the upstream provider API returns an error or is unreachable.
|
||||
"""
|
||||
|
||||
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
default_detail = "Unable to communicate with the provider. Please try again later."
|
||||
default_code = "service_unavailable"
|
||||
|
||||
def __init__(self, detail=None):
|
||||
super().__init__(
|
||||
detail=[
|
||||
{
|
||||
"detail": detail or self.default_detail,
|
||||
"status": str(self.status_code),
|
||||
"code": self.default_code,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class UpstreamInternalError(APIException):
|
||||
"""Unexpected error communicating with provider (500 Internal Server Error).
|
||||
|
||||
Used as a catch-all for unexpected errors during provider communication.
|
||||
"""
|
||||
|
||||
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
default_detail = (
|
||||
"An unexpected error occurred while communicating with the provider."
|
||||
)
|
||||
default_code = "internal_error"
|
||||
|
||||
def __init__(self, detail=None):
|
||||
super().__init__(
|
||||
detail=[
|
||||
{
|
||||
"detail": detail or self.default_detail,
|
||||
"status": str(self.status_code),
|
||||
"code": self.default_code,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from api.models import (
|
||||
Finding,
|
||||
Integration,
|
||||
Invitation,
|
||||
AttackPathsScan,
|
||||
LighthouseProviderConfiguration,
|
||||
LighthouseProviderModels,
|
||||
Membership,
|
||||
@@ -45,6 +46,7 @@ from api.models import (
|
||||
Role,
|
||||
Scan,
|
||||
ScanCategorySummary,
|
||||
ScanGroupSummary,
|
||||
ScanSummary,
|
||||
SeverityChoices,
|
||||
StateChoices,
|
||||
@@ -214,6 +216,9 @@ class CommonFindingFilters(FilterSet):
|
||||
category = CharFilter(method="filter_category")
|
||||
category__in = CharInFilter(field_name="categories", lookup_expr="overlap")
|
||||
|
||||
resource_groups = CharFilter(field_name="resource_groups", lookup_expr="exact")
|
||||
resource_groups__in = CharInFilter(field_name="resource_groups", lookup_expr="in")
|
||||
|
||||
# Temporarily disabled until we implement tag filtering in the UI
|
||||
# resource_tag_key = CharFilter(field_name="resources__tags__key")
|
||||
# resource_tag_key__in = CharInFilter(
|
||||
@@ -392,6 +397,23 @@ class ScanFilter(ProviderRelationshipFilterSet):
|
||||
}
|
||||
|
||||
|
||||
class AttackPathsScanFilter(ProviderRelationshipFilterSet):
|
||||
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
|
||||
completed_at = DateFilter(field_name="completed_at", lookup_expr="date")
|
||||
started_at = DateFilter(field_name="started_at", lookup_expr="date")
|
||||
state = ChoiceFilter(choices=StateChoices.choices)
|
||||
state__in = ChoiceInFilter(
|
||||
field_name="state", choices=StateChoices.choices, lookup_expr="in"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = AttackPathsScan
|
||||
fields = {
|
||||
"provider": ["exact", "in"],
|
||||
"scan": ["exact", "in"],
|
||||
}
|
||||
|
||||
|
||||
class TaskFilter(FilterSet):
|
||||
name = CharFilter(field_name="task_runner_task__task_name", lookup_expr="exact")
|
||||
name__icontains = CharFilter(
|
||||
@@ -439,6 +461,8 @@ class ResourceFilter(ProviderRelationshipFilterSet):
|
||||
updated_at = DateFilter(field_name="updated_at", lookup_expr="date")
|
||||
scan = UUIDFilter(field_name="provider__scan", lookup_expr="exact")
|
||||
scan__in = UUIDInFilter(field_name="provider__scan", lookup_expr="in")
|
||||
groups = CharFilter(method="filter_groups")
|
||||
groups__in = CharInFilter(field_name="groups", lookup_expr="overlap")
|
||||
|
||||
class Meta:
|
||||
model = Resource
|
||||
@@ -453,6 +477,9 @@ class ResourceFilter(ProviderRelationshipFilterSet):
|
||||
"updated_at": ["gte", "lte"],
|
||||
}
|
||||
|
||||
def filter_groups(self, queryset, name, value):
|
||||
return queryset.filter(groups__contains=[value])
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
if not (self.data.get("scan") or self.data.get("scan__in")) and not (
|
||||
self.data.get("updated_at")
|
||||
@@ -517,6 +544,8 @@ class LatestResourceFilter(ProviderRelationshipFilterSet):
|
||||
tag_value = CharFilter(method="filter_tag_value")
|
||||
tag = CharFilter(method="filter_tag")
|
||||
tags = CharFilter(method="filter_tag")
|
||||
groups = CharFilter(method="filter_groups")
|
||||
groups__in = CharInFilter(field_name="groups", lookup_expr="overlap")
|
||||
|
||||
class Meta:
|
||||
model = Resource
|
||||
@@ -529,6 +558,9 @@ class LatestResourceFilter(ProviderRelationshipFilterSet):
|
||||
"type": ["exact", "icontains", "in"],
|
||||
}
|
||||
|
||||
def filter_groups(self, queryset, name, value):
|
||||
return queryset.filter(groups__contains=[value])
|
||||
|
||||
def filter_tag_key(self, queryset, name, value):
|
||||
return queryset.filter(Q(tags__key=value) | Q(tags__key__icontains=value))
|
||||
|
||||
@@ -1154,6 +1186,26 @@ class CategoryOverviewFilter(BaseScanProviderFilter):
|
||||
|
||||
class Meta(BaseScanProviderFilter.Meta):
|
||||
model = ScanCategorySummary
|
||||
fields = {}
|
||||
|
||||
|
||||
class ResourceGroupOverviewFilter(FilterSet):
|
||||
provider_id = UUIDFilter(field_name="scan__provider__id", lookup_expr="exact")
|
||||
provider_id__in = UUIDInFilter(field_name="scan__provider__id", lookup_expr="in")
|
||||
provider_type = ChoiceFilter(
|
||||
field_name="scan__provider__provider", choices=Provider.ProviderChoices.choices
|
||||
)
|
||||
provider_type__in = ChoiceInFilter(
|
||||
field_name="scan__provider__provider",
|
||||
choices=Provider.ProviderChoices.choices,
|
||||
lookup_expr="in",
|
||||
)
|
||||
resource_group = CharFilter(field_name="resource_group", lookup_expr="exact")
|
||||
resource_group__in = CharInFilter(field_name="resource_group", lookup_expr="in")
|
||||
|
||||
class Meta:
|
||||
model = ScanGroupSummary
|
||||
fields = {}
|
||||
|
||||
|
||||
class ComplianceWatchlistFilter(BaseProviderFilter):
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
[
|
||||
{
|
||||
"model": "api.attackpathsscan",
|
||||
"pk": "a7f0f6de-6f8e-4b3a-8cbe-3f6dd9012345",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"provider": "b85601a8-4b45-4194-8135-03fb980ef428",
|
||||
"scan": "01920573-aa9c-73c9-bcda-f2e35c9b19d2",
|
||||
"state": "completed",
|
||||
"progress": 100,
|
||||
"update_tag": 1693586667,
|
||||
"graph_database": "db-a7f0f6de-6f8e-4b3a-8cbe-3f6dd9012345",
|
||||
"is_graph_database_deleted": false,
|
||||
"task": null,
|
||||
"inserted_at": "2024-09-01T17:24:37Z",
|
||||
"updated_at": "2024-09-01T17:44:37Z",
|
||||
"started_at": "2024-09-01T17:34:37Z",
|
||||
"completed_at": "2024-09-01T17:44:37Z",
|
||||
"duration": 269,
|
||||
"ingestion_exceptions": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.attackpathsscan",
|
||||
"pk": "4a2fb2af-8a60-4d7d-9cae-4ca65e098765",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"provider": "15fce1fa-ecaa-433f-a9dc-62553f3a2555",
|
||||
"scan": "01929f3b-ed2e-7623-ad63-7c37cd37828f",
|
||||
"state": "executing",
|
||||
"progress": 48,
|
||||
"update_tag": 1697625000,
|
||||
"graph_database": "db-4a2fb2af-8a60-4d7d-9cae-4ca65e098765",
|
||||
"is_graph_database_deleted": false,
|
||||
"task": null,
|
||||
"inserted_at": "2024-10-18T10:55:57Z",
|
||||
"updated_at": "2024-10-18T10:56:15Z",
|
||||
"started_at": "2024-10-18T10:56:05Z"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,126 @@
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0067_tenant_compliance_summary"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="finding",
|
||||
name="resource_groups",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
help_text="Resource group from check metadata for efficient filtering",
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="ScanGroupSummary",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="api.tenant",
|
||||
),
|
||||
),
|
||||
(
|
||||
"inserted_at",
|
||||
models.DateTimeField(auto_now_add=True),
|
||||
),
|
||||
(
|
||||
"scan",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="resource_group_summaries",
|
||||
related_query_name="resource_group_summary",
|
||||
to="api.scan",
|
||||
),
|
||||
),
|
||||
(
|
||||
"resource_group",
|
||||
models.CharField(max_length=50),
|
||||
),
|
||||
(
|
||||
"severity",
|
||||
api.db_utils.SeverityEnumField(
|
||||
choices=[
|
||||
("critical", "Critical"),
|
||||
("high", "High"),
|
||||
("medium", "Medium"),
|
||||
("low", "Low"),
|
||||
("informational", "Informational"),
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
"total_findings",
|
||||
models.IntegerField(
|
||||
default=0, help_text="Non-muted findings (PASS + FAIL)"
|
||||
),
|
||||
),
|
||||
(
|
||||
"failed_findings",
|
||||
models.IntegerField(
|
||||
default=0,
|
||||
help_text="Non-muted FAIL findings (subset of total_findings)",
|
||||
),
|
||||
),
|
||||
(
|
||||
"new_failed_findings",
|
||||
models.IntegerField(
|
||||
default=0,
|
||||
help_text="Non-muted FAIL with delta='new' (subset of failed_findings)",
|
||||
),
|
||||
),
|
||||
(
|
||||
"resources_count",
|
||||
models.IntegerField(
|
||||
default=0, help_text="Count of distinct resource_uid values"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "scan_resource_group_summaries",
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="scangroupsummary",
|
||||
index=models.Index(
|
||||
fields=["tenant_id", "scan"], name="srgs_tenant_scan_idx"
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="scangroupsummary",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("tenant_id", "scan_id", "resource_group", "severity"),
|
||||
name="unique_resource_group_severity_per_scan",
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="scangroupsummary",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_scangroupsummary",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,21 @@
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0068_finding_resource_group_scangroupsummary"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="resource",
|
||||
name="groups",
|
||||
field=ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
help_text="Groups for categorization (e.g., compute, storage, IAM)",
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,154 @@
|
||||
# Generated by Django 5.1.13 on 2025-11-06 16:20
|
||||
|
||||
import django.db.models.deletion
|
||||
|
||||
from django.db import migrations, models
|
||||
from uuid6 import uuid7
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0069_resource_resource_group"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="AttackPathsScan",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid7,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"state",
|
||||
api.db_utils.StateEnumField(
|
||||
choices=[
|
||||
("available", "Available"),
|
||||
("scheduled", "Scheduled"),
|
||||
("executing", "Executing"),
|
||||
("completed", "Completed"),
|
||||
("failed", "Failed"),
|
||||
("cancelled", "Cancelled"),
|
||||
],
|
||||
default="available",
|
||||
),
|
||||
),
|
||||
("progress", models.IntegerField(default=0)),
|
||||
("started_at", models.DateTimeField(blank=True, null=True)),
|
||||
("completed_at", models.DateTimeField(blank=True, null=True)),
|
||||
(
|
||||
"duration",
|
||||
models.IntegerField(
|
||||
blank=True, help_text="Duration in seconds", null=True
|
||||
),
|
||||
),
|
||||
(
|
||||
"update_tag",
|
||||
models.BigIntegerField(
|
||||
blank=True,
|
||||
help_text="Cartography update tag (epoch)",
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"graph_database",
|
||||
models.CharField(blank=True, max_length=63, null=True),
|
||||
),
|
||||
(
|
||||
"is_graph_database_deleted",
|
||||
models.BooleanField(default=False),
|
||||
),
|
||||
(
|
||||
"ingestion_exceptions",
|
||||
models.JSONField(blank=True, default=dict, null=True),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="attack_paths_scans",
|
||||
related_query_name="attack_paths_scan",
|
||||
to="api.provider",
|
||||
),
|
||||
),
|
||||
(
|
||||
"scan",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="attack_paths_scans",
|
||||
related_query_name="attack_paths_scan",
|
||||
to="api.scan",
|
||||
),
|
||||
),
|
||||
(
|
||||
"task",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="attack_paths_scans",
|
||||
related_query_name="attack_paths_scan",
|
||||
to="api.task",
|
||||
),
|
||||
),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "attack_paths_scans",
|
||||
"abstract": False,
|
||||
"indexes": [
|
||||
models.Index(
|
||||
fields=["tenant_id", "provider_id", "-inserted_at"],
|
||||
name="aps_prov_ins_desc_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "state", "-inserted_at"],
|
||||
name="aps_state_ins_desc_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "scan_id"],
|
||||
name="aps_scan_lookup_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "provider_id"],
|
||||
name="aps_active_graph_idx",
|
||||
include=["graph_database", "id"],
|
||||
condition=models.Q(("is_graph_database_deleted", False)),
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "provider_id", "-completed_at"],
|
||||
name="aps_completed_graph_idx",
|
||||
include=["graph_database", "id"],
|
||||
condition=models.Q(
|
||||
("state", "completed"),
|
||||
("is_graph_database_deleted", False),
|
||||
),
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="attackpathsscan",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_attackpathsscan",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -626,6 +626,101 @@ class Scan(RowLevelSecurityProtectedModel):
|
||||
resource_name = "scans"
|
||||
|
||||
|
||||
class AttackPathsScan(RowLevelSecurityProtectedModel):
|
||||
objects = ActiveProviderManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid7, editable=False)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
updated_at = models.DateTimeField(auto_now=True, editable=False)
|
||||
|
||||
state = StateEnumField(choices=StateChoices.choices, default=StateChoices.AVAILABLE)
|
||||
progress = models.IntegerField(default=0)
|
||||
|
||||
# Timing
|
||||
started_at = models.DateTimeField(null=True, blank=True)
|
||||
completed_at = models.DateTimeField(null=True, blank=True)
|
||||
duration = models.IntegerField(
|
||||
null=True, blank=True, help_text="Duration in seconds"
|
||||
)
|
||||
|
||||
# Relationship to the provider and optional prowler Scan and celery Task
|
||||
provider = models.ForeignKey(
|
||||
"Provider",
|
||||
on_delete=models.CASCADE,
|
||||
related_name="attack_paths_scans",
|
||||
related_query_name="attack_paths_scan",
|
||||
)
|
||||
scan = models.ForeignKey(
|
||||
"Scan",
|
||||
on_delete=models.SET_NULL,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name="attack_paths_scans",
|
||||
related_query_name="attack_paths_scan",
|
||||
)
|
||||
task = models.ForeignKey(
|
||||
"Task",
|
||||
on_delete=models.SET_NULL,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name="attack_paths_scans",
|
||||
related_query_name="attack_paths_scan",
|
||||
)
|
||||
|
||||
# Cartography specific metadata
|
||||
update_tag = models.BigIntegerField(
|
||||
null=True, blank=True, help_text="Cartography update tag (epoch)"
|
||||
)
|
||||
graph_database = models.CharField(max_length=63, null=True, blank=True)
|
||||
is_graph_database_deleted = models.BooleanField(default=False)
|
||||
ingestion_exceptions = models.JSONField(default=dict, null=True, blank=True)
|
||||
|
||||
class Meta(RowLevelSecurityProtectedModel.Meta):
|
||||
db_table = "attack_paths_scans"
|
||||
|
||||
constraints = [
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
indexes = [
|
||||
models.Index(
|
||||
fields=["tenant_id", "provider_id", "-inserted_at"],
|
||||
name="aps_prov_ins_desc_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "state", "-inserted_at"],
|
||||
name="aps_state_ins_desc_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "scan_id"],
|
||||
name="aps_scan_lookup_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "provider_id"],
|
||||
name="aps_active_graph_idx",
|
||||
include=["graph_database", "id"],
|
||||
condition=Q(is_graph_database_deleted=False),
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "provider_id", "-completed_at"],
|
||||
name="aps_completed_graph_idx",
|
||||
include=["graph_database", "id"],
|
||||
condition=Q(
|
||||
state=StateChoices.COMPLETED,
|
||||
is_graph_database_deleted=False,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-scans"
|
||||
|
||||
|
||||
class ResourceTag(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
@@ -704,6 +799,12 @@ class Resource(RowLevelSecurityProtectedModel):
|
||||
metadata = models.TextField(blank=True, null=True)
|
||||
details = models.TextField(blank=True, null=True)
|
||||
partition = models.TextField(blank=True, null=True)
|
||||
groups = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
null=True,
|
||||
help_text="Groups for categorization (e.g., compute, storage, IAM)",
|
||||
)
|
||||
|
||||
failed_findings_count = models.IntegerField(default=0)
|
||||
|
||||
@@ -890,6 +991,11 @@ class Finding(PostgresPartitionedModel, RowLevelSecurityProtectedModel):
|
||||
null=True,
|
||||
help_text="Categories from check metadata for efficient filtering",
|
||||
)
|
||||
resource_groups = models.TextField(
|
||||
blank=True,
|
||||
null=True,
|
||||
help_text="Resource group from check metadata for efficient filtering",
|
||||
)
|
||||
|
||||
# Relationships
|
||||
scan = models.ForeignKey(to=Scan, related_name="findings", on_delete=models.CASCADE)
|
||||
@@ -2032,6 +2138,67 @@ class ScanCategorySummary(RowLevelSecurityProtectedModel):
|
||||
resource_name = "scan-category-summaries"
|
||||
|
||||
|
||||
class ScanGroupSummary(RowLevelSecurityProtectedModel):
|
||||
"""
|
||||
Pre-aggregated resource group metrics per scan by severity.
|
||||
|
||||
Stores one row per (resource_group, severity) combination per scan for efficient
|
||||
overview queries. Resource groups come from check_metadata.Group.
|
||||
|
||||
Count relationships (each is a subset of the previous):
|
||||
- total_findings >= failed_findings >= new_failed_findings
|
||||
"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
|
||||
scan = models.ForeignKey(
|
||||
Scan,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="resource_group_summaries",
|
||||
related_query_name="resource_group_summary",
|
||||
)
|
||||
|
||||
resource_group = models.CharField(max_length=50)
|
||||
severity = SeverityEnumField(choices=SeverityChoices)
|
||||
|
||||
total_findings = models.IntegerField(
|
||||
default=0, help_text="Non-muted findings (PASS + FAIL)"
|
||||
)
|
||||
failed_findings = models.IntegerField(
|
||||
default=0, help_text="Non-muted FAIL findings (subset of total_findings)"
|
||||
)
|
||||
new_failed_findings = models.IntegerField(
|
||||
default=0,
|
||||
help_text="Non-muted FAIL with delta='new' (subset of failed_findings)",
|
||||
)
|
||||
resources_count = models.IntegerField(
|
||||
default=0, help_text="Count of distinct resource_uid values"
|
||||
)
|
||||
|
||||
class Meta(RowLevelSecurityProtectedModel.Meta):
|
||||
db_table = "scan_resource_group_summaries"
|
||||
|
||||
indexes = [
|
||||
models.Index(fields=["tenant_id", "scan"], name="srgs_tenant_scan_idx"),
|
||||
]
|
||||
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=("tenant_id", "scan_id", "resource_group", "severity"),
|
||||
name="unique_resource_group_severity_per_scan",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "scan-resource-group-summaries"
|
||||
|
||||
|
||||
class LighthouseConfiguration(RowLevelSecurityProtectedModel):
|
||||
"""
|
||||
Stores configuration and API keys for LLM services.
|
||||
|
||||
+1986
-15
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,921 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from django.test import override_settings
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
|
||||
from api.models import Provider
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestProviderBatchCreate:
|
||||
"""Tests for the batch provider creation endpoint."""
|
||||
|
||||
content_type = "application/json"
|
||||
|
||||
def test_batch_create_single_provider_success(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test creating a single provider via batch endpoint."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "111111111111",
|
||||
"alias": "Test AWS Account",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 1
|
||||
assert data[0]["attributes"]["provider"] == "aws"
|
||||
assert data[0]["attributes"]["uid"] == "111111111111"
|
||||
assert data[0]["attributes"]["alias"] == "Test AWS Account"
|
||||
|
||||
def test_batch_create_multiple_providers_mixed_types(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test creating multiple providers of different types in one batch."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "222222222222",
|
||||
"alias": "AWS Account 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "azure",
|
||||
"uid": "a1b2c3d4-e5f6-4890-abcd-ef1234567890",
|
||||
"alias": "Azure Subscription",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "gcp",
|
||||
"uid": "my-gcp-project-id",
|
||||
"alias": "GCP Project",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 3
|
||||
|
||||
providers_by_type = {p["attributes"]["provider"]: p for p in data}
|
||||
assert "aws" in providers_by_type
|
||||
assert "azure" in providers_by_type
|
||||
assert "gcp" in providers_by_type
|
||||
|
||||
def test_batch_create_duplicate_uid_in_batch_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that duplicate UIDs within same batch fails entire batch (all-or-nothing)."""
|
||||
initial_count = Provider.objects.count()
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "444444444444",
|
||||
"alias": "AWS Account 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "444444444444",
|
||||
"alias": "AWS Account 2 (duplicate)",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
# All-or-nothing: entire batch fails
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
|
||||
# Should have duplicate error
|
||||
assert any("Duplicate UID" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
# Verify no providers were created
|
||||
assert Provider.objects.count() == initial_count
|
||||
|
||||
def test_batch_create_existing_uid_error(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test that UIDs already existing in tenant are rejected."""
|
||||
existing_provider = providers_fixture[0]
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": existing_provider.provider,
|
||||
"uid": existing_provider.uid,
|
||||
"alias": "Duplicate of existing",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("already exists" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
def test_batch_create_invalid_uid_format_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that invalid UID formats are rejected with proper error messages."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "invalid-aws-uid",
|
||||
"alias": "Invalid AWS",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("/data/0/attributes" in str(e.get("source", {})) for e in errors)
|
||||
|
||||
def test_batch_create_permission_denied(
|
||||
self, authenticated_client_no_permissions_rbac, tenants_fixture
|
||||
):
|
||||
"""Test that users without MANAGE_PROVIDERS permission cannot batch create."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "555555555555",
|
||||
"alias": "Test",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client_no_permissions_rbac.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_batch_create_exceeds_limit_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that batch requests exceeding the limit are rejected."""
|
||||
limit = settings.API_BATCH_MAX_SIZE
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": f"{i:012d}",
|
||||
"alias": f"Provider {i}",
|
||||
},
|
||||
}
|
||||
for i in range(limit + 1)
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any(f"Maximum {limit}" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
def test_batch_create_empty_array_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that empty batch requests are rejected."""
|
||||
payload = {"data": []}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("At least one provider" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
def test_batch_create_invalid_data_format_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that non-array data is rejected."""
|
||||
payload = {
|
||||
"data": {
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "666666666666",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("Must be an array" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
def test_batch_create_sets_correct_tenant(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that batch-created providers have correct tenant assignment."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "777777777777",
|
||||
"alias": "Tenant 1 Provider",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
provider_id = response.json()["data"][0]["id"]
|
||||
|
||||
provider = Provider.objects.get(id=provider_id)
|
||||
assert provider.tenant_id == tenants_fixture[0].id
|
||||
|
||||
def test_batch_create_mixed_valid_invalid_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that mixed valid/invalid items fails entire batch (all-or-nothing)."""
|
||||
initial_count = Provider.objects.count()
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "888888888888",
|
||||
"alias": "Valid AWS",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "invalid-uid",
|
||||
"alias": "Invalid AWS",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
# All-or-nothing: entire batch fails
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
|
||||
# Should have error for invalid item
|
||||
assert len(errors) >= 1
|
||||
assert any(
|
||||
"/data/1" in str(e.get("source", {}).get("pointer", "")) for e in errors
|
||||
)
|
||||
|
||||
# No providers should have been created
|
||||
assert Provider.objects.count() == initial_count
|
||||
|
||||
def test_batch_create_multiple_errors_reported(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that all validation errors are reported, not just the first one."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": "invalid1",
|
||||
"alias": "Invalid 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "azure",
|
||||
"uid": "not-a-uuid",
|
||||
"alias": "Invalid 2",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
# Should have errors for both items
|
||||
error_pointers = [e.get("source", {}).get("pointer", "") for e in errors]
|
||||
assert any("/data/0" in p for p in error_pointers)
|
||||
assert any("/data/1" in p for p in error_pointers)
|
||||
|
||||
def test_batch_create_at_exact_limit_success(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that batch requests at exactly the limit are accepted."""
|
||||
limit = settings.API_BATCH_MAX_SIZE
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": f"{i:012d}",
|
||||
"alias": f"Provider {i}",
|
||||
},
|
||||
}
|
||||
for i in range(limit)
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()["data"]
|
||||
assert len(data) == limit
|
||||
|
||||
@override_settings(API_BATCH_MAX_SIZE=5)
|
||||
def test_batch_create_respects_custom_limit_setting(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that batch endpoint respects custom API_BATCH_MAX_SIZE setting."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": f"{900000000000 + i}",
|
||||
"alias": f"Provider {i}",
|
||||
},
|
||||
}
|
||||
for i in range(6)
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("Maximum 5" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
@override_settings(API_BATCH_MAX_SIZE=3)
|
||||
def test_batch_create_at_custom_limit_success(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that batch requests at exactly the custom limit are accepted."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {
|
||||
"provider": "aws",
|
||||
"uid": f"{800000000000 + i}",
|
||||
"alias": f"Provider {i}",
|
||||
},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 3
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestProviderBatchUpdate:
|
||||
"""Tests for the batch provider update endpoint."""
|
||||
|
||||
content_type = "application/json"
|
||||
|
||||
def test_batch_update_single_provider_success(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test updating a single provider via batch endpoint."""
|
||||
provider = providers_fixture[0]
|
||||
new_alias = "Updated AWS Account"
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(provider.id),
|
||||
"attributes": {
|
||||
"alias": new_alias,
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 1
|
||||
assert data[0]["attributes"]["alias"] == new_alias
|
||||
|
||||
# Verify in database
|
||||
provider.refresh_from_db()
|
||||
assert provider.alias == new_alias
|
||||
|
||||
def test_batch_update_multiple_providers_success(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test updating multiple providers in one batch."""
|
||||
provider1 = providers_fixture[0]
|
||||
provider2 = providers_fixture[1]
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(provider1.id),
|
||||
"attributes": {"alias": "Updated Provider 1"},
|
||||
},
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(provider2.id),
|
||||
"attributes": {"alias": "Updated Provider 2"},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 2
|
||||
|
||||
# Verify in database
|
||||
provider1.refresh_from_db()
|
||||
provider2.refresh_from_db()
|
||||
assert provider1.alias == "Updated Provider 1"
|
||||
assert provider2.alias == "Updated Provider 2"
|
||||
|
||||
def test_batch_update_provider_not_found_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that non-existent providers are rejected."""
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": fake_id,
|
||||
"attributes": {"alias": "New Alias"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("not found" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
def test_batch_update_duplicate_id_in_batch_error(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test that duplicate IDs within same batch fails entire batch (all-or-nothing)."""
|
||||
provider = providers_fixture[0]
|
||||
original_alias = provider.alias
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(provider.id),
|
||||
"attributes": {"alias": "First Update"},
|
||||
},
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(provider.id),
|
||||
"attributes": {"alias": "Second Update (duplicate)"},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
# All-or-nothing: entire batch fails
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
|
||||
# Should have duplicate error
|
||||
assert any("Duplicate provider ID" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
# Verify provider was not updated
|
||||
provider.refresh_from_db()
|
||||
assert provider.alias == original_alias
|
||||
|
||||
def test_batch_update_permission_denied(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
"""Test that users without MANAGE_PROVIDERS permission cannot batch update."""
|
||||
provider = providers_fixture[0]
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(provider.id),
|
||||
"attributes": {"alias": "New Alias"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client_no_permissions_rbac.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_batch_update_exceeds_limit_error(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test that batch requests exceeding the limit are rejected."""
|
||||
limit = settings.API_BATCH_MAX_SIZE
|
||||
provider = providers_fixture[0]
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(provider.id),
|
||||
"attributes": {"alias": f"Provider {i}"},
|
||||
}
|
||||
for i in range(limit + 1)
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any(f"Maximum {limit}" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
def test_batch_update_empty_array_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that empty batch requests are rejected."""
|
||||
payload = {"data": []}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("At least one provider" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
def test_batch_update_invalid_data_format_error(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that non-array data is rejected."""
|
||||
payload = {
|
||||
"data": {
|
||||
"type": "providers",
|
||||
"id": str(uuid.uuid4()),
|
||||
"attributes": {"alias": "Test"},
|
||||
}
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("Must be an array" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
def test_batch_update_missing_id_error(self, authenticated_client, tenants_fixture):
|
||||
"""Test that missing ID is rejected."""
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"attributes": {"alias": "New Alias"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("required" in str(e.get("detail", "")).lower() for e in errors)
|
||||
|
||||
def test_batch_update_preserves_other_fields(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test that updating alias doesn't change other fields."""
|
||||
provider = providers_fixture[0]
|
||||
original_uid = provider.uid
|
||||
original_provider_type = provider.provider
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(provider.id),
|
||||
"attributes": {"alias": "Updated Alias Only"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
provider.refresh_from_db()
|
||||
assert provider.alias == "Updated Alias Only"
|
||||
assert provider.uid == original_uid
|
||||
assert provider.provider == original_provider_type
|
||||
|
||||
def test_batch_update_multiple_errors_reported(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that all validation errors are reported, not just the first one."""
|
||||
fake_id1 = str(uuid.uuid4())
|
||||
fake_id2 = str(uuid.uuid4())
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": fake_id1,
|
||||
"attributes": {"alias": "Provider 1"},
|
||||
},
|
||||
{
|
||||
"type": "providers",
|
||||
"id": fake_id2,
|
||||
"attributes": {"alias": "Provider 2"},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
# Should have errors for both items
|
||||
error_pointers = [e.get("source", {}).get("pointer", "") for e in errors]
|
||||
assert any("/data/0" in p for p in error_pointers)
|
||||
assert any("/data/1" in p for p in error_pointers)
|
||||
|
||||
def test_batch_update_at_exact_limit_success(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that batch requests at exactly the limit are accepted."""
|
||||
limit = settings.API_BATCH_MAX_SIZE
|
||||
tenant = tenants_fixture[0]
|
||||
|
||||
providers = [
|
||||
Provider.objects.create(
|
||||
provider="aws",
|
||||
uid=f"{700000000000 + i}",
|
||||
alias=f"Provider {i}",
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
for i in range(limit)
|
||||
]
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(providers[i].id),
|
||||
"attributes": {"alias": f"Updated Provider {i}"},
|
||||
}
|
||||
for i in range(limit)
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert len(data) == limit
|
||||
|
||||
@override_settings(API_BATCH_MAX_SIZE=5)
|
||||
def test_batch_update_respects_custom_limit_setting(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that batch update endpoint respects custom API_BATCH_MAX_SIZE setting."""
|
||||
tenant = tenants_fixture[0]
|
||||
|
||||
providers = [
|
||||
Provider.objects.create(
|
||||
provider="aws",
|
||||
uid=f"{600000000000 + i}",
|
||||
alias=f"Provider {i}",
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
for i in range(6)
|
||||
]
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(providers[i].id),
|
||||
"attributes": {"alias": f"Updated Provider {i}"},
|
||||
}
|
||||
for i in range(6)
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert any("Maximum 5" in str(e.get("detail", "")) for e in errors)
|
||||
|
||||
@override_settings(API_BATCH_MAX_SIZE=3)
|
||||
def test_batch_update_at_custom_limit_success(
|
||||
self, authenticated_client, tenants_fixture
|
||||
):
|
||||
"""Test that batch requests at exactly the custom limit are accepted."""
|
||||
tenant = tenants_fixture[0]
|
||||
|
||||
providers = [
|
||||
Provider.objects.create(
|
||||
provider="aws",
|
||||
uid=f"{500000000000 + i}",
|
||||
alias=f"Provider {i}",
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"type": "providers",
|
||||
"id": str(providers[i].id),
|
||||
"attributes": {"alias": f"Updated Provider {i}"},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
reverse("provider-batch"),
|
||||
data=payload,
|
||||
content_type=self.content_type,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 3
|
||||
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
|
||||
import api
|
||||
import api.apps as api_apps_module
|
||||
from api.apps import (
|
||||
ApiConfig,
|
||||
@@ -150,3 +153,82 @@ def test_ensure_crypto_keys_skips_when_env_vars(monkeypatch, tmp_path):
|
||||
|
||||
# Assert: orchestrator did not trigger generation when env present
|
||||
assert called["ensure"] is False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_api_modules():
|
||||
"""Provide dummy modules imported during ApiConfig.ready()."""
|
||||
created = []
|
||||
for name in ("api.schema_extensions", "api.signals"):
|
||||
if name not in sys.modules:
|
||||
sys.modules[name] = types.ModuleType(name)
|
||||
created.append(name)
|
||||
|
||||
yield
|
||||
|
||||
for name in created:
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
def _set_argv(monkeypatch, argv):
|
||||
monkeypatch.setattr(sys, "argv", argv, raising=False)
|
||||
|
||||
|
||||
def _set_testing(monkeypatch, value):
|
||||
monkeypatch.setattr(settings, "TESTING", value, raising=False)
|
||||
|
||||
|
||||
def _make_app():
|
||||
return ApiConfig("api", api)
|
||||
|
||||
|
||||
def test_ready_initializes_driver_for_api_process(monkeypatch):
|
||||
config = _make_app()
|
||||
_set_argv(monkeypatch, ["gunicorn"])
|
||||
_set_testing(monkeypatch, False)
|
||||
|
||||
with patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None), patch(
|
||||
"api.attack_paths.database.init_driver"
|
||||
) as init_driver:
|
||||
config.ready()
|
||||
|
||||
init_driver.assert_called_once()
|
||||
|
||||
|
||||
def test_ready_skips_driver_for_celery(monkeypatch):
|
||||
config = _make_app()
|
||||
_set_argv(monkeypatch, ["celery", "-A", "api"])
|
||||
_set_testing(monkeypatch, False)
|
||||
|
||||
with patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None), patch(
|
||||
"api.attack_paths.database.init_driver"
|
||||
) as init_driver:
|
||||
config.ready()
|
||||
|
||||
init_driver.assert_not_called()
|
||||
|
||||
|
||||
def test_ready_skips_driver_for_manage_py_skip_command(monkeypatch):
|
||||
config = _make_app()
|
||||
_set_argv(monkeypatch, ["manage.py", "migrate"])
|
||||
_set_testing(monkeypatch, False)
|
||||
|
||||
with patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None), patch(
|
||||
"api.attack_paths.database.init_driver"
|
||||
) as init_driver:
|
||||
config.ready()
|
||||
|
||||
init_driver.assert_not_called()
|
||||
|
||||
|
||||
def test_ready_skips_driver_when_testing(monkeypatch):
|
||||
config = _make_app()
|
||||
_set_argv(monkeypatch, ["gunicorn"])
|
||||
_set_testing(monkeypatch, True)
|
||||
|
||||
with patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None), patch(
|
||||
"api.attack_paths.database.init_driver"
|
||||
) as init_driver:
|
||||
config.ready()
|
||||
|
||||
init_driver.assert_not_called()
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from rest_framework.exceptions import APIException, ValidationError
|
||||
|
||||
from api.attack_paths import database as graph_database
|
||||
from api.attack_paths import views_helpers
|
||||
|
||||
|
||||
def test_normalize_run_payload_extracts_attributes_section():
|
||||
payload = {
|
||||
"data": {
|
||||
"id": "ignored",
|
||||
"attributes": {
|
||||
"id": "aws-rds",
|
||||
"parameters": {"ip": "192.0.2.0"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result = views_helpers.normalize_run_payload(payload)
|
||||
|
||||
assert result == {"id": "aws-rds", "parameters": {"ip": "192.0.2.0"}}
|
||||
|
||||
|
||||
def test_normalize_run_payload_passthrough_for_non_dict():
|
||||
sentinel = "not-a-dict"
|
||||
assert views_helpers.normalize_run_payload(sentinel) is sentinel
|
||||
|
||||
|
||||
def test_prepare_query_parameters_includes_provider_and_casts(
|
||||
attack_paths_query_definition_factory,
|
||||
):
|
||||
definition = attack_paths_query_definition_factory(cast_type=int)
|
||||
result = views_helpers.prepare_query_parameters(
|
||||
definition,
|
||||
{"limit": "5"},
|
||||
provider_uid="123456789012",
|
||||
)
|
||||
|
||||
assert result["provider_uid"] == "123456789012"
|
||||
assert result["limit"] == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provided,expected_message",
|
||||
[
|
||||
({}, "Missing required parameter"),
|
||||
({"limit": 10, "extra": True}, "Unknown parameter"),
|
||||
],
|
||||
)
|
||||
def test_prepare_query_parameters_validates_names(
|
||||
attack_paths_query_definition_factory, provided, expected_message
|
||||
):
|
||||
definition = attack_paths_query_definition_factory()
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
views_helpers.prepare_query_parameters(definition, provided, provider_uid="1")
|
||||
|
||||
assert expected_message in str(exc.value)
|
||||
|
||||
|
||||
def test_prepare_query_parameters_validates_cast(
|
||||
attack_paths_query_definition_factory,
|
||||
):
|
||||
definition = attack_paths_query_definition_factory(cast_type=int)
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
views_helpers.prepare_query_parameters(
|
||||
definition,
|
||||
{"limit": "not-an-int"},
|
||||
provider_uid="1",
|
||||
)
|
||||
|
||||
assert "Invalid value" in str(exc.value)
|
||||
|
||||
|
||||
def test_execute_attack_paths_query_serializes_graph(
|
||||
attack_paths_query_definition_factory, attack_paths_graph_stub_classes
|
||||
):
|
||||
definition = attack_paths_query_definition_factory(
|
||||
id="aws-rds",
|
||||
name="RDS",
|
||||
description="",
|
||||
cypher="MATCH (n) RETURN n",
|
||||
parameters=[],
|
||||
)
|
||||
parameters = {"provider_uid": "123"}
|
||||
attack_paths_scan = SimpleNamespace(graph_database="tenant-db")
|
||||
|
||||
node = attack_paths_graph_stub_classes.Node(
|
||||
element_id="node-1",
|
||||
labels=["AWSAccount"],
|
||||
properties={
|
||||
"name": "account",
|
||||
"complex": {
|
||||
"items": [
|
||||
attack_paths_graph_stub_classes.NativeValue("value"),
|
||||
{"nested": 1},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
relationship = attack_paths_graph_stub_classes.Relationship(
|
||||
element_id="rel-1",
|
||||
rel_type="OWNS",
|
||||
start_node=node,
|
||||
end_node=attack_paths_graph_stub_classes.Node("node-2", ["RDSInstance"], {}),
|
||||
properties={"weight": 1},
|
||||
)
|
||||
graph = SimpleNamespace(nodes=[node], relationships=[relationship])
|
||||
|
||||
run_result = MagicMock()
|
||||
run_result.graph.return_value = graph
|
||||
|
||||
session = MagicMock()
|
||||
session.run.return_value = run_result
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.views_helpers.graph_database.get_session",
|
||||
return_value=session_ctx,
|
||||
) as mock_get_session:
|
||||
result = views_helpers.execute_attack_paths_query(
|
||||
attack_paths_scan, definition, parameters
|
||||
)
|
||||
|
||||
mock_get_session.assert_called_once_with("tenant-db")
|
||||
session.run.assert_called_once_with(definition.cypher, parameters)
|
||||
assert result["nodes"][0]["id"] == "node-1"
|
||||
assert result["nodes"][0]["properties"]["complex"]["items"][0] == "value"
|
||||
assert result["relationships"][0]["label"] == "OWNS"
|
||||
|
||||
|
||||
def test_execute_attack_paths_query_wraps_graph_errors(
|
||||
attack_paths_query_definition_factory,
|
||||
):
|
||||
definition = attack_paths_query_definition_factory(
|
||||
id="aws-rds",
|
||||
name="RDS",
|
||||
description="",
|
||||
cypher="MATCH (n) RETURN n",
|
||||
parameters=[],
|
||||
)
|
||||
attack_paths_scan = SimpleNamespace(graph_database="tenant-db")
|
||||
parameters = {"provider_uid": "123"}
|
||||
|
||||
class ExplodingContext:
|
||||
def __enter__(self):
|
||||
raise graph_database.GraphDatabaseQueryException("boom")
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.attack_paths.views_helpers.graph_database.get_session",
|
||||
return_value=ExplodingContext(),
|
||||
),
|
||||
patch("api.attack_paths.views_helpers.logger") as mock_logger,
|
||||
):
|
||||
with pytest.raises(APIException):
|
||||
views_helpers.execute_attack_paths_query(
|
||||
attack_paths_scan, definition, parameters
|
||||
)
|
||||
|
||||
mock_logger.error.assert_called_once()
|
||||
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Tests for Neo4j database lazy initialization.
|
||||
|
||||
The Neo4j driver connects on first use by default. API processes may
|
||||
eagerly initialize the driver during app startup, while Celery workers
|
||||
remain lazy. These tests validate the database module behavior itself.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLazyInitialization:
|
||||
"""Test that Neo4j driver is initialized lazily on first use."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
"""Reset module-level singleton state before each test."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
original_driver = db_module._driver
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
|
||||
def test_driver_not_initialized_at_import(self):
|
||||
"""Driver should be None after module import (no eager connection)."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
assert db_module._driver is None
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_init_driver_creates_connection_on_first_call(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""init_driver() should create connection only when called."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
mock_driver = MagicMock()
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
assert db_module._driver is None
|
||||
|
||||
result = db_module.init_driver()
|
||||
|
||||
mock_driver_factory.assert_called_once()
|
||||
mock_driver.verify_connectivity.assert_called_once()
|
||||
assert result is mock_driver
|
||||
assert db_module._driver is mock_driver
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_init_driver_returns_cached_driver_on_subsequent_calls(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""Subsequent calls should return cached driver without reconnecting."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
mock_driver = MagicMock()
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
first_result = db_module.init_driver()
|
||||
second_result = db_module.init_driver()
|
||||
third_result = db_module.init_driver()
|
||||
|
||||
# Only one connection attempt
|
||||
assert mock_driver_factory.call_count == 1
|
||||
assert mock_driver.verify_connectivity.call_count == 1
|
||||
|
||||
# All calls return same instance
|
||||
assert first_result is second_result is third_result
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_get_driver_delegates_to_init_driver(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""get_driver() should use init_driver() for lazy initialization."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
mock_driver = MagicMock()
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
result = db_module.get_driver()
|
||||
|
||||
assert result is mock_driver
|
||||
mock_driver_factory.assert_called_once()
|
||||
|
||||
|
||||
class TestAtexitRegistration:
|
||||
"""Test that atexit cleanup handler is registered correctly."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
"""Reset module-level singleton state before each test."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
original_driver = db_module._driver
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.atexit.register")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_atexit_registered_on_first_init(
|
||||
self, mock_driver_factory, mock_atexit_register, mock_settings
|
||||
):
|
||||
"""atexit.register should be called on first initialization."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
mock_driver_factory.return_value = MagicMock()
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
db_module.init_driver()
|
||||
|
||||
mock_atexit_register.assert_called_once_with(db_module.close_driver)
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.atexit.register")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_atexit_registered_only_once(
|
||||
self, mock_driver_factory, mock_atexit_register, mock_settings
|
||||
):
|
||||
"""atexit.register should only be called once across multiple inits.
|
||||
|
||||
The double-checked locking on _driver ensures the atexit registration
|
||||
block only executes once (when _driver is first created).
|
||||
"""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
mock_driver_factory.return_value = MagicMock()
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
db_module.init_driver()
|
||||
db_module.init_driver()
|
||||
db_module.init_driver()
|
||||
|
||||
# Only registered once because subsequent calls hit the fast path
|
||||
assert mock_atexit_register.call_count == 1
|
||||
|
||||
|
||||
class TestCloseDriver:
|
||||
"""Test driver cleanup functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
"""Reset module-level singleton state before each test."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
original_driver = db_module._driver
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
|
||||
def test_close_driver_closes_and_clears_driver(self):
|
||||
"""close_driver() should close the driver and set it to None."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
mock_driver = MagicMock()
|
||||
db_module._driver = mock_driver
|
||||
|
||||
db_module.close_driver()
|
||||
|
||||
mock_driver.close.assert_called_once()
|
||||
assert db_module._driver is None
|
||||
|
||||
def test_close_driver_handles_none_driver(self):
|
||||
"""close_driver() should handle case where driver is None."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
# Should not raise
|
||||
db_module.close_driver()
|
||||
|
||||
assert db_module._driver is None
|
||||
|
||||
def test_close_driver_clears_driver_even_on_close_error(self):
|
||||
"""Driver should be cleared even if close() raises an exception."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
mock_driver = MagicMock()
|
||||
mock_driver.close.side_effect = Exception("Connection error")
|
||||
db_module._driver = mock_driver
|
||||
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
db_module.close_driver()
|
||||
|
||||
# Driver should still be cleared
|
||||
assert db_module._driver is None
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Test thread-safe initialization."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_module_state(self):
|
||||
"""Reset module-level singleton state before each test."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
original_driver = db_module._driver
|
||||
|
||||
db_module._driver = None
|
||||
|
||||
yield
|
||||
|
||||
db_module._driver = original_driver
|
||||
|
||||
@patch("api.attack_paths.database.settings")
|
||||
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
|
||||
def test_concurrent_init_creates_single_driver(
|
||||
self, mock_driver_factory, mock_settings
|
||||
):
|
||||
"""Multiple threads calling init_driver() should create only one driver."""
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
mock_driver = MagicMock()
|
||||
mock_driver_factory.return_value = mock_driver
|
||||
mock_settings.DATABASES = {
|
||||
"neo4j": {
|
||||
"HOST": "localhost",
|
||||
"PORT": 7687,
|
||||
"USER": "neo4j",
|
||||
"PASSWORD": "password",
|
||||
}
|
||||
}
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def call_init():
|
||||
try:
|
||||
result = db_module.init_driver()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=call_init) for _ in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors, f"Threads raised errors: {errors}"
|
||||
|
||||
# Only one driver created
|
||||
assert mock_driver_factory.call_count == 1
|
||||
|
||||
# All threads got the same driver instance
|
||||
assert all(r is mock_driver for r in results)
|
||||
assert len(results) == 10
|
||||
@@ -6,7 +6,6 @@ from api.compliance import (
|
||||
get_prowler_provider_checks,
|
||||
get_prowler_provider_compliance,
|
||||
load_prowler_checks,
|
||||
load_prowler_compliance,
|
||||
)
|
||||
from api.models import Provider
|
||||
|
||||
@@ -35,55 +34,6 @@ class TestCompliance:
|
||||
assert compliance_data == mock_compliance.get_bulk.return_value
|
||||
mock_compliance.get_bulk.assert_called_once_with(provider_type)
|
||||
|
||||
@patch("api.models.Provider.ProviderChoices")
|
||||
@patch("api.compliance.get_prowler_provider_compliance")
|
||||
@patch("api.compliance.generate_compliance_overview_template")
|
||||
@patch("api.compliance.load_prowler_checks")
|
||||
def test_load_prowler_compliance(
|
||||
self,
|
||||
mock_load_prowler_checks,
|
||||
mock_generate_compliance_overview_template,
|
||||
mock_get_prowler_provider_compliance,
|
||||
mock_provider_choices,
|
||||
):
|
||||
mock_provider_choices.values = ["aws", "azure"]
|
||||
|
||||
compliance_data_aws = {"compliance_aws": MagicMock()}
|
||||
compliance_data_azure = {"compliance_azure": MagicMock()}
|
||||
|
||||
compliance_data_dict = {
|
||||
"aws": compliance_data_aws,
|
||||
"azure": compliance_data_azure,
|
||||
}
|
||||
|
||||
def mock_get_compliance(provider_type):
|
||||
return compliance_data_dict[provider_type]
|
||||
|
||||
mock_get_prowler_provider_compliance.side_effect = mock_get_compliance
|
||||
|
||||
mock_generate_compliance_overview_template.return_value = {
|
||||
"template_key": "template_value"
|
||||
}
|
||||
|
||||
mock_load_prowler_checks.return_value = {"checks_key": "checks_value"}
|
||||
|
||||
load_prowler_compliance()
|
||||
|
||||
from api.compliance import PROWLER_CHECKS, PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE
|
||||
|
||||
assert PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE == {
|
||||
"template_key": "template_value"
|
||||
}
|
||||
assert PROWLER_CHECKS == {"checks_key": "checks_value"}
|
||||
|
||||
expected_prowler_compliance = compliance_data_dict
|
||||
mock_get_prowler_provider_compliance.assert_any_call("aws")
|
||||
mock_get_prowler_provider_compliance.assert_any_call("azure")
|
||||
mock_generate_compliance_overview_template.assert_called_once_with(
|
||||
expected_prowler_compliance
|
||||
)
|
||||
mock_load_prowler_checks.assert_called_once_with(expected_prowler_compliance)
|
||||
|
||||
@patch("api.compliance.get_prowler_provider_checks")
|
||||
@patch("api.models.Provider.ProviderChoices")
|
||||
def test_load_prowler_checks(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from allauth.socialaccount.providers.oauth2.client import OAuth2Client
|
||||
from django.contrib.postgres.aggregates import ArrayAgg
|
||||
@@ -11,19 +14,25 @@ from api.exceptions import InvitationTokenExpiredException
|
||||
from api.models import Integration, Invitation, Processor, Provider, Resource
|
||||
from api.v1.serializers import FindingMetadataSerializer
|
||||
from prowler.lib.outputs.jira.jira import Jira, JiraBasicAuthError
|
||||
from prowler.providers.alibabacloud.alibabacloud_provider import AlibabacloudProvider
|
||||
from prowler.providers.aws.aws_provider import AwsProvider
|
||||
from prowler.providers.aws.lib.s3.s3 import S3
|
||||
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
|
||||
from prowler.providers.azure.azure_provider import AzureProvider
|
||||
from prowler.providers.common.models import Connection
|
||||
from prowler.providers.gcp.gcp_provider import GcpProvider
|
||||
from prowler.providers.github.github_provider import GithubProvider
|
||||
from prowler.providers.iac.iac_provider import IacProvider
|
||||
from prowler.providers.kubernetes.kubernetes_provider import KubernetesProvider
|
||||
from prowler.providers.m365.m365_provider import M365Provider
|
||||
from prowler.providers.mongodbatlas.mongodbatlas_provider import MongodbatlasProvider
|
||||
from prowler.providers.oraclecloud.oraclecloud_provider import OraclecloudProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from prowler.providers.alibabacloud.alibabacloud_provider import (
|
||||
AlibabacloudProvider,
|
||||
)
|
||||
from prowler.providers.aws.aws_provider import AwsProvider
|
||||
from prowler.providers.azure.azure_provider import AzureProvider
|
||||
from prowler.providers.gcp.gcp_provider import GcpProvider
|
||||
from prowler.providers.github.github_provider import GithubProvider
|
||||
from prowler.providers.iac.iac_provider import IacProvider
|
||||
from prowler.providers.kubernetes.kubernetes_provider import KubernetesProvider
|
||||
from prowler.providers.m365.m365_provider import M365Provider
|
||||
from prowler.providers.mongodbatlas.mongodbatlas_provider import (
|
||||
MongodbatlasProvider,
|
||||
)
|
||||
from prowler.providers.oraclecloud.oraclecloud_provider import OraclecloudProvider
|
||||
|
||||
|
||||
class CustomOAuth2Client(OAuth2Client):
|
||||
@@ -89,24 +98,52 @@ def return_prowler_provider(
|
||||
"""
|
||||
match provider.provider:
|
||||
case Provider.ProviderChoices.AWS.value:
|
||||
from prowler.providers.aws.aws_provider import AwsProvider
|
||||
|
||||
prowler_provider = AwsProvider
|
||||
case Provider.ProviderChoices.GCP.value:
|
||||
from prowler.providers.gcp.gcp_provider import GcpProvider
|
||||
|
||||
prowler_provider = GcpProvider
|
||||
case Provider.ProviderChoices.AZURE.value:
|
||||
from prowler.providers.azure.azure_provider import AzureProvider
|
||||
|
||||
prowler_provider = AzureProvider
|
||||
case Provider.ProviderChoices.KUBERNETES.value:
|
||||
from prowler.providers.kubernetes.kubernetes_provider import (
|
||||
KubernetesProvider,
|
||||
)
|
||||
|
||||
prowler_provider = KubernetesProvider
|
||||
case Provider.ProviderChoices.M365.value:
|
||||
from prowler.providers.m365.m365_provider import M365Provider
|
||||
|
||||
prowler_provider = M365Provider
|
||||
case Provider.ProviderChoices.GITHUB.value:
|
||||
from prowler.providers.github.github_provider import GithubProvider
|
||||
|
||||
prowler_provider = GithubProvider
|
||||
case Provider.ProviderChoices.MONGODBATLAS.value:
|
||||
from prowler.providers.mongodbatlas.mongodbatlas_provider import (
|
||||
MongodbatlasProvider,
|
||||
)
|
||||
|
||||
prowler_provider = MongodbatlasProvider
|
||||
case Provider.ProviderChoices.IAC.value:
|
||||
from prowler.providers.iac.iac_provider import IacProvider
|
||||
|
||||
prowler_provider = IacProvider
|
||||
case Provider.ProviderChoices.ORACLECLOUD.value:
|
||||
from prowler.providers.oraclecloud.oraclecloud_provider import (
|
||||
OraclecloudProvider,
|
||||
)
|
||||
|
||||
prowler_provider = OraclecloudProvider
|
||||
case Provider.ProviderChoices.ALIBABACLOUD.value:
|
||||
from prowler.providers.alibabacloud.alibabacloud_provider import (
|
||||
AlibabacloudProvider,
|
||||
)
|
||||
|
||||
prowler_provider = AlibabacloudProvider
|
||||
case _:
|
||||
raise ValueError(f"Provider type {provider.provider} not supported")
|
||||
@@ -393,11 +430,21 @@ def get_findings_metadata_no_aggregations(tenant_id: str, filtered_queryset):
|
||||
categories_set.update(categories_list)
|
||||
categories = sorted(categories_set)
|
||||
|
||||
# Aggregate groups from findings
|
||||
groups = list(
|
||||
filtered_queryset.exclude(resource_groups__isnull=True)
|
||||
.exclude(resource_groups__exact="")
|
||||
.values_list("resource_groups", flat=True)
|
||||
.distinct()
|
||||
.order_by("resource_groups")
|
||||
)
|
||||
|
||||
result = {
|
||||
"services": services,
|
||||
"regions": regions,
|
||||
"resource_types": resource_types,
|
||||
"categories": categories,
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
serializer = FindingMetadataSerializer(data=result)
|
||||
|
||||
@@ -21,6 +21,7 @@ from rest_framework_simplejwt.tokens import RefreshToken
|
||||
from api.db_router import MainRouter
|
||||
from api.exceptions import ConflictException
|
||||
from api.models import (
|
||||
AttackPathsScan,
|
||||
Finding,
|
||||
Integration,
|
||||
IntegrationProviderRelationship,
|
||||
@@ -978,6 +979,283 @@ class ProviderUpdateSerializer(BaseWriteSerializer):
|
||||
}
|
||||
|
||||
|
||||
class ProviderBatchItemSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
class Meta:
|
||||
model = Provider
|
||||
fields = ["alias", "provider", "uid"]
|
||||
|
||||
def validate(self, attrs):
|
||||
provider_type = attrs.get("provider")
|
||||
uid = attrs.get("uid")
|
||||
if provider_type and uid:
|
||||
validator_method = getattr(Provider, f"validate_{provider_type}_uid", None)
|
||||
if validator_method:
|
||||
validator_method(uid)
|
||||
return attrs
|
||||
|
||||
|
||||
class ProviderBatchCreateSerializer(BaseSerializerV1):
|
||||
"""Serializer for batch creation of providers with all-or-nothing semantics (JSON:API compliant)."""
|
||||
|
||||
class Meta:
|
||||
resource_name = "providers"
|
||||
|
||||
def validate(self, attrs):
|
||||
data = self.initial_data.get("data", [])
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValidationError({"data": "Must be an array of provider objects"})
|
||||
|
||||
if len(data) > settings.API_BATCH_MAX_SIZE:
|
||||
raise ValidationError(
|
||||
{"data": f"Maximum {settings.API_BATCH_MAX_SIZE} providers per batch"}
|
||||
)
|
||||
|
||||
if len(data) == 0:
|
||||
raise ValidationError({"data": "At least one provider required"})
|
||||
|
||||
seen_uids = {}
|
||||
all_errors = []
|
||||
validated_items = []
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
for idx, item in enumerate(data):
|
||||
current_errors = []
|
||||
item_type = item.get("type")
|
||||
|
||||
if not item_type:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/type"},
|
||||
}
|
||||
)
|
||||
elif item_type != "providers":
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": f"Invalid type '{item_type}'. Expected 'providers'.",
|
||||
"source": {"pointer": f"/data/{idx}/type"},
|
||||
}
|
||||
)
|
||||
|
||||
item_attrs = item.get("attributes", {})
|
||||
if not item_attrs:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/attributes"},
|
||||
}
|
||||
)
|
||||
all_errors.extend(current_errors)
|
||||
continue
|
||||
|
||||
provider_type = item_attrs.get("provider")
|
||||
uid = item_attrs.get("uid")
|
||||
key = (provider_type, uid)
|
||||
|
||||
# Validate provider type before any DB queries
|
||||
valid_provider_types = [choice.value for choice in Provider.ProviderChoices]
|
||||
provider_type_valid = provider_type in valid_provider_types
|
||||
|
||||
if not provider_type:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/attributes/provider"},
|
||||
}
|
||||
)
|
||||
elif not provider_type_valid:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": f"Invalid provider type '{provider_type}'. Must be one of: {', '.join(valid_provider_types)}.",
|
||||
"source": {"pointer": f"/data/{idx}/attributes/provider"},
|
||||
}
|
||||
)
|
||||
|
||||
if key in seen_uids:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": f"Duplicate UID '{uid}' at index {idx} (first at {seen_uids[key]})",
|
||||
"source": {"pointer": f"/data/{idx}/attributes/uid"},
|
||||
}
|
||||
)
|
||||
else:
|
||||
seen_uids[key] = idx
|
||||
|
||||
# Only check DB if provider type is valid (to avoid enum errors)
|
||||
if (
|
||||
provider_type_valid
|
||||
and uid
|
||||
and Provider.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_type,
|
||||
uid=uid,
|
||||
is_deleted=False,
|
||||
).exists()
|
||||
):
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": f"Provider with uid '{uid}' already exists",
|
||||
"source": {"pointer": f"/data/{idx}/attributes/uid"},
|
||||
}
|
||||
)
|
||||
|
||||
item_serializer = ProviderBatchItemSerializer(
|
||||
data=item_attrs, context=self.context
|
||||
)
|
||||
if not item_serializer.is_valid():
|
||||
for field, field_errors in item_serializer.errors.items():
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": str(field_errors[0]),
|
||||
"source": {"pointer": f"/data/{idx}/attributes/{field}"},
|
||||
}
|
||||
)
|
||||
|
||||
if current_errors:
|
||||
all_errors.extend(current_errors)
|
||||
else:
|
||||
validated_items.append(
|
||||
{"index": idx, "data": item_serializer.validated_data}
|
||||
)
|
||||
|
||||
# All-or-nothing: if any errors, fail the entire batch
|
||||
if all_errors:
|
||||
raise ValidationError(all_errors)
|
||||
|
||||
attrs["_validated_items"] = validated_items
|
||||
return attrs
|
||||
|
||||
|
||||
class ProviderBatchUpdateItemSerializer(BaseWriteSerializer):
|
||||
"""Serializer for validating individual provider update items in batch."""
|
||||
|
||||
class Meta:
|
||||
model = Provider
|
||||
fields = ["alias"]
|
||||
|
||||
|
||||
class ProviderBatchUpdateSerializer(BaseSerializerV1):
|
||||
"""Serializer for batch update of providers with all-or-nothing semantics (JSON:API compliant)."""
|
||||
|
||||
class Meta:
|
||||
resource_name = "providers"
|
||||
|
||||
def validate(self, attrs):
|
||||
data = self.initial_data.get("data", [])
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValidationError({"data": "Must be an array of provider objects"})
|
||||
|
||||
if len(data) > settings.API_BATCH_MAX_SIZE:
|
||||
raise ValidationError(
|
||||
{"data": f"Maximum {settings.API_BATCH_MAX_SIZE} providers per batch"}
|
||||
)
|
||||
|
||||
if len(data) == 0:
|
||||
raise ValidationError({"data": "At least one provider required"})
|
||||
|
||||
all_errors = []
|
||||
validated_items = []
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
seen_ids = {}
|
||||
|
||||
for idx, item in enumerate(data):
|
||||
current_errors = []
|
||||
item_type = item.get("type")
|
||||
item_id = item.get("id")
|
||||
|
||||
if not item_type:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/type"},
|
||||
}
|
||||
)
|
||||
elif item_type != "providers":
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": f"Invalid type '{item_type}'. Expected 'providers'.",
|
||||
"source": {"pointer": f"/data/{idx}/type"},
|
||||
}
|
||||
)
|
||||
|
||||
if not item_id:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/id"},
|
||||
}
|
||||
)
|
||||
all_errors.extend(current_errors)
|
||||
continue
|
||||
|
||||
if item_id in seen_ids:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": f"Duplicate provider ID '{item_id}' at index {idx} (first at {seen_ids[item_id]})",
|
||||
"source": {"pointer": f"/data/{idx}/id"},
|
||||
}
|
||||
)
|
||||
else:
|
||||
seen_ids[item_id] = idx
|
||||
|
||||
try:
|
||||
provider = Provider.objects.get(
|
||||
id=item_id, tenant_id=tenant_id, is_deleted=False
|
||||
)
|
||||
except Provider.DoesNotExist:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": f"Provider '{item_id}' not found.",
|
||||
"source": {"pointer": f"/data/{idx}/id"},
|
||||
}
|
||||
)
|
||||
all_errors.extend(current_errors)
|
||||
continue
|
||||
|
||||
item_attrs = item.get("attributes", {})
|
||||
if not item_attrs:
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/attributes"},
|
||||
}
|
||||
)
|
||||
all_errors.extend(current_errors)
|
||||
continue
|
||||
|
||||
item_serializer = ProviderBatchUpdateItemSerializer(
|
||||
data=item_attrs, context=self.context
|
||||
)
|
||||
if not item_serializer.is_valid():
|
||||
for field, field_errors in item_serializer.errors.items():
|
||||
current_errors.append(
|
||||
{
|
||||
"detail": str(field_errors[0]),
|
||||
"source": {"pointer": f"/data/{idx}/attributes/{field}"},
|
||||
}
|
||||
)
|
||||
|
||||
if current_errors:
|
||||
all_errors.extend(current_errors)
|
||||
else:
|
||||
validated_items.append(
|
||||
{
|
||||
"index": idx,
|
||||
"provider": provider,
|
||||
**item_serializer.validated_data,
|
||||
}
|
||||
)
|
||||
|
||||
# All-or-nothing: if any errors, fail the entire batch
|
||||
if all_errors:
|
||||
raise ValidationError(all_errors)
|
||||
|
||||
attrs["_validated_items"] = validated_items
|
||||
return attrs
|
||||
|
||||
|
||||
# Scans
|
||||
|
||||
|
||||
@@ -1132,6 +1410,109 @@ class ScanComplianceReportSerializer(BaseSerializerV1):
|
||||
fields = ["id", "name"]
|
||||
|
||||
|
||||
class AttackPathsScanSerializer(RLSSerializer):
|
||||
state = StateEnumSerializerField(read_only=True)
|
||||
provider_alias = serializers.SerializerMethodField(read_only=True)
|
||||
provider_type = serializers.SerializerMethodField(read_only=True)
|
||||
provider_uid = serializers.SerializerMethodField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = AttackPathsScan
|
||||
fields = [
|
||||
"id",
|
||||
"state",
|
||||
"progress",
|
||||
"provider",
|
||||
"provider_alias",
|
||||
"provider_type",
|
||||
"provider_uid",
|
||||
"scan",
|
||||
"task",
|
||||
"inserted_at",
|
||||
"started_at",
|
||||
"completed_at",
|
||||
"duration",
|
||||
]
|
||||
|
||||
included_serializers = {
|
||||
"provider": "api.v1.serializers.ProviderIncludeSerializer",
|
||||
"scan": "api.v1.serializers.ScanIncludeSerializer",
|
||||
"task": "api.v1.serializers.TaskSerializer",
|
||||
}
|
||||
|
||||
def get_provider_alias(self, obj):
|
||||
provider = getattr(obj, "provider", None)
|
||||
return provider.alias if provider else None
|
||||
|
||||
def get_provider_type(self, obj):
|
||||
provider = getattr(obj, "provider", None)
|
||||
return provider.provider if provider else None
|
||||
|
||||
def get_provider_uid(self, obj):
|
||||
provider = getattr(obj, "provider", None)
|
||||
return provider.uid if provider else None
|
||||
|
||||
|
||||
class AttackPathsQueryParameterSerializer(BaseSerializerV1):
|
||||
name = serializers.CharField()
|
||||
label = serializers.CharField()
|
||||
data_type = serializers.CharField(default="string")
|
||||
description = serializers.CharField(allow_null=True, required=False)
|
||||
placeholder = serializers.CharField(allow_null=True, required=False)
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-query-parameters"
|
||||
|
||||
|
||||
class AttackPathsQuerySerializer(BaseSerializerV1):
|
||||
id = serializers.CharField()
|
||||
name = serializers.CharField()
|
||||
description = serializers.CharField()
|
||||
provider = serializers.CharField()
|
||||
parameters = AttackPathsQueryParameterSerializer(many=True)
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-queries"
|
||||
|
||||
|
||||
class AttackPathsQueryRunRequestSerializer(BaseSerializerV1):
|
||||
id = serializers.CharField()
|
||||
parameters = serializers.DictField(
|
||||
child=serializers.JSONField(), allow_empty=True, required=False
|
||||
)
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-query-run-requests"
|
||||
|
||||
|
||||
class AttackPathsNodeSerializer(BaseSerializerV1):
|
||||
id = serializers.CharField()
|
||||
labels = serializers.ListField(child=serializers.CharField())
|
||||
properties = serializers.DictField(child=serializers.JSONField())
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-query-result-nodes"
|
||||
|
||||
|
||||
class AttackPathsRelationshipSerializer(BaseSerializerV1):
|
||||
id = serializers.CharField()
|
||||
label = serializers.CharField()
|
||||
source = serializers.CharField()
|
||||
target = serializers.CharField()
|
||||
properties = serializers.DictField(child=serializers.JSONField())
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-query-result-relationships"
|
||||
|
||||
|
||||
class AttackPathsQueryResultSerializer(BaseSerializerV1):
|
||||
nodes = AttackPathsNodeSerializer(many=True)
|
||||
relationships = AttackPathsRelationshipSerializer(many=True)
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-query-results"
|
||||
|
||||
|
||||
class ResourceTagSerializer(RLSSerializer):
|
||||
"""
|
||||
Serializer for the ResourceTag model
|
||||
@@ -1175,6 +1556,7 @@ class ResourceSerializer(RLSSerializer):
|
||||
"metadata",
|
||||
"details",
|
||||
"partition",
|
||||
"groups",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"id": {"read_only": True},
|
||||
@@ -1183,6 +1565,7 @@ class ResourceSerializer(RLSSerializer):
|
||||
"metadata": {"read_only": True},
|
||||
"details": {"read_only": True},
|
||||
"partition": {"read_only": True},
|
||||
"groups": {"read_only": True},
|
||||
}
|
||||
|
||||
included_serializers = {
|
||||
@@ -1276,6 +1659,7 @@ class ResourceMetadataSerializer(BaseSerializerV1):
|
||||
services = serializers.ListField(child=serializers.CharField(), allow_empty=True)
|
||||
regions = serializers.ListField(child=serializers.CharField(), allow_empty=True)
|
||||
types = serializers.ListField(child=serializers.CharField(), allow_empty=True)
|
||||
groups = serializers.ListField(child=serializers.CharField(), allow_empty=True)
|
||||
# Temporarily disabled until we implement tag filtering in the UI
|
||||
# tags = serializers.JSONField(help_text="Tags are described as key-value pairs.")
|
||||
|
||||
@@ -1302,6 +1686,7 @@ class FindingSerializer(RLSSerializer):
|
||||
"check_id",
|
||||
"check_metadata",
|
||||
"categories",
|
||||
"resource_groups",
|
||||
"raw_result",
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
@@ -1358,6 +1743,9 @@ class FindingMetadataSerializer(BaseSerializerV1):
|
||||
child=serializers.CharField(), allow_empty=True
|
||||
)
|
||||
categories = serializers.ListField(child=serializers.CharField(), allow_empty=True)
|
||||
groups = serializers.ListField(
|
||||
child=serializers.CharField(), allow_empty=True, required=False, default=list
|
||||
)
|
||||
# Temporarily disabled until we implement tag filtering in the UI
|
||||
# tags = serializers.JSONField(help_text="Tags are described as key-value pairs.")
|
||||
|
||||
@@ -1603,6 +1991,13 @@ class ProviderSecretSerializer(RLSSerializer):
|
||||
"url",
|
||||
]
|
||||
|
||||
def get_root_meta(self, _resource, _many):
|
||||
meta = super().get_root_meta(_resource, _many)
|
||||
skipped = self.context.get("_skipped_providers")
|
||||
if skipped:
|
||||
meta["skipped"] = skipped
|
||||
return meta
|
||||
|
||||
|
||||
class ProviderSecretCreateSerializer(RLSSerializer, BaseWriteProviderSecretSerializer):
|
||||
secret = ProviderSecretField(write_only=True)
|
||||
@@ -1664,6 +2059,446 @@ class ProviderSecretUpdateSerializer(BaseWriteProviderSecretSerializer):
|
||||
return validated_attrs
|
||||
|
||||
|
||||
class ProviderSecretBatchItemSerializer(BaseWriteProviderSecretSerializer):
|
||||
"""Serializer for an individual item in the batch of secrets."""
|
||||
|
||||
secret = ProviderSecretField(write_only=True)
|
||||
|
||||
class Meta:
|
||||
model = ProviderSecret
|
||||
fields = ["name", "secret_type", "secret"]
|
||||
|
||||
def validate(self, attrs):
|
||||
# Provider is passed via context since it's validated separately
|
||||
provider = self.context.get("provider")
|
||||
secret_type = attrs.get("secret_type")
|
||||
secret = attrs.get("secret")
|
||||
|
||||
if provider and secret_type and secret:
|
||||
self.validate_secret_based_on_provider(
|
||||
provider.provider, secret_type, secret
|
||||
)
|
||||
return attrs
|
||||
|
||||
|
||||
class ProviderSecretBatchCreateSerializer(BaseSerializerV1):
|
||||
"""
|
||||
Serializer for batch creation of provider secrets.
|
||||
|
||||
Supports to-many relationship format where one secret definition can be
|
||||
associated with multiple providers. Each provider creates a separate secret.
|
||||
|
||||
JSON:API compliant: all-or-nothing for hard errors, soft skips for providers
|
||||
that already have secrets (reported in meta.skipped).
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
resource_name = "provider-secrets"
|
||||
|
||||
def _extract_providers_data(self, relationships, idx):
|
||||
"""
|
||||
Extract providers data from relationships, supporting both formats:
|
||||
- to-one: relationships.provider.data (single object) - backwards compatible
|
||||
- to-many: relationships.providers.data (array) - new format
|
||||
|
||||
Returns (providers_list, errors) where providers_list is normalized to array.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Try to-many format first (providers plural)
|
||||
providers_rel = relationships.get("providers", {})
|
||||
providers_data = providers_rel.get("data")
|
||||
|
||||
if providers_data is not None:
|
||||
# to-many format
|
||||
if isinstance(providers_data, dict):
|
||||
# Single object in to-many field - normalize to array
|
||||
providers_data = [providers_data]
|
||||
elif not isinstance(providers_data, list):
|
||||
errors.append(
|
||||
{
|
||||
"detail": "Must be an array of provider resource identifiers.",
|
||||
"source": {
|
||||
"pointer": f"/data/{idx}/relationships/providers/data"
|
||||
},
|
||||
}
|
||||
)
|
||||
return None, errors
|
||||
return providers_data, errors
|
||||
|
||||
# Fall back to to-one format (provider singular) for backwards compatibility
|
||||
provider_rel = relationships.get("provider", {})
|
||||
provider_data = provider_rel.get("data")
|
||||
|
||||
if provider_data is not None:
|
||||
if isinstance(provider_data, dict):
|
||||
return [provider_data], errors
|
||||
else:
|
||||
errors.append(
|
||||
{
|
||||
"detail": "Must be a provider resource identifier object.",
|
||||
"source": {
|
||||
"pointer": f"/data/{idx}/relationships/provider/data"
|
||||
},
|
||||
}
|
||||
)
|
||||
return None, errors
|
||||
|
||||
# No providers relationship found
|
||||
errors.append(
|
||||
{
|
||||
"detail": "Providers relationship is required.",
|
||||
"source": {"pointer": f"/data/{idx}/relationships/providers"},
|
||||
}
|
||||
)
|
||||
return None, errors
|
||||
|
||||
def validate(self, attrs):
|
||||
data = self.initial_data.get("data", [])
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValidationError(
|
||||
{"data": "Must be an array of provider-secret objects"}
|
||||
)
|
||||
|
||||
if len(data) > settings.API_BATCH_MAX_SIZE:
|
||||
raise ValidationError(
|
||||
{"data": f"Maximum {settings.API_BATCH_MAX_SIZE} secrets per batch"}
|
||||
)
|
||||
|
||||
if len(data) == 0:
|
||||
raise ValidationError({"data": "At least one secret required"})
|
||||
|
||||
hard_errors = [] # Will cause full batch failure
|
||||
skipped_providers = [] # Already have secrets, reported in meta
|
||||
validated_items = []
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
seen_providers = {} # Track duplicates across the entire batch
|
||||
|
||||
for idx, item in enumerate(data):
|
||||
item_type = item.get("type")
|
||||
|
||||
# Validate type
|
||||
if not item_type:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/type"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
elif item_type != "provider-secrets":
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": f"Invalid type '{item_type}'. Expected 'provider-secrets'.",
|
||||
"source": {"pointer": f"/data/{idx}/type"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate attributes
|
||||
item_attrs = item.get("attributes", {})
|
||||
if not item_attrs:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/attributes"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Extract providers (supports both to-one and to-many formats)
|
||||
relationships = item.get("relationships", {})
|
||||
providers_data, extract_errors = self._extract_providers_data(
|
||||
relationships, idx
|
||||
)
|
||||
if extract_errors:
|
||||
hard_errors.extend(extract_errors)
|
||||
continue
|
||||
|
||||
if not providers_data:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": "At least one provider is required.",
|
||||
"source": {
|
||||
"pointer": f"/data/{idx}/relationships/providers/data"
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Process each provider in the relationship
|
||||
for prov_idx, prov_data in enumerate(providers_data):
|
||||
provider_id = (
|
||||
prov_data.get("id") if isinstance(prov_data, dict) else None
|
||||
)
|
||||
provider_type = (
|
||||
prov_data.get("type") if isinstance(prov_data, dict) else None
|
||||
)
|
||||
|
||||
# Validate provider resource identifier
|
||||
if not provider_id:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": "Provider id is required.",
|
||||
"source": {
|
||||
"pointer": f"/data/{idx}/relationships/providers/data/{prov_idx}/id"
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if provider_type and provider_type != "providers":
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": f"Invalid type '{provider_type}'. Expected 'providers'.",
|
||||
"source": {
|
||||
"pointer": f"/data/{idx}/relationships/providers/data/{prov_idx}/type"
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check for duplicate provider in entire batch
|
||||
if provider_id in seen_providers:
|
||||
prev_idx, prev_prov_idx = seen_providers[provider_id]
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": f"Duplicate provider '{provider_id}' (first at data/{prev_idx}/providers/{prev_prov_idx}).",
|
||||
"source": {
|
||||
"pointer": f"/data/{idx}/relationships/providers/data/{prov_idx}"
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
seen_providers[provider_id] = (idx, prov_idx)
|
||||
|
||||
# Validate provider exists and belongs to tenant
|
||||
try:
|
||||
provider = Provider.objects.get(
|
||||
id=provider_id, tenant_id=tenant_id, is_deleted=False
|
||||
)
|
||||
except Provider.DoesNotExist:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": f"Provider '{provider_id}' not found.",
|
||||
"source": {
|
||||
"pointer": f"/data/{idx}/relationships/providers/data/{prov_idx}"
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Soft skip: provider already has a secret
|
||||
if ProviderSecret.objects.filter(
|
||||
provider_id=provider_id, tenant_id=tenant_id
|
||||
).exists():
|
||||
skipped_providers.append(
|
||||
{
|
||||
"provider_id": str(provider_id),
|
||||
"source_index": idx,
|
||||
"reason": "Provider already has a secret.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate secret attributes for this specific provider
|
||||
item_context = {**self.context, "provider": provider}
|
||||
item_serializer = ProviderSecretBatchItemSerializer(
|
||||
data=item_attrs, context=item_context
|
||||
)
|
||||
|
||||
if not item_serializer.is_valid():
|
||||
for field, field_errors in item_serializer.errors.items():
|
||||
pointer = f"/data/{idx}/attributes/{field}"
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": str(field_errors[0]),
|
||||
"source": {"pointer": pointer},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
validated_items.append(
|
||||
{
|
||||
"source_index": idx,
|
||||
**item_serializer.validated_data,
|
||||
"provider": provider,
|
||||
}
|
||||
)
|
||||
|
||||
# All-or-nothing: if any hard errors, fail the entire batch
|
||||
if hard_errors:
|
||||
raise ValidationError(hard_errors)
|
||||
|
||||
attrs["_validated_items"] = validated_items
|
||||
attrs["_skipped_providers"] = skipped_providers
|
||||
return attrs
|
||||
|
||||
|
||||
class ProviderSecretBatchUpdateItemSerializer(BaseWriteProviderSecretSerializer):
|
||||
"""Serializer for validating individual provider secret update items in batch."""
|
||||
|
||||
secret = ProviderSecretField(write_only=True, required=False)
|
||||
|
||||
class Meta:
|
||||
model = ProviderSecret
|
||||
fields = ["name", "secret_type", "secret"]
|
||||
extra_kwargs = {
|
||||
"name": {"required": False},
|
||||
"secret_type": {"required": False},
|
||||
}
|
||||
|
||||
def validate(self, attrs):
|
||||
provider = self.context.get("provider")
|
||||
secret_type = attrs.get("secret_type")
|
||||
secret = attrs.get("secret")
|
||||
|
||||
if provider and secret_type and secret:
|
||||
self.validate_secret_based_on_provider(
|
||||
provider.provider, secret_type, secret
|
||||
)
|
||||
elif provider and secret and not secret_type:
|
||||
existing_secret = ProviderSecret.objects.filter(provider=provider).first()
|
||||
if existing_secret:
|
||||
self.validate_secret_based_on_provider(
|
||||
provider.provider, existing_secret.secret_type, secret
|
||||
)
|
||||
return attrs
|
||||
|
||||
|
||||
class ProviderSecretBatchUpdateSerializer(BaseSerializerV1):
|
||||
"""
|
||||
Serializer for batch update of provider secrets.
|
||||
|
||||
JSON:API compliant with all-or-nothing semantics for validation errors.
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
resource_name = "provider-secrets"
|
||||
|
||||
def validate(self, attrs):
|
||||
data = self.initial_data.get("data", [])
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValidationError(
|
||||
{"data": "Must be an array of provider-secret objects"}
|
||||
)
|
||||
|
||||
if len(data) > settings.API_BATCH_MAX_SIZE:
|
||||
raise ValidationError(
|
||||
{"data": f"Maximum {settings.API_BATCH_MAX_SIZE} secrets per batch"}
|
||||
)
|
||||
|
||||
if len(data) == 0:
|
||||
raise ValidationError({"data": "At least one secret required"})
|
||||
|
||||
hard_errors = [] # Will cause full batch failure
|
||||
validated_items = []
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
seen_ids = {}
|
||||
|
||||
for idx, item in enumerate(data):
|
||||
item_type = item.get("type")
|
||||
item_id = item.get("id")
|
||||
|
||||
# Validate type
|
||||
if not item_type:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/type"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
elif item_type != "provider-secrets":
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": f"Invalid type '{item_type}'. Expected 'provider-secrets'.",
|
||||
"source": {"pointer": f"/data/{idx}/type"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate id
|
||||
if not item_id:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/id"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check for duplicate id in batch
|
||||
if item_id in seen_ids:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": f"Duplicate secret ID '{item_id}' (first at data/{seen_ids[item_id]}).",
|
||||
"source": {"pointer": f"/data/{idx}/id"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
seen_ids[item_id] = idx
|
||||
|
||||
# Validate secret exists and belongs to tenant
|
||||
try:
|
||||
provider_secret = ProviderSecret.objects.select_related("provider").get(
|
||||
id=item_id, tenant_id=tenant_id
|
||||
)
|
||||
except ProviderSecret.DoesNotExist:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": f"Provider secret '{item_id}' not found.",
|
||||
"source": {"pointer": f"/data/{idx}/id"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate attributes
|
||||
item_attrs = item.get("attributes", {})
|
||||
if not item_attrs:
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": "This field is required.",
|
||||
"source": {"pointer": f"/data/{idx}/attributes"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
item_context = {**self.context, "provider": provider_secret.provider}
|
||||
|
||||
item_serializer = ProviderSecretBatchUpdateItemSerializer(
|
||||
data=item_attrs, context=item_context
|
||||
)
|
||||
if not item_serializer.is_valid():
|
||||
for field, field_errors in item_serializer.errors.items():
|
||||
hard_errors.append(
|
||||
{
|
||||
"detail": str(field_errors[0]),
|
||||
"source": {"pointer": f"/data/{idx}/attributes/{field}"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
validated_items.append(
|
||||
{
|
||||
"source_index": idx,
|
||||
"provider_secret": provider_secret,
|
||||
**item_serializer.validated_data,
|
||||
}
|
||||
)
|
||||
|
||||
# All-or-nothing: if any hard errors, fail the entire batch
|
||||
if hard_errors:
|
||||
raise ValidationError(hard_errors)
|
||||
|
||||
attrs["_validated_items"] = validated_items
|
||||
return attrs
|
||||
|
||||
|
||||
# Invitations
|
||||
|
||||
|
||||
@@ -2303,6 +3138,22 @@ class CategoryOverviewSerializer(BaseSerializerV1):
|
||||
resource_name = "category-overviews"
|
||||
|
||||
|
||||
class ResourceGroupOverviewSerializer(BaseSerializerV1):
|
||||
"""Serializer for resource group overview aggregations."""
|
||||
|
||||
id = serializers.CharField(source="resource_group")
|
||||
total_findings = serializers.IntegerField()
|
||||
failed_findings = serializers.IntegerField()
|
||||
new_failed_findings = serializers.IntegerField()
|
||||
resources_count = serializers.IntegerField()
|
||||
severity = serializers.JSONField(
|
||||
help_text="Severity breakdown: {informational, low, medium, high, critical}"
|
||||
)
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "resource-group-overviews"
|
||||
|
||||
|
||||
class ComplianceWatchlistOverviewSerializer(BaseSerializerV1):
|
||||
"""Serializer for compliance watchlist overview with FAIL-dominant aggregation."""
|
||||
|
||||
@@ -3848,3 +4699,31 @@ class ThreatScoreSnapshotSerializer(RLSSerializer):
|
||||
if getattr(obj, "_aggregated", False):
|
||||
return "n/a"
|
||||
return str(obj.id)
|
||||
|
||||
|
||||
# Resource Events Serializers
|
||||
|
||||
|
||||
class ResourceEventSerializer(BaseSerializerV1):
|
||||
"""Serializer for resource events (CloudTrail modification history).
|
||||
|
||||
NOTE: drf-spectacular auto-generates fields[resource-events] sparse fieldsets
|
||||
parameter in the OpenAPI schema. This endpoint does not support sparse fieldsets.
|
||||
"""
|
||||
|
||||
id = serializers.CharField(source="event_id")
|
||||
event_time = serializers.DateTimeField()
|
||||
event_name = serializers.CharField()
|
||||
event_source = serializers.CharField()
|
||||
actor = serializers.CharField()
|
||||
actor_uid = serializers.CharField(allow_null=True, required=False)
|
||||
actor_type = serializers.CharField(allow_null=True, required=False)
|
||||
source_ip_address = serializers.CharField(allow_null=True, required=False)
|
||||
user_agent = serializers.CharField(allow_null=True, required=False)
|
||||
request_data = serializers.JSONField(allow_null=True, required=False)
|
||||
response_data = serializers.JSONField(allow_null=True, required=False)
|
||||
error_code = serializers.CharField(allow_null=True, required=False)
|
||||
error_message = serializers.CharField(allow_null=True, required=False)
|
||||
|
||||
class Meta:
|
||||
resource_name = "resource-events"
|
||||
|
||||
@@ -4,6 +4,7 @@ from drf_spectacular.views import SpectacularRedocView
|
||||
from rest_framework_nested import routers
|
||||
|
||||
from api.v1.views import (
|
||||
AttackPathsScanViewSet,
|
||||
ComplianceOverviewViewSet,
|
||||
CustomSAMLLoginView,
|
||||
CustomTokenObtainView,
|
||||
@@ -53,6 +54,9 @@ router.register(r"tenants", TenantViewSet, basename="tenant")
|
||||
router.register(r"providers", ProviderViewSet, basename="provider")
|
||||
router.register(r"provider-groups", ProviderGroupViewSet, basename="providergroup")
|
||||
router.register(r"scans", ScanViewSet, basename="scan")
|
||||
router.register(
|
||||
r"attack-paths-scans", AttackPathsScanViewSet, basename="attack-paths-scans"
|
||||
)
|
||||
router.register(r"tasks", TaskViewSet, basename="task")
|
||||
router.register(r"resources", ResourceViewSet, basename="resource")
|
||||
router.register(r"findings", FindingViewSet, basename="finding")
|
||||
@@ -107,6 +111,18 @@ urlpatterns = [
|
||||
ProviderSecretViewSet.as_view({"get": "list", "post": "create"}),
|
||||
name="providersecret-list",
|
||||
),
|
||||
path(
|
||||
"providers/secrets/batch",
|
||||
ProviderSecretViewSet.as_view(
|
||||
{"post": "batch_create", "patch": "batch_update"}
|
||||
),
|
||||
name="providersecret-batch",
|
||||
),
|
||||
path(
|
||||
"providers/batch",
|
||||
ProviderViewSet.as_view({"post": "batch_create", "patch": "batch_update"}),
|
||||
name="provider-batch",
|
||||
),
|
||||
path(
|
||||
"providers/secrets/<uuid:pk>",
|
||||
ProviderSecretViewSet.as_view(
|
||||
|
||||
+1071
-7
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
import warnings
|
||||
|
||||
from celery import Celery, Task
|
||||
|
||||
from config.env import env
|
||||
|
||||
# Suppress specific warnings from django-rest-auth: https://github.com/iMerica/dj-rest-auth/issues/684
|
||||
|
||||
@@ -276,7 +276,7 @@ FINDINGS_MAX_DAYS_IN_RANGE = env.int("DJANGO_FINDINGS_MAX_DAYS_IN_RANGE", 7)
|
||||
DJANGO_TMP_OUTPUT_DIRECTORY = env.str(
|
||||
"DJANGO_TMP_OUTPUT_DIRECTORY", "/tmp/prowler_api_output"
|
||||
)
|
||||
DJANGO_FINDINGS_BATCH_SIZE = env.str("DJANGO_FINDINGS_BATCH_SIZE", 1000)
|
||||
DJANGO_FINDINGS_BATCH_SIZE = env.int("DJANGO_FINDINGS_BATCH_SIZE", 1000)
|
||||
|
||||
DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET = env.str("DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET", "")
|
||||
DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID = env.str("DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID", "")
|
||||
@@ -293,6 +293,8 @@ SECURE_REFERRER_POLICY = "strict-origin-when-cross-origin"
|
||||
|
||||
DJANGO_DELETION_BATCH_SIZE = env.int("DJANGO_DELETION_BATCH_SIZE", 5000)
|
||||
|
||||
API_BATCH_MAX_SIZE = env.int("DJANGO_API_BATCH_MAX_SIZE", 100)
|
||||
|
||||
# SAML requirement
|
||||
CSRF_COOKIE_SECURE = True
|
||||
SESSION_COOKIE_SECURE = True
|
||||
|
||||
@@ -44,6 +44,12 @@ DATABASES = {
|
||||
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
|
||||
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
|
||||
},
|
||||
"neo4j": {
|
||||
"HOST": env.str("NEO4J_HOST", "neo4j"),
|
||||
"PORT": env.str("NEO4J_PORT", "7687"),
|
||||
"USER": env.str("NEO4J_USER", "neo4j"),
|
||||
"PASSWORD": env.str("NEO4J_PASSWORD", "neo4j_password"),
|
||||
},
|
||||
}
|
||||
|
||||
DATABASES["default"] = DATABASES["prowler_user"]
|
||||
|
||||
@@ -45,6 +45,12 @@ DATABASES = {
|
||||
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
|
||||
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
|
||||
},
|
||||
"neo4j": {
|
||||
"HOST": env.str("NEO4J_HOST"),
|
||||
"PORT": env.str("NEO4J_PORT"),
|
||||
"USER": env.str("NEO4J_USER"),
|
||||
"PASSWORD": env.str("NEO4J_PASSWORD"),
|
||||
},
|
||||
}
|
||||
|
||||
DATABASES["default"] = DATABASES["prowler_user"]
|
||||
|
||||
+219
-11
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from allauth.socialaccount.models import SocialLogin
|
||||
from django.conf import settings
|
||||
from django.db import connection as django_connection
|
||||
@@ -11,13 +14,14 @@ from django.urls import reverse
|
||||
from django_celery_results.models import TaskResult
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
from tasks.jobs.backfill import (
|
||||
backfill_resource_scan_summaries,
|
||||
backfill_scan_category_summaries,
|
||||
)
|
||||
|
||||
from api.attack_paths import (
|
||||
AttackPathsQueryDefinition,
|
||||
AttackPathsQueryParameterDefinition,
|
||||
)
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
AttackPathsScan,
|
||||
AttackSurfaceOverview,
|
||||
ComplianceOverview,
|
||||
ComplianceRequirementOverview,
|
||||
@@ -41,6 +45,7 @@ from api.models import (
|
||||
SAMLDomainIndex,
|
||||
Scan,
|
||||
ScanCategorySummary,
|
||||
ScanGroupSummary,
|
||||
ScanSummary,
|
||||
StateChoices,
|
||||
StatusChoices,
|
||||
@@ -54,6 +59,11 @@ from api.rls import Tenant
|
||||
from api.v1.serializers import TokenSerializer
|
||||
from prowler.lib.check.models import Severity
|
||||
from prowler.lib.outputs.finding import Status
|
||||
from tasks.jobs.backfill import (
|
||||
backfill_resource_scan_summaries,
|
||||
backfill_scan_category_summaries,
|
||||
backfill_scan_resource_group_summaries,
|
||||
)
|
||||
|
||||
TODAY = str(datetime.today().date())
|
||||
API_JSON_CONTENT_TYPE = "application/vnd.api+json"
|
||||
@@ -166,22 +176,20 @@ def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker, tenants_f
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
|
||||
def create_test_user_rbac_limited(django_db_setup, django_db_blocker, tenants_fixture):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing_limited",
|
||||
email="rbac_limited@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
tenant = tenants_fixture[0]
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
role=Membership.RoleChoices.OWNER,
|
||||
)
|
||||
Role.objects.create(
|
||||
role = Role.objects.create(
|
||||
name="limited",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=False,
|
||||
@@ -194,7 +202,7 @@ def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=Role.objects.get(name="limited"),
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
return user
|
||||
@@ -739,6 +747,7 @@ def resources_fixture(providers_fixture):
|
||||
region="us-east-1",
|
||||
service="ec2",
|
||||
type="prowler-test",
|
||||
groups=["compute"],
|
||||
)
|
||||
|
||||
resource1.upsert_or_delete_tags(tags)
|
||||
@@ -751,6 +760,7 @@ def resources_fixture(providers_fixture):
|
||||
region="eu-west-1",
|
||||
service="s3",
|
||||
type="prowler-test",
|
||||
groups=["storage"],
|
||||
)
|
||||
resource2.upsert_or_delete_tags(tags)
|
||||
|
||||
@@ -762,6 +772,7 @@ def resources_fixture(providers_fixture):
|
||||
region="us-east-1",
|
||||
service="ec2",
|
||||
type="test",
|
||||
groups=["compute"],
|
||||
)
|
||||
|
||||
tags = [
|
||||
@@ -1234,7 +1245,7 @@ def lighthouse_config_fixture(authenticated_client, tenants_fixture):
|
||||
return LighthouseConfiguration.objects.create(
|
||||
tenant_id=tenants_fixture[0].id,
|
||||
name="OpenAI",
|
||||
api_key_decoded="sk-test1234567890T3BlbkFJtest1234567890",
|
||||
api_key_decoded="sk-fake-test-key-for-unit-testing-only",
|
||||
model="gpt-4o",
|
||||
temperature=0,
|
||||
max_tokens=4000,
|
||||
@@ -1383,11 +1394,13 @@ def latest_scan_finding_with_categories(
|
||||
check_id="genai_iam_check",
|
||||
check_metadata={"CheckId": "genai_iam_check"},
|
||||
categories=["gen-ai", "iam"],
|
||||
resource_groups="ai_ml",
|
||||
first_seen_at="2024-01-02T00:00:00Z",
|
||||
)
|
||||
finding.add_resources([resource])
|
||||
backfill_resource_scan_summaries(tenant_id, str(scan.id))
|
||||
backfill_scan_category_summaries(tenant_id, str(scan.id))
|
||||
backfill_scan_resource_group_summaries(tenant_id, str(scan.id))
|
||||
return finding
|
||||
|
||||
|
||||
@@ -1590,6 +1603,104 @@ def mute_rules_fixture(tenants_fixture, create_test_user, findings_fixture):
|
||||
return mute_rule1, mute_rule2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_attack_paths_scan():
|
||||
"""Factory fixture to create Attack Paths scans for tests."""
|
||||
|
||||
def _create(
|
||||
provider,
|
||||
*,
|
||||
scan=None,
|
||||
state=StateChoices.COMPLETED,
|
||||
progress=0,
|
||||
graph_database="tenant-db",
|
||||
**extra_fields,
|
||||
):
|
||||
scan_instance = scan or Scan.objects.create(
|
||||
name=extra_fields.pop("scan_name", "Attack Paths Supporting Scan"),
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=extra_fields.pop("scan_state", StateChoices.COMPLETED),
|
||||
tenant_id=provider.tenant_id,
|
||||
)
|
||||
|
||||
payload = {
|
||||
"tenant_id": provider.tenant_id,
|
||||
"provider": provider,
|
||||
"scan": scan_instance,
|
||||
"state": state,
|
||||
"progress": progress,
|
||||
"graph_database": graph_database,
|
||||
}
|
||||
payload.update(extra_fields)
|
||||
|
||||
return AttackPathsScan.objects.create(**payload)
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def attack_paths_query_definition_factory():
|
||||
"""Factory fixture for building Attack Paths query definitions."""
|
||||
|
||||
def _create(**overrides):
|
||||
cast_type = overrides.pop("cast_type", str)
|
||||
parameters = overrides.pop(
|
||||
"parameters",
|
||||
[
|
||||
AttackPathsQueryParameterDefinition(
|
||||
name="limit",
|
||||
label="Limit",
|
||||
cast=cast_type,
|
||||
)
|
||||
],
|
||||
)
|
||||
definition_payload = {
|
||||
"id": "aws-test",
|
||||
"name": "Attack Paths Test Query",
|
||||
"description": "Synthetic Attack Paths definition for tests.",
|
||||
"provider": "aws",
|
||||
"cypher": "RETURN 1",
|
||||
"parameters": parameters,
|
||||
}
|
||||
definition_payload.update(overrides)
|
||||
return AttackPathsQueryDefinition(**definition_payload)
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def attack_paths_graph_stub_classes():
|
||||
"""Provide lightweight graph element stubs for Attack Paths serialization tests."""
|
||||
|
||||
class AttackPathsNativeValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def to_native(self):
|
||||
return self._value
|
||||
|
||||
class AttackPathsNode:
|
||||
def __init__(self, element_id, labels, properties):
|
||||
self.element_id = element_id
|
||||
self.labels = labels
|
||||
self._properties = properties
|
||||
|
||||
class AttackPathsRelationship:
|
||||
def __init__(self, element_id, rel_type, start_node, end_node, properties):
|
||||
self.element_id = element_id
|
||||
self.type = rel_type
|
||||
self.start_node = start_node
|
||||
self.end_node = end_node
|
||||
self._properties = properties
|
||||
|
||||
return SimpleNamespace(
|
||||
NativeValue=AttackPathsNativeValue,
|
||||
Node=AttackPathsNode,
|
||||
Relationship=AttackPathsRelationship,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_attack_surface_overview():
|
||||
def _create(tenant, scan, attack_surface_type, total=10, failed=5, muted_failed=2):
|
||||
@@ -1629,6 +1740,103 @@ def create_scan_category_summary():
|
||||
return _create
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def findings_with_group(scans_fixture, resources_fixture):
|
||||
scan = scans_fixture[0]
|
||||
resource = resources_fixture[0]
|
||||
|
||||
finding = Finding.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
uid="finding_with_group_1",
|
||||
scan=scan,
|
||||
delta=None,
|
||||
status=Status.FAIL,
|
||||
status_extended="test status",
|
||||
impact=Severity.critical,
|
||||
impact_extended="test impact",
|
||||
severity=Severity.critical,
|
||||
raw_result={"status": Status.FAIL},
|
||||
check_id="storage_check",
|
||||
check_metadata={"CheckId": "storage_check"},
|
||||
resource_groups="storage",
|
||||
first_seen_at="2024-01-02T00:00:00Z",
|
||||
)
|
||||
finding.add_resources([resource])
|
||||
backfill_resource_scan_summaries(str(scan.tenant_id), str(scan.id))
|
||||
return finding
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def findings_with_multiple_groups(scans_fixture, resources_fixture):
|
||||
scan = scans_fixture[0]
|
||||
resource1, resource2 = resources_fixture[:2]
|
||||
|
||||
finding1 = Finding.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
uid="finding_multi_grp_1",
|
||||
scan=scan,
|
||||
delta=None,
|
||||
status=Status.FAIL,
|
||||
status_extended="test status",
|
||||
impact=Severity.critical,
|
||||
impact_extended="test impact",
|
||||
severity=Severity.critical,
|
||||
raw_result={"status": Status.FAIL},
|
||||
check_id="storage_check",
|
||||
check_metadata={"CheckId": "storage_check"},
|
||||
resource_groups="storage",
|
||||
first_seen_at="2024-01-02T00:00:00Z",
|
||||
)
|
||||
finding1.add_resources([resource1])
|
||||
|
||||
finding2 = Finding.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
uid="finding_multi_grp_2",
|
||||
scan=scan,
|
||||
delta=None,
|
||||
status=Status.FAIL,
|
||||
status_extended="test status 2",
|
||||
impact=Severity.high,
|
||||
impact_extended="test impact 2",
|
||||
severity=Severity.high,
|
||||
raw_result={"status": Status.FAIL},
|
||||
check_id="security_check",
|
||||
check_metadata={"CheckId": "security_check"},
|
||||
resource_groups="security",
|
||||
first_seen_at="2024-01-02T00:00:00Z",
|
||||
)
|
||||
finding2.add_resources([resource2])
|
||||
|
||||
backfill_resource_scan_summaries(str(scan.tenant_id), str(scan.id))
|
||||
return finding1, finding2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_scan_resource_group_summary():
|
||||
def _create(
|
||||
tenant,
|
||||
scan,
|
||||
resource_group,
|
||||
severity,
|
||||
total_findings=10,
|
||||
failed_findings=5,
|
||||
new_failed_findings=2,
|
||||
resources_count=3,
|
||||
):
|
||||
return ScanGroupSummary.objects.create(
|
||||
tenant=tenant,
|
||||
scan=scan,
|
||||
resource_group=resource_group,
|
||||
severity=severity,
|
||||
total_findings=total_findings,
|
||||
failed_findings=failed_findings,
|
||||
new_failed_findings=new_failed_findings,
|
||||
resources_count=resources_count,
|
||||
)
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
def get_authorization_header(access_token: str) -> dict:
|
||||
return {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from tasks.tasks import perform_scheduled_scan_task
|
||||
from api.db_utils import rls_transaction
|
||||
from api.exceptions import ConflictException
|
||||
from api.models import Provider, Scan, StateChoices
|
||||
from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils
|
||||
|
||||
|
||||
def schedule_provider_scan(provider_instance: Provider):
|
||||
@@ -39,6 +40,12 @@ def schedule_provider_scan(provider_instance: Provider):
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
attack_paths_db_utils.create_attack_paths_scan(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=str(scheduled_scan.id),
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Schedule the task
|
||||
periodic_task_instance = PeriodicTask.objects.create(
|
||||
interval=schedule,
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from tasks.jobs.attack_paths.db_utils import can_provider_run_attack_paths_scan
|
||||
from tasks.jobs.attack_paths.scan import run as attack_paths_scan
|
||||
|
||||
__all__ = [
|
||||
"attack_paths_scan",
|
||||
"can_provider_run_attack_paths_scan",
|
||||
]
|
||||
@@ -0,0 +1,253 @@
|
||||
# Portions of this file are based on code from the Cartography project
|
||||
# (https://github.com/cartography-cncf/cartography), which is licensed under the Apache 2.0 License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import aioboto3
|
||||
import boto3
|
||||
import neo4j
|
||||
|
||||
from cartography.config import Config as CartographyConfig
|
||||
from cartography.intel import aws as cartography_aws
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from api.models import (
|
||||
AttackPathsScan as ProwlerAPIAttackPathsScan,
|
||||
Provider as ProwlerAPIProvider,
|
||||
)
|
||||
from prowler.providers.common.provider import Provider as ProwlerSDKProvider
|
||||
from tasks.jobs.attack_paths import db_utils, utils
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
def start_aws_ingestion(
|
||||
neo4j_session: neo4j.Session,
|
||||
cartography_config: CartographyConfig,
|
||||
prowler_api_provider: ProwlerAPIProvider,
|
||||
prowler_sdk_provider: ProwlerSDKProvider,
|
||||
attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Code based on Cartography version 0.122.0, specifically on `cartography.intel.aws.__init__.py`.
|
||||
|
||||
For the scan progress updates:
|
||||
- The caller of this function (`tasks.jobs.attack_paths.scan.run`) has set it to 2.
|
||||
- When the control returns to the caller, it will be set to 95.
|
||||
"""
|
||||
|
||||
# Initialize variables common to all jobs
|
||||
common_job_parameters = {
|
||||
"UPDATE_TAG": cartography_config.update_tag,
|
||||
"permission_relationships_file": cartography_config.permission_relationships_file,
|
||||
"aws_guardduty_severity_threshold": cartography_config.aws_guardduty_severity_threshold,
|
||||
"aws_cloudtrail_management_events_lookback_hours": cartography_config.aws_cloudtrail_management_events_lookback_hours,
|
||||
"experimental_aws_inspector_batch": cartography_config.experimental_aws_inspector_batch,
|
||||
}
|
||||
|
||||
boto3_session = get_boto3_session(prowler_api_provider, prowler_sdk_provider)
|
||||
regions: list[str] = list(prowler_sdk_provider._enabled_regions)
|
||||
requested_syncs = list(cartography_aws.RESOURCE_FUNCTIONS.keys())
|
||||
|
||||
sync_args = cartography_aws._build_aws_sync_kwargs(
|
||||
neo4j_session,
|
||||
boto3_session,
|
||||
regions,
|
||||
prowler_api_provider.uid,
|
||||
cartography_config.update_tag,
|
||||
common_job_parameters,
|
||||
)
|
||||
|
||||
# Starting with sync functions
|
||||
logger.info(f"Syncing organizations for AWS account {prowler_api_provider.uid}")
|
||||
cartography_aws.organizations.sync(
|
||||
neo4j_session,
|
||||
{prowler_api_provider.alias: prowler_api_provider.uid},
|
||||
cartography_config.update_tag,
|
||||
common_job_parameters,
|
||||
)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 3)
|
||||
|
||||
# Adding an extra field
|
||||
common_job_parameters["AWS_ID"] = prowler_api_provider.uid
|
||||
|
||||
cartography_aws._autodiscover_accounts(
|
||||
neo4j_session,
|
||||
boto3_session,
|
||||
prowler_api_provider.uid,
|
||||
cartography_config.update_tag,
|
||||
common_job_parameters,
|
||||
)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 4)
|
||||
|
||||
failed_syncs = sync_aws_account(
|
||||
prowler_api_provider, requested_syncs, sync_args, attack_paths_scan
|
||||
)
|
||||
|
||||
if "permission_relationships" in requested_syncs:
|
||||
logger.info(
|
||||
f"Syncing function permission_relationships for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"](**sync_args)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 88)
|
||||
|
||||
if "resourcegroupstaggingapi" in requested_syncs:
|
||||
logger.info(
|
||||
f"Syncing function resourcegroupstaggingapi for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
cartography_aws.RESOURCE_FUNCTIONS["resourcegroupstaggingapi"](**sync_args)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 89)
|
||||
|
||||
logger.info(
|
||||
f"Syncing ec2_iaminstanceprofile scoped analysis for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
cartography_aws.run_scoped_analysis_job(
|
||||
"aws_ec2_iaminstanceprofile.json",
|
||||
neo4j_session,
|
||||
common_job_parameters,
|
||||
)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 90)
|
||||
|
||||
logger.info(
|
||||
f"Syncing lambda_ecr analysis for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
cartography_aws.run_analysis_job(
|
||||
"aws_lambda_ecr.json",
|
||||
neo4j_session,
|
||||
common_job_parameters,
|
||||
)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 91)
|
||||
|
||||
logger.info(f"Syncing metadata for AWS account {prowler_api_provider.uid}")
|
||||
cartography_aws.merge_module_sync_metadata(
|
||||
neo4j_session,
|
||||
group_type="AWSAccount",
|
||||
group_id=prowler_api_provider.uid,
|
||||
synced_type="AWSAccount",
|
||||
update_tag=cartography_config.update_tag,
|
||||
stat_handler=cartography_aws.stat_handler,
|
||||
)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 92)
|
||||
|
||||
# Removing the added extra field
|
||||
del common_job_parameters["AWS_ID"]
|
||||
|
||||
logger.info(f"Syncing cleanup_job for AWS account {prowler_api_provider.uid}")
|
||||
cartography_aws.run_cleanup_job(
|
||||
"aws_post_ingestion_principals_cleanup.json",
|
||||
neo4j_session,
|
||||
common_job_parameters,
|
||||
)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 93)
|
||||
|
||||
logger.info(f"Syncing analysis for AWS account {prowler_api_provider.uid}")
|
||||
cartography_aws._perform_aws_analysis(
|
||||
requested_syncs, neo4j_session, common_job_parameters
|
||||
)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 94)
|
||||
|
||||
return failed_syncs
|
||||
|
||||
|
||||
def get_boto3_session(
|
||||
prowler_api_provider: ProwlerAPIProvider, prowler_sdk_provider: ProwlerSDKProvider
|
||||
) -> boto3.Session:
|
||||
boto3_session = prowler_sdk_provider.session.current_session
|
||||
|
||||
aws_accounts_from_session = cartography_aws.organizations.get_aws_account_default(
|
||||
boto3_session
|
||||
)
|
||||
if not aws_accounts_from_session:
|
||||
raise Exception(
|
||||
"No valid AWS credentials could be found. No AWS accounts can be synced."
|
||||
)
|
||||
|
||||
aws_account_id_from_session = list(aws_accounts_from_session.values())[0]
|
||||
if prowler_api_provider.uid != aws_account_id_from_session:
|
||||
raise Exception(
|
||||
f"Provider {prowler_api_provider.uid} doesn't match AWS account {aws_account_id_from_session}."
|
||||
)
|
||||
|
||||
if boto3_session.region_name is None:
|
||||
global_region = prowler_sdk_provider.get_global_region()
|
||||
boto3_session._session.set_config_variable("region", global_region)
|
||||
|
||||
return boto3_session
|
||||
|
||||
|
||||
def get_aioboto3_session(boto3_session: boto3.Session) -> aioboto3.Session:
|
||||
return aioboto3.Session(botocore_session=boto3_session._session)
|
||||
|
||||
|
||||
def sync_aws_account(
|
||||
prowler_api_provider: ProwlerAPIProvider,
|
||||
requested_syncs: list[str],
|
||||
sync_args: dict[str, Any],
|
||||
attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
) -> dict[str, str]:
|
||||
current_progress = 4 # `cartography_aws._autodiscover_accounts`
|
||||
max_progress = (
|
||||
87 # `cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"]` - 1
|
||||
)
|
||||
n_steps = (
|
||||
len(requested_syncs) - 2
|
||||
) # Excluding `permission_relationships` and `resourcegroupstaggingapi`
|
||||
progress_step = (max_progress - current_progress) / n_steps
|
||||
|
||||
failed_syncs = {}
|
||||
|
||||
for func_name in requested_syncs:
|
||||
if func_name in cartography_aws.RESOURCE_FUNCTIONS:
|
||||
logger.info(
|
||||
f"Syncing function {func_name} for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
|
||||
# Updating progress, not really the right place but good enough
|
||||
current_progress += progress_step
|
||||
db_utils.update_attack_paths_scan_progress(
|
||||
attack_paths_scan, int(current_progress)
|
||||
)
|
||||
|
||||
try:
|
||||
# `ecr:image_layers` uses `aioboto3_session` instead of `boto3_session`
|
||||
if func_name == "ecr:image_layers":
|
||||
cartography_aws.RESOURCE_FUNCTIONS[func_name](
|
||||
neo4j_session=sync_args.get("neo4j_session"),
|
||||
aioboto3_session=get_aioboto3_session(
|
||||
sync_args.get("boto3_session")
|
||||
),
|
||||
regions=sync_args.get("regions"),
|
||||
current_aws_account_id=sync_args.get("current_aws_account_id"),
|
||||
update_tag=sync_args.get("update_tag"),
|
||||
common_job_parameters=sync_args.get("common_job_parameters"),
|
||||
)
|
||||
|
||||
# Skip permission relationships and tags for now because they rely on data already being in the graph
|
||||
elif func_name in [
|
||||
"permission_relationships",
|
||||
"resourcegroupstaggingapi",
|
||||
]:
|
||||
continue
|
||||
|
||||
else:
|
||||
cartography_aws.RESOURCE_FUNCTIONS[func_name](**sync_args)
|
||||
|
||||
except Exception as e:
|
||||
exception_message = utils.stringify_exception(
|
||||
e, f"Exception for AWS sync function: {func_name}"
|
||||
)
|
||||
failed_syncs[func_name] = exception_message
|
||||
|
||||
logger.warning(
|
||||
f"Caught exception syncing function {func_name} from AWS account {prowler_api_provider.uid}. We "
|
||||
"are continuing on to the next AWS sync function.",
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f'AWS sync function "{func_name}" was specified but does not exist. Did you misspell it?'
|
||||
)
|
||||
|
||||
return failed_syncs
|
||||
@@ -0,0 +1,168 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import Q
|
||||
from cartography.config import Config as CartographyConfig
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
AttackPathsScan as ProwlerAPIAttackPathsScan,
|
||||
Provider as ProwlerAPIProvider,
|
||||
StateChoices,
|
||||
)
|
||||
from tasks.jobs.attack_paths.providers import is_provider_available
|
||||
|
||||
|
||||
def can_provider_run_attack_paths_scan(tenant_id: str, provider_id: int) -> bool:
|
||||
with rls_transaction(tenant_id):
|
||||
prowler_api_provider = ProwlerAPIProvider.objects.get(id=provider_id)
|
||||
|
||||
return is_provider_available(prowler_api_provider.provider)
|
||||
|
||||
|
||||
def create_attack_paths_scan(
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
provider_id: int,
|
||||
) -> ProwlerAPIAttackPathsScan | None:
|
||||
if not can_provider_run_attack_paths_scan(tenant_id, provider_id):
|
||||
return None
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scan_id=scan_id,
|
||||
state=StateChoices.SCHEDULED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
attack_paths_scan.save()
|
||||
|
||||
return attack_paths_scan
|
||||
|
||||
|
||||
def retrieve_attack_paths_scan(
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
) -> ProwlerAPIAttackPathsScan | None:
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.get(
|
||||
scan_id=scan_id,
|
||||
)
|
||||
|
||||
return attack_paths_scan
|
||||
|
||||
except ProwlerAPIAttackPathsScan.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def starting_attack_paths_scan(
|
||||
attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
task_id: str,
|
||||
cartography_config: CartographyConfig,
|
||||
) -> None:
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
attack_paths_scan.task_id = task_id
|
||||
attack_paths_scan.state = StateChoices.EXECUTING
|
||||
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
|
||||
attack_paths_scan.update_tag = cartography_config.update_tag
|
||||
attack_paths_scan.graph_database = cartography_config.neo4j_database
|
||||
|
||||
attack_paths_scan.save(
|
||||
update_fields=[
|
||||
"task_id",
|
||||
"state",
|
||||
"started_at",
|
||||
"update_tag",
|
||||
"graph_database",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def finish_attack_paths_scan(
|
||||
attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
state: StateChoices,
|
||||
ingestion_exceptions: dict[str, Any],
|
||||
) -> None:
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
duration = int((now - attack_paths_scan.started_at).total_seconds())
|
||||
|
||||
attack_paths_scan.state = state
|
||||
attack_paths_scan.progress = 100
|
||||
attack_paths_scan.completed_at = now
|
||||
attack_paths_scan.duration = duration
|
||||
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
|
||||
|
||||
attack_paths_scan.save(
|
||||
update_fields=[
|
||||
"state",
|
||||
"progress",
|
||||
"completed_at",
|
||||
"duration",
|
||||
"ingestion_exceptions",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def update_attack_paths_scan_progress(
|
||||
attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
progress: int,
|
||||
) -> None:
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
attack_paths_scan.progress = progress
|
||||
attack_paths_scan.save(update_fields=["progress"])
|
||||
|
||||
|
||||
def get_old_attack_paths_scans(
|
||||
tenant_id: str,
|
||||
provider_id: str,
|
||||
attack_paths_scan_id: str,
|
||||
) -> list[ProwlerAPIAttackPathsScan]:
|
||||
"""
|
||||
An `old_attack_paths_scan` is any `completed` Attack Paths scan for the same provider,
|
||||
with its graph database not deleted, excluding the current Attack Paths scan.
|
||||
"""
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
completed_scans_qs = (
|
||||
ProwlerAPIAttackPathsScan.objects.filter(
|
||||
provider_id=provider_id,
|
||||
state=StateChoices.COMPLETED,
|
||||
is_graph_database_deleted=False,
|
||||
)
|
||||
.exclude(id=attack_paths_scan_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return list(completed_scans_qs)
|
||||
|
||||
|
||||
def update_old_attack_paths_scan(
|
||||
old_attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
) -> None:
|
||||
with rls_transaction(old_attack_paths_scan.tenant_id):
|
||||
old_attack_paths_scan.is_graph_database_deleted = True
|
||||
old_attack_paths_scan.save(update_fields=["is_graph_database_deleted"])
|
||||
|
||||
|
||||
def get_provider_graph_database_names(tenant_id: str, provider_id: str) -> list[str]:
|
||||
"""
|
||||
Return existing graph database names for a tenant/provider.
|
||||
|
||||
Note: For accesing the `AttackPathsScan` we need to use `all_objects` manager because the provider is soft-deleted.
|
||||
"""
|
||||
with rls_transaction(tenant_id):
|
||||
graph_databases_names_qs = (
|
||||
ProwlerAPIAttackPathsScan.all_objects.filter(
|
||||
~Q(graph_database=""),
|
||||
graph_database__isnull=False,
|
||||
provider_id=provider_id,
|
||||
is_graph_database_deleted=False,
|
||||
)
|
||||
.values_list("graph_database", flat=True)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
return list(graph_databases_names_qs)
|
||||
@@ -0,0 +1,23 @@
|
||||
AVAILABLE_PROVIDERS: list[str] = [
|
||||
"aws",
|
||||
]
|
||||
|
||||
ROOT_NODE_LABELS: dict[str, str] = {
|
||||
"aws": "AWSAccount",
|
||||
}
|
||||
|
||||
NODE_UID_FIELDS: dict[str, str] = {
|
||||
"aws": "arn",
|
||||
}
|
||||
|
||||
|
||||
def is_provider_available(provider_type: str) -> bool:
|
||||
return provider_type in AVAILABLE_PROVIDERS
|
||||
|
||||
|
||||
def get_root_node_label(provider_type: str) -> str:
|
||||
return ROOT_NODE_LABELS.get(provider_type, "UnknownProviderAccount")
|
||||
|
||||
|
||||
def get_node_uid_field(provider_type: str) -> str:
|
||||
return NODE_UID_FIELDS.get(provider_type, "UnknownProviderUID")
|
||||
@@ -0,0 +1,290 @@
|
||||
from collections import defaultdict
|
||||
from typing import Generator
|
||||
|
||||
import neo4j
|
||||
from cartography.client.core.tx import run_write_query
|
||||
from cartography.config import Config as CartographyConfig
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.env import env
|
||||
from tasks.jobs.attack_paths.providers import get_node_uid_field, get_root_node_label
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Finding, Provider, ResourceFindingMapping
|
||||
from prowler.config import config as ProwlerConfig
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
BATCH_SIZE = env.int("ATTACK_PATHS_FINDINGS_BATCH_SIZE", 1000)
|
||||
|
||||
INDEX_STATEMENTS = [
|
||||
"CREATE INDEX prowler_finding_id IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.id);",
|
||||
"CREATE INDEX prowler_finding_provider_uid IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.provider_uid);",
|
||||
"CREATE INDEX prowler_finding_lastupdated IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.lastupdated);",
|
||||
"CREATE INDEX prowler_finding_check_id IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.status);",
|
||||
]
|
||||
|
||||
INSERT_STATEMENT_TEMPLATE = """
|
||||
MATCH (account:__ROOT_NODE_LABEL__ {id: $provider_uid})
|
||||
UNWIND $findings_data AS finding_data
|
||||
|
||||
OPTIONAL MATCH (account)-->(resource_by_uid)
|
||||
WHERE resource_by_uid.__NODE_UID_FIELD__ = finding_data.resource_uid
|
||||
WITH account, finding_data, resource_by_uid
|
||||
|
||||
OPTIONAL MATCH (account)-->(resource_by_id)
|
||||
WHERE resource_by_uid IS NULL
|
||||
AND resource_by_id.id = finding_data.resource_uid
|
||||
WITH account, finding_data, COALESCE(resource_by_uid, resource_by_id) AS resource
|
||||
WHERE resource IS NOT NULL
|
||||
|
||||
MERGE (finding:ProwlerFinding {id: finding_data.id})
|
||||
ON CREATE SET
|
||||
finding.id = finding_data.id,
|
||||
finding.uid = finding_data.uid,
|
||||
finding.inserted_at = finding_data.inserted_at,
|
||||
finding.updated_at = finding_data.updated_at,
|
||||
finding.first_seen_at = finding_data.first_seen_at,
|
||||
finding.scan_id = finding_data.scan_id,
|
||||
finding.delta = finding_data.delta,
|
||||
finding.status = finding_data.status,
|
||||
finding.status_extended = finding_data.status_extended,
|
||||
finding.severity = finding_data.severity,
|
||||
finding.check_id = finding_data.check_id,
|
||||
finding.check_title = finding_data.check_title,
|
||||
finding.muted = finding_data.muted,
|
||||
finding.muted_reason = finding_data.muted_reason,
|
||||
finding.provider_uid = $provider_uid,
|
||||
finding.firstseen = timestamp(),
|
||||
finding.lastupdated = $last_updated,
|
||||
finding._module_name = 'cartography:prowler',
|
||||
finding._module_version = $prowler_version
|
||||
ON MATCH SET
|
||||
finding.status = finding_data.status,
|
||||
finding.status_extended = finding_data.status_extended,
|
||||
finding.lastupdated = $last_updated
|
||||
|
||||
MERGE (resource)-[rel:HAS_FINDING]->(finding)
|
||||
ON CREATE SET
|
||||
rel.provider_uid = $provider_uid,
|
||||
rel.firstseen = timestamp(),
|
||||
rel.lastupdated = $last_updated,
|
||||
rel._module_name = 'cartography:prowler',
|
||||
rel._module_version = $prowler_version
|
||||
ON MATCH SET
|
||||
rel.lastupdated = $last_updated
|
||||
"""
|
||||
|
||||
CLEANUP_STATEMENT = """
|
||||
MATCH (finding:ProwlerFinding {provider_uid: $provider_uid})
|
||||
WHERE finding.lastupdated < $last_updated
|
||||
|
||||
WITH finding LIMIT $batch_size
|
||||
|
||||
DETACH DELETE finding
|
||||
|
||||
RETURN COUNT(finding) AS deleted_findings_count
|
||||
"""
|
||||
|
||||
|
||||
def create_indexes(neo4j_session: neo4j.Session) -> None:
|
||||
"""
|
||||
Code based on Cartography version 0.122.0, specifically on `cartography.intel.create_indexes.run`.
|
||||
"""
|
||||
|
||||
logger.info("Creating indexes for Prowler Findings node types")
|
||||
for statement in INDEX_STATEMENTS:
|
||||
run_write_query(neo4j_session, statement)
|
||||
|
||||
|
||||
def analysis(
|
||||
neo4j_session: neo4j.Session,
|
||||
prowler_api_provider: Provider,
|
||||
scan_id: str,
|
||||
config: CartographyConfig,
|
||||
) -> None:
|
||||
findings_data = get_provider_last_scan_findings(prowler_api_provider, scan_id)
|
||||
load_findings(neo4j_session, findings_data, prowler_api_provider, config)
|
||||
cleanup_findings(neo4j_session, prowler_api_provider, config)
|
||||
|
||||
|
||||
def get_provider_last_scan_findings(
|
||||
prowler_api_provider: Provider,
|
||||
scan_id: str,
|
||||
) -> Generator[list[dict[str, str]], None, None]:
|
||||
"""
|
||||
Generator that yields batches of finding-resource pairs.
|
||||
|
||||
Two-step query approach per batch:
|
||||
1. Paginate findings for scan (single table, indexed by scan_id)
|
||||
2. Batch-fetch resource UIDs via mapping table (single join)
|
||||
3. Merge and yield flat structure for Neo4j
|
||||
|
||||
Memory efficient: never holds more than BATCH_SIZE findings in memory.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"Starting findings fetch for scan {scan_id} (tenant {prowler_api_provider.tenant_id}) with batch size {BATCH_SIZE}"
|
||||
)
|
||||
|
||||
iteration = 0
|
||||
last_id = None
|
||||
|
||||
while True:
|
||||
iteration += 1
|
||||
|
||||
with rls_transaction(prowler_api_provider.tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Use all_objects to avoid the ActiveProviderManager's implicit JOIN
|
||||
# through Scan -> Provider (to check is_deleted=False).
|
||||
# The provider is already validated as active in this context.
|
||||
qs = Finding.all_objects.filter(scan_id=scan_id).order_by("id")
|
||||
if last_id is not None:
|
||||
qs = qs.filter(id__gt=last_id)
|
||||
|
||||
findings_batch = list(
|
||||
qs.values(
|
||||
"id",
|
||||
"uid",
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
"first_seen_at",
|
||||
"scan_id",
|
||||
"delta",
|
||||
"status",
|
||||
"status_extended",
|
||||
"severity",
|
||||
"check_id",
|
||||
"check_metadata__checktitle",
|
||||
"muted",
|
||||
"muted_reason",
|
||||
)[:BATCH_SIZE]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Iteration #{iteration} fetched {len(findings_batch)} findings"
|
||||
)
|
||||
|
||||
if not findings_batch:
|
||||
logger.info(
|
||||
f"No findings returned for iteration #{iteration}; stopping pagination"
|
||||
)
|
||||
break
|
||||
|
||||
last_id = findings_batch[-1]["id"]
|
||||
enriched_batch = _enrich_and_flatten_batch(findings_batch)
|
||||
|
||||
# Yield outside the transaction
|
||||
if enriched_batch:
|
||||
yield enriched_batch
|
||||
|
||||
logger.info(f"Finished fetching findings for scan {scan_id}")
|
||||
|
||||
|
||||
def _enrich_and_flatten_batch(
|
||||
findings_batch: list[dict],
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Fetch resource UIDs for a batch of findings and return flat structure.
|
||||
|
||||
One finding with 3 resources becomes 3 dicts (same output format as before).
|
||||
Must be called within an RLS transaction context.
|
||||
"""
|
||||
finding_ids = [f["id"] for f in findings_batch]
|
||||
|
||||
# Single join: mapping -> resource
|
||||
resource_mappings = ResourceFindingMapping.objects.filter(
|
||||
finding_id__in=finding_ids
|
||||
).values_list("finding_id", "resource__uid")
|
||||
|
||||
# Build finding_id -> [resource_uids] mapping
|
||||
finding_resources = defaultdict(list)
|
||||
for finding_id, resource_uid in resource_mappings:
|
||||
finding_resources[finding_id].append(resource_uid)
|
||||
|
||||
# Flatten: one dict per (finding, resource) pair
|
||||
results = []
|
||||
for f in findings_batch:
|
||||
resource_uids = finding_resources.get(f["id"], [])
|
||||
|
||||
if not resource_uids:
|
||||
continue
|
||||
|
||||
for resource_uid in resource_uids:
|
||||
results.append(
|
||||
{
|
||||
"resource_uid": str(resource_uid),
|
||||
"id": str(f["id"]),
|
||||
"uid": f["uid"],
|
||||
"inserted_at": f["inserted_at"],
|
||||
"updated_at": f["updated_at"],
|
||||
"first_seen_at": f["first_seen_at"],
|
||||
"scan_id": str(f["scan_id"]),
|
||||
"delta": f["delta"],
|
||||
"status": f["status"],
|
||||
"status_extended": f["status_extended"],
|
||||
"severity": f["severity"],
|
||||
"check_id": str(f["check_id"]),
|
||||
"check_title": f["check_metadata__checktitle"],
|
||||
"muted": f["muted"],
|
||||
"muted_reason": f["muted_reason"],
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_findings(
|
||||
neo4j_session: neo4j.Session,
|
||||
findings_batches: Generator[list[dict[str, str]], None, None],
|
||||
prowler_api_provider: Provider,
|
||||
config: CartographyConfig,
|
||||
) -> None:
|
||||
replacements = {
|
||||
"__ROOT_NODE_LABEL__": get_root_node_label(prowler_api_provider.provider),
|
||||
"__NODE_UID_FIELD__": get_node_uid_field(prowler_api_provider.provider),
|
||||
}
|
||||
query = INSERT_STATEMENT_TEMPLATE
|
||||
for replace_key, replace_value in replacements.items():
|
||||
query = query.replace(replace_key, replace_value)
|
||||
|
||||
parameters = {
|
||||
"provider_uid": str(prowler_api_provider.uid),
|
||||
"last_updated": config.update_tag,
|
||||
"prowler_version": ProwlerConfig.prowler_version,
|
||||
}
|
||||
|
||||
batch_num = 0
|
||||
total_records = 0
|
||||
for batch in findings_batches:
|
||||
batch_num += 1
|
||||
batch_size = len(batch)
|
||||
total_records += batch_size
|
||||
|
||||
parameters["findings_data"] = batch
|
||||
|
||||
logger.info(f"Loading findings batch {batch_num} ({batch_size} records)")
|
||||
neo4j_session.run(query, parameters)
|
||||
|
||||
logger.info(f"Finished loading {total_records} records in {batch_num} batches")
|
||||
|
||||
|
||||
def cleanup_findings(
|
||||
neo4j_session: neo4j.Session,
|
||||
prowler_api_provider: Provider,
|
||||
config: CartographyConfig,
|
||||
) -> None:
|
||||
parameters = {
|
||||
"provider_uid": str(prowler_api_provider.uid),
|
||||
"last_updated": config.update_tag,
|
||||
"batch_size": BATCH_SIZE,
|
||||
}
|
||||
|
||||
batch = 1
|
||||
deleted_count = 1
|
||||
while deleted_count > 0:
|
||||
logger.info(f"Cleaning findings batch {batch}")
|
||||
|
||||
result = neo4j_session.run(CLEANUP_STATEMENT, parameters)
|
||||
|
||||
deleted_count = result.single().get("deleted_findings_count", 0)
|
||||
batch += 1
|
||||
@@ -0,0 +1,197 @@
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from cartography.config import Config as CartographyConfig
|
||||
from cartography.intel import analysis as cartography_analysis
|
||||
from cartography.intel import create_indexes as cartography_create_indexes
|
||||
from cartography.intel import ontology as cartography_ontology
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from api.attack_paths import database as graph_database
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
Provider as ProwlerAPIProvider,
|
||||
StateChoices,
|
||||
)
|
||||
from api.utils import initialize_prowler_provider
|
||||
from tasks.jobs.attack_paths import aws, db_utils, prowler, utils
|
||||
|
||||
# Without this Celery goes crazy with Cartography logging
|
||||
logging.getLogger("cartography").setLevel(logging.ERROR)
|
||||
logging.getLogger("neo4j").propagate = False
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
CARTOGRAPHY_INGESTION_FUNCTIONS: dict[str, Callable] = {
|
||||
"aws": aws.start_aws_ingestion,
|
||||
}
|
||||
|
||||
|
||||
def get_cartography_ingestion_function(provider_type: str) -> Callable | None:
|
||||
return CARTOGRAPHY_INGESTION_FUNCTIONS.get(provider_type)
|
||||
|
||||
|
||||
def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Code based on Cartography version 0.122.0, specifically on `cartography.cli.main`, `cartography.cli.CLI.main`,
|
||||
`cartography.sync.run_with_config` and `cartography.sync.Sync.run`.
|
||||
"""
|
||||
ingestion_exceptions = {} # This will hold any exceptions raised during ingestion
|
||||
|
||||
# Prowler necessary objects
|
||||
with rls_transaction(tenant_id):
|
||||
prowler_api_provider = ProwlerAPIProvider.objects.get(scan__pk=scan_id)
|
||||
prowler_sdk_provider = initialize_prowler_provider(prowler_api_provider)
|
||||
|
||||
# Attack Paths Scan necessary objects
|
||||
cartography_ingestion_function = get_cartography_ingestion_function(
|
||||
prowler_api_provider.provider
|
||||
)
|
||||
attack_paths_scan = db_utils.retrieve_attack_paths_scan(tenant_id, scan_id)
|
||||
|
||||
# Checks before starting the scan
|
||||
if not cartography_ingestion_function:
|
||||
ingestion_exceptions = {
|
||||
"global_error": f"Provider {prowler_api_provider.provider} is not supported for Attack Paths scans"
|
||||
}
|
||||
if attack_paths_scan:
|
||||
db_utils.finish_attack_paths_scan(
|
||||
attack_paths_scan, StateChoices.COMPLETED, ingestion_exceptions
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Provider {prowler_api_provider.provider} is not supported for Attack Paths scans"
|
||||
)
|
||||
return ingestion_exceptions
|
||||
|
||||
else:
|
||||
if not attack_paths_scan:
|
||||
logger.warning(
|
||||
f"No Attack Paths Scan found for scan {scan_id} and tenant {tenant_id}, let's create it then"
|
||||
)
|
||||
attack_paths_scan = db_utils.create_attack_paths_scan(
|
||||
tenant_id, scan_id, prowler_api_provider.id
|
||||
)
|
||||
|
||||
# While creating the Cartography configuration, attributes `neo4j_user` and `neo4j_password` are not really needed in this config object
|
||||
cartography_config = CartographyConfig(
|
||||
neo4j_uri=graph_database.get_uri(),
|
||||
neo4j_database=graph_database.get_database_name(attack_paths_scan.id),
|
||||
update_tag=int(time.time()),
|
||||
)
|
||||
|
||||
# Starting the Attack Paths scan
|
||||
db_utils.starting_attack_paths_scan(attack_paths_scan, task_id, cartography_config)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Creating Neo4j database {cartography_config.neo4j_database} for tenant {prowler_api_provider.tenant_id}"
|
||||
)
|
||||
|
||||
graph_database.create_database(cartography_config.neo4j_database)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 1)
|
||||
|
||||
logger.info(
|
||||
f"Starting Cartography ({attack_paths_scan.id}) for "
|
||||
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id}"
|
||||
)
|
||||
with graph_database.get_session(
|
||||
cartography_config.neo4j_database
|
||||
) as neo4j_session:
|
||||
# Indexes creation
|
||||
cartography_create_indexes.run(neo4j_session, cartography_config)
|
||||
prowler.create_indexes(neo4j_session)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 2)
|
||||
|
||||
# The real scan, where iterates over cloud services
|
||||
ingestion_exceptions = _call_within_event_loop(
|
||||
cartography_ingestion_function,
|
||||
neo4j_session,
|
||||
cartography_config,
|
||||
prowler_api_provider,
|
||||
prowler_sdk_provider,
|
||||
attack_paths_scan,
|
||||
)
|
||||
|
||||
# Post-processing: Just keeping it to be more Cartography compliant
|
||||
logger.info(
|
||||
f"Syncing Cartography ontology for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
cartography_ontology.run(neo4j_session, cartography_config)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 95)
|
||||
|
||||
logger.info(
|
||||
f"Syncing Cartography analysis for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
cartography_analysis.run(neo4j_session, cartography_config)
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 96)
|
||||
|
||||
# Adding Prowler nodes and relationships
|
||||
logger.info(
|
||||
f"Syncing Prowler analysis for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
prowler.analysis(
|
||||
neo4j_session, prowler_api_provider, scan_id, cartography_config
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Clearing Neo4j cache for database {cartography_config.neo4j_database}"
|
||||
)
|
||||
graph_database.clear_cache(cartography_config.neo4j_database)
|
||||
|
||||
logger.info(
|
||||
f"Completed Cartography ({attack_paths_scan.id}) for "
|
||||
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id}"
|
||||
)
|
||||
|
||||
# Handling databases changes
|
||||
old_attack_paths_scans = db_utils.get_old_attack_paths_scans(
|
||||
prowler_api_provider.tenant_id,
|
||||
prowler_api_provider.id,
|
||||
attack_paths_scan.id,
|
||||
)
|
||||
for old_attack_paths_scan in old_attack_paths_scans:
|
||||
graph_database.drop_database(old_attack_paths_scan.graph_database)
|
||||
db_utils.update_old_attack_paths_scan(old_attack_paths_scan)
|
||||
|
||||
db_utils.finish_attack_paths_scan(
|
||||
attack_paths_scan, StateChoices.COMPLETED, ingestion_exceptions
|
||||
)
|
||||
return ingestion_exceptions
|
||||
|
||||
except Exception as e:
|
||||
exception_message = utils.stringify_exception(e, "Cartography failed")
|
||||
logger.error(exception_message)
|
||||
ingestion_exceptions["global_cartography_error"] = exception_message
|
||||
|
||||
# Handling databases changes
|
||||
graph_database.drop_database(cartography_config.neo4j_database)
|
||||
db_utils.finish_attack_paths_scan(
|
||||
attack_paths_scan, StateChoices.FAILED, ingestion_exceptions
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _call_within_event_loop(fn, *args, **kwargs):
|
||||
"""
|
||||
Cartography needs a running event loop, so assuming there is none (Celery task or even regular DRF endpoint),
|
||||
let's create a new one and set it as the current event loop for this thread.
|
||||
"""
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to shutdown async generators cleanly: {e}")
|
||||
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
@@ -0,0 +1,10 @@
|
||||
import traceback
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def stringify_exception(exception: Exception, context: str) -> str:
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
exception_traceback = traceback.TracebackException.from_exception(exception)
|
||||
traceback_string = "".join(exception_traceback.format())
|
||||
return f"{timestamp} - {context}\n{traceback_string}"
|
||||
@@ -2,13 +2,13 @@ from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from django.db.models import Sum
|
||||
from django.db.models import OuterRef, Subquery, Sum
|
||||
from django.utils import timezone
|
||||
from tasks.jobs.queries import (
|
||||
COMPLIANCE_UPSERT_PROVIDER_SCORE_SQL,
|
||||
COMPLIANCE_UPSERT_TENANT_SUMMARY_ALL_SQL,
|
||||
)
|
||||
from tasks.jobs.scan import aggregate_category_counts
|
||||
from tasks.jobs.scan import aggregate_category_counts, aggregate_resource_group_counts
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS, MainRouter
|
||||
from api.db_utils import (
|
||||
@@ -28,6 +28,7 @@ from api.models import (
|
||||
ResourceScanSummary,
|
||||
Scan,
|
||||
ScanCategorySummary,
|
||||
ScanGroupSummary,
|
||||
ScanSummary,
|
||||
StateChoices,
|
||||
)
|
||||
@@ -356,6 +357,92 @@ def backfill_scan_category_summaries(tenant_id: str, scan_id: str):
|
||||
return {"status": "backfilled", "categories_count": len(category_counts)}
|
||||
|
||||
|
||||
def backfill_scan_resource_group_summaries(tenant_id: str, scan_id: str):
|
||||
"""
|
||||
Backfill ScanGroupSummary for a completed scan.
|
||||
|
||||
Aggregates resource group counts from all findings in the scan and creates
|
||||
one ScanGroupSummary row per (resource_group, severity) combination.
|
||||
|
||||
Args:
|
||||
tenant_id: Target tenant UUID
|
||||
scan_id: Scan UUID to backfill
|
||||
|
||||
Returns:
|
||||
dict: Status indicating whether backfill was performed
|
||||
"""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
if ScanGroupSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
|
||||
resource_group_counts: dict[tuple[str, str], dict[str, int]] = {}
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
# Get findings with their first resource UID via annotation
|
||||
resource_uid_subquery = ResourceFindingMapping.objects.filter(
|
||||
finding_id=OuterRef("id"), tenant_id=tenant_id
|
||||
).values("resource__uid")[:1]
|
||||
|
||||
for finding in (
|
||||
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
.annotate(resource_uid=Subquery(resource_uid_subquery))
|
||||
.values(
|
||||
"resource_groups",
|
||||
"severity",
|
||||
"status",
|
||||
"delta",
|
||||
"muted",
|
||||
"resource_uid",
|
||||
)
|
||||
):
|
||||
aggregate_resource_group_counts(
|
||||
resource_group=finding.get("resource_groups"),
|
||||
severity=finding.get("severity"),
|
||||
status=finding.get("status"),
|
||||
delta=finding.get("delta"),
|
||||
muted=finding.get("muted", False),
|
||||
resource_uid=finding.get("resource_uid") or "",
|
||||
cache=resource_group_counts,
|
||||
group_resources_cache=group_resources_cache,
|
||||
)
|
||||
|
||||
if not resource_group_counts:
|
||||
return {"status": "no resource groups to backfill"}
|
||||
|
||||
# Compute group-level resource counts (same value for all severity rows in a group)
|
||||
group_resource_counts = {
|
||||
grp: len(uids) for grp, uids in group_resources_cache.items()
|
||||
}
|
||||
resource_group_summaries = [
|
||||
ScanGroupSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
resource_group=grp,
|
||||
severity=severity,
|
||||
total_findings=counts["total"],
|
||||
failed_findings=counts["failed"],
|
||||
new_failed_findings=counts["new_failed"],
|
||||
resources_count=group_resource_counts.get(grp, 0),
|
||||
)
|
||||
for (grp, severity), counts in resource_group_counts.items()
|
||||
]
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
ScanGroupSummary.objects.bulk_create(
|
||||
resource_group_summaries, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
|
||||
return {"status": "backfilled", "resource_groups_count": len(resource_group_counts)}
|
||||
|
||||
|
||||
def backfill_provider_compliance_scores(tenant_id: str) -> dict:
|
||||
"""
|
||||
Backfill ProviderComplianceScore from latest completed scan per provider.
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
from celery.utils.log import get_task_logger
|
||||
from django.db import DatabaseError
|
||||
|
||||
from api.attack_paths import database as graph_database
|
||||
from api.db_router import MainRouter
|
||||
from api.db_utils import batch_delete, rls_transaction
|
||||
from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant
|
||||
from api.models import (
|
||||
AttackPathsScan,
|
||||
Finding,
|
||||
Provider,
|
||||
Resource,
|
||||
Scan,
|
||||
ScanSummary,
|
||||
Tenant,
|
||||
)
|
||||
from tasks.jobs.attack_paths.db_utils import get_provider_graph_database_names
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
@@ -23,16 +33,27 @@ def delete_provider(tenant_id: str, pk: str):
|
||||
Raises:
|
||||
Provider.DoesNotExist: If no instance with the provided primary key exists.
|
||||
"""
|
||||
# Delete the Attack Paths' graph databases related to the provider
|
||||
graph_database_names = get_provider_graph_database_names(tenant_id, pk)
|
||||
try:
|
||||
for graph_database_name in graph_database_names:
|
||||
graph_database.drop_database(graph_database_name)
|
||||
except graph_database.GraphDatabaseQueryException as gdb_error:
|
||||
logger.error(f"Error deleting Provider databases: {gdb_error}")
|
||||
raise
|
||||
|
||||
# Get all provider related data and delete them in batches
|
||||
with rls_transaction(tenant_id):
|
||||
instance = Provider.all_objects.get(pk=pk)
|
||||
deletion_summary = {}
|
||||
deletion_steps = [
|
||||
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
|
||||
("Findings", Finding.all_objects.filter(scan__provider=instance)),
|
||||
("Resources", Resource.all_objects.filter(provider=instance)),
|
||||
("Scans", Scan.all_objects.filter(provider=instance)),
|
||||
("AttackPathsScans", AttackPathsScan.all_objects.filter(provider=instance)),
|
||||
]
|
||||
|
||||
deletion_summary = {}
|
||||
for step_name, queryset in deletion_steps:
|
||||
try:
|
||||
_, step_summary = batch_delete(tenant_id, queryset)
|
||||
@@ -48,6 +69,7 @@ def delete_provider(tenant_id: str, pk: str):
|
||||
except DatabaseError as db_error:
|
||||
logger.error(f"Error deleting Provider: {db_error}")
|
||||
raise
|
||||
|
||||
return deletion_summary
|
||||
|
||||
|
||||
|
||||
+155
-3536
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,186 @@
|
||||
# Base classes and data structures
|
||||
from .base import (
|
||||
BaseComplianceReportGenerator,
|
||||
ComplianceData,
|
||||
RequirementData,
|
||||
create_pdf_styles,
|
||||
get_requirement_metadata,
|
||||
)
|
||||
|
||||
# Chart functions
|
||||
from .charts import (
|
||||
create_horizontal_bar_chart,
|
||||
create_pie_chart,
|
||||
create_radar_chart,
|
||||
create_stacked_bar_chart,
|
||||
create_vertical_bar_chart,
|
||||
get_chart_color_for_percentage,
|
||||
)
|
||||
|
||||
# Reusable components
|
||||
# Reusable components: Color helpers, Badge components, Risk component,
|
||||
# Table components, Section components
|
||||
from .components import (
|
||||
ColumnConfig,
|
||||
create_badge,
|
||||
create_data_table,
|
||||
create_findings_table,
|
||||
create_info_table,
|
||||
create_multi_badge_row,
|
||||
create_risk_component,
|
||||
create_section_header,
|
||||
create_status_badge,
|
||||
create_summary_table,
|
||||
get_color_for_compliance,
|
||||
get_color_for_risk_level,
|
||||
get_color_for_weight,
|
||||
get_status_color,
|
||||
)
|
||||
|
||||
# Framework configuration: Main configuration, Color constants, ENS colors,
|
||||
# NIS2 colors, Chart colors, ENS constants, Section constants, Layout constants
|
||||
from .config import (
|
||||
CHART_COLOR_BLUE,
|
||||
CHART_COLOR_GREEN_1,
|
||||
CHART_COLOR_GREEN_2,
|
||||
CHART_COLOR_ORANGE,
|
||||
CHART_COLOR_RED,
|
||||
CHART_COLOR_YELLOW,
|
||||
COL_WIDTH_LARGE,
|
||||
COL_WIDTH_MEDIUM,
|
||||
COL_WIDTH_SMALL,
|
||||
COL_WIDTH_XLARGE,
|
||||
COL_WIDTH_XXLARGE,
|
||||
COLOR_BG_BLUE,
|
||||
COLOR_BG_LIGHT_BLUE,
|
||||
COLOR_BLUE,
|
||||
COLOR_DARK_GRAY,
|
||||
COLOR_ENS_ALTO,
|
||||
COLOR_ENS_BAJO,
|
||||
COLOR_ENS_MEDIO,
|
||||
COLOR_ENS_OPCIONAL,
|
||||
COLOR_GRAY,
|
||||
COLOR_HIGH_RISK,
|
||||
COLOR_LIGHT_BLUE,
|
||||
COLOR_LIGHT_GRAY,
|
||||
COLOR_LIGHTER_BLUE,
|
||||
COLOR_LOW_RISK,
|
||||
COLOR_MEDIUM_RISK,
|
||||
COLOR_NIS2_PRIMARY,
|
||||
COLOR_NIS2_SECONDARY,
|
||||
COLOR_PROWLER_DARK_GREEN,
|
||||
COLOR_SAFE,
|
||||
COLOR_WHITE,
|
||||
DIMENSION_KEYS,
|
||||
DIMENSION_MAPPING,
|
||||
DIMENSION_NAMES,
|
||||
ENS_NIVEL_ORDER,
|
||||
ENS_TIPO_ORDER,
|
||||
FRAMEWORK_REGISTRY,
|
||||
NIS2_SECTION_TITLES,
|
||||
NIS2_SECTIONS,
|
||||
PADDING_LARGE,
|
||||
PADDING_MEDIUM,
|
||||
PADDING_SMALL,
|
||||
PADDING_XLARGE,
|
||||
THREATSCORE_SECTIONS,
|
||||
TIPO_ICONS,
|
||||
FrameworkConfig,
|
||||
get_framework_config,
|
||||
)
|
||||
|
||||
# Framework-specific generators
|
||||
from .ens import ENSReportGenerator
|
||||
from .nis2 import NIS2ReportGenerator
|
||||
from .threatscore import ThreatScoreReportGenerator
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"BaseComplianceReportGenerator",
|
||||
"ComplianceData",
|
||||
"RequirementData",
|
||||
"create_pdf_styles",
|
||||
"get_requirement_metadata",
|
||||
# Framework-specific generators
|
||||
"ThreatScoreReportGenerator",
|
||||
"ENSReportGenerator",
|
||||
"NIS2ReportGenerator",
|
||||
# Configuration
|
||||
"FrameworkConfig",
|
||||
"FRAMEWORK_REGISTRY",
|
||||
"get_framework_config",
|
||||
# Color constants
|
||||
"COLOR_BLUE",
|
||||
"COLOR_LIGHT_BLUE",
|
||||
"COLOR_LIGHTER_BLUE",
|
||||
"COLOR_BG_BLUE",
|
||||
"COLOR_BG_LIGHT_BLUE",
|
||||
"COLOR_GRAY",
|
||||
"COLOR_LIGHT_GRAY",
|
||||
"COLOR_DARK_GRAY",
|
||||
"COLOR_WHITE",
|
||||
"COLOR_HIGH_RISK",
|
||||
"COLOR_MEDIUM_RISK",
|
||||
"COLOR_LOW_RISK",
|
||||
"COLOR_SAFE",
|
||||
"COLOR_PROWLER_DARK_GREEN",
|
||||
"COLOR_ENS_ALTO",
|
||||
"COLOR_ENS_MEDIO",
|
||||
"COLOR_ENS_BAJO",
|
||||
"COLOR_ENS_OPCIONAL",
|
||||
"COLOR_NIS2_PRIMARY",
|
||||
"COLOR_NIS2_SECONDARY",
|
||||
"CHART_COLOR_BLUE",
|
||||
"CHART_COLOR_GREEN_1",
|
||||
"CHART_COLOR_GREEN_2",
|
||||
"CHART_COLOR_YELLOW",
|
||||
"CHART_COLOR_ORANGE",
|
||||
"CHART_COLOR_RED",
|
||||
# ENS constants
|
||||
"DIMENSION_MAPPING",
|
||||
"DIMENSION_NAMES",
|
||||
"DIMENSION_KEYS",
|
||||
"ENS_NIVEL_ORDER",
|
||||
"ENS_TIPO_ORDER",
|
||||
"TIPO_ICONS",
|
||||
# Section constants
|
||||
"THREATSCORE_SECTIONS",
|
||||
"NIS2_SECTIONS",
|
||||
"NIS2_SECTION_TITLES",
|
||||
# Layout constants
|
||||
"COL_WIDTH_SMALL",
|
||||
"COL_WIDTH_MEDIUM",
|
||||
"COL_WIDTH_LARGE",
|
||||
"COL_WIDTH_XLARGE",
|
||||
"COL_WIDTH_XXLARGE",
|
||||
"PADDING_SMALL",
|
||||
"PADDING_MEDIUM",
|
||||
"PADDING_LARGE",
|
||||
"PADDING_XLARGE",
|
||||
# Color helpers
|
||||
"get_color_for_risk_level",
|
||||
"get_color_for_weight",
|
||||
"get_color_for_compliance",
|
||||
"get_status_color",
|
||||
# Badge components
|
||||
"create_badge",
|
||||
"create_status_badge",
|
||||
"create_multi_badge_row",
|
||||
# Risk component
|
||||
"create_risk_component",
|
||||
# Table components
|
||||
"create_info_table",
|
||||
"create_data_table",
|
||||
"create_findings_table",
|
||||
"ColumnConfig",
|
||||
# Section components
|
||||
"create_section_header",
|
||||
"create_summary_table",
|
||||
# Chart functions
|
||||
"get_chart_color_for_percentage",
|
||||
"create_vertical_bar_chart",
|
||||
"create_horizontal_bar_chart",
|
||||
"create_radar_chart",
|
||||
"create_pie_chart",
|
||||
"create_stacked_bar_chart",
|
||||
]
|
||||
@@ -0,0 +1,911 @@
|
||||
import gc
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from reportlab.lib.enums import TA_CENTER
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.ttfonts import TTFont
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.platypus import Image, PageBreak, Paragraph, SimpleDocTemplate, Spacer
|
||||
from tasks.jobs.threatscore_utils import (
|
||||
_aggregate_requirement_statistics_from_database,
|
||||
_calculate_requirements_data_from_statistics,
|
||||
_load_findings_for_requirement_checks,
|
||||
)
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Provider, StatusChoices
|
||||
from api.utils import initialize_prowler_provider
|
||||
from prowler.lib.check.compliance_models import Compliance
|
||||
from prowler.lib.outputs.finding import Finding as FindingOutput
|
||||
|
||||
from .components import (
|
||||
ColumnConfig,
|
||||
create_data_table,
|
||||
create_info_table,
|
||||
create_status_badge,
|
||||
)
|
||||
from .config import (
|
||||
COLOR_BG_BLUE,
|
||||
COLOR_BG_LIGHT_BLUE,
|
||||
COLOR_BLUE,
|
||||
COLOR_BORDER_GRAY,
|
||||
COLOR_GRAY,
|
||||
COLOR_LIGHT_BLUE,
|
||||
COLOR_LIGHTER_BLUE,
|
||||
COLOR_PROWLER_DARK_GREEN,
|
||||
PADDING_LARGE,
|
||||
PADDING_SMALL,
|
||||
FrameworkConfig,
|
||||
)
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
# Register fonts (done once at module load)
|
||||
_fonts_registered: bool = False
|
||||
|
||||
|
||||
def _register_fonts() -> None:
|
||||
"""Register custom fonts for PDF generation.
|
||||
|
||||
Uses a module-level flag to ensure fonts are only registered once,
|
||||
avoiding duplicate registration errors from reportlab.
|
||||
"""
|
||||
global _fonts_registered
|
||||
if _fonts_registered:
|
||||
return
|
||||
|
||||
fonts_dir = os.path.join(os.path.dirname(__file__), "../../assets/fonts")
|
||||
|
||||
pdfmetrics.registerFont(
|
||||
TTFont(
|
||||
"PlusJakartaSans",
|
||||
os.path.join(fonts_dir, "PlusJakartaSans-Regular.ttf"),
|
||||
)
|
||||
)
|
||||
|
||||
pdfmetrics.registerFont(
|
||||
TTFont(
|
||||
"FiraCode",
|
||||
os.path.join(fonts_dir, "FiraCode-Regular.ttf"),
|
||||
)
|
||||
)
|
||||
|
||||
_fonts_registered = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequirementData:
|
||||
"""Data for a single compliance requirement.
|
||||
|
||||
Attributes:
|
||||
id: Requirement identifier
|
||||
description: Requirement description
|
||||
status: Compliance status (PASS, FAIL, MANUAL)
|
||||
passed_findings: Number of passed findings
|
||||
failed_findings: Number of failed findings
|
||||
total_findings: Total number of findings
|
||||
checks: List of check IDs associated with this requirement
|
||||
attributes: Framework-specific requirement attributes
|
||||
"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
status: str
|
||||
passed_findings: int = 0
|
||||
failed_findings: int = 0
|
||||
total_findings: int = 0
|
||||
checks: list[str] = field(default_factory=list)
|
||||
attributes: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplianceData:
|
||||
"""Aggregated compliance data for report generation.
|
||||
|
||||
This dataclass holds all the data needed to generate a compliance report,
|
||||
including compliance framework metadata, requirements, and findings.
|
||||
|
||||
Attributes:
|
||||
tenant_id: Tenant identifier
|
||||
scan_id: Scan identifier
|
||||
provider_id: Provider identifier
|
||||
compliance_id: Compliance framework identifier
|
||||
framework: Framework name (e.g., "CIS", "ENS")
|
||||
name: Full compliance framework name
|
||||
version: Framework version
|
||||
description: Framework description
|
||||
requirements: List of RequirementData objects
|
||||
attributes_by_requirement_id: Mapping of requirement IDs to their attributes
|
||||
findings_by_check_id: Mapping of check IDs to their findings
|
||||
provider_obj: Provider model object
|
||||
prowler_provider: Initialized Prowler provider
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
scan_id: str
|
||||
provider_id: str
|
||||
compliance_id: str
|
||||
framework: str
|
||||
name: str
|
||||
version: str
|
||||
description: str
|
||||
requirements: list[RequirementData] = field(default_factory=list)
|
||||
attributes_by_requirement_id: dict[str, dict] = field(default_factory=dict)
|
||||
findings_by_check_id: dict[str, list[FindingOutput]] = field(default_factory=dict)
|
||||
provider_obj: Provider | None = None
|
||||
prowler_provider: Any = None
|
||||
|
||||
|
||||
def get_requirement_metadata(
|
||||
requirement_id: str,
|
||||
attributes_by_requirement_id: dict[str, dict],
|
||||
) -> Any | None:
|
||||
"""Get the first requirement metadata object from attributes.
|
||||
|
||||
This helper function extracts the requirement metadata (req_attributes)
|
||||
from the attributes dictionary. It's a common pattern used across all
|
||||
report generators.
|
||||
|
||||
Args:
|
||||
requirement_id: The requirement ID to look up.
|
||||
attributes_by_requirement_id: Mapping of requirement IDs to their attributes.
|
||||
|
||||
Returns:
|
||||
The first requirement attribute object, or None if not found.
|
||||
|
||||
Example:
|
||||
>>> meta = get_requirement_metadata(req.id, data.attributes_by_requirement_id)
|
||||
>>> if meta:
|
||||
... section = getattr(meta, "Section", "Unknown")
|
||||
"""
|
||||
req_attrs = attributes_by_requirement_id.get(requirement_id, {})
|
||||
meta_list = req_attrs.get("attributes", {}).get("req_attributes", [])
|
||||
if meta_list:
|
||||
return meta_list[0]
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PDF Styles Cache
|
||||
# =============================================================================
|
||||
|
||||
_PDF_STYLES_CACHE: dict[str, ParagraphStyle] | None = None
|
||||
|
||||
|
||||
def create_pdf_styles() -> dict[str, ParagraphStyle]:
|
||||
"""Create and return PDF paragraph styles used throughout the report.
|
||||
|
||||
Styles are cached on first call to improve performance.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the following styles:
|
||||
- 'title': Title style with prowler green color
|
||||
- 'h1': Heading 1 style with blue color and background
|
||||
- 'h2': Heading 2 style with light blue color
|
||||
- 'h3': Heading 3 style for sub-headings
|
||||
- 'normal': Normal text style with left indent
|
||||
- 'normal_center': Normal text style without indent
|
||||
"""
|
||||
global _PDF_STYLES_CACHE
|
||||
|
||||
if _PDF_STYLES_CACHE is not None:
|
||||
return _PDF_STYLES_CACHE
|
||||
|
||||
_register_fonts()
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
title_style = ParagraphStyle(
|
||||
"CustomTitle",
|
||||
parent=styles["Title"],
|
||||
fontSize=24,
|
||||
textColor=COLOR_PROWLER_DARK_GREEN,
|
||||
spaceAfter=20,
|
||||
fontName="PlusJakartaSans",
|
||||
alignment=TA_CENTER,
|
||||
)
|
||||
|
||||
h1 = ParagraphStyle(
|
||||
"CustomH1",
|
||||
parent=styles["Heading1"],
|
||||
fontSize=18,
|
||||
textColor=COLOR_BLUE,
|
||||
spaceBefore=20,
|
||||
spaceAfter=12,
|
||||
fontName="PlusJakartaSans",
|
||||
leftIndent=0,
|
||||
borderWidth=2,
|
||||
borderColor=COLOR_BLUE,
|
||||
borderPadding=PADDING_LARGE,
|
||||
backColor=COLOR_BG_BLUE,
|
||||
)
|
||||
|
||||
h2 = ParagraphStyle(
|
||||
"CustomH2",
|
||||
parent=styles["Heading2"],
|
||||
fontSize=14,
|
||||
textColor=COLOR_LIGHT_BLUE,
|
||||
spaceBefore=15,
|
||||
spaceAfter=8,
|
||||
fontName="PlusJakartaSans",
|
||||
leftIndent=10,
|
||||
borderWidth=1,
|
||||
borderColor=COLOR_BORDER_GRAY,
|
||||
borderPadding=5,
|
||||
backColor=COLOR_BG_LIGHT_BLUE,
|
||||
)
|
||||
|
||||
h3 = ParagraphStyle(
|
||||
"CustomH3",
|
||||
parent=styles["Heading3"],
|
||||
fontSize=12,
|
||||
textColor=COLOR_LIGHTER_BLUE,
|
||||
spaceBefore=10,
|
||||
spaceAfter=6,
|
||||
fontName="PlusJakartaSans",
|
||||
leftIndent=20,
|
||||
)
|
||||
|
||||
normal = ParagraphStyle(
|
||||
"CustomNormal",
|
||||
parent=styles["Normal"],
|
||||
fontSize=10,
|
||||
textColor=COLOR_GRAY,
|
||||
spaceBefore=PADDING_SMALL,
|
||||
spaceAfter=PADDING_SMALL,
|
||||
leftIndent=30,
|
||||
fontName="PlusJakartaSans",
|
||||
)
|
||||
|
||||
normal_center = ParagraphStyle(
|
||||
"CustomNormalCenter",
|
||||
parent=styles["Normal"],
|
||||
fontSize=10,
|
||||
textColor=COLOR_GRAY,
|
||||
fontName="PlusJakartaSans",
|
||||
)
|
||||
|
||||
_PDF_STYLES_CACHE = {
|
||||
"title": title_style,
|
||||
"h1": h1,
|
||||
"h2": h2,
|
||||
"h3": h3,
|
||||
"normal": normal,
|
||||
"normal_center": normal_center,
|
||||
}
|
||||
|
||||
return _PDF_STYLES_CACHE
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base Report Generator
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BaseComplianceReportGenerator(ABC):
|
||||
"""Abstract base class for compliance PDF report generators.
|
||||
|
||||
This class implements the Template Method pattern, providing a common
|
||||
structure for all compliance reports while allowing subclasses to
|
||||
customize specific sections.
|
||||
|
||||
Subclasses must implement:
|
||||
- create_executive_summary()
|
||||
- create_charts_section()
|
||||
- create_requirements_index()
|
||||
|
||||
Optionally, subclasses can override:
|
||||
- create_cover_page()
|
||||
- create_detailed_findings()
|
||||
- get_footer_text()
|
||||
"""
|
||||
|
||||
def __init__(self, config: FrameworkConfig):
|
||||
"""Initialize the report generator.
|
||||
|
||||
Args:
|
||||
config: Framework configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.styles = create_pdf_styles()
|
||||
|
||||
# =========================================================================
|
||||
# Template Method
|
||||
# =========================================================================
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
compliance_id: str,
|
||||
output_path: str,
|
||||
provider_id: str,
|
||||
provider_obj: Provider | None = None,
|
||||
requirement_statistics: dict[str, dict[str, int]] | None = None,
|
||||
findings_cache: dict[str, list[FindingOutput]] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Generate the PDF compliance report.
|
||||
|
||||
This is the template method that orchestrates the report generation.
|
||||
It calls abstract methods that subclasses must implement.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for RLS context
|
||||
scan_id: Scan identifier
|
||||
compliance_id: Compliance framework identifier
|
||||
output_path: Path where the PDF will be saved
|
||||
provider_id: Provider identifier
|
||||
provider_obj: Optional pre-fetched Provider object
|
||||
requirement_statistics: Optional pre-aggregated statistics
|
||||
findings_cache: Optional pre-loaded findings cache
|
||||
**kwargs: Additional framework-specific arguments
|
||||
"""
|
||||
logger.info(
|
||||
"Generating %s report for scan %s", self.config.display_name, scan_id
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. Load compliance data
|
||||
data = self._load_compliance_data(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
compliance_id=compliance_id,
|
||||
provider_id=provider_id,
|
||||
provider_obj=provider_obj,
|
||||
requirement_statistics=requirement_statistics,
|
||||
findings_cache=findings_cache,
|
||||
)
|
||||
|
||||
# 2. Create PDF document
|
||||
doc = self._create_document(output_path, data)
|
||||
|
||||
# 3. Build report elements incrementally to manage memory
|
||||
# We collect garbage after heavy sections to prevent OOM on large reports
|
||||
elements = []
|
||||
|
||||
# Cover page (lightweight)
|
||||
elements.extend(self.create_cover_page(data))
|
||||
elements.append(PageBreak())
|
||||
|
||||
# Executive summary (framework-specific)
|
||||
elements.extend(self.create_executive_summary(data))
|
||||
|
||||
# Body sections (charts + requirements index)
|
||||
# Override _build_body_sections() in subclasses to change section order
|
||||
elements.extend(self._build_body_sections(data))
|
||||
|
||||
# Detailed findings - heaviest section, loads findings on-demand
|
||||
logger.info("Building detailed findings section...")
|
||||
elements.extend(self.create_detailed_findings(data, **kwargs))
|
||||
gc.collect() # Free findings data after processing
|
||||
|
||||
# 4. Build the PDF
|
||||
logger.info("Building PDF document with %d elements...", len(elements))
|
||||
self._build_pdf(doc, elements, data)
|
||||
|
||||
# Final cleanup
|
||||
del elements
|
||||
gc.collect()
|
||||
|
||||
logger.info("Successfully generated report at %s", output_path)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
tb_lineno = e.__traceback__.tb_lineno if e.__traceback__ else "unknown"
|
||||
logger.error("Error generating report, line %s -- %s", tb_lineno, e)
|
||||
logger.error("Full traceback:\n%s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
def _build_body_sections(self, data: ComplianceData) -> list:
|
||||
"""Build the body sections between executive summary and detailed findings.
|
||||
|
||||
Override in subclasses to change section order.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
# Charts section (framework-specific) - heavy on memory due to matplotlib
|
||||
elements.extend(self.create_charts_section(data))
|
||||
elements.append(PageBreak())
|
||||
gc.collect() # Free matplotlib resources
|
||||
|
||||
# Requirements index (framework-specific)
|
||||
elements.extend(self.create_requirements_index(data))
|
||||
elements.append(PageBreak())
|
||||
|
||||
return elements
|
||||
|
||||
# =========================================================================
|
||||
# Abstract Methods (must be implemented by subclasses)
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
def create_executive_summary(self, data: ComplianceData) -> list:
|
||||
"""Create the executive summary section.
|
||||
|
||||
This section typically includes:
|
||||
- Overall compliance score/metrics
|
||||
- High-level statistics
|
||||
- Critical findings summary
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_charts_section(self, data: ComplianceData) -> list:
|
||||
"""Create the charts and visualizations section.
|
||||
|
||||
This section typically includes:
|
||||
- Compliance score charts by section
|
||||
- Distribution charts
|
||||
- Trend visualizations
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_requirements_index(self, data: ComplianceData) -> list:
|
||||
"""Create the requirements index/table of contents.
|
||||
|
||||
This section typically includes:
|
||||
- Hierarchical list of requirements
|
||||
- Status indicators
|
||||
- Section groupings
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Common Methods (can be overridden by subclasses)
|
||||
# =========================================================================
|
||||
|
||||
def create_cover_page(self, data: ComplianceData) -> list:
|
||||
"""Create the report cover page.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements
|
||||
"""
|
||||
elements = []
|
||||
|
||||
# Prowler logo
|
||||
logo_path = os.path.join(
|
||||
os.path.dirname(__file__), "../../assets/img/prowler_logo.png"
|
||||
)
|
||||
if os.path.exists(logo_path):
|
||||
logo = Image(logo_path, width=5 * inch, height=1 * inch)
|
||||
elements.append(logo)
|
||||
|
||||
elements.append(Spacer(1, 0.5 * inch))
|
||||
|
||||
# Title
|
||||
title_text = f"{self.config.display_name} Report"
|
||||
elements.append(Paragraph(title_text, self.styles["title"]))
|
||||
elements.append(Spacer(1, 0.5 * inch))
|
||||
|
||||
# Compliance info table
|
||||
info_rows = self._build_info_rows(data, language=self.config.language)
|
||||
|
||||
info_table = create_info_table(
|
||||
rows=info_rows,
|
||||
label_width=2 * inch,
|
||||
value_width=4 * inch,
|
||||
normal_style=self.styles["normal_center"],
|
||||
)
|
||||
elements.append(info_table)
|
||||
|
||||
return elements
|
||||
|
||||
def _build_info_rows(
|
||||
self, data: ComplianceData, language: str = "en"
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Build the standard info rows for the cover page table.
|
||||
|
||||
This helper method creates the common metadata rows used in all
|
||||
report cover pages. Subclasses can use this to maintain consistency
|
||||
while customizing other aspects of the cover page.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
language: Language for labels ("en" or "es").
|
||||
|
||||
Returns:
|
||||
List of (label, value) tuples for the info table.
|
||||
"""
|
||||
# Labels based on language
|
||||
labels = {
|
||||
"en": {
|
||||
"framework": "Framework:",
|
||||
"id": "ID:",
|
||||
"name": "Name:",
|
||||
"version": "Version:",
|
||||
"provider": "Provider:",
|
||||
"account_id": "Account ID:",
|
||||
"alias": "Alias:",
|
||||
"scan_id": "Scan ID:",
|
||||
"description": "Description:",
|
||||
},
|
||||
"es": {
|
||||
"framework": "Framework:",
|
||||
"id": "ID:",
|
||||
"name": "Nombre:",
|
||||
"version": "Versión:",
|
||||
"provider": "Proveedor:",
|
||||
"account_id": "Account ID:",
|
||||
"alias": "Alias:",
|
||||
"scan_id": "Scan ID:",
|
||||
"description": "Descripción:",
|
||||
},
|
||||
}
|
||||
lang_labels = labels.get(language, labels["en"])
|
||||
|
||||
info_rows = [
|
||||
(lang_labels["framework"], data.framework),
|
||||
(lang_labels["id"], data.compliance_id),
|
||||
(lang_labels["name"], data.name),
|
||||
(lang_labels["version"], data.version),
|
||||
]
|
||||
|
||||
# Add provider info if available
|
||||
if data.provider_obj:
|
||||
info_rows.append(
|
||||
(lang_labels["provider"], data.provider_obj.provider.upper())
|
||||
)
|
||||
info_rows.append(
|
||||
(lang_labels["account_id"], data.provider_obj.uid or "N/A")
|
||||
)
|
||||
info_rows.append((lang_labels["alias"], data.provider_obj.alias or "N/A"))
|
||||
|
||||
info_rows.append((lang_labels["scan_id"], data.scan_id))
|
||||
|
||||
if data.description:
|
||||
info_rows.append((lang_labels["description"], data.description))
|
||||
|
||||
return info_rows
|
||||
|
||||
def create_detailed_findings(self, data: ComplianceData, **kwargs) -> list:
|
||||
"""Create the detailed findings section.
|
||||
|
||||
This default implementation creates a requirement-by-requirement
|
||||
breakdown with findings tables. Subclasses can override for
|
||||
framework-specific presentation.
|
||||
|
||||
This method implements on-demand loading of findings using the shared
|
||||
findings cache to minimize database queries and memory usage.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data
|
||||
**kwargs: Framework-specific options (e.g., only_failed)
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements
|
||||
"""
|
||||
elements = []
|
||||
only_failed = kwargs.get("only_failed", True)
|
||||
include_manual = kwargs.get("include_manual", False)
|
||||
|
||||
# Filter requirements if needed
|
||||
requirements = data.requirements
|
||||
if only_failed:
|
||||
# Include FAIL requirements, and optionally MANUAL if include_manual is True
|
||||
if include_manual:
|
||||
requirements = [
|
||||
r
|
||||
for r in requirements
|
||||
if r.status in (StatusChoices.FAIL, StatusChoices.MANUAL)
|
||||
]
|
||||
else:
|
||||
requirements = [
|
||||
r for r in requirements if r.status == StatusChoices.FAIL
|
||||
]
|
||||
|
||||
# Collect all check IDs for requirements that will be displayed
|
||||
# This allows us to load only the findings we actually need (memory optimization)
|
||||
check_ids_to_load = []
|
||||
for req in requirements:
|
||||
check_ids_to_load.extend(req.checks)
|
||||
|
||||
# Load findings on-demand only for the checks that will be displayed
|
||||
# Uses the shared findings cache to avoid duplicate queries across reports
|
||||
logger.info("Loading findings on-demand for %d requirements", len(requirements))
|
||||
findings_by_check_id = _load_findings_for_requirement_checks(
|
||||
data.tenant_id,
|
||||
data.scan_id,
|
||||
check_ids_to_load,
|
||||
data.prowler_provider,
|
||||
data.findings_by_check_id, # Pass the cache to update it
|
||||
)
|
||||
|
||||
for req in requirements:
|
||||
# Requirement header
|
||||
elements.append(
|
||||
Paragraph(
|
||||
f"{req.id}: {req.description}",
|
||||
self.styles["h1"],
|
||||
)
|
||||
)
|
||||
|
||||
# Status badge
|
||||
elements.append(create_status_badge(req.status))
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
|
||||
# Findings for this requirement
|
||||
for check_id in req.checks:
|
||||
elements.append(Paragraph(f"Check: {check_id}", self.styles["h2"]))
|
||||
|
||||
findings = findings_by_check_id.get(check_id, [])
|
||||
if not findings:
|
||||
elements.append(
|
||||
Paragraph(
|
||||
"- No information for this finding currently",
|
||||
self.styles["normal"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Create findings table
|
||||
findings_table = self._create_findings_table(findings)
|
||||
elements.append(findings_table)
|
||||
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
|
||||
elements.append(PageBreak())
|
||||
|
||||
return elements
|
||||
|
||||
def get_footer_text(self, page_num: int) -> tuple[str, str]:
|
||||
"""Get footer text for a page.
|
||||
|
||||
Args:
|
||||
page_num: Current page number
|
||||
|
||||
Returns:
|
||||
Tuple of (left_text, right_text) for the footer
|
||||
"""
|
||||
if self.config.language == "es":
|
||||
page_text = f"Página {page_num}"
|
||||
else:
|
||||
page_text = f"Page {page_num}"
|
||||
|
||||
return page_text, "Powered by Prowler"
|
||||
|
||||
# =========================================================================
|
||||
# Private Helper Methods
|
||||
# =========================================================================
|
||||
|
||||
def _load_compliance_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
compliance_id: str,
|
||||
provider_id: str,
|
||||
provider_obj: Provider | None,
|
||||
requirement_statistics: dict | None,
|
||||
findings_cache: dict | None,
|
||||
) -> ComplianceData:
|
||||
"""Load and aggregate compliance data from the database.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
scan_id: Scan identifier
|
||||
compliance_id: Compliance framework identifier
|
||||
provider_id: Provider identifier
|
||||
provider_obj: Optional pre-fetched Provider
|
||||
requirement_statistics: Optional pre-aggregated statistics
|
||||
findings_cache: Optional pre-loaded findings
|
||||
|
||||
Returns:
|
||||
Aggregated ComplianceData object
|
||||
"""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Load provider
|
||||
if provider_obj is None:
|
||||
provider_obj = Provider.objects.get(id=provider_id)
|
||||
|
||||
prowler_provider = initialize_prowler_provider(provider_obj)
|
||||
provider_type = provider_obj.provider
|
||||
|
||||
# Load compliance framework
|
||||
frameworks_bulk = Compliance.get_bulk(provider_type)
|
||||
compliance_obj = frameworks_bulk.get(compliance_id)
|
||||
|
||||
if not compliance_obj:
|
||||
raise ValueError(f"Compliance framework not found: {compliance_id}")
|
||||
|
||||
framework = getattr(compliance_obj, "Framework", "N/A")
|
||||
name = getattr(compliance_obj, "Name", "N/A")
|
||||
version = getattr(compliance_obj, "Version", "N/A")
|
||||
description = getattr(compliance_obj, "Description", "")
|
||||
|
||||
# Aggregate requirement statistics
|
||||
if requirement_statistics is None:
|
||||
logger.info("Aggregating requirement statistics for scan %s", scan_id)
|
||||
requirement_statistics = _aggregate_requirement_statistics_from_database(
|
||||
tenant_id, scan_id
|
||||
)
|
||||
else:
|
||||
logger.info("Reusing pre-aggregated statistics for scan %s", scan_id)
|
||||
|
||||
# Calculate requirements data
|
||||
attributes_by_requirement_id, requirements_list = (
|
||||
_calculate_requirements_data_from_statistics(
|
||||
compliance_obj, requirement_statistics
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to RequirementData objects
|
||||
requirements = []
|
||||
for req_dict in requirements_list:
|
||||
req = RequirementData(
|
||||
id=req_dict["id"],
|
||||
description=req_dict["attributes"].get("description", ""),
|
||||
status=req_dict["attributes"].get("status", StatusChoices.MANUAL),
|
||||
passed_findings=req_dict["attributes"].get("passed_findings", 0),
|
||||
failed_findings=req_dict["attributes"].get("failed_findings", 0),
|
||||
total_findings=req_dict["attributes"].get("total_findings", 0),
|
||||
checks=attributes_by_requirement_id.get(req_dict["id"], {})
|
||||
.get("attributes", {})
|
||||
.get("checks", []),
|
||||
)
|
||||
requirements.append(req)
|
||||
|
||||
return ComplianceData(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
provider_id=provider_id,
|
||||
compliance_id=compliance_id,
|
||||
framework=framework,
|
||||
name=name,
|
||||
version=version,
|
||||
description=description,
|
||||
requirements=requirements,
|
||||
attributes_by_requirement_id=attributes_by_requirement_id,
|
||||
findings_by_check_id=findings_cache if findings_cache is not None else {},
|
||||
provider_obj=provider_obj,
|
||||
prowler_provider=prowler_provider,
|
||||
)
|
||||
|
||||
def _create_document(
|
||||
self, output_path: str, data: ComplianceData
|
||||
) -> SimpleDocTemplate:
|
||||
"""Create the PDF document template.
|
||||
|
||||
Args:
|
||||
output_path: Path for the output PDF
|
||||
data: Compliance data for metadata
|
||||
|
||||
Returns:
|
||||
Configured SimpleDocTemplate
|
||||
"""
|
||||
return SimpleDocTemplate(
|
||||
output_path,
|
||||
pagesize=letter,
|
||||
title=f"{self.config.display_name} Report - {data.framework}",
|
||||
author="Prowler",
|
||||
subject=f"Compliance Report for {data.framework}",
|
||||
creator="Prowler Engineering Team",
|
||||
keywords=f"compliance,{data.framework},security,framework,prowler",
|
||||
)
|
||||
|
||||
def _build_pdf(
|
||||
self,
|
||||
doc: SimpleDocTemplate,
|
||||
elements: list,
|
||||
data: ComplianceData,
|
||||
) -> None:
|
||||
"""Build the final PDF with footers.
|
||||
|
||||
Args:
|
||||
doc: Document template
|
||||
elements: List of ReportLab elements
|
||||
data: Compliance data
|
||||
"""
|
||||
|
||||
def add_footer(
|
||||
canvas_obj: canvas.Canvas,
|
||||
doc_template: SimpleDocTemplate,
|
||||
) -> None:
|
||||
canvas_obj.saveState()
|
||||
width, _ = doc_template.pagesize
|
||||
left_text, right_text = self.get_footer_text(doc_template.page)
|
||||
|
||||
canvas_obj.setFont("PlusJakartaSans", 9)
|
||||
canvas_obj.setFillColorRGB(0.4, 0.4, 0.4)
|
||||
canvas_obj.drawString(30, 20, left_text)
|
||||
|
||||
text_width = canvas_obj.stringWidth(right_text, "PlusJakartaSans", 9)
|
||||
canvas_obj.drawString(width - text_width - 30, 20, right_text)
|
||||
canvas_obj.restoreState()
|
||||
|
||||
doc.build(
|
||||
elements,
|
||||
onFirstPage=add_footer,
|
||||
onLaterPages=add_footer,
|
||||
)
|
||||
|
||||
def _create_findings_table(self, findings: list[FindingOutput]) -> Any:
|
||||
"""Create a findings table.
|
||||
|
||||
Args:
|
||||
findings: List of finding objects
|
||||
|
||||
Returns:
|
||||
ReportLab Table element
|
||||
"""
|
||||
|
||||
def get_finding_title(f):
|
||||
metadata = getattr(f, "metadata", None)
|
||||
if metadata:
|
||||
return getattr(metadata, "CheckTitle", getattr(f, "check_id", ""))
|
||||
return getattr(f, "check_id", "")
|
||||
|
||||
def get_resource_name(f):
|
||||
name = getattr(f, "resource_name", "")
|
||||
if not name:
|
||||
name = getattr(f, "resource_uid", "")
|
||||
return name
|
||||
|
||||
def get_severity(f):
|
||||
metadata = getattr(f, "metadata", None)
|
||||
if metadata:
|
||||
return getattr(metadata, "Severity", "").capitalize()
|
||||
return ""
|
||||
|
||||
# Convert findings to dicts for the table
|
||||
data = []
|
||||
for f in findings:
|
||||
item = {
|
||||
"title": get_finding_title(f),
|
||||
"resource_name": get_resource_name(f),
|
||||
"severity": get_severity(f),
|
||||
"status": getattr(f, "status", "").upper(),
|
||||
"region": getattr(f, "region", "global"),
|
||||
}
|
||||
data.append(item)
|
||||
|
||||
columns = [
|
||||
ColumnConfig("Finding", 2.5 * inch, "title"),
|
||||
ColumnConfig("Resource", 3 * inch, "resource_name"),
|
||||
ColumnConfig("Severity", 0.9 * inch, "severity"),
|
||||
ColumnConfig("Status", 0.9 * inch, "status"),
|
||||
ColumnConfig("Region", 0.9 * inch, "region"),
|
||||
]
|
||||
|
||||
return create_data_table(
|
||||
data=data,
|
||||
columns=columns,
|
||||
header_color=self.config.primary_color,
|
||||
normal_style=self.styles["normal_center"],
|
||||
)
|
||||
@@ -0,0 +1,404 @@
|
||||
import gc
|
||||
import io
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import matplotlib
|
||||
|
||||
# Use non-interactive Agg backend for memory efficiency in server environments
|
||||
# This MUST be set before importing pyplot
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt # noqa: E402
|
||||
|
||||
from .config import ( # noqa: E402
|
||||
CHART_COLOR_BLUE,
|
||||
CHART_COLOR_GREEN_1,
|
||||
CHART_COLOR_GREEN_2,
|
||||
CHART_COLOR_ORANGE,
|
||||
CHART_COLOR_RED,
|
||||
CHART_COLOR_YELLOW,
|
||||
CHART_DPI_DEFAULT,
|
||||
)
|
||||
|
||||
# Use centralized DPI setting from config
|
||||
DEFAULT_CHART_DPI = CHART_DPI_DEFAULT
|
||||
|
||||
|
||||
def get_chart_color_for_percentage(percentage: float) -> str:
|
||||
"""Get chart color string based on percentage.
|
||||
|
||||
Args:
|
||||
percentage: Value between 0 and 100
|
||||
|
||||
Returns:
|
||||
Hex color string for matplotlib
|
||||
"""
|
||||
if percentage >= 80:
|
||||
return CHART_COLOR_GREEN_1
|
||||
if percentage >= 60:
|
||||
return CHART_COLOR_GREEN_2
|
||||
if percentage >= 40:
|
||||
return CHART_COLOR_YELLOW
|
||||
if percentage >= 20:
|
||||
return CHART_COLOR_ORANGE
|
||||
return CHART_COLOR_RED
|
||||
|
||||
|
||||
def create_vertical_bar_chart(
|
||||
labels: list[str],
|
||||
values: list[float],
|
||||
ylabel: str = "Compliance Score (%)",
|
||||
xlabel: str = "Section",
|
||||
title: str | None = None,
|
||||
color_func: Callable[[float], str] | None = None,
|
||||
colors: list[str] | None = None,
|
||||
figsize: tuple[int, int] = (10, 6),
|
||||
dpi: int = DEFAULT_CHART_DPI,
|
||||
y_limit: tuple[float, float] = (0, 100),
|
||||
show_labels: bool = True,
|
||||
rotation: int = 45,
|
||||
) -> io.BytesIO:
|
||||
"""Create a vertical bar chart.
|
||||
|
||||
Args:
|
||||
labels: X-axis labels
|
||||
values: Bar heights (numeric values)
|
||||
ylabel: Y-axis label
|
||||
xlabel: X-axis label
|
||||
title: Optional chart title
|
||||
color_func: Function to determine bar color based on value
|
||||
colors: Explicit list of colors (overrides color_func)
|
||||
figsize: Figure size (width, height) in inches
|
||||
dpi: Resolution for output image
|
||||
y_limit: Y-axis limits (min, max)
|
||||
show_labels: Whether to show value labels on bars
|
||||
rotation: X-axis label rotation angle
|
||||
|
||||
Returns:
|
||||
BytesIO buffer containing the PNG image
|
||||
"""
|
||||
if color_func is None:
|
||||
color_func = get_chart_color_for_percentage
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Determine colors
|
||||
if colors is None:
|
||||
colors_list = [color_func(v) for v in values]
|
||||
else:
|
||||
colors_list = colors
|
||||
|
||||
bars = ax.bar(labels, values, color=colors_list)
|
||||
|
||||
ax.set_ylabel(ylabel, fontsize=12)
|
||||
ax.set_xlabel(xlabel, fontsize=12)
|
||||
ax.set_ylim(*y_limit)
|
||||
|
||||
if title:
|
||||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||||
|
||||
# Add value labels on bars
|
||||
if show_labels:
|
||||
for bar_item, value in zip(bars, values):
|
||||
height = bar_item.get_height()
|
||||
ax.text(
|
||||
bar_item.get_x() + bar_item.get_width() / 2.0,
|
||||
height + 1,
|
||||
f"{value:.1f}%",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
plt.xticks(rotation=rotation, ha="right")
|
||||
ax.grid(True, alpha=0.3, axis="y")
|
||||
plt.tight_layout()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
try:
|
||||
fig.savefig(buffer, format="png", dpi=dpi, bbox_inches="tight")
|
||||
buffer.seek(0)
|
||||
finally:
|
||||
plt.close(fig)
|
||||
gc.collect() # Force garbage collection after heavy matplotlib operation
|
||||
|
||||
return buffer
|
||||
|
||||
|
||||
def create_horizontal_bar_chart(
|
||||
labels: list[str],
|
||||
values: list[float],
|
||||
xlabel: str = "Compliance (%)",
|
||||
title: str | None = None,
|
||||
color_func: Callable[[float], str] | None = None,
|
||||
colors: list[str] | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
dpi: int = DEFAULT_CHART_DPI,
|
||||
x_limit: tuple[float, float] = (0, 100),
|
||||
show_labels: bool = True,
|
||||
label_fontsize: int = 16,
|
||||
) -> io.BytesIO:
|
||||
"""Create a horizontal bar chart.
|
||||
|
||||
Args:
|
||||
labels: Y-axis labels (bar names)
|
||||
values: Bar widths (numeric values)
|
||||
xlabel: X-axis label
|
||||
title: Optional chart title
|
||||
color_func: Function to determine bar color based on value
|
||||
colors: Explicit list of colors (overrides color_func)
|
||||
figsize: Figure size (auto-calculated if None based on label count)
|
||||
dpi: Resolution for output image
|
||||
x_limit: X-axis limits (min, max)
|
||||
show_labels: Whether to show value labels on bars
|
||||
label_fontsize: Font size for y-axis labels
|
||||
|
||||
Returns:
|
||||
BytesIO buffer containing the PNG image
|
||||
"""
|
||||
if color_func is None:
|
||||
color_func = get_chart_color_for_percentage
|
||||
|
||||
# Auto-calculate figure size based on number of items
|
||||
if figsize is None:
|
||||
figsize = (10, max(6, int(len(labels) * 0.4)))
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Determine colors
|
||||
if colors is None:
|
||||
colors_list = [color_func(v) for v in values]
|
||||
else:
|
||||
colors_list = colors
|
||||
|
||||
y_pos = range(len(labels))
|
||||
bars = ax.barh(y_pos, values, color=colors_list)
|
||||
|
||||
ax.set_yticks(y_pos)
|
||||
ax.set_yticklabels(labels, fontsize=label_fontsize)
|
||||
ax.set_xlabel(xlabel, fontsize=14)
|
||||
ax.set_xlim(*x_limit)
|
||||
|
||||
if title:
|
||||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||||
|
||||
# Add value labels
|
||||
if show_labels:
|
||||
for bar_item, value in zip(bars, values):
|
||||
width = bar_item.get_width()
|
||||
ax.text(
|
||||
width + 1,
|
||||
bar_item.get_y() + bar_item.get_height() / 2.0,
|
||||
f"{value:.1f}%",
|
||||
ha="left",
|
||||
va="center",
|
||||
fontweight="bold",
|
||||
fontsize=10,
|
||||
)
|
||||
|
||||
ax.grid(True, alpha=0.3, axis="x")
|
||||
plt.tight_layout()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
try:
|
||||
fig.savefig(buffer, format="png", dpi=dpi, bbox_inches="tight")
|
||||
buffer.seek(0)
|
||||
finally:
|
||||
plt.close(fig)
|
||||
gc.collect() # Force garbage collection after heavy matplotlib operation
|
||||
|
||||
return buffer
|
||||
|
||||
|
||||
def create_radar_chart(
|
||||
labels: list[str],
|
||||
values: list[float],
|
||||
color: str = CHART_COLOR_BLUE,
|
||||
fill_alpha: float = 0.25,
|
||||
figsize: tuple[int, int] = (8, 8),
|
||||
dpi: int = DEFAULT_CHART_DPI,
|
||||
y_limit: tuple[float, float] = (0, 100),
|
||||
y_ticks: list[int] | None = None,
|
||||
label_fontsize: int = 14,
|
||||
title: str | None = None,
|
||||
) -> io.BytesIO:
|
||||
"""Create a radar/spider chart.
|
||||
|
||||
Args:
|
||||
labels: Category names around the chart
|
||||
values: Values for each category (should have same length as labels)
|
||||
color: Line and fill color
|
||||
fill_alpha: Transparency of the fill (0-1)
|
||||
figsize: Figure size (width, height) in inches
|
||||
dpi: Resolution for output image
|
||||
y_limit: Radial axis limits (min, max)
|
||||
y_ticks: Custom tick values for radial axis
|
||||
label_fontsize: Font size for category labels
|
||||
title: Optional chart title
|
||||
|
||||
Returns:
|
||||
BytesIO buffer containing the PNG image
|
||||
"""
|
||||
num_vars = len(labels)
|
||||
angles = [n / float(num_vars) * 2 * math.pi for n in range(num_vars)]
|
||||
|
||||
# Close the polygon
|
||||
values_closed = list(values) + [values[0]]
|
||||
angles_closed = angles + [angles[0]]
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize, subplot_kw={"projection": "polar"})
|
||||
|
||||
ax.plot(angles_closed, values_closed, "o-", linewidth=2, color=color)
|
||||
ax.fill(angles_closed, values_closed, alpha=fill_alpha, color=color)
|
||||
|
||||
ax.set_xticks(angles)
|
||||
ax.set_xticklabels(labels, fontsize=label_fontsize)
|
||||
ax.set_ylim(*y_limit)
|
||||
|
||||
if y_ticks is None:
|
||||
y_ticks = [20, 40, 60, 80, 100]
|
||||
ax.set_yticks(y_ticks)
|
||||
ax.set_yticklabels([f"{t}%" for t in y_ticks], fontsize=12)
|
||||
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
if title:
|
||||
ax.set_title(title, fontsize=14, fontweight="bold", y=1.08)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
try:
|
||||
fig.savefig(buffer, format="png", dpi=dpi, bbox_inches="tight")
|
||||
buffer.seek(0)
|
||||
finally:
|
||||
plt.close(fig)
|
||||
gc.collect() # Force garbage collection after heavy matplotlib operation
|
||||
|
||||
return buffer
|
||||
|
||||
|
||||
def create_pie_chart(
|
||||
labels: list[str],
|
||||
values: list[float],
|
||||
colors: list[str] | None = None,
|
||||
figsize: tuple[int, int] = (6, 6),
|
||||
dpi: int = DEFAULT_CHART_DPI,
|
||||
autopct: str = "%1.1f%%",
|
||||
startangle: int = 90,
|
||||
title: str | None = None,
|
||||
) -> io.BytesIO:
|
||||
"""Create a pie chart.
|
||||
|
||||
Args:
|
||||
labels: Slice labels
|
||||
values: Slice values
|
||||
colors: Optional list of colors for slices
|
||||
figsize: Figure size (width, height) in inches
|
||||
dpi: Resolution for output image
|
||||
autopct: Format string for percentage labels
|
||||
startangle: Starting angle for first slice
|
||||
title: Optional chart title
|
||||
|
||||
Returns:
|
||||
BytesIO buffer containing the PNG image
|
||||
"""
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
_, _, autotexts = ax.pie(
|
||||
values,
|
||||
labels=labels,
|
||||
colors=colors,
|
||||
autopct=autopct,
|
||||
startangle=startangle,
|
||||
)
|
||||
|
||||
# Style the text
|
||||
for autotext in autotexts:
|
||||
autotext.set_fontweight("bold")
|
||||
|
||||
if title:
|
||||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
try:
|
||||
fig.savefig(buffer, format="png", dpi=dpi, bbox_inches="tight")
|
||||
buffer.seek(0)
|
||||
finally:
|
||||
plt.close(fig)
|
||||
gc.collect() # Force garbage collection after heavy matplotlib operation
|
||||
|
||||
return buffer
|
||||
|
||||
|
||||
def create_stacked_bar_chart(
|
||||
labels: list[str],
|
||||
data_series: dict[str, list[float]],
|
||||
colors: dict[str, str] | None = None,
|
||||
xlabel: str = "",
|
||||
ylabel: str = "Count",
|
||||
title: str | None = None,
|
||||
figsize: tuple[int, int] = (10, 6),
|
||||
dpi: int = DEFAULT_CHART_DPI,
|
||||
rotation: int = 45,
|
||||
show_legend: bool = True,
|
||||
) -> io.BytesIO:
|
||||
"""Create a stacked bar chart.
|
||||
|
||||
Args:
|
||||
labels: X-axis labels
|
||||
data_series: Dictionary mapping series name to list of values
|
||||
colors: Dictionary mapping series name to color
|
||||
xlabel: X-axis label
|
||||
ylabel: Y-axis label
|
||||
title: Optional chart title
|
||||
figsize: Figure size (width, height) in inches
|
||||
dpi: Resolution for output image
|
||||
rotation: X-axis label rotation angle
|
||||
show_legend: Whether to show the legend
|
||||
|
||||
Returns:
|
||||
BytesIO buffer containing the PNG image
|
||||
"""
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Default colors if not provided
|
||||
default_colors = {
|
||||
"Pass": CHART_COLOR_GREEN_1,
|
||||
"Fail": CHART_COLOR_RED,
|
||||
"Manual": CHART_COLOR_YELLOW,
|
||||
}
|
||||
if colors is None:
|
||||
colors = default_colors
|
||||
|
||||
bottom = [0] * len(labels)
|
||||
for series_name, values in data_series.items():
|
||||
color = colors.get(series_name, CHART_COLOR_BLUE)
|
||||
ax.bar(labels, values, bottom=bottom, label=series_name, color=color)
|
||||
bottom = [b + v for b, v in zip(bottom, values)]
|
||||
|
||||
ax.set_xlabel(xlabel, fontsize=12)
|
||||
ax.set_ylabel(ylabel, fontsize=12)
|
||||
|
||||
if title:
|
||||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||||
|
||||
plt.xticks(rotation=rotation, ha="right")
|
||||
|
||||
if show_legend:
|
||||
ax.legend()
|
||||
|
||||
ax.grid(True, alpha=0.3, axis="y")
|
||||
plt.tight_layout()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
try:
|
||||
fig.savefig(buffer, format="png", dpi=dpi, bbox_inches="tight")
|
||||
buffer.seek(0)
|
||||
finally:
|
||||
plt.close(fig)
|
||||
gc.collect() # Force garbage collection after heavy matplotlib operation
|
||||
|
||||
return buffer
|
||||
@@ -0,0 +1,599 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.styles import ParagraphStyle
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.platypus import LongTable, Paragraph, Spacer, Table, TableStyle
|
||||
|
||||
from .config import (
|
||||
ALTERNATE_ROWS_MAX_SIZE,
|
||||
COLOR_BLUE,
|
||||
COLOR_BORDER_GRAY,
|
||||
COLOR_DARK_GRAY,
|
||||
COLOR_GRID_GRAY,
|
||||
COLOR_HIGH_RISK,
|
||||
COLOR_LIGHT_GRAY,
|
||||
COLOR_LOW_RISK,
|
||||
COLOR_MEDIUM_RISK,
|
||||
COLOR_SAFE,
|
||||
COLOR_WHITE,
|
||||
LONG_TABLE_THRESHOLD,
|
||||
PADDING_LARGE,
|
||||
PADDING_MEDIUM,
|
||||
PADDING_SMALL,
|
||||
PADDING_XLARGE,
|
||||
)
|
||||
|
||||
|
||||
def get_color_for_risk_level(risk_level: int) -> colors.Color:
|
||||
"""
|
||||
Get color based on risk level.
|
||||
|
||||
Args:
|
||||
risk_level (int): Numeric risk level (0-5).
|
||||
|
||||
Returns:
|
||||
colors.Color: Appropriate color for the risk level.
|
||||
"""
|
||||
if risk_level >= 4:
|
||||
return COLOR_HIGH_RISK
|
||||
if risk_level >= 3:
|
||||
return COLOR_MEDIUM_RISK
|
||||
if risk_level >= 2:
|
||||
return COLOR_LOW_RISK
|
||||
return COLOR_SAFE
|
||||
|
||||
|
||||
def get_color_for_weight(weight: int) -> colors.Color:
|
||||
"""
|
||||
Get color based on weight value.
|
||||
|
||||
Args:
|
||||
weight (int): Numeric weight value.
|
||||
|
||||
Returns:
|
||||
colors.Color: Appropriate color for the weight.
|
||||
"""
|
||||
if weight > 100:
|
||||
return COLOR_HIGH_RISK
|
||||
if weight > 50:
|
||||
return COLOR_LOW_RISK
|
||||
return COLOR_SAFE
|
||||
|
||||
|
||||
def get_color_for_compliance(percentage: float) -> colors.Color:
|
||||
"""
|
||||
Get color based on compliance percentage.
|
||||
|
||||
Args:
|
||||
percentage (float): Compliance percentage (0-100).
|
||||
|
||||
Returns:
|
||||
colors.Color: Appropriate color for the compliance level.
|
||||
"""
|
||||
if percentage >= 80:
|
||||
return COLOR_SAFE
|
||||
if percentage >= 60:
|
||||
return COLOR_LOW_RISK
|
||||
return COLOR_HIGH_RISK
|
||||
|
||||
|
||||
def get_status_color(status: str) -> colors.Color:
|
||||
"""
|
||||
Get color for a status value.
|
||||
|
||||
Args:
|
||||
status (str): Status string (PASS, FAIL, MANUAL, etc.).
|
||||
|
||||
Returns:
|
||||
colors.Color: Appropriate color for the status.
|
||||
"""
|
||||
status_upper = status.upper()
|
||||
if status_upper == "PASS":
|
||||
return COLOR_SAFE
|
||||
if status_upper == "FAIL":
|
||||
return COLOR_HIGH_RISK
|
||||
return COLOR_DARK_GRAY
|
||||
|
||||
|
||||
def create_badge(
|
||||
text: str,
|
||||
bg_color: colors.Color,
|
||||
text_color: colors.Color = COLOR_WHITE,
|
||||
width: float = 1.4 * inch,
|
||||
font: str = "FiraCode",
|
||||
font_size: int = 11,
|
||||
) -> Table:
|
||||
"""
|
||||
Create a generic colored badge component.
|
||||
|
||||
Args:
|
||||
text (str): Text to display in the badge.
|
||||
bg_color (colors.Color): Background color.
|
||||
text_color (colors.Color): Text color (default white).
|
||||
width (float): Badge width in inches.
|
||||
font (str): Font name to use.
|
||||
font_size (int): Font size.
|
||||
|
||||
Returns:
|
||||
Table: A Table object styled as a badge.
|
||||
"""
|
||||
data = [[text]]
|
||||
table = Table(data, colWidths=[width])
|
||||
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (0, 0), bg_color),
|
||||
("TEXTCOLOR", (0, 0), (0, 0), text_color),
|
||||
("FONTNAME", (0, 0), (0, 0), font),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("FONTSIZE", (0, 0), (-1, -1), font_size),
|
||||
("GRID", (0, 0), (-1, -1), 0.5, colors.black),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
("TOPPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def create_status_badge(status: str) -> Table:
|
||||
"""
|
||||
Create a PASS/FAIL/MANUAL status badge.
|
||||
|
||||
Args:
|
||||
status (str): Status value (e.g., "PASS", "FAIL", "MANUAL").
|
||||
|
||||
Returns:
|
||||
Table: A styled Table badge for the status.
|
||||
"""
|
||||
status_upper = status.upper()
|
||||
status_color = get_status_color(status_upper)
|
||||
|
||||
data = [["State:", status_upper]]
|
||||
table = Table(data, colWidths=[0.6 * inch, 0.8 * inch])
|
||||
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (0, 0), COLOR_LIGHT_GRAY),
|
||||
("FONTNAME", (0, 0), (0, 0), "PlusJakartaSans"),
|
||||
("BACKGROUND", (1, 0), (1, 0), status_color),
|
||||
("TEXTCOLOR", (1, 0), (1, 0), COLOR_WHITE),
|
||||
("FONTNAME", (1, 0), (1, 0), "FiraCode"),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("FONTSIZE", (0, 0), (-1, -1), 12),
|
||||
("GRID", (0, 0), (-1, -1), 0.5, colors.black),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
("TOPPADDING", (0, 0), (-1, -1), PADDING_XLARGE),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), PADDING_XLARGE),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def create_multi_badge_row(
|
||||
badges: list[tuple[str, colors.Color]],
|
||||
badge_width: float = 0.4 * inch,
|
||||
font: str = "FiraCode",
|
||||
) -> Table:
|
||||
"""
|
||||
Create a row of multiple small badges.
|
||||
|
||||
Args:
|
||||
badges (list[tuple[str, colors.Color]]): List of (text, color) tuples for each badge.
|
||||
badge_width (float): Width of each badge.
|
||||
font (str): Font name to use.
|
||||
|
||||
Returns:
|
||||
Table: A Table with multiple colored badges in a row.
|
||||
"""
|
||||
if not badges:
|
||||
data = [["N/A"]]
|
||||
table = Table(data, colWidths=[1 * inch])
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (0, 0), COLOR_LIGHT_GRAY),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("FONTSIZE", (0, 0), (-1, -1), 10),
|
||||
]
|
||||
)
|
||||
)
|
||||
return table
|
||||
|
||||
data = [[text for text, _ in badges]]
|
||||
col_widths = [badge_width] * len(badges)
|
||||
table = Table(data, colWidths=col_widths)
|
||||
|
||||
styles = [
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("FONTNAME", (0, 0), (-1, -1), font),
|
||||
("FONTSIZE", (0, 0), (-1, -1), 10),
|
||||
("TEXTCOLOR", (0, 0), (-1, -1), COLOR_WHITE),
|
||||
("GRID", (0, 0), (-1, -1), 0.5, colors.black),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), PADDING_SMALL),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), PADDING_SMALL),
|
||||
("TOPPADDING", (0, 0), (-1, -1), PADDING_MEDIUM),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), PADDING_MEDIUM),
|
||||
]
|
||||
|
||||
for idx, (_, badge_color) in enumerate(badges):
|
||||
styles.append(("BACKGROUND", (idx, 0), (idx, 0), badge_color))
|
||||
|
||||
table.setStyle(TableStyle(styles))
|
||||
return table
|
||||
|
||||
|
||||
def create_risk_component(
|
||||
risk_level: int,
|
||||
weight: int,
|
||||
score: int = 0,
|
||||
) -> Table:
|
||||
"""
|
||||
Create a visual risk component showing risk level, weight, and score.
|
||||
|
||||
Args:
|
||||
risk_level (int): The risk level (0-5).
|
||||
weight (int): The weight value.
|
||||
score (int): The calculated score (default 0).
|
||||
|
||||
Returns:
|
||||
Table: A styled Table showing risk metrics.
|
||||
"""
|
||||
risk_color = get_color_for_risk_level(risk_level)
|
||||
weight_color = get_color_for_weight(weight)
|
||||
|
||||
data = [
|
||||
[
|
||||
"Risk Level:",
|
||||
str(risk_level),
|
||||
"Weight:",
|
||||
str(weight),
|
||||
"Score:",
|
||||
str(score),
|
||||
]
|
||||
]
|
||||
|
||||
table = Table(
|
||||
data,
|
||||
colWidths=[
|
||||
0.8 * inch,
|
||||
0.4 * inch,
|
||||
0.6 * inch,
|
||||
0.4 * inch,
|
||||
0.5 * inch,
|
||||
0.4 * inch,
|
||||
],
|
||||
)
|
||||
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (0, 0), COLOR_LIGHT_GRAY),
|
||||
("BACKGROUND", (1, 0), (1, 0), risk_color),
|
||||
("TEXTCOLOR", (1, 0), (1, 0), COLOR_WHITE),
|
||||
("FONTNAME", (1, 0), (1, 0), "FiraCode"),
|
||||
("BACKGROUND", (2, 0), (2, 0), COLOR_LIGHT_GRAY),
|
||||
("BACKGROUND", (3, 0), (3, 0), weight_color),
|
||||
("TEXTCOLOR", (3, 0), (3, 0), COLOR_WHITE),
|
||||
("FONTNAME", (3, 0), (3, 0), "FiraCode"),
|
||||
("BACKGROUND", (4, 0), (4, 0), COLOR_LIGHT_GRAY),
|
||||
("BACKGROUND", (5, 0), (5, 0), COLOR_DARK_GRAY),
|
||||
("TEXTCOLOR", (5, 0), (5, 0), COLOR_WHITE),
|
||||
("FONTNAME", (5, 0), (5, 0), "FiraCode"),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("FONTSIZE", (0, 0), (-1, -1), 10),
|
||||
("GRID", (0, 0), (-1, -1), 0.5, colors.black),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), PADDING_MEDIUM),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), PADDING_MEDIUM),
|
||||
("TOPPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def create_info_table(
|
||||
rows: list[tuple[str, Any]],
|
||||
label_width: float = 2 * inch,
|
||||
value_width: float = 4 * inch,
|
||||
label_color: colors.Color = COLOR_BLUE,
|
||||
value_bg_color: colors.Color | None = None,
|
||||
normal_style: ParagraphStyle | None = None,
|
||||
) -> Table:
|
||||
"""
|
||||
Create a key-value information table.
|
||||
|
||||
Args:
|
||||
rows (list[tuple[str, Any]]): List of (label, value) tuples.
|
||||
label_width (float): Width of the label column.
|
||||
value_width (float): Width of the value column.
|
||||
label_color (colors.Color): Background color for labels.
|
||||
value_bg_color (colors.Color | None): Background color for values (optional).
|
||||
normal_style (ParagraphStyle | None): ParagraphStyle for wrapping long values.
|
||||
|
||||
Returns:
|
||||
Table: A styled Table with key-value pairs.
|
||||
"""
|
||||
from .config import COLOR_BG_BLUE
|
||||
|
||||
if value_bg_color is None:
|
||||
value_bg_color = COLOR_BG_BLUE
|
||||
|
||||
# Handle empty rows case - Table requires at least one row
|
||||
if not rows:
|
||||
table = Table([["", ""]], colWidths=[label_width, value_width])
|
||||
table.setStyle(TableStyle([("FONTSIZE", (0, 0), (-1, -1), 0)]))
|
||||
return table
|
||||
|
||||
# Process rows - wrap long values in Paragraph if style provided
|
||||
table_data = []
|
||||
for label, value in rows:
|
||||
if normal_style and isinstance(value, str) and len(value) > 50:
|
||||
value = Paragraph(value, normal_style)
|
||||
table_data.append([label, value])
|
||||
|
||||
table = Table(table_data, colWidths=[label_width, value_width])
|
||||
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (0, -1), label_color),
|
||||
("TEXTCOLOR", (0, 0), (0, -1), COLOR_WHITE),
|
||||
("FONTNAME", (0, 0), (0, -1), "FiraCode"),
|
||||
("BACKGROUND", (1, 0), (1, -1), value_bg_color),
|
||||
("TEXTCOLOR", (1, 0), (1, -1), COLOR_DARK_GRAY),
|
||||
("FONTNAME", (1, 0), (1, -1), "PlusJakartaSans"),
|
||||
("ALIGN", (0, 0), (-1, -1), "LEFT"),
|
||||
("VALIGN", (0, 0), (-1, -1), "TOP"),
|
||||
("FONTSIZE", (0, 0), (-1, -1), 11),
|
||||
("GRID", (0, 0), (-1, -1), 1, COLOR_BORDER_GRAY),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), PADDING_XLARGE),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), PADDING_XLARGE),
|
||||
("TOPPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), PADDING_LARGE),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnConfig:
|
||||
"""
|
||||
Configuration for a table column.
|
||||
|
||||
Attributes:
|
||||
header (str): Column header text.
|
||||
width (float): Column width in inches.
|
||||
field (str | Callable[[Any], str]): Field name or callable to extract value from data.
|
||||
align (str): Text alignment (LEFT, CENTER, RIGHT).
|
||||
"""
|
||||
|
||||
header: str
|
||||
width: float
|
||||
field: str | Callable[[Any], str]
|
||||
align: str = "CENTER"
|
||||
|
||||
|
||||
def create_data_table(
|
||||
data: list[dict[str, Any]],
|
||||
columns: list[ColumnConfig],
|
||||
header_color: colors.Color = COLOR_BLUE,
|
||||
alternate_rows: bool = True,
|
||||
normal_style: ParagraphStyle | None = None,
|
||||
) -> Table | LongTable:
|
||||
"""
|
||||
Create a data table with configurable columns.
|
||||
|
||||
Uses LongTable for large datasets (>50 rows) for better memory efficiency
|
||||
and page splitting. LongTable repeats headers on each page and has
|
||||
optimized memory handling for large tables.
|
||||
|
||||
Args:
|
||||
data (list[dict[str, Any]]): List of data dictionaries.
|
||||
columns (list[ColumnConfig]): Column configuration list.
|
||||
header_color (colors.Color): Background color for header row.
|
||||
alternate_rows (bool): Whether to alternate row backgrounds.
|
||||
normal_style (ParagraphStyle | None): ParagraphStyle for cell values.
|
||||
|
||||
Returns:
|
||||
Table or LongTable: A styled table with data.
|
||||
"""
|
||||
# Build header row
|
||||
header_row = [col.header for col in columns]
|
||||
table_data = [header_row]
|
||||
|
||||
# Build data rows
|
||||
for item in data:
|
||||
row = []
|
||||
for col in columns:
|
||||
if callable(col.field):
|
||||
value = col.field(item)
|
||||
else:
|
||||
value = item.get(col.field, "")
|
||||
|
||||
if normal_style and isinstance(value, str):
|
||||
value = Paragraph(value, normal_style)
|
||||
row.append(value)
|
||||
table_data.append(row)
|
||||
|
||||
col_widths = [col.width for col in columns]
|
||||
|
||||
# Use LongTable for large datasets - it handles page breaks better
|
||||
# and has optimized memory handling for tables with many rows
|
||||
use_long_table = len(data) > LONG_TABLE_THRESHOLD
|
||||
if use_long_table:
|
||||
table = LongTable(table_data, colWidths=col_widths, repeatRows=1)
|
||||
else:
|
||||
table = Table(table_data, colWidths=col_widths)
|
||||
|
||||
styles = [
|
||||
("BACKGROUND", (0, 0), (-1, 0), header_color),
|
||||
("TEXTCOLOR", (0, 0), (-1, 0), COLOR_WHITE),
|
||||
("FONTNAME", (0, 0), (-1, 0), "FiraCode"),
|
||||
("FONTSIZE", (0, 0), (-1, 0), 10),
|
||||
("FONTSIZE", (0, 1), (-1, -1), 9),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("GRID", (0, 0), (-1, -1), 1, COLOR_GRID_GRAY),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), PADDING_MEDIUM),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), PADDING_MEDIUM),
|
||||
("TOPPADDING", (0, 0), (-1, -1), PADDING_MEDIUM),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), PADDING_MEDIUM),
|
||||
]
|
||||
|
||||
# Apply column alignments
|
||||
for idx, col in enumerate(columns):
|
||||
styles.append(("ALIGN", (idx, 0), (idx, -1), col.align))
|
||||
|
||||
# Alternate row backgrounds - skip for very large tables as it adds memory overhead
|
||||
if (
|
||||
alternate_rows
|
||||
and len(table_data) > 1
|
||||
and len(table_data) <= ALTERNATE_ROWS_MAX_SIZE
|
||||
):
|
||||
for i in range(1, len(table_data)):
|
||||
if i % 2 == 0:
|
||||
styles.append(
|
||||
("BACKGROUND", (0, i), (-1, i), colors.Color(0.98, 0.98, 0.98))
|
||||
)
|
||||
|
||||
table.setStyle(TableStyle(styles))
|
||||
return table
|
||||
|
||||
|
||||
def create_findings_table(
|
||||
findings: list[Any],
|
||||
columns: list[ColumnConfig] | None = None,
|
||||
header_color: colors.Color = COLOR_BLUE,
|
||||
normal_style: ParagraphStyle | None = None,
|
||||
) -> Table:
|
||||
"""
|
||||
Create a findings table with default or custom columns.
|
||||
|
||||
Args:
|
||||
findings (list[Any]): List of finding objects.
|
||||
columns (list[ColumnConfig] | None): Optional column configuration (defaults to standard columns).
|
||||
header_color (colors.Color): Background color for header row.
|
||||
normal_style (ParagraphStyle | None): ParagraphStyle for cell values.
|
||||
|
||||
Returns:
|
||||
Table: A styled Table with findings data.
|
||||
"""
|
||||
if columns is None:
|
||||
columns = [
|
||||
ColumnConfig("Finding", 2.5 * inch, "title"),
|
||||
ColumnConfig("Resource", 3 * inch, "resource_name"),
|
||||
ColumnConfig("Severity", 0.9 * inch, "severity"),
|
||||
ColumnConfig("Status", 0.9 * inch, "status"),
|
||||
ColumnConfig("Region", 0.9 * inch, "region"),
|
||||
]
|
||||
|
||||
# Convert findings to dicts
|
||||
data = []
|
||||
for finding in findings:
|
||||
item = {}
|
||||
for col in columns:
|
||||
if callable(col.field):
|
||||
item[col.header.lower()] = col.field(finding)
|
||||
elif hasattr(finding, col.field):
|
||||
item[col.field] = getattr(finding, col.field, "")
|
||||
elif isinstance(finding, dict):
|
||||
item[col.field] = finding.get(col.field, "")
|
||||
data.append(item)
|
||||
|
||||
return create_data_table(
|
||||
data=data,
|
||||
columns=columns,
|
||||
header_color=header_color,
|
||||
alternate_rows=True,
|
||||
normal_style=normal_style,
|
||||
)
|
||||
|
||||
|
||||
def create_section_header(
|
||||
text: str,
|
||||
style: ParagraphStyle,
|
||||
add_spacer: bool = True,
|
||||
spacer_height: float = 0.2,
|
||||
) -> list:
|
||||
"""
|
||||
Create a section header with optional spacer.
|
||||
|
||||
Args:
|
||||
text (str): Header text.
|
||||
style (ParagraphStyle): ParagraphStyle to apply.
|
||||
add_spacer (bool): Whether to add a spacer after the header.
|
||||
spacer_height (float): Height of the spacer in inches.
|
||||
|
||||
Returns:
|
||||
list: List of elements (Paragraph and optional Spacer).
|
||||
"""
|
||||
elements = [Paragraph(text, style)]
|
||||
if add_spacer:
|
||||
elements.append(Spacer(1, spacer_height * inch))
|
||||
return elements
|
||||
|
||||
|
||||
def create_summary_table(
|
||||
label: str,
|
||||
value: str,
|
||||
value_color: colors.Color,
|
||||
label_width: float = 2.5 * inch,
|
||||
value_width: float = 2 * inch,
|
||||
) -> Table:
|
||||
"""
|
||||
Create a summary metric table (e.g., for ThreatScore display).
|
||||
|
||||
Args:
|
||||
label (str): Label text (e.g., "ThreatScore:").
|
||||
value (str): Value text (e.g., "85.5%").
|
||||
value_color (colors.Color): Background color for the value cell.
|
||||
label_width (float): Width of the label column.
|
||||
value_width (float): Width of the value column.
|
||||
|
||||
Returns:
|
||||
Table: A styled summary Table.
|
||||
"""
|
||||
data = [[label, value]]
|
||||
table = Table(data, colWidths=[label_width, value_width])
|
||||
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (0, 0), colors.Color(0.1, 0.3, 0.5)),
|
||||
("TEXTCOLOR", (0, 0), (0, 0), COLOR_WHITE),
|
||||
("FONTNAME", (0, 0), (0, 0), "FiraCode"),
|
||||
("FONTSIZE", (0, 0), (0, 0), 12),
|
||||
("BACKGROUND", (1, 0), (1, 0), value_color),
|
||||
("TEXTCOLOR", (1, 0), (1, 0), COLOR_WHITE),
|
||||
("FONTNAME", (1, 0), (1, 0), "FiraCode"),
|
||||
("FONTSIZE", (1, 0), (1, 0), 16),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("GRID", (0, 0), (-1, -1), 1.5, colors.Color(0.5, 0.6, 0.7)),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), 12),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), 12),
|
||||
("TOPPADDING", (0, 0), (-1, -1), 10),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), 10),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return table
|
||||
@@ -0,0 +1,286 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
# =============================================================================
|
||||
# Performance & Memory Optimization Settings
|
||||
# =============================================================================
|
||||
# These settings control memory usage and performance for large reports.
|
||||
# Adjust these values if workers are running out of memory.
|
||||
|
||||
# Chart settings - lower DPI = less memory, 150 is good quality for PDF
|
||||
CHART_DPI_DEFAULT = 150
|
||||
|
||||
# LongTable threshold - use LongTable for tables with more rows than this
|
||||
# LongTable handles page breaks better and has optimized memory for large tables
|
||||
LONG_TABLE_THRESHOLD = 50
|
||||
|
||||
# Skip alternating row colors for tables larger than this (reduces memory)
|
||||
ALTERNATE_ROWS_MAX_SIZE = 200
|
||||
|
||||
# Database query batch size for findings (matches Django settings)
|
||||
# Larger = fewer queries but more memory per batch
|
||||
FINDINGS_BATCH_SIZE = 2000
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base colors
|
||||
# =============================================================================
|
||||
COLOR_PROWLER_DARK_GREEN = colors.Color(0.1, 0.5, 0.2)
|
||||
COLOR_BLUE = colors.Color(0.2, 0.4, 0.6)
|
||||
COLOR_LIGHT_BLUE = colors.Color(0.3, 0.5, 0.7)
|
||||
COLOR_LIGHTER_BLUE = colors.Color(0.4, 0.6, 0.8)
|
||||
COLOR_BG_BLUE = colors.Color(0.95, 0.97, 1.0)
|
||||
COLOR_BG_LIGHT_BLUE = colors.Color(0.98, 0.99, 1.0)
|
||||
COLOR_GRAY = colors.Color(0.2, 0.2, 0.2)
|
||||
COLOR_LIGHT_GRAY = colors.Color(0.9, 0.9, 0.9)
|
||||
COLOR_BORDER_GRAY = colors.Color(0.7, 0.8, 0.9)
|
||||
COLOR_GRID_GRAY = colors.Color(0.7, 0.7, 0.7)
|
||||
COLOR_DARK_GRAY = colors.Color(0.4, 0.4, 0.4)
|
||||
COLOR_HEADER_DARK = colors.Color(0.1, 0.3, 0.5)
|
||||
COLOR_HEADER_MEDIUM = colors.Color(0.15, 0.35, 0.55)
|
||||
COLOR_WHITE = colors.white
|
||||
|
||||
# Risk and status colors
|
||||
COLOR_HIGH_RISK = colors.Color(0.8, 0.2, 0.2)
|
||||
COLOR_MEDIUM_RISK = colors.Color(0.9, 0.6, 0.2)
|
||||
COLOR_LOW_RISK = colors.Color(0.9, 0.9, 0.2)
|
||||
COLOR_SAFE = colors.Color(0.2, 0.8, 0.2)
|
||||
|
||||
# ENS specific colors
|
||||
COLOR_ENS_ALTO = colors.Color(0.8, 0.2, 0.2)
|
||||
COLOR_ENS_MEDIO = colors.Color(0.98, 0.75, 0.13)
|
||||
COLOR_ENS_BAJO = colors.Color(0.06, 0.72, 0.51)
|
||||
COLOR_ENS_OPCIONAL = colors.Color(0.42, 0.45, 0.50)
|
||||
COLOR_ENS_TIPO = colors.Color(0.2, 0.4, 0.6)
|
||||
COLOR_ENS_AUTO = colors.Color(0.30, 0.69, 0.31)
|
||||
COLOR_ENS_MANUAL = colors.Color(0.96, 0.60, 0.0)
|
||||
|
||||
# NIS2 specific colors
|
||||
COLOR_NIS2_PRIMARY = colors.Color(0.12, 0.23, 0.54)
|
||||
COLOR_NIS2_SECONDARY = colors.Color(0.23, 0.51, 0.96)
|
||||
COLOR_NIS2_BG_BLUE = colors.Color(0.96, 0.97, 0.99)
|
||||
|
||||
# Chart colors (hex strings for matplotlib)
|
||||
CHART_COLOR_GREEN_1 = "#4CAF50"
|
||||
CHART_COLOR_GREEN_2 = "#8BC34A"
|
||||
CHART_COLOR_YELLOW = "#FFEB3B"
|
||||
CHART_COLOR_ORANGE = "#FF9800"
|
||||
CHART_COLOR_RED = "#F44336"
|
||||
CHART_COLOR_BLUE = "#2196F3"
|
||||
|
||||
# ENS dimension mappings: dimension name -> (abbreviation, color)
|
||||
DIMENSION_MAPPING = {
|
||||
"trazabilidad": ("T", colors.Color(0.26, 0.52, 0.96)),
|
||||
"autenticidad": ("A", colors.Color(0.30, 0.69, 0.31)),
|
||||
"integridad": ("I", colors.Color(0.61, 0.15, 0.69)),
|
||||
"confidencialidad": ("C", colors.Color(0.96, 0.26, 0.21)),
|
||||
"disponibilidad": ("D", colors.Color(1.0, 0.60, 0.0)),
|
||||
}
|
||||
|
||||
# ENS tipo icons
|
||||
TIPO_ICONS = {
|
||||
"requisito": "\u26a0\ufe0f",
|
||||
"refuerzo": "\U0001f6e1\ufe0f",
|
||||
"recomendacion": "\U0001f4a1",
|
||||
"medida": "\U0001f4cb",
|
||||
}
|
||||
|
||||
# Dimension names for charts (Spanish)
|
||||
DIMENSION_NAMES = [
|
||||
"Trazabilidad",
|
||||
"Autenticidad",
|
||||
"Integridad",
|
||||
"Confidencialidad",
|
||||
"Disponibilidad",
|
||||
]
|
||||
|
||||
DIMENSION_KEYS = [
|
||||
"trazabilidad",
|
||||
"autenticidad",
|
||||
"integridad",
|
||||
"confidencialidad",
|
||||
"disponibilidad",
|
||||
]
|
||||
|
||||
# ENS nivel and tipo order
|
||||
ENS_NIVEL_ORDER = ["alto", "medio", "bajo", "opcional"]
|
||||
ENS_TIPO_ORDER = ["requisito", "refuerzo", "recomendacion", "medida"]
|
||||
|
||||
# ThreatScore sections
|
||||
THREATSCORE_SECTIONS = [
|
||||
"1. IAM",
|
||||
"2. Attack Surface",
|
||||
"3. Logging and Monitoring",
|
||||
"4. Encryption",
|
||||
]
|
||||
|
||||
# NIS2 sections
|
||||
NIS2_SECTIONS = [
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"9",
|
||||
"11",
|
||||
"12",
|
||||
]
|
||||
|
||||
NIS2_SECTION_TITLES = {
|
||||
"1": "1. Policy on Security",
|
||||
"2": "2. Risk Management",
|
||||
"3": "3. Incident Handling",
|
||||
"4": "4. Business Continuity",
|
||||
"5": "5. Supply Chain",
|
||||
"6": "6. Acquisition & Dev",
|
||||
"7": "7. Effectiveness",
|
||||
"9": "9. Cryptography",
|
||||
"11": "11. Access Control",
|
||||
"12": "12. Asset Management",
|
||||
}
|
||||
|
||||
# Table column widths
|
||||
COL_WIDTH_SMALL = 0.4 * inch
|
||||
COL_WIDTH_MEDIUM = 0.9 * inch
|
||||
COL_WIDTH_LARGE = 1.5 * inch
|
||||
COL_WIDTH_XLARGE = 2 * inch
|
||||
COL_WIDTH_XXLARGE = 3 * inch
|
||||
|
||||
# Common padding values
|
||||
PADDING_SMALL = 4
|
||||
PADDING_MEDIUM = 6
|
||||
PADDING_LARGE = 8
|
||||
PADDING_XLARGE = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameworkConfig:
|
||||
"""
|
||||
Configuration for a compliance framework PDF report.
|
||||
|
||||
This dataclass defines all the configurable aspects of a compliance framework
|
||||
report, including visual styling, metadata fields, and feature flags.
|
||||
|
||||
Attributes:
|
||||
name (str): Internal framework identifier (e.g., "prowler_threatscore").
|
||||
display_name (str): Human-readable framework name for the report title.
|
||||
logo_filename (str | None): Optional filename of the framework logo in assets/img/.
|
||||
primary_color (colors.Color): Main color used for headers and important elements.
|
||||
secondary_color (colors.Color): Secondary color for sub-headers and accents.
|
||||
bg_color (colors.Color): Background color for highlighted sections.
|
||||
attribute_fields (list[str]): List of metadata field names to extract from requirements.
|
||||
sections (list[str] | None): Optional ordered list of section names for grouping.
|
||||
language (str): Report language ("en" for English, "es" for Spanish).
|
||||
has_risk_levels (bool): Whether the framework uses numeric risk levels.
|
||||
has_dimensions (bool): Whether the framework uses security dimensions (ENS).
|
||||
has_niveles (bool): Whether the framework uses nivel classification (ENS).
|
||||
has_weight (bool): Whether requirements have weight values.
|
||||
"""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
logo_filename: str | None = None
|
||||
primary_color: colors.Color = field(default_factory=lambda: COLOR_BLUE)
|
||||
secondary_color: colors.Color = field(default_factory=lambda: COLOR_LIGHT_BLUE)
|
||||
bg_color: colors.Color = field(default_factory=lambda: COLOR_BG_BLUE)
|
||||
attribute_fields: list[str] = field(default_factory=list)
|
||||
sections: list[str] | None = None
|
||||
language: str = "en"
|
||||
has_risk_levels: bool = False
|
||||
has_dimensions: bool = False
|
||||
has_niveles: bool = False
|
||||
has_weight: bool = False
|
||||
|
||||
|
||||
FRAMEWORK_REGISTRY: dict[str, FrameworkConfig] = {
|
||||
"prowler_threatscore": FrameworkConfig(
|
||||
name="prowler_threatscore",
|
||||
display_name="Prowler ThreatScore",
|
||||
logo_filename=None,
|
||||
primary_color=COLOR_BLUE,
|
||||
secondary_color=COLOR_LIGHT_BLUE,
|
||||
bg_color=COLOR_BG_BLUE,
|
||||
attribute_fields=[
|
||||
"Title",
|
||||
"Section",
|
||||
"SubSection",
|
||||
"LevelOfRisk",
|
||||
"Weight",
|
||||
"AttributeDescription",
|
||||
"AdditionalInformation",
|
||||
],
|
||||
sections=THREATSCORE_SECTIONS,
|
||||
language="en",
|
||||
has_risk_levels=True,
|
||||
has_weight=True,
|
||||
),
|
||||
"ens": FrameworkConfig(
|
||||
name="ens",
|
||||
display_name="ENS RD2022",
|
||||
logo_filename="ens_logo.png",
|
||||
primary_color=COLOR_ENS_ALTO,
|
||||
secondary_color=COLOR_ENS_MEDIO,
|
||||
bg_color=COLOR_BG_BLUE,
|
||||
attribute_fields=[
|
||||
"IdGrupoControl",
|
||||
"Marco",
|
||||
"Categoria",
|
||||
"DescripcionControl",
|
||||
"Tipo",
|
||||
"Nivel",
|
||||
"Dimensiones",
|
||||
"ModoEjecucion",
|
||||
],
|
||||
sections=None,
|
||||
language="es",
|
||||
has_risk_levels=False,
|
||||
has_dimensions=True,
|
||||
has_niveles=True,
|
||||
has_weight=False,
|
||||
),
|
||||
"nis2": FrameworkConfig(
|
||||
name="nis2",
|
||||
display_name="NIS2 Directive",
|
||||
logo_filename="nis2_logo.png",
|
||||
primary_color=COLOR_NIS2_PRIMARY,
|
||||
secondary_color=COLOR_NIS2_SECONDARY,
|
||||
bg_color=COLOR_NIS2_BG_BLUE,
|
||||
attribute_fields=[
|
||||
"Section",
|
||||
"SubSection",
|
||||
"Description",
|
||||
],
|
||||
sections=NIS2_SECTIONS,
|
||||
language="en",
|
||||
has_risk_levels=False,
|
||||
has_dimensions=False,
|
||||
has_niveles=False,
|
||||
has_weight=False,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_framework_config(compliance_id: str) -> FrameworkConfig | None:
|
||||
"""
|
||||
Get framework configuration based on compliance ID.
|
||||
|
||||
Args:
|
||||
compliance_id (str): The compliance framework identifier (e.g., "prowler_threatscore_aws").
|
||||
|
||||
Returns:
|
||||
FrameworkConfig | None: The framework configuration if found, None otherwise.
|
||||
"""
|
||||
compliance_lower = compliance_id.lower()
|
||||
|
||||
if "threatscore" in compliance_lower:
|
||||
return FRAMEWORK_REGISTRY["prowler_threatscore"]
|
||||
if "ens" in compliance_lower:
|
||||
return FRAMEWORK_REGISTRY["ens"]
|
||||
if "nis2" in compliance_lower:
|
||||
return FRAMEWORK_REGISTRY["nis2"]
|
||||
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,471 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.platypus import Image, PageBreak, Paragraph, Spacer, Table, TableStyle
|
||||
|
||||
from api.models import StatusChoices
|
||||
|
||||
from .base import (
|
||||
BaseComplianceReportGenerator,
|
||||
ComplianceData,
|
||||
get_requirement_metadata,
|
||||
)
|
||||
from .charts import create_horizontal_bar_chart, get_chart_color_for_percentage
|
||||
from .config import (
|
||||
COLOR_BORDER_GRAY,
|
||||
COLOR_DARK_GRAY,
|
||||
COLOR_GRAY,
|
||||
COLOR_GRID_GRAY,
|
||||
COLOR_HIGH_RISK,
|
||||
COLOR_NIS2_BG_BLUE,
|
||||
COLOR_NIS2_PRIMARY,
|
||||
COLOR_SAFE,
|
||||
COLOR_WHITE,
|
||||
NIS2_SECTION_TITLES,
|
||||
NIS2_SECTIONS,
|
||||
)
|
||||
|
||||
|
||||
def _extract_section_number(section_string: str) -> str:
|
||||
"""Extract the section number from a full NIS2 section title.
|
||||
|
||||
NIS2 section strings are formatted like:
|
||||
"1 POLICY ON THE SECURITY OF NETWORK AND INFORMATION SYSTEMS..."
|
||||
|
||||
This function extracts just the leading number.
|
||||
|
||||
Args:
|
||||
section_string: Full section title string.
|
||||
|
||||
Returns:
|
||||
Section number as string (e.g., "1", "2", "11").
|
||||
"""
|
||||
if not section_string:
|
||||
return "Other"
|
||||
parts = section_string.split()
|
||||
if parts and parts[0].isdigit():
|
||||
return parts[0]
|
||||
return "Other"
|
||||
|
||||
|
||||
class NIS2ReportGenerator(BaseComplianceReportGenerator):
|
||||
"""
|
||||
PDF report generator for NIS2 Directive (EU) 2022/2555.
|
||||
|
||||
This generator creates comprehensive PDF reports containing:
|
||||
- Cover page with both Prowler and NIS2 logos
|
||||
- Executive summary with overall compliance score
|
||||
- Section analysis with horizontal bar chart
|
||||
- SubSection breakdown table
|
||||
- Critical failed requirements
|
||||
- Requirements index organized by section and subsection
|
||||
- Detailed findings for failed requirements
|
||||
"""
|
||||
|
||||
def create_cover_page(self, data: ComplianceData) -> list:
|
||||
"""
|
||||
Create the NIS2 report cover page with both logos.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
# Create logos side by side
|
||||
prowler_logo_path = os.path.join(
|
||||
os.path.dirname(__file__), "../../assets/img/prowler_logo.png"
|
||||
)
|
||||
nis2_logo_path = os.path.join(
|
||||
os.path.dirname(__file__), "../../assets/img/nis2_logo.png"
|
||||
)
|
||||
|
||||
prowler_logo = Image(prowler_logo_path, width=3.5 * inch, height=0.7 * inch)
|
||||
nis2_logo = Image(nis2_logo_path, width=2.3 * inch, height=1.5 * inch)
|
||||
|
||||
logos_table = Table(
|
||||
[[prowler_logo, nis2_logo]], colWidths=[4 * inch, 2.5 * inch]
|
||||
)
|
||||
logos_table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("ALIGN", (0, 0), (0, 0), "LEFT"),
|
||||
("ALIGN", (1, 0), (1, 0), "RIGHT"),
|
||||
("VALIGN", (0, 0), (0, 0), "MIDDLE"),
|
||||
("VALIGN", (1, 0), (1, 0), "MIDDLE"),
|
||||
]
|
||||
)
|
||||
)
|
||||
elements.append(logos_table)
|
||||
elements.append(Spacer(1, 0.3 * inch))
|
||||
|
||||
# Title
|
||||
title = Paragraph(
|
||||
"NIS2 Compliance Report<br/>Directive (EU) 2022/2555",
|
||||
self.styles["title"],
|
||||
)
|
||||
elements.append(title)
|
||||
elements.append(Spacer(1, 0.3 * inch))
|
||||
|
||||
# Compliance metadata table - use base class helper for consistency
|
||||
info_rows = self._build_info_rows(data, language="en")
|
||||
# Convert tuples to lists and wrap long text in Paragraphs
|
||||
metadata_data = []
|
||||
for label, value in info_rows:
|
||||
if label in ("Name:", "Description:") and value:
|
||||
metadata_data.append(
|
||||
[label, Paragraph(value, self.styles["normal_center"])]
|
||||
)
|
||||
else:
|
||||
metadata_data.append([label, value])
|
||||
|
||||
metadata_table = Table(metadata_data, colWidths=[2 * inch, 4 * inch])
|
||||
metadata_table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (0, -1), COLOR_NIS2_PRIMARY),
|
||||
("TEXTCOLOR", (0, 0), (0, -1), COLOR_WHITE),
|
||||
("FONTNAME", (0, 0), (0, -1), "FiraCode"),
|
||||
("BACKGROUND", (1, 0), (1, -1), COLOR_NIS2_BG_BLUE),
|
||||
("TEXTCOLOR", (1, 0), (1, -1), COLOR_GRAY),
|
||||
("FONTNAME", (1, 0), (1, -1), "PlusJakartaSans"),
|
||||
("ALIGN", (0, 0), (-1, -1), "LEFT"),
|
||||
("VALIGN", (0, 0), (-1, -1), "TOP"),
|
||||
("FONTSIZE", (0, 0), (-1, -1), 11),
|
||||
("GRID", (0, 0), (-1, -1), 1, COLOR_BORDER_GRAY),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), 10),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), 10),
|
||||
("TOPPADDING", (0, 0), (-1, -1), 8),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), 8),
|
||||
]
|
||||
)
|
||||
)
|
||||
elements.append(metadata_table)
|
||||
|
||||
return elements
|
||||
|
||||
def create_executive_summary(self, data: ComplianceData) -> list:
|
||||
"""
|
||||
Create the executive summary with compliance metrics.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
elements.append(Paragraph("Executive Summary", self.styles["h1"]))
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
|
||||
# Calculate statistics
|
||||
total = len(data.requirements)
|
||||
passed = sum(1 for r in data.requirements if r.status == StatusChoices.PASS)
|
||||
failed = sum(1 for r in data.requirements if r.status == StatusChoices.FAIL)
|
||||
manual = sum(1 for r in data.requirements if r.status == StatusChoices.MANUAL)
|
||||
|
||||
# Calculate compliance excluding manual
|
||||
evaluated = passed + failed
|
||||
overall_compliance = (passed / evaluated * 100) if evaluated > 0 else 100
|
||||
|
||||
# Summary statistics table
|
||||
summary_data = [
|
||||
["Metric", "Value"],
|
||||
["Total Requirements", str(total)],
|
||||
["Passed ✓", str(passed)],
|
||||
["Failed ✗", str(failed)],
|
||||
["Manual ⊙", str(manual)],
|
||||
["Overall Compliance", f"{overall_compliance:.1f}%"],
|
||||
]
|
||||
|
||||
summary_table = Table(summary_data, colWidths=[3 * inch, 2 * inch])
|
||||
summary_table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (-1, 0), COLOR_NIS2_PRIMARY),
|
||||
("TEXTCOLOR", (0, 0), (-1, 0), COLOR_WHITE),
|
||||
("BACKGROUND", (0, 2), (0, 2), COLOR_SAFE),
|
||||
("TEXTCOLOR", (0, 2), (0, 2), COLOR_WHITE),
|
||||
("BACKGROUND", (0, 3), (0, 3), COLOR_HIGH_RISK),
|
||||
("TEXTCOLOR", (0, 3), (0, 3), COLOR_WHITE),
|
||||
("BACKGROUND", (0, 4), (0, 4), COLOR_DARK_GRAY),
|
||||
("TEXTCOLOR", (0, 4), (0, 4), COLOR_WHITE),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("FONTNAME", (0, 0), (-1, 0), "PlusJakartaSans"),
|
||||
("FONTSIZE", (0, 0), (-1, 0), 12),
|
||||
("FONTSIZE", (0, 1), (-1, -1), 10),
|
||||
("BOTTOMPADDING", (0, 0), (-1, 0), 10),
|
||||
("GRID", (0, 0), (-1, -1), 0.5, COLOR_BORDER_GRAY),
|
||||
(
|
||||
"ROWBACKGROUNDS",
|
||||
(1, 1),
|
||||
(1, -1),
|
||||
[COLOR_WHITE, COLOR_NIS2_BG_BLUE],
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
elements.append(summary_table)
|
||||
|
||||
return elements
|
||||
|
||||
def create_charts_section(self, data: ComplianceData) -> list:
|
||||
"""
|
||||
Create the charts section with section analysis.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
# Section chart
|
||||
elements.append(Paragraph("Compliance by Section", self.styles["h1"]))
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
elements.append(
|
||||
Paragraph(
|
||||
"The following chart shows compliance percentage for each main section "
|
||||
"of the NIS2 directive:",
|
||||
self.styles["normal_center"],
|
||||
)
|
||||
)
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
|
||||
chart_buffer = self._create_section_chart(data)
|
||||
chart_buffer.seek(0)
|
||||
chart_image = Image(chart_buffer, width=6.5 * inch, height=5 * inch)
|
||||
elements.append(chart_image)
|
||||
elements.append(PageBreak())
|
||||
|
||||
# SubSection breakdown table
|
||||
elements.append(Paragraph("SubSection Breakdown", self.styles["h1"]))
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
|
||||
subsection_table = self._create_subsection_table(data)
|
||||
elements.append(subsection_table)
|
||||
|
||||
return elements
|
||||
|
||||
def create_requirements_index(self, data: ComplianceData) -> list:
|
||||
"""
|
||||
Create the requirements index organized by section and subsection.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
elements.append(Paragraph("Requirements Index", self.styles["h1"]))
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
|
||||
# Organize by section number and subsection
|
||||
sections = {}
|
||||
for req in data.requirements:
|
||||
m = get_requirement_metadata(req.id, data.attributes_by_requirement_id)
|
||||
if m:
|
||||
full_section = getattr(m, "Section", "Other")
|
||||
# Extract section number from full title (e.g., "1 POLICY..." -> "1")
|
||||
section_num = _extract_section_number(full_section)
|
||||
subsection = getattr(m, "SubSection", "")
|
||||
description = getattr(m, "Description", req.description)
|
||||
|
||||
if section_num not in sections:
|
||||
sections[section_num] = {}
|
||||
if subsection not in sections[section_num]:
|
||||
sections[section_num][subsection] = []
|
||||
|
||||
sections[section_num][subsection].append(
|
||||
{
|
||||
"id": req.id,
|
||||
"description": description,
|
||||
"status": req.status,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by NIS2 section order
|
||||
for section in NIS2_SECTIONS:
|
||||
if section not in sections:
|
||||
continue
|
||||
|
||||
section_title = NIS2_SECTION_TITLES.get(section, f"Section {section}")
|
||||
elements.append(Paragraph(section_title, self.styles["h2"]))
|
||||
|
||||
for subsection_name, reqs in sections[section].items():
|
||||
if subsection_name:
|
||||
# Truncate long subsection names for display
|
||||
display_subsection = (
|
||||
subsection_name[:80] + "..."
|
||||
if len(subsection_name) > 80
|
||||
else subsection_name
|
||||
)
|
||||
elements.append(Paragraph(display_subsection, self.styles["h3"]))
|
||||
|
||||
for req in reqs:
|
||||
status_indicator = (
|
||||
"✓" if req["status"] == StatusChoices.PASS else "✗"
|
||||
)
|
||||
if req["status"] == StatusChoices.MANUAL:
|
||||
status_indicator = "⊙"
|
||||
|
||||
desc = (
|
||||
req["description"][:60] + "..."
|
||||
if len(req["description"]) > 60
|
||||
else req["description"]
|
||||
)
|
||||
elements.append(
|
||||
Paragraph(
|
||||
f"{status_indicator} {req['id']}: {desc}",
|
||||
self.styles["normal"],
|
||||
)
|
||||
)
|
||||
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
|
||||
return elements
|
||||
|
||||
def _create_section_chart(self, data: ComplianceData):
|
||||
"""
|
||||
Create the section compliance chart.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
BytesIO buffer containing the chart image.
|
||||
"""
|
||||
section_scores = defaultdict(lambda: {"passed": 0, "total": 0})
|
||||
|
||||
for req in data.requirements:
|
||||
if req.status == StatusChoices.MANUAL:
|
||||
continue
|
||||
|
||||
m = get_requirement_metadata(req.id, data.attributes_by_requirement_id)
|
||||
if m:
|
||||
full_section = getattr(m, "Section", "Other")
|
||||
# Extract section number from full title (e.g., "1 POLICY..." -> "1")
|
||||
section_num = _extract_section_number(full_section)
|
||||
section_scores[section_num]["total"] += 1
|
||||
if req.status == StatusChoices.PASS:
|
||||
section_scores[section_num]["passed"] += 1
|
||||
|
||||
# Build labels and values in NIS2 section order
|
||||
labels = []
|
||||
values = []
|
||||
for section in NIS2_SECTIONS:
|
||||
if section in section_scores and section_scores[section]["total"] > 0:
|
||||
scores = section_scores[section]
|
||||
pct = (scores["passed"] / scores["total"]) * 100
|
||||
section_title = NIS2_SECTION_TITLES.get(section, f"Section {section}")
|
||||
labels.append(section_title)
|
||||
values.append(pct)
|
||||
|
||||
return create_horizontal_bar_chart(
|
||||
labels=labels,
|
||||
values=values,
|
||||
xlabel="Compliance (%)",
|
||||
color_func=get_chart_color_for_percentage,
|
||||
)
|
||||
|
||||
def _create_subsection_table(self, data: ComplianceData) -> Table:
|
||||
"""
|
||||
Create the subsection breakdown table.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
ReportLab Table element.
|
||||
"""
|
||||
subsection_scores = defaultdict(lambda: {"passed": 0, "failed": 0, "manual": 0})
|
||||
|
||||
for req in data.requirements:
|
||||
m = get_requirement_metadata(req.id, data.attributes_by_requirement_id)
|
||||
if m:
|
||||
full_section = getattr(m, "Section", "")
|
||||
subsection = getattr(m, "SubSection", "")
|
||||
# Use section number + subsection for grouping
|
||||
section_num = _extract_section_number(full_section)
|
||||
# Create a shorter key using section number
|
||||
if subsection:
|
||||
# Extract subsection number if present (e.g., "1.1 Policy..." -> "1.1")
|
||||
subsection_parts = subsection.split()
|
||||
if subsection_parts:
|
||||
key = subsection_parts[0] # Just the number like "1.1"
|
||||
else:
|
||||
key = f"{section_num}"
|
||||
else:
|
||||
key = section_num
|
||||
|
||||
if req.status == StatusChoices.PASS:
|
||||
subsection_scores[key]["passed"] += 1
|
||||
elif req.status == StatusChoices.FAIL:
|
||||
subsection_scores[key]["failed"] += 1
|
||||
else:
|
||||
subsection_scores[key]["manual"] += 1
|
||||
|
||||
table_data = [["Section", "Passed", "Failed", "Manual", "Compliance"]]
|
||||
for key, scores in sorted(
|
||||
subsection_scores.items(), key=lambda x: self._sort_section_key(x[0])
|
||||
):
|
||||
total = scores["passed"] + scores["failed"]
|
||||
pct = (scores["passed"] / total * 100) if total > 0 else 100
|
||||
table_data.append(
|
||||
[
|
||||
key,
|
||||
str(scores["passed"]),
|
||||
str(scores["failed"]),
|
||||
str(scores["manual"]),
|
||||
f"{pct:.1f}%",
|
||||
]
|
||||
)
|
||||
|
||||
table = Table(
|
||||
table_data,
|
||||
colWidths=[1.2 * inch, 0.9 * inch, 0.9 * inch, 0.9 * inch, 1.2 * inch],
|
||||
)
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (-1, 0), COLOR_NIS2_PRIMARY),
|
||||
("TEXTCOLOR", (0, 0), (-1, 0), COLOR_WHITE),
|
||||
("FONTNAME", (0, 0), (-1, 0), "FiraCode"),
|
||||
("FONTSIZE", (0, 0), (-1, 0), 10),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("FONTSIZE", (0, 1), (-1, -1), 9),
|
||||
("GRID", (0, 0), (-1, -1), 0.5, COLOR_GRID_GRAY),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), 6),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), 6),
|
||||
("TOPPADDING", (0, 0), (-1, -1), 4),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), 4),
|
||||
(
|
||||
"ROWBACKGROUNDS",
|
||||
(0, 1),
|
||||
(-1, -1),
|
||||
[COLOR_WHITE, COLOR_NIS2_BG_BLUE],
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
def _sort_section_key(self, key: str) -> tuple:
|
||||
"""Sort section keys numerically (e.g., 1, 1.1, 1.2, 2, 11)."""
|
||||
parts = key.split(".")
|
||||
result = []
|
||||
for part in parts:
|
||||
try:
|
||||
result.append(int(part))
|
||||
except ValueError:
|
||||
result.append(float("inf"))
|
||||
return tuple(result)
|
||||
@@ -0,0 +1,509 @@
|
||||
import gc
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.styles import ParagraphStyle
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.platypus import Image, PageBreak, Paragraph, Spacer, Table, TableStyle
|
||||
|
||||
from api.models import StatusChoices
|
||||
|
||||
from .base import (
|
||||
BaseComplianceReportGenerator,
|
||||
ComplianceData,
|
||||
get_requirement_metadata,
|
||||
)
|
||||
from .charts import create_vertical_bar_chart, get_chart_color_for_percentage
|
||||
from .components import get_color_for_compliance, get_color_for_weight
|
||||
from .config import COLOR_HIGH_RISK, COLOR_WHITE
|
||||
|
||||
|
||||
class ThreatScoreReportGenerator(BaseComplianceReportGenerator):
|
||||
"""
|
||||
PDF report generator for Prowler ThreatScore framework.
|
||||
|
||||
This generator creates comprehensive PDF reports containing:
|
||||
- Compliance overview and metadata
|
||||
- Section-by-section compliance scores with charts
|
||||
- Overall ThreatScore calculation
|
||||
- Critical failed requirements
|
||||
- Detailed findings for each requirement
|
||||
"""
|
||||
|
||||
def create_executive_summary(self, data: ComplianceData) -> list:
|
||||
"""
|
||||
Create the executive summary section with ThreatScore calculation.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
elements.append(Paragraph("Compliance Score by Sections", self.styles["h1"]))
|
||||
elements.append(Spacer(1, 0.2 * inch))
|
||||
|
||||
# Create section score chart
|
||||
chart_buffer = self._create_section_score_chart(data)
|
||||
chart_image = Image(chart_buffer, width=7 * inch, height=5.5 * inch)
|
||||
elements.append(chart_image)
|
||||
|
||||
# Calculate overall ThreatScore
|
||||
overall_compliance = self._calculate_threatscore(data)
|
||||
|
||||
elements.append(Spacer(1, 0.3 * inch))
|
||||
|
||||
# Summary table
|
||||
summary_data = [["ThreatScore:", f"{overall_compliance:.2f}%"]]
|
||||
compliance_color = get_color_for_compliance(overall_compliance)
|
||||
|
||||
summary_table = Table(summary_data, colWidths=[2.5 * inch, 2 * inch])
|
||||
summary_table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (0, 0), colors.Color(0.1, 0.3, 0.5)),
|
||||
("TEXTCOLOR", (0, 0), (0, 0), colors.white),
|
||||
("FONTNAME", (0, 0), (0, 0), "FiraCode"),
|
||||
("FONTSIZE", (0, 0), (0, 0), 12),
|
||||
("BACKGROUND", (1, 0), (1, 0), compliance_color),
|
||||
("TEXTCOLOR", (1, 0), (1, 0), colors.white),
|
||||
("FONTNAME", (1, 0), (1, 0), "FiraCode"),
|
||||
("FONTSIZE", (1, 0), (1, 0), 16),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("GRID", (0, 0), (-1, -1), 1.5, colors.Color(0.5, 0.6, 0.7)),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), 12),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), 12),
|
||||
("TOPPADDING", (0, 0), (-1, -1), 10),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), 10),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
elements.append(summary_table)
|
||||
|
||||
return elements
|
||||
|
||||
def _build_body_sections(self, data: ComplianceData) -> list:
|
||||
"""Override section order: Requirements Index before Critical Requirements."""
|
||||
elements = []
|
||||
|
||||
# Page break to separate from executive summary
|
||||
elements.append(PageBreak())
|
||||
|
||||
# Requirements index first
|
||||
elements.extend(self.create_requirements_index(data))
|
||||
|
||||
# Critical requirements section (already starts with PageBreak internally)
|
||||
elements.extend(self.create_charts_section(data))
|
||||
elements.append(PageBreak())
|
||||
gc.collect()
|
||||
|
||||
return elements
|
||||
|
||||
def create_charts_section(self, data: ComplianceData) -> list:
|
||||
"""
|
||||
Create the critical failed requirements section.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements.
|
||||
"""
|
||||
elements = []
|
||||
min_risk_level = getattr(self, "_min_risk_level", 4)
|
||||
|
||||
# Start on a new page
|
||||
elements.append(PageBreak())
|
||||
elements.append(
|
||||
Paragraph("Top Requirements by Level of Risk", self.styles["h1"])
|
||||
)
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
elements.append(
|
||||
Paragraph(
|
||||
f"Critical Failed Requirements (Risk Level ≥ {min_risk_level})",
|
||||
self.styles["h2"],
|
||||
)
|
||||
)
|
||||
elements.append(Spacer(1, 0.2 * inch))
|
||||
|
||||
critical_failed = self._get_critical_failed_requirements(data, min_risk_level)
|
||||
|
||||
if not critical_failed:
|
||||
elements.append(
|
||||
Paragraph(
|
||||
"✅ No critical failed requirements found. Great job!",
|
||||
self.styles["normal"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
elements.append(
|
||||
Paragraph(
|
||||
f"Found {len(critical_failed)} critical failed requirements "
|
||||
"that require immediate attention:",
|
||||
self.styles["normal"],
|
||||
)
|
||||
)
|
||||
elements.append(Spacer(1, 0.5 * inch))
|
||||
|
||||
table = self._create_critical_requirements_table(critical_failed)
|
||||
elements.append(table)
|
||||
|
||||
# Immediate action required banner
|
||||
elements.append(Spacer(1, 0.3 * inch))
|
||||
elements.append(self._create_action_required_banner())
|
||||
|
||||
return elements
|
||||
|
||||
def create_requirements_index(self, data: ComplianceData) -> list:
|
||||
"""
|
||||
Create the requirements index organized by section and subsection.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
List of ReportLab elements.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
elements.append(Paragraph("Requirements Index", self.styles["h1"]))
|
||||
|
||||
# Organize requirements by section and subsection
|
||||
sections = {}
|
||||
for req_id in data.attributes_by_requirement_id:
|
||||
m = get_requirement_metadata(req_id, data.attributes_by_requirement_id)
|
||||
if m:
|
||||
section = getattr(m, "Section", "N/A")
|
||||
subsection = getattr(m, "SubSection", "N/A")
|
||||
title = getattr(m, "Title", "N/A")
|
||||
|
||||
if section not in sections:
|
||||
sections[section] = {}
|
||||
if subsection not in sections[section]:
|
||||
sections[section][subsection] = []
|
||||
|
||||
sections[section][subsection].append({"id": req_id, "title": title})
|
||||
|
||||
section_num = 1
|
||||
for section_name, subsections in sections.items():
|
||||
elements.append(
|
||||
Paragraph(f"{section_num}. {section_name}", self.styles["h2"])
|
||||
)
|
||||
|
||||
for subsection_name, requirements in subsections.items():
|
||||
elements.append(Paragraph(f"{subsection_name}", self.styles["h3"]))
|
||||
|
||||
for req in requirements:
|
||||
elements.append(
|
||||
Paragraph(
|
||||
f"{req['id']} - {req['title']}", self.styles["normal"]
|
||||
)
|
||||
)
|
||||
|
||||
section_num += 1
|
||||
elements.append(Spacer(1, 0.1 * inch))
|
||||
|
||||
return elements
|
||||
|
||||
def _create_section_score_chart(self, data: ComplianceData):
|
||||
"""
|
||||
Create the section compliance score chart using weighted ThreatScore formula.
|
||||
|
||||
The section score uses the same weighted formula as the overall ThreatScore:
|
||||
Score = Σ(rate_i * total_findings_i * weight_i * rfac_i) / Σ(total_findings_i * weight_i * rfac_i)
|
||||
Where rfac_i = 1 + 0.25 * risk_level
|
||||
|
||||
Sections without findings are shown with 100% score.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
BytesIO buffer containing the chart image.
|
||||
"""
|
||||
# First, collect ALL sections from requirements (including those without findings)
|
||||
all_sections = set()
|
||||
sections_data = {}
|
||||
|
||||
for req in data.requirements:
|
||||
m = get_requirement_metadata(req.id, data.attributes_by_requirement_id)
|
||||
if m:
|
||||
section = getattr(m, "Section", "Other")
|
||||
all_sections.add(section)
|
||||
|
||||
# Only calculate scores for requirements with findings
|
||||
if req.total_findings == 0:
|
||||
continue
|
||||
|
||||
risk_level_raw = getattr(m, "LevelOfRisk", 0)
|
||||
weight_raw = getattr(m, "Weight", 0)
|
||||
# Ensure numeric types for calculations (compliance data may have str)
|
||||
try:
|
||||
risk_level = int(risk_level_raw) if risk_level_raw else 0
|
||||
except (ValueError, TypeError):
|
||||
risk_level = 0
|
||||
try:
|
||||
weight = int(weight_raw) if weight_raw else 0
|
||||
except (ValueError, TypeError):
|
||||
weight = 0
|
||||
|
||||
# ThreatScore formula components
|
||||
rate_i = req.passed_findings / req.total_findings
|
||||
rfac_i = 1 + 0.25 * risk_level
|
||||
|
||||
if section not in sections_data:
|
||||
sections_data[section] = {
|
||||
"numerator": 0,
|
||||
"denominator": 0,
|
||||
}
|
||||
|
||||
sections_data[section]["numerator"] += (
|
||||
rate_i * req.total_findings * weight * rfac_i
|
||||
)
|
||||
sections_data[section]["denominator"] += (
|
||||
req.total_findings * weight * rfac_i
|
||||
)
|
||||
|
||||
# Calculate percentages for all sections
|
||||
labels = []
|
||||
values = []
|
||||
for section in sorted(all_sections):
|
||||
if section in sections_data and sections_data[section]["denominator"] > 0:
|
||||
pct = (
|
||||
sections_data[section]["numerator"]
|
||||
/ sections_data[section]["denominator"]
|
||||
) * 100
|
||||
else:
|
||||
# Sections without findings get 100%
|
||||
pct = 100.0
|
||||
labels.append(section)
|
||||
values.append(pct)
|
||||
|
||||
return create_vertical_bar_chart(
|
||||
labels=labels,
|
||||
values=values,
|
||||
ylabel="Compliance Score (%)",
|
||||
xlabel="",
|
||||
color_func=get_chart_color_for_percentage,
|
||||
rotation=0,
|
||||
)
|
||||
|
||||
def _calculate_threatscore(self, data: ComplianceData) -> float:
|
||||
"""
|
||||
Calculate the overall ThreatScore using the weighted formula.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
|
||||
Returns:
|
||||
Overall ThreatScore percentage.
|
||||
"""
|
||||
numerator = 0
|
||||
denominator = 0
|
||||
has_findings = False
|
||||
|
||||
for req in data.requirements:
|
||||
if req.total_findings == 0:
|
||||
continue
|
||||
|
||||
has_findings = True
|
||||
m = get_requirement_metadata(req.id, data.attributes_by_requirement_id)
|
||||
|
||||
if m:
|
||||
risk_level_raw = getattr(m, "LevelOfRisk", 0)
|
||||
weight_raw = getattr(m, "Weight", 0)
|
||||
# Ensure numeric types for calculations (compliance data may have str)
|
||||
try:
|
||||
risk_level = int(risk_level_raw) if risk_level_raw else 0
|
||||
except (ValueError, TypeError):
|
||||
risk_level = 0
|
||||
try:
|
||||
weight = int(weight_raw) if weight_raw else 0
|
||||
except (ValueError, TypeError):
|
||||
weight = 0
|
||||
|
||||
rate_i = req.passed_findings / req.total_findings
|
||||
rfac_i = 1 + 0.25 * risk_level
|
||||
|
||||
numerator += rate_i * req.total_findings * weight * rfac_i
|
||||
denominator += req.total_findings * weight * rfac_i
|
||||
|
||||
if not has_findings:
|
||||
return 100.0
|
||||
if denominator > 0:
|
||||
return (numerator / denominator) * 100
|
||||
return 0.0
|
||||
|
||||
def _get_critical_failed_requirements(
|
||||
self, data: ComplianceData, min_risk_level: int
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get critical failed requirements sorted by risk level and weight.
|
||||
|
||||
Args:
|
||||
data: Aggregated compliance data.
|
||||
min_risk_level: Minimum risk level threshold.
|
||||
|
||||
Returns:
|
||||
List of critical failed requirement dictionaries.
|
||||
"""
|
||||
critical = []
|
||||
|
||||
for req in data.requirements:
|
||||
if req.status != StatusChoices.FAIL:
|
||||
continue
|
||||
|
||||
m = get_requirement_metadata(req.id, data.attributes_by_requirement_id)
|
||||
|
||||
if m:
|
||||
risk_level_raw = getattr(m, "LevelOfRisk", 0)
|
||||
weight_raw = getattr(m, "Weight", 0)
|
||||
# Ensure numeric types for calculations (compliance data may have str)
|
||||
try:
|
||||
risk_level = int(risk_level_raw) if risk_level_raw else 0
|
||||
except (ValueError, TypeError):
|
||||
risk_level = 0
|
||||
try:
|
||||
weight = int(weight_raw) if weight_raw else 0
|
||||
except (ValueError, TypeError):
|
||||
weight = 0
|
||||
|
||||
if risk_level >= min_risk_level:
|
||||
critical.append(
|
||||
{
|
||||
"id": req.id,
|
||||
"risk_level": risk_level,
|
||||
"weight": weight,
|
||||
"title": getattr(m, "Title", "N/A"),
|
||||
"section": getattr(m, "Section", "N/A"),
|
||||
}
|
||||
)
|
||||
|
||||
critical.sort(key=lambda x: (x["risk_level"], x["weight"]), reverse=True)
|
||||
return critical
|
||||
|
||||
def _create_critical_requirements_table(self, critical_requirements: list) -> Table:
|
||||
"""
|
||||
Create the critical requirements table.
|
||||
|
||||
Args:
|
||||
critical_requirements: List of critical requirement dictionaries.
|
||||
|
||||
Returns:
|
||||
ReportLab Table element.
|
||||
"""
|
||||
table_data = [["Risk", "Weight", "Requirement ID", "Title", "Section"]]
|
||||
|
||||
for req in critical_requirements:
|
||||
title = req["title"]
|
||||
if len(title) > 50:
|
||||
title = title[:47] + "..."
|
||||
|
||||
table_data.append(
|
||||
[
|
||||
str(req["risk_level"]),
|
||||
str(req["weight"]),
|
||||
req["id"],
|
||||
title,
|
||||
req["section"],
|
||||
]
|
||||
)
|
||||
|
||||
table = Table(
|
||||
table_data,
|
||||
colWidths=[0.7 * inch, 0.9 * inch, 1.3 * inch, 3.1 * inch, 1.5 * inch],
|
||||
)
|
||||
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (0, 0), (-1, 0), COLOR_HIGH_RISK),
|
||||
("TEXTCOLOR", (0, 0), (-1, 0), COLOR_WHITE),
|
||||
("FONTNAME", (0, 0), (-1, 0), "FiraCode"),
|
||||
("FONTSIZE", (0, 0), (-1, 0), 10),
|
||||
("BACKGROUND", (0, 1), (0, -1), COLOR_HIGH_RISK),
|
||||
("TEXTCOLOR", (0, 1), (0, -1), COLOR_WHITE),
|
||||
("FONTNAME", (0, 1), (0, -1), "FiraCode"),
|
||||
("ALIGN", (0, 1), (0, -1), "CENTER"),
|
||||
("FONTSIZE", (0, 1), (0, -1), 12),
|
||||
("ALIGN", (1, 1), (1, -1), "CENTER"),
|
||||
("FONTNAME", (1, 1), (1, -1), "FiraCode"),
|
||||
("FONTNAME", (2, 1), (2, -1), "FiraCode"),
|
||||
("FONTSIZE", (2, 1), (2, -1), 9),
|
||||
("FONTNAME", (3, 1), (-1, -1), "PlusJakartaSans"),
|
||||
("FONTSIZE", (3, 1), (-1, -1), 8),
|
||||
("VALIGN", (0, 0), (-1, -1), "MIDDLE"),
|
||||
("GRID", (0, 0), (-1, -1), 1, colors.Color(0.7, 0.7, 0.7)),
|
||||
("LEFTPADDING", (0, 0), (-1, -1), 6),
|
||||
("RIGHTPADDING", (0, 0), (-1, -1), 6),
|
||||
("TOPPADDING", (0, 0), (-1, -1), 8),
|
||||
("BOTTOMPADDING", (0, 0), (-1, -1), 8),
|
||||
("BACKGROUND", (1, 1), (-1, -1), colors.Color(0.98, 0.98, 0.98)),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Color weight column based on value
|
||||
for idx, req in enumerate(critical_requirements):
|
||||
row_idx = idx + 1
|
||||
weight_color = get_color_for_weight(req["weight"])
|
||||
table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
("BACKGROUND", (1, row_idx), (1, row_idx), weight_color),
|
||||
("TEXTCOLOR", (1, row_idx), (1, row_idx), COLOR_WHITE),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
def _create_action_required_banner(self) -> Table:
|
||||
"""
|
||||
Create the 'Immediate Action Required' banner for critical requirements.
|
||||
|
||||
Returns:
|
||||
ReportLab Table element styled as a red-bordered alert banner.
|
||||
"""
|
||||
banner_style = ParagraphStyle(
|
||||
"ActionRequired",
|
||||
fontName="PlusJakartaSans",
|
||||
fontSize=11,
|
||||
textColor=COLOR_HIGH_RISK,
|
||||
leading=16,
|
||||
)
|
||||
|
||||
banner_content = Paragraph(
|
||||
"<b>IMMEDIATE ACTION REQUIRED:</b><br/>"
|
||||
"These requirements have the highest risk levels and have failed "
|
||||
"compliance checks. Please prioritize addressing these issues to "
|
||||
"improve your security posture.",
|
||||
banner_style,
|
||||
)
|
||||
|
||||
banner_table = Table(
|
||||
[[banner_content]],
|
||||
colWidths=[6.5 * inch],
|
||||
)
|
||||
banner_table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
(
|
||||
"BACKGROUND",
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
colors.Color(0.98, 0.92, 0.92),
|
||||
),
|
||||
("BOX", (0, 0), (0, 0), 2, COLOR_HIGH_RISK),
|
||||
("LEFTPADDING", (0, 0), (0, 0), 20),
|
||||
("RIGHTPADDING", (0, 0), (0, 0), 20),
|
||||
("TOPPADDING", (0, 0), (0, 0), 15),
|
||||
("BOTTOMPADDING", (0, 0), (0, 0), 15),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return banner_table
|
||||
@@ -45,6 +45,7 @@ from api.models import (
|
||||
ResourceTag,
|
||||
Scan,
|
||||
ScanCategorySummary,
|
||||
ScanGroupSummary,
|
||||
ScanSummary,
|
||||
StateChoices,
|
||||
)
|
||||
@@ -127,6 +128,50 @@ def aggregate_category_counts(
|
||||
cache[key]["new_failed"] += 1
|
||||
|
||||
|
||||
def aggregate_resource_group_counts(
|
||||
resource_group: str | None,
|
||||
severity: str,
|
||||
status: str,
|
||||
delta: str | None,
|
||||
muted: bool,
|
||||
resource_uid: str,
|
||||
cache: dict[tuple[str, str], dict[str, int]],
|
||||
group_resources_cache: dict[str, set],
|
||||
) -> None:
|
||||
"""
|
||||
Increment resource group counters in-place for a finding.
|
||||
|
||||
Args:
|
||||
resource_group: Resource group from check metadata (e.g., "database", "compute").
|
||||
severity: Severity level (e.g., "high", "medium").
|
||||
status: Finding status as string ("FAIL", "PASS").
|
||||
delta: Delta value as string ("new", "changed") or None.
|
||||
muted: Whether the finding is muted.
|
||||
resource_uid: Unique identifier for the resource to count distinct resources.
|
||||
cache: Dict {(resource_group, severity): {"total", "failed", "new_failed"}} to update.
|
||||
group_resources_cache: Dict {resource_group: set(resource_uids)} for group-level resource tracking.
|
||||
"""
|
||||
if not resource_group:
|
||||
return
|
||||
|
||||
is_failed = status == "FAIL" and not muted
|
||||
is_new_failed = is_failed and delta == "new"
|
||||
|
||||
key = (resource_group, severity)
|
||||
if key not in cache:
|
||||
cache[key] = {"total": 0, "failed": 0, "new_failed": 0}
|
||||
if not muted:
|
||||
cache[key]["total"] += 1
|
||||
if is_failed:
|
||||
cache[key]["failed"] += 1
|
||||
if is_new_failed:
|
||||
cache[key]["new_failed"] += 1
|
||||
|
||||
# Track resources at GROUP level (not per-severity) to avoid over-counting
|
||||
if resource_uid and not muted:
|
||||
group_resources_cache.setdefault(resource_group, set()).add(resource_uid)
|
||||
|
||||
|
||||
def _get_attack_surface_mapping_from_provider(provider_type: str) -> dict:
|
||||
global _ATTACK_SURFACE_MAPPING_CACHE
|
||||
|
||||
@@ -438,6 +483,8 @@ def _process_finding_micro_batch(
|
||||
scan_resource_cache: set,
|
||||
mute_rules_cache: dict,
|
||||
scan_categories_cache: dict[tuple[str, str], dict[str, int]],
|
||||
scan_resource_groups_cache: dict[tuple[str, str], dict[str, int]],
|
||||
group_resources_cache: dict[str, set],
|
||||
) -> None:
|
||||
"""
|
||||
Process a micro-batch of findings and persist them using bulk operations.
|
||||
@@ -459,6 +506,8 @@ def _process_finding_micro_batch(
|
||||
scan_resource_cache: Set of tuples used to create `ResourceScanSummary` rows.
|
||||
mute_rules_cache: Map of finding UID -> mute reason gathered before the scan.
|
||||
scan_categories_cache: Dict tracking category counts {(category, severity): {"total", "failed", "new_failed"}}.
|
||||
scan_resource_groups_cache: Dict tracking resource group counts {(resource_group, severity): {"total", "failed", "new_failed"}}.
|
||||
group_resources_cache: Dict tracking unique resources per group {resource_group: set(resource_uids)}.
|
||||
"""
|
||||
# Accumulate objects for bulk operations
|
||||
findings_to_create = []
|
||||
@@ -499,6 +548,8 @@ def _process_finding_micro_batch(
|
||||
with rls_transaction(tenant_id):
|
||||
resource_uid = finding.resource_uid
|
||||
if resource_uid not in resource_cache:
|
||||
check_metadata = finding.get_metadata()
|
||||
group = check_metadata.get("resourcegroup") or None
|
||||
resource_instance, _ = Resource.objects.get_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_instance,
|
||||
@@ -508,6 +559,7 @@ def _process_finding_micro_batch(
|
||||
"service": finding.service_name,
|
||||
"type": finding.resource_type,
|
||||
"name": finding.resource_name,
|
||||
"groups": [group] if group else None,
|
||||
},
|
||||
)
|
||||
resource_cache[resource_uid] = resource_instance
|
||||
@@ -528,6 +580,8 @@ def _process_finding_micro_batch(
|
||||
|
||||
# Track resource field changes (defer save)
|
||||
updated = False
|
||||
check_metadata = finding.get_metadata()
|
||||
group = check_metadata.get("resourcegroup") or None
|
||||
if finding.region and resource_instance.region != finding.region:
|
||||
resource_instance.region = finding.region
|
||||
updated = True
|
||||
@@ -548,6 +602,11 @@ def _process_finding_micro_batch(
|
||||
if resource_instance.partition != finding.partition:
|
||||
resource_instance.partition = finding.partition
|
||||
updated = True
|
||||
if group and (
|
||||
not resource_instance.groups or group not in resource_instance.groups
|
||||
):
|
||||
resource_instance.groups = (resource_instance.groups or []) + [group]
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
dirty_resources[resource_uid] = resource_instance
|
||||
@@ -633,6 +692,7 @@ def _process_finding_micro_batch(
|
||||
muted_reason=muted_reason,
|
||||
compliance=finding.compliance,
|
||||
categories=check_metadata.get("categories", []) or [],
|
||||
resource_groups=check_metadata.get("resourcegroup") or None,
|
||||
)
|
||||
findings_to_create.append(finding_instance)
|
||||
resource_denormalized_data.append((finding_instance, resource_instance))
|
||||
@@ -657,6 +717,18 @@ def _process_finding_micro_batch(
|
||||
cache=scan_categories_cache,
|
||||
)
|
||||
|
||||
# Track resource groups with counts for ScanGroupSummary
|
||||
aggregate_resource_group_counts(
|
||||
resource_group=check_metadata.get("resourcegroup") or None,
|
||||
severity=finding.severity.value,
|
||||
status=status.value,
|
||||
delta=delta.value if delta else None,
|
||||
muted=is_muted,
|
||||
resource_uid=resource_instance.uid if resource_instance else "",
|
||||
cache=scan_resource_groups_cache,
|
||||
group_resources_cache=group_resources_cache,
|
||||
)
|
||||
|
||||
# Bulk operations within single transaction
|
||||
with rls_transaction(tenant_id):
|
||||
# Bulk create findings
|
||||
@@ -714,7 +786,15 @@ def _process_finding_micro_batch(
|
||||
tenant_id=tenant_id,
|
||||
model=Resource,
|
||||
objects=list(dirty_resources.values()),
|
||||
fields=["metadata", "details", "partition", "region", "service", "type"],
|
||||
fields=[
|
||||
"metadata",
|
||||
"details",
|
||||
"partition",
|
||||
"region",
|
||||
"service",
|
||||
"type",
|
||||
"groups",
|
||||
],
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
@@ -757,6 +837,8 @@ def perform_prowler_scan(
|
||||
unique_resources = set()
|
||||
scan_resource_cache: set[tuple[str, str, str, str]] = set()
|
||||
scan_categories_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
scan_resource_groups_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
start_time = time.time()
|
||||
exc = None
|
||||
|
||||
@@ -847,6 +929,8 @@ def perform_prowler_scan(
|
||||
scan_resource_cache=scan_resource_cache,
|
||||
mute_rules_cache=mute_rules_cache,
|
||||
scan_categories_cache=scan_categories_cache,
|
||||
scan_resource_groups_cache=scan_resource_groups_cache,
|
||||
group_resources_cache=group_resources_cache,
|
||||
)
|
||||
|
||||
# Update scan progress
|
||||
@@ -933,6 +1017,38 @@ def perform_prowler_scan(
|
||||
sentry_sdk.capture_exception(cat_exception)
|
||||
logger.error(f"Error storing categories for scan {scan_id}: {cat_exception}")
|
||||
|
||||
try:
|
||||
if scan_resource_groups_cache:
|
||||
# Compute group-level resource counts (same value for all severity rows in a group)
|
||||
group_resource_counts = {
|
||||
grp: len(uids) for grp, uids in group_resources_cache.items()
|
||||
}
|
||||
resource_group_summaries = [
|
||||
ScanGroupSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
resource_group=grp,
|
||||
severity=severity,
|
||||
total_findings=counts["total"],
|
||||
failed_findings=counts["failed"],
|
||||
new_failed_findings=counts["new_failed"],
|
||||
resources_count=group_resource_counts.get(grp, 0),
|
||||
)
|
||||
for (
|
||||
grp,
|
||||
severity,
|
||||
), counts in scan_resource_groups_cache.items()
|
||||
]
|
||||
with rls_transaction(tenant_id):
|
||||
ScanGroupSummary.objects.bulk_create(
|
||||
resource_group_summaries, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
except Exception as rg_exception:
|
||||
sentry_sdk.capture_exception(rg_exception)
|
||||
logger.error(
|
||||
f"Error storing resource groups for scan {scan_id}: {rg_exception}"
|
||||
)
|
||||
|
||||
serializer = ScanTaskSerializer(instance=scan_instance)
|
||||
return serializer.data
|
||||
|
||||
|
||||
@@ -131,9 +131,11 @@ def compute_threatscore_metrics(
|
||||
continue
|
||||
|
||||
m = metadata[0]
|
||||
risk_level = getattr(m, "LevelOfRisk", 0)
|
||||
weight = getattr(m, "Weight", 0)
|
||||
risk_level_raw = getattr(m, "LevelOfRisk", 0)
|
||||
weight_raw = getattr(m, "Weight", 0)
|
||||
section = getattr(m, "Section", "Unknown")
|
||||
risk_level = int(risk_level_raw) if risk_level_raw else 0
|
||||
weight = int(weight_raw) if weight_raw else 0
|
||||
|
||||
# Calculate ThreatScore components using formula from UI
|
||||
rate_i = req_passed_findings / req_total_findings
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
|
||||
from django.db.models import Count, Q
|
||||
from tasks.utils import batched
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
@@ -154,6 +151,12 @@ def _load_findings_for_requirement_checks(
|
||||
Supports optional caching to avoid duplicate queries when generating multiple
|
||||
reports for the same scan.
|
||||
|
||||
Memory optimizations:
|
||||
- Uses database iterator with chunk_size for streaming large result sets
|
||||
- Shares references between cache and return dict (no duplication)
|
||||
- Only selects required fields from database
|
||||
- Processes findings in batches to reduce memory pressure
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant ID for Row-Level Security context.
|
||||
scan_id (str): The ID of the scan to retrieve findings for.
|
||||
@@ -171,69 +174,73 @@ def _load_findings_for_requirement_checks(
|
||||
'aws_s3_bucket_public_access': [FindingOutput(...)]
|
||||
}
|
||||
"""
|
||||
findings_by_check_id = defaultdict(list)
|
||||
|
||||
if not check_ids:
|
||||
return dict(findings_by_check_id)
|
||||
return {}
|
||||
|
||||
# Initialize cache if not provided
|
||||
if findings_cache is None:
|
||||
findings_cache = {}
|
||||
|
||||
# Deduplicate check_ids to avoid redundant processing
|
||||
unique_check_ids = list(set(check_ids))
|
||||
|
||||
# Separate cached and non-cached check_ids
|
||||
check_ids_to_load = []
|
||||
cache_hits = 0
|
||||
cache_misses = 0
|
||||
|
||||
for check_id in check_ids:
|
||||
for check_id in unique_check_ids:
|
||||
if check_id in findings_cache:
|
||||
# Reuse from cache
|
||||
findings_by_check_id[check_id] = findings_cache[check_id]
|
||||
cache_hits += 1
|
||||
else:
|
||||
# Need to load from database
|
||||
check_ids_to_load.append(check_id)
|
||||
cache_misses += 1
|
||||
|
||||
if cache_hits > 0:
|
||||
total_checks = len(unique_check_ids)
|
||||
logger.info(
|
||||
f"Findings cache: {cache_hits} hits, {cache_misses} misses "
|
||||
f"({cache_hits / (cache_hits + cache_misses) * 100:.1f}% hit rate)"
|
||||
f"Findings cache: {cache_hits}/{total_checks} hits "
|
||||
f"({cache_hits / total_checks * 100:.1f}% hit rate)"
|
||||
)
|
||||
|
||||
# If all check_ids were in cache, return early
|
||||
if not check_ids_to_load:
|
||||
return dict(findings_by_check_id)
|
||||
|
||||
logger.info(f"Loading findings for {len(check_ids_to_load)} checks on-demand")
|
||||
|
||||
findings_queryset = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id, check_id__in=check_ids_to_load
|
||||
# Load missing check_ids from database
|
||||
if check_ids_to_load:
|
||||
logger.info(
|
||||
f"Loading findings for {len(check_ids_to_load)} checks from database"
|
||||
)
|
||||
.order_by("uid")
|
||||
.iterator()
|
||||
)
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
for batch, is_last_batch in batched(
|
||||
findings_queryset, DJANGO_FINDINGS_BATCH_SIZE
|
||||
):
|
||||
for finding_model in batch:
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Use iterator with chunk_size for memory-efficient streaming
|
||||
# chunk_size controls how many rows Django fetches from DB at once
|
||||
findings_queryset = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
check_id__in=check_ids_to_load,
|
||||
)
|
||||
.order_by("check_id", "uid")
|
||||
.iterator(chunk_size=DJANGO_FINDINGS_BATCH_SIZE)
|
||||
)
|
||||
|
||||
# Pre-initialize empty lists for all check_ids to load
|
||||
# This avoids repeated dict lookups and 'if not in' checks
|
||||
for check_id in check_ids_to_load:
|
||||
findings_cache[check_id] = []
|
||||
|
||||
findings_count = 0
|
||||
for finding_model in findings_queryset:
|
||||
finding_output = FindingOutput.transform_api_finding(
|
||||
finding_model, prowler_provider
|
||||
)
|
||||
findings_by_check_id[finding_output.check_id].append(finding_output)
|
||||
# Update cache with newly loaded findings
|
||||
if finding_output.check_id not in findings_cache:
|
||||
findings_cache[finding_output.check_id] = []
|
||||
findings_cache[finding_output.check_id].append(finding_output)
|
||||
findings_count += 1
|
||||
|
||||
total_findings_loaded = sum(
|
||||
len(findings) for findings in findings_by_check_id.values()
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {total_findings_loaded} findings for {len(findings_by_check_id)} checks"
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {findings_count} findings for {len(check_ids_to_load)} checks"
|
||||
)
|
||||
|
||||
return dict(findings_by_check_id)
|
||||
# Build result dict using cache references (no data duplication)
|
||||
# This shares the same list objects between cache and result
|
||||
result = {
|
||||
check_id: findings_cache.get(check_id, []) for check_id in unique_check_ids
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -8,12 +8,17 @@ from celery.utils.log import get_task_logger
|
||||
from config.celery import RLSTask
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
from tasks.jobs.attack_paths import (
|
||||
attack_paths_scan,
|
||||
can_provider_run_attack_paths_scan,
|
||||
)
|
||||
from tasks.jobs.backfill import (
|
||||
backfill_compliance_summaries,
|
||||
backfill_daily_severity_summaries,
|
||||
backfill_provider_compliance_scores,
|
||||
backfill_resource_scan_summaries,
|
||||
backfill_scan_category_summaries,
|
||||
backfill_scan_resource_group_summaries,
|
||||
)
|
||||
from tasks.jobs.connection import (
|
||||
check_integration_connection,
|
||||
@@ -47,7 +52,11 @@ from tasks.jobs.scan import (
|
||||
perform_prowler_scan,
|
||||
update_provider_compliance_scores,
|
||||
)
|
||||
from tasks.utils import batched, get_next_execution_datetime
|
||||
from tasks.utils import (
|
||||
_get_or_create_scheduled_scan,
|
||||
batched,
|
||||
get_next_execution_datetime,
|
||||
)
|
||||
|
||||
from api.compliance import get_compliance_frameworks
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
@@ -152,6 +161,11 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str)
|
||||
),
|
||||
).apply_async()
|
||||
|
||||
if can_provider_run_attack_paths_scan(tenant_id, provider_id):
|
||||
perform_attack_paths_scan_task.apply_async(
|
||||
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
|
||||
)
|
||||
|
||||
|
||||
@shared_task(base=RLSTask, name="provider-connection-check")
|
||||
@set_tenant
|
||||
@@ -264,44 +278,38 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
||||
periodic_task_instance = PeriodicTask.objects.get(
|
||||
name=f"scan-perform-scheduled-{provider_id}"
|
||||
)
|
||||
|
||||
executed_scan = Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
task__task_runner_task__task_id=task_id,
|
||||
).order_by("completed_at")
|
||||
|
||||
if (
|
||||
executing_scan = (
|
||||
Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.EXECUTING,
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
scheduled_at__date=datetime.now(timezone.utc).date(),
|
||||
).exists()
|
||||
or executed_scan.exists()
|
||||
):
|
||||
# Duplicated task execution due to visibility timeout or scan is already running
|
||||
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
|
||||
try:
|
||||
affected_scan = executed_scan.first()
|
||||
if not affected_scan:
|
||||
raise ValueError(
|
||||
"Error retrieving affected scan details after detecting duplicated scheduled "
|
||||
"scan."
|
||||
)
|
||||
# Return the affected scan details to avoid losing data
|
||||
serializer = ScanTaskSerializer(instance=affected_scan)
|
||||
except Exception as duplicated_scan_exception:
|
||||
logger.error(
|
||||
f"Duplicated scheduled scan for provider {provider_id}. Error retrieving affected scan details: "
|
||||
f"{str(duplicated_scan_exception)}"
|
||||
)
|
||||
raise duplicated_scan_exception
|
||||
return serializer.data
|
||||
)
|
||||
.order_by("-started_at")
|
||||
.first()
|
||||
)
|
||||
if executing_scan:
|
||||
logger.warning(
|
||||
f"Scheduled scan already executing for provider {provider_id}. Skipping."
|
||||
)
|
||||
return ScanTaskSerializer(instance=executing_scan).data
|
||||
|
||||
executed_scan = Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
task__task_runner_task__task_id=task_id,
|
||||
).first()
|
||||
|
||||
if executed_scan:
|
||||
# Duplicated task execution due to visibility timeout
|
||||
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
|
||||
return ScanTaskSerializer(instance=executed_scan).data
|
||||
|
||||
interval = periodic_task_instance.interval
|
||||
next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
|
||||
current_scan_datetime = next_scan_datetime - timedelta(
|
||||
**{interval.period: interval.every}
|
||||
)
|
||||
|
||||
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
|
||||
_cleanup_orphan_scheduled_scans(
|
||||
@@ -310,19 +318,12 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
)
|
||||
|
||||
scan_instance, _ = Scan.objects.get_or_create(
|
||||
scan_instance = _get_or_create_scheduled_scan(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
defaults={
|
||||
"state": StateChoices.SCHEDULED,
|
||||
"name": "Daily scheduled scan",
|
||||
"scheduled_at": next_scan_datetime - timedelta(days=1),
|
||||
},
|
||||
scheduled_at=current_scan_datetime,
|
||||
)
|
||||
|
||||
scan_instance.task_id = task_id
|
||||
scan_instance.save()
|
||||
|
||||
@@ -332,18 +333,19 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
||||
scan_id=str(scan_instance.id),
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
with rls_transaction(tenant_id):
|
||||
Scan.objects.get_or_create(
|
||||
now = datetime.now(timezone.utc)
|
||||
if next_scan_datetime <= now:
|
||||
interval_delta = timedelta(**{interval.period: interval.every})
|
||||
while next_scan_datetime <= now:
|
||||
next_scan_datetime += interval_delta
|
||||
_get_or_create_scheduled_scan(
|
||||
tenant_id=tenant_id,
|
||||
name="Daily scheduled scan",
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.SCHEDULED,
|
||||
scheduled_at=next_scan_datetime,
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
scheduled_at=next_scan_datetime,
|
||||
update_state=True,
|
||||
)
|
||||
|
||||
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)
|
||||
@@ -357,6 +359,29 @@ def perform_scan_summary_task(tenant_id: str, scan_id: str):
|
||||
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
|
||||
|
||||
|
||||
@shared_task(
|
||||
base=RLSTask,
|
||||
bind=True,
|
||||
name="attack-paths-scan-perform",
|
||||
queue="attack-paths-scans",
|
||||
)
|
||||
def perform_attack_paths_scan_task(self, tenant_id: str, scan_id: str):
|
||||
"""
|
||||
Execute an Attack Paths scan for the given provider within the current tenant RLS context.
|
||||
|
||||
Args:
|
||||
self: The task instance (automatically passed when bind=True).
|
||||
tenant_id (str): The tenant identifier for RLS context.
|
||||
scan_id (str): The Prowler scan identifier for obtaining the tenant and provider context.
|
||||
|
||||
Returns:
|
||||
Any: The result from `attack_paths_scan`, including any per-scan failure details.
|
||||
"""
|
||||
return attack_paths_scan(
|
||||
tenant_id=tenant_id, scan_id=scan_id, task_id=self.request.id
|
||||
)
|
||||
|
||||
|
||||
@shared_task(name="tenant-deletion", queue="deletion", autoretry_for=(Exception,))
|
||||
def delete_tenant_task(tenant_id: str):
|
||||
return delete_tenant(pk=tenant_id)
|
||||
@@ -613,6 +638,21 @@ def backfill_scan_category_summaries_task(tenant_id: str, scan_id: str):
|
||||
return backfill_scan_category_summaries(tenant_id=tenant_id, scan_id=scan_id)
|
||||
|
||||
|
||||
@shared_task(name="backfill-scan-resource-group-summaries", queue="backfill")
|
||||
@handle_provider_deletion
|
||||
def backfill_scan_resource_group_summaries_task(tenant_id: str, scan_id: str):
|
||||
"""
|
||||
Backfill ScanGroupSummary for a completed scan.
|
||||
|
||||
Aggregates unique resource groups from findings and creates a summary row.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant identifier.
|
||||
scan_id (str): The scan identifier.
|
||||
"""
|
||||
return backfill_scan_resource_group_summaries(tenant_id=tenant_id, scan_id=scan_id)
|
||||
|
||||
|
||||
@shared_task(name="backfill-provider-compliance-scores", queue="backfill")
|
||||
def backfill_provider_compliance_scores_task(tenant_id: str):
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,708 @@
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from tasks.jobs.attack_paths import prowler as prowler_module
|
||||
from tasks.jobs.attack_paths.scan import run as attack_paths_run
|
||||
|
||||
from api.models import (
|
||||
AttackPathsScan,
|
||||
Finding,
|
||||
Provider,
|
||||
Resource,
|
||||
ResourceFindingMapping,
|
||||
Scan,
|
||||
StateChoices,
|
||||
StatusChoices,
|
||||
)
|
||||
from prowler.lib.check.models import Severity
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestAttackPathsRun:
|
||||
def test_run_success_flow(self, tenants_fixture, providers_fixture, scans_fixture):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
scan = scans_fixture[0]
|
||||
scan.provider = provider
|
||||
scan.save()
|
||||
|
||||
attack_paths_scan = AttackPathsScan.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
scan=scan,
|
||||
state=StateChoices.SCHEDULED,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = mock_session
|
||||
session_ctx.__exit__.return_value = False
|
||||
ingestion_result = {"organizations": "warning"}
|
||||
ingestion_fn = MagicMock(return_value=ingestion_result)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
|
||||
return_value=MagicMock(_enabled_regions=["us-east-1"]),
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.graph_database.get_uri",
|
||||
return_value="bolt://neo4j",
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
|
||||
return_value="db-scan-id",
|
||||
) as mock_get_db_name,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.graph_database.create_database"
|
||||
) as mock_create_db,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.graph_database.get_session",
|
||||
return_value=session_ctx,
|
||||
) as mock_get_session,
|
||||
patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache"),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.cartography_create_indexes.run"
|
||||
) as mock_cartography_indexes,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.cartography_analysis.run"
|
||||
) as mock_cartography_analysis,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.cartography_ontology.run"
|
||||
) as mock_cartography_ontology,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.prowler.create_indexes"
|
||||
) as mock_prowler_indexes,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.prowler.analysis"
|
||||
) as mock_prowler_analysis,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan",
|
||||
return_value=attack_paths_scan,
|
||||
) as mock_retrieve_scan,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan"
|
||||
) as mock_starting,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress"
|
||||
) as mock_update_progress,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan"
|
||||
) as mock_finish,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
|
||||
return_value=ingestion_fn,
|
||||
) as mock_get_ingestion,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan._call_within_event_loop",
|
||||
side_effect=lambda fn, *a, **kw: fn(*a, **kw),
|
||||
) as mock_event_loop,
|
||||
):
|
||||
result = attack_paths_run(str(tenant.id), str(scan.id), "task-123")
|
||||
|
||||
assert result == ingestion_result
|
||||
mock_retrieve_scan.assert_called_once_with(str(tenant.id), str(scan.id))
|
||||
mock_starting.assert_called_once()
|
||||
config = mock_starting.call_args[0][2]
|
||||
assert config.neo4j_database == "db-scan-id"
|
||||
|
||||
mock_create_db.assert_called_once_with("db-scan-id")
|
||||
mock_get_session.assert_called_once_with("db-scan-id")
|
||||
mock_cartography_indexes.assert_called_once_with(mock_session, config)
|
||||
mock_prowler_indexes.assert_called_once_with(mock_session)
|
||||
mock_cartography_analysis.assert_called_once_with(mock_session, config)
|
||||
mock_cartography_ontology.assert_called_once_with(mock_session, config)
|
||||
mock_prowler_analysis.assert_called_once_with(
|
||||
mock_session,
|
||||
provider,
|
||||
str(scan.id),
|
||||
config,
|
||||
)
|
||||
mock_get_ingestion.assert_called_once_with(provider.provider)
|
||||
mock_event_loop.assert_called_once()
|
||||
mock_update_progress.assert_any_call(attack_paths_scan, 1)
|
||||
mock_update_progress.assert_any_call(attack_paths_scan, 2)
|
||||
mock_update_progress.assert_any_call(attack_paths_scan, 95)
|
||||
mock_finish.assert_called_once_with(
|
||||
attack_paths_scan, StateChoices.COMPLETED, ingestion_result
|
||||
)
|
||||
mock_get_db_name.assert_called_once_with(attack_paths_scan.id)
|
||||
|
||||
def test_run_failure_marks_scan_failed(
|
||||
self, tenants_fixture, providers_fixture, scans_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
scan = scans_fixture[0]
|
||||
scan.provider = provider
|
||||
scan.save()
|
||||
|
||||
attack_paths_scan = AttackPathsScan.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
scan=scan,
|
||||
state=StateChoices.SCHEDULED,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = mock_session
|
||||
session_ctx.__exit__.return_value = False
|
||||
ingestion_fn = MagicMock(side_effect=RuntimeError("ingestion boom"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
|
||||
return_value=MagicMock(_enabled_regions=["us-east-1"]),
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.scan.graph_database.get_uri"),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
|
||||
return_value="db-scan-id",
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.scan.graph_database.create_database"),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.graph_database.get_session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run"),
|
||||
patch("tasks.jobs.attack_paths.scan.cartography_analysis.run"),
|
||||
patch("tasks.jobs.attack_paths.scan.prowler.create_indexes"),
|
||||
patch("tasks.jobs.attack_paths.scan.prowler.analysis"),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan",
|
||||
return_value=attack_paths_scan,
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan"),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress"
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan"
|
||||
) as mock_finish,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
|
||||
return_value=ingestion_fn,
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan._call_within_event_loop",
|
||||
side_effect=lambda fn, *a, **kw: fn(*a, **kw),
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.utils.stringify_exception",
|
||||
return_value="Cartography failed: ingestion boom",
|
||||
),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="ingestion boom"):
|
||||
attack_paths_run(str(tenant.id), str(scan.id), "task-456")
|
||||
|
||||
failure_args = mock_finish.call_args[0]
|
||||
assert failure_args[0] is attack_paths_scan
|
||||
assert failure_args[1] == StateChoices.FAILED
|
||||
assert failure_args[2] == {
|
||||
"global_cartography_error": "Cartography failed: ingestion boom"
|
||||
}
|
||||
|
||||
def test_run_returns_early_for_unsupported_provider(self, tenants_fixture):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = Provider.objects.create(
|
||||
provider=Provider.ProviderChoices.GCP,
|
||||
uid="gcp-account",
|
||||
alias="gcp",
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
scan = Scan.objects.create(
|
||||
name="GCP Scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.AVAILABLE,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
|
||||
return_value=None,
|
||||
) as mock_get_ingestion,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan"
|
||||
) as mock_retrieve,
|
||||
):
|
||||
mock_retrieve.return_value = None
|
||||
result = attack_paths_run(str(tenant.id), str(scan.id), "task-789")
|
||||
|
||||
assert result == {
|
||||
"global_error": "Provider gcp is not supported for Attack Paths scans"
|
||||
}
|
||||
mock_get_ingestion.assert_called_once_with(provider.provider)
|
||||
mock_retrieve.assert_called_once_with(str(tenant.id), str(scan.id))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestAttackPathsProwlerHelpers:
|
||||
def test_create_indexes_executes_all_statements(self):
|
||||
mock_session = MagicMock()
|
||||
with patch("tasks.jobs.attack_paths.prowler.run_write_query") as mock_run_write:
|
||||
prowler_module.create_indexes(mock_session)
|
||||
|
||||
assert mock_run_write.call_count == len(prowler_module.INDEX_STATEMENTS)
|
||||
mock_run_write.assert_has_calls(
|
||||
[call(mock_session, stmt) for stmt in prowler_module.INDEX_STATEMENTS]
|
||||
)
|
||||
|
||||
def test_load_findings_batches_requests(self, providers_fixture):
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
# Create a generator that yields two batches
|
||||
def findings_generator():
|
||||
yield [{"id": "1", "resource_uid": "r-1"}]
|
||||
yield [{"id": "2", "resource_uid": "r-2"}]
|
||||
|
||||
config = SimpleNamespace(update_tag=12345)
|
||||
mock_session = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.prowler.get_root_node_label",
|
||||
return_value="AWSAccount",
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.prowler.get_node_uid_field",
|
||||
return_value="arn",
|
||||
),
|
||||
):
|
||||
prowler_module.load_findings(
|
||||
mock_session, findings_generator(), provider, config
|
||||
)
|
||||
|
||||
assert mock_session.run.call_count == 2
|
||||
for call_args in mock_session.run.call_args_list:
|
||||
params = call_args.args[1]
|
||||
assert params["provider_uid"] == str(provider.uid)
|
||||
assert params["last_updated"] == config.update_tag
|
||||
assert "findings_data" in params
|
||||
|
||||
def test_cleanup_findings_runs_batches(self, providers_fixture):
|
||||
provider = providers_fixture[0]
|
||||
config = SimpleNamespace(update_tag=1024)
|
||||
mock_session = MagicMock()
|
||||
|
||||
first_batch = MagicMock()
|
||||
first_batch.single.return_value = {"deleted_findings_count": 3}
|
||||
second_batch = MagicMock()
|
||||
second_batch.single.return_value = {"deleted_findings_count": 0}
|
||||
mock_session.run.side_effect = [first_batch, second_batch]
|
||||
|
||||
prowler_module.cleanup_findings(mock_session, provider, config)
|
||||
|
||||
assert mock_session.run.call_count == 2
|
||||
params = mock_session.run.call_args.args[1]
|
||||
assert params["provider_uid"] == str(provider.uid)
|
||||
assert params["last_updated"] == config.update_tag
|
||||
|
||||
def test_get_provider_last_scan_findings_returns_latest_scan_data(
|
||||
self,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
resource = Resource.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
uid="resource-uid",
|
||||
name="Resource",
|
||||
region="us-east-1",
|
||||
service="ec2",
|
||||
type="instance",
|
||||
)
|
||||
|
||||
older_scan = Scan.objects.create(
|
||||
name="Older",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
old_finding = Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
uid="older-finding",
|
||||
scan=older_scan,
|
||||
delta=Finding.DeltaChoices.NEW,
|
||||
status=StatusChoices.PASS,
|
||||
status_extended="ok",
|
||||
severity=Severity.low,
|
||||
impact=Severity.low,
|
||||
impact_extended="",
|
||||
raw_result={},
|
||||
check_id="check-old",
|
||||
check_metadata={"checktitle": "Old"},
|
||||
first_seen_at=older_scan.inserted_at,
|
||||
)
|
||||
ResourceFindingMapping.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
resource=resource,
|
||||
finding=old_finding,
|
||||
)
|
||||
|
||||
latest_scan = Scan.objects.create(
|
||||
name="Latest",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
finding = Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
uid="finding-uid",
|
||||
scan=latest_scan,
|
||||
delta=Finding.DeltaChoices.NEW,
|
||||
status=StatusChoices.FAIL,
|
||||
status_extended="failed",
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
impact_extended="",
|
||||
raw_result={},
|
||||
check_id="check-1",
|
||||
check_metadata={"checktitle": "Check title"},
|
||||
first_seen_at=latest_scan.inserted_at,
|
||||
)
|
||||
ResourceFindingMapping.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
resource=resource,
|
||||
finding=finding,
|
||||
)
|
||||
|
||||
latest_scan.refresh_from_db()
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.prowler.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
), patch(
|
||||
"tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS",
|
||||
"default",
|
||||
):
|
||||
# Generator yields batches, collect all findings from all batches
|
||||
findings_batches = prowler_module.get_provider_last_scan_findings(
|
||||
provider,
|
||||
str(latest_scan.id),
|
||||
)
|
||||
findings_data = []
|
||||
for batch in findings_batches:
|
||||
findings_data.extend(batch)
|
||||
|
||||
assert len(findings_data) == 1
|
||||
finding_dict = findings_data[0]
|
||||
assert finding_dict["id"] == str(finding.id)
|
||||
assert finding_dict["resource_uid"] == resource.uid
|
||||
assert finding_dict["check_title"] == "Check title"
|
||||
assert finding_dict["scan_id"] == str(latest_scan.id)
|
||||
|
||||
def test_enrich_and_flatten_batch_single_resource(
|
||||
self,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""One finding + one resource = one output dict"""
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
resource = Resource.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
uid="resource-uid-1",
|
||||
name="Resource 1",
|
||||
region="us-east-1",
|
||||
service="ec2",
|
||||
type="instance",
|
||||
)
|
||||
|
||||
scan = Scan.objects.create(
|
||||
name="Test Scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
finding = Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
uid="finding-uid",
|
||||
scan=scan,
|
||||
delta=Finding.DeltaChoices.NEW,
|
||||
status=StatusChoices.FAIL,
|
||||
status_extended="failed",
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
impact_extended="",
|
||||
raw_result={},
|
||||
check_id="check-1",
|
||||
check_metadata={"checktitle": "Check title"},
|
||||
first_seen_at=scan.inserted_at,
|
||||
)
|
||||
ResourceFindingMapping.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
resource=resource,
|
||||
finding=finding,
|
||||
)
|
||||
|
||||
# Simulate the dict returned by .values()
|
||||
finding_dict = {
|
||||
"id": finding.id,
|
||||
"uid": finding.uid,
|
||||
"inserted_at": finding.inserted_at,
|
||||
"updated_at": finding.updated_at,
|
||||
"first_seen_at": finding.first_seen_at,
|
||||
"scan_id": scan.id,
|
||||
"delta": finding.delta,
|
||||
"status": finding.status,
|
||||
"status_extended": finding.status_extended,
|
||||
"severity": finding.severity,
|
||||
"check_id": finding.check_id,
|
||||
"check_metadata__checktitle": finding.check_metadata["checktitle"],
|
||||
"muted": finding.muted,
|
||||
"muted_reason": finding.muted_reason,
|
||||
}
|
||||
|
||||
# _enrich_and_flatten_batch queries ResourceFindingMapping directly
|
||||
# No RLS mock needed - test DB doesn't enforce RLS policies
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS",
|
||||
"default",
|
||||
):
|
||||
result = prowler_module._enrich_and_flatten_batch([finding_dict])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["resource_uid"] == resource.uid
|
||||
assert result[0]["id"] == str(finding.id)
|
||||
assert result[0]["status"] == "FAIL"
|
||||
|
||||
def test_enrich_and_flatten_batch_multiple_resources(
|
||||
self,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""One finding + three resources = three output dicts"""
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
resources = []
|
||||
for i in range(3):
|
||||
resource = Resource.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
uid=f"resource-uid-{i}",
|
||||
name=f"Resource {i}",
|
||||
region="us-east-1",
|
||||
service="ec2",
|
||||
type="instance",
|
||||
)
|
||||
resources.append(resource)
|
||||
|
||||
scan = Scan.objects.create(
|
||||
name="Test Scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
finding = Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
uid="finding-uid",
|
||||
scan=scan,
|
||||
delta=Finding.DeltaChoices.NEW,
|
||||
status=StatusChoices.FAIL,
|
||||
status_extended="failed",
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
impact_extended="",
|
||||
raw_result={},
|
||||
check_id="check-1",
|
||||
check_metadata={"checktitle": "Check title"},
|
||||
first_seen_at=scan.inserted_at,
|
||||
)
|
||||
|
||||
# Map finding to all 3 resources
|
||||
for resource in resources:
|
||||
ResourceFindingMapping.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
resource=resource,
|
||||
finding=finding,
|
||||
)
|
||||
|
||||
finding_dict = {
|
||||
"id": finding.id,
|
||||
"uid": finding.uid,
|
||||
"inserted_at": finding.inserted_at,
|
||||
"updated_at": finding.updated_at,
|
||||
"first_seen_at": finding.first_seen_at,
|
||||
"scan_id": scan.id,
|
||||
"delta": finding.delta,
|
||||
"status": finding.status,
|
||||
"status_extended": finding.status_extended,
|
||||
"severity": finding.severity,
|
||||
"check_id": finding.check_id,
|
||||
"check_metadata__checktitle": finding.check_metadata["checktitle"],
|
||||
"muted": finding.muted,
|
||||
"muted_reason": finding.muted_reason,
|
||||
}
|
||||
|
||||
# _enrich_and_flatten_batch queries ResourceFindingMapping directly
|
||||
# No RLS mock needed - test DB doesn't enforce RLS policies
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS",
|
||||
"default",
|
||||
):
|
||||
result = prowler_module._enrich_and_flatten_batch([finding_dict])
|
||||
|
||||
assert len(result) == 3
|
||||
result_resource_uids = {r["resource_uid"] for r in result}
|
||||
assert result_resource_uids == {r.uid for r in resources}
|
||||
|
||||
# All should have same finding data
|
||||
for r in result:
|
||||
assert r["id"] == str(finding.id)
|
||||
assert r["status"] == "FAIL"
|
||||
|
||||
def test_enrich_and_flatten_batch_no_resources_skips(
|
||||
self,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""Finding without resources should be skipped"""
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
scan = Scan.objects.create(
|
||||
name="Test Scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
finding = Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
uid="orphan-finding",
|
||||
scan=scan,
|
||||
delta=Finding.DeltaChoices.NEW,
|
||||
status=StatusChoices.FAIL,
|
||||
status_extended="failed",
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
impact_extended="",
|
||||
raw_result={},
|
||||
check_id="check-1",
|
||||
check_metadata={"checktitle": "Check title"},
|
||||
first_seen_at=scan.inserted_at,
|
||||
)
|
||||
# Note: No ResourceFindingMapping created
|
||||
|
||||
finding_dict = {
|
||||
"id": finding.id,
|
||||
"uid": finding.uid,
|
||||
"inserted_at": finding.inserted_at,
|
||||
"updated_at": finding.updated_at,
|
||||
"first_seen_at": finding.first_seen_at,
|
||||
"scan_id": scan.id,
|
||||
"delta": finding.delta,
|
||||
"status": finding.status,
|
||||
"status_extended": finding.status_extended,
|
||||
"severity": finding.severity,
|
||||
"check_id": finding.check_id,
|
||||
"check_metadata__checktitle": finding.check_metadata["checktitle"],
|
||||
"muted": finding.muted,
|
||||
"muted_reason": finding.muted_reason,
|
||||
}
|
||||
|
||||
# Mock logger to verify no warning is emitted
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS",
|
||||
"default",
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.prowler.logger") as mock_logger,
|
||||
):
|
||||
result = prowler_module._enrich_and_flatten_batch([finding_dict])
|
||||
|
||||
assert len(result) == 0
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_generator_is_lazy(self, providers_fixture):
|
||||
"""Generator should not execute queries until iterated"""
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
scan_id = "some-scan-id"
|
||||
|
||||
with (
|
||||
patch("tasks.jobs.attack_paths.prowler.rls_transaction") as mock_rls,
|
||||
patch("tasks.jobs.attack_paths.prowler.Finding") as mock_finding,
|
||||
):
|
||||
# Create generator but don't iterate
|
||||
prowler_module.get_provider_last_scan_findings(provider, scan_id)
|
||||
|
||||
# Nothing should be called yet
|
||||
mock_rls.assert_not_called()
|
||||
mock_finding.objects.filter.assert_not_called()
|
||||
|
||||
def test_load_findings_empty_generator(self, providers_fixture):
|
||||
"""Empty generator should not call neo4j"""
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
mock_session = MagicMock()
|
||||
config = SimpleNamespace(update_tag=12345)
|
||||
|
||||
def empty_gen():
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.prowler.get_root_node_label",
|
||||
return_value="AWSAccount",
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.prowler.get_node_uid_field",
|
||||
return_value="arn",
|
||||
),
|
||||
):
|
||||
prowler_module.load_findings(mock_session, empty_gen(), provider, config)
|
||||
|
||||
mock_session.run.assert_not_called()
|
||||
@@ -8,6 +8,7 @@ from tasks.jobs.backfill import (
|
||||
backfill_provider_compliance_scores,
|
||||
backfill_resource_scan_summaries,
|
||||
backfill_scan_category_summaries,
|
||||
backfill_scan_resource_group_summaries,
|
||||
)
|
||||
|
||||
from api.models import (
|
||||
@@ -16,6 +17,7 @@ from api.models import (
|
||||
ResourceScanSummary,
|
||||
Scan,
|
||||
ScanCategorySummary,
|
||||
ScanGroupSummary,
|
||||
StateChoices,
|
||||
)
|
||||
from prowler.lib.check.models import Severity
|
||||
@@ -265,6 +267,94 @@ class TestBackfillScanCategorySummaries:
|
||||
assert summary.new_failed_findings == 1
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def findings_with_group_fixture(scans_fixture, resources_fixture):
|
||||
scan = scans_fixture[0]
|
||||
resource = resources_fixture[0]
|
||||
|
||||
finding = Finding.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
uid="finding_with_group",
|
||||
scan=scan,
|
||||
delta="new",
|
||||
status=Status.FAIL,
|
||||
status_extended="test status",
|
||||
impact=Severity.high,
|
||||
impact_extended="test impact",
|
||||
severity=Severity.high,
|
||||
raw_result={"status": Status.FAIL},
|
||||
check_id="test_check",
|
||||
check_metadata={"CheckId": "test_check"},
|
||||
resource_groups="ai_ml",
|
||||
first_seen_at="2024-01-02T00:00:00Z",
|
||||
)
|
||||
finding.add_resources([resource])
|
||||
return finding
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def scan_resource_group_summary_fixture(scans_fixture):
|
||||
scan = scans_fixture[0]
|
||||
return ScanGroupSummary.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
scan=scan,
|
||||
resource_group="existing-group",
|
||||
severity=Severity.high,
|
||||
total_findings=1,
|
||||
failed_findings=0,
|
||||
new_failed_findings=0,
|
||||
resources_count=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestBackfillScanGroupSummaries:
|
||||
def test_already_backfilled(self, scan_resource_group_summary_fixture):
|
||||
tenant_id = scan_resource_group_summary_fixture.tenant_id
|
||||
scan_id = scan_resource_group_summary_fixture.scan_id
|
||||
|
||||
result = backfill_scan_resource_group_summaries(str(tenant_id), str(scan_id))
|
||||
|
||||
assert result == {"status": "already backfilled"}
|
||||
|
||||
def test_not_completed_scan(self, get_not_completed_scans):
|
||||
for scan in get_not_completed_scans:
|
||||
result = backfill_scan_resource_group_summaries(
|
||||
str(scan.tenant_id), str(scan.id)
|
||||
)
|
||||
assert result == {"status": "scan is not completed"}
|
||||
|
||||
def test_no_resource_groups_to_backfill(self, scans_fixture):
|
||||
scan = scans_fixture[1] # Failed scan with no findings
|
||||
result = backfill_scan_resource_group_summaries(
|
||||
str(scan.tenant_id), str(scan.id)
|
||||
)
|
||||
assert result == {"status": "no resource groups to backfill"}
|
||||
|
||||
def test_successful_backfill(self, findings_with_group_fixture):
|
||||
finding = findings_with_group_fixture
|
||||
tenant_id = str(finding.tenant_id)
|
||||
scan_id = str(finding.scan_id)
|
||||
|
||||
result = backfill_scan_resource_group_summaries(tenant_id, scan_id)
|
||||
|
||||
# 1 resource group × 1 severity = 1 row
|
||||
assert result == {"status": "backfilled", "resource_groups_count": 1}
|
||||
|
||||
summaries = ScanGroupSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
)
|
||||
assert summaries.count() == 1
|
||||
|
||||
summary = summaries.first()
|
||||
assert summary.resource_group == "ai_ml"
|
||||
assert summary.severity == Severity.high
|
||||
assert summary.total_findings == 1
|
||||
assert summary.failed_findings == 1
|
||||
assert summary.new_failed_findings == 1
|
||||
assert summary.resources_count == 1
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestBackfillProviderComplianceScores:
|
||||
def test_no_completed_scans(self, tenants_fixture):
|
||||
|
||||
@@ -82,7 +82,7 @@ def test_check_provider_connection_exception(
|
||||
[
|
||||
{
|
||||
"name": "OpenAI",
|
||||
"api_key_decoded": "sk-test1234567890T3BlbkFJtest1234567890",
|
||||
"api_key_decoded": "sk-fake-test-key-for-unit-testing-only",
|
||||
"model": "gpt-4o",
|
||||
"temperature": 0,
|
||||
"max_tokens": 4000,
|
||||
|
||||
@@ -1,27 +1,60 @@
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from tasks.jobs.deletion import delete_provider, delete_tenant
|
||||
|
||||
from api.models import Provider, Tenant
|
||||
from tasks.jobs.deletion import delete_provider, delete_tenant
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestDeleteProvider:
|
||||
def test_delete_provider_success(self, providers_fixture):
|
||||
instance = providers_fixture[0]
|
||||
tenant_id = str(instance.tenant_id)
|
||||
result = delete_provider(tenant_id, instance.id)
|
||||
with patch(
|
||||
"tasks.jobs.deletion.get_provider_graph_database_names"
|
||||
) as mock_get_provider_graph_database_names, patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_database"
|
||||
) as mock_drop_database:
|
||||
graph_db_names = ["graph-db-1", "graph-db-2"]
|
||||
mock_get_provider_graph_database_names.return_value = graph_db_names
|
||||
|
||||
assert result
|
||||
with pytest.raises(ObjectDoesNotExist):
|
||||
Provider.objects.get(pk=instance.id)
|
||||
instance = providers_fixture[0]
|
||||
tenant_id = str(instance.tenant_id)
|
||||
result = delete_provider(tenant_id, instance.id)
|
||||
|
||||
assert result
|
||||
with pytest.raises(ObjectDoesNotExist):
|
||||
Provider.objects.get(pk=instance.id)
|
||||
|
||||
mock_get_provider_graph_database_names.assert_called_once_with(
|
||||
tenant_id, instance.id
|
||||
)
|
||||
mock_drop_database.assert_has_calls(
|
||||
[call(graph_db_name) for graph_db_name in graph_db_names]
|
||||
)
|
||||
|
||||
def test_delete_provider_does_not_exist(self, tenants_fixture):
|
||||
tenant_id = str(tenants_fixture[0].id)
|
||||
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
|
||||
with patch(
|
||||
"tasks.jobs.deletion.get_provider_graph_database_names"
|
||||
) as mock_get_provider_graph_database_names, patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_database"
|
||||
) as mock_drop_database:
|
||||
graph_db_names = ["graph-db-1"]
|
||||
mock_get_provider_graph_database_names.return_value = graph_db_names
|
||||
|
||||
with pytest.raises(ObjectDoesNotExist):
|
||||
delete_provider(tenant_id, non_existent_pk)
|
||||
tenant_id = str(tenants_fixture[0].id)
|
||||
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
|
||||
|
||||
with pytest.raises(ObjectDoesNotExist):
|
||||
delete_provider(tenant_id, non_existent_pk)
|
||||
|
||||
mock_get_provider_graph_database_names.assert_called_once_with(
|
||||
tenant_id, non_existent_pk
|
||||
)
|
||||
mock_drop_database.assert_has_calls(
|
||||
[call(graph_db_name) for graph_db_name in graph_db_names]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -30,33 +63,68 @@ class TestDeleteTenant:
|
||||
"""
|
||||
Test successful deletion of a tenant and its related data.
|
||||
"""
|
||||
tenant = tenants_fixture[0]
|
||||
providers = Provider.objects.filter(tenant_id=tenant.id)
|
||||
with patch(
|
||||
"tasks.jobs.deletion.get_provider_graph_database_names"
|
||||
) as mock_get_provider_graph_database_names, patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_database"
|
||||
) as mock_drop_database:
|
||||
tenant = tenants_fixture[0]
|
||||
providers = list(Provider.objects.filter(tenant_id=tenant.id))
|
||||
|
||||
# Ensure the tenant and related providers exist before deletion
|
||||
assert Tenant.objects.filter(id=tenant.id).exists()
|
||||
assert providers.exists()
|
||||
graph_db_names_per_provider = [
|
||||
[f"graph-db-{provider.id}"] for provider in providers
|
||||
]
|
||||
mock_get_provider_graph_database_names.side_effect = (
|
||||
graph_db_names_per_provider
|
||||
)
|
||||
|
||||
# Call the function and validate the result
|
||||
deletion_summary = delete_tenant(tenant.id)
|
||||
# Ensure the tenant and related providers exist before deletion
|
||||
assert Tenant.objects.filter(id=tenant.id).exists()
|
||||
assert providers
|
||||
|
||||
assert deletion_summary is not None
|
||||
assert not Tenant.objects.filter(id=tenant.id).exists()
|
||||
assert not Provider.objects.filter(tenant_id=tenant.id).exists()
|
||||
# Call the function and validate the result
|
||||
deletion_summary = delete_tenant(tenant.id)
|
||||
|
||||
assert deletion_summary is not None
|
||||
assert not Tenant.objects.filter(id=tenant.id).exists()
|
||||
assert not Provider.objects.filter(tenant_id=tenant.id).exists()
|
||||
|
||||
expected_calls = [
|
||||
call(provider.tenant_id, provider.id) for provider in providers
|
||||
]
|
||||
mock_get_provider_graph_database_names.assert_has_calls(
|
||||
expected_calls, any_order=True
|
||||
)
|
||||
assert mock_get_provider_graph_database_names.call_count == len(
|
||||
expected_calls
|
||||
)
|
||||
expected_drop_calls = [
|
||||
call(graph_db_name[0]) for graph_db_name in graph_db_names_per_provider
|
||||
]
|
||||
mock_drop_database.assert_has_calls(expected_drop_calls, any_order=True)
|
||||
assert mock_drop_database.call_count == len(expected_drop_calls)
|
||||
|
||||
def test_delete_tenant_with_no_providers(self, tenants_fixture):
|
||||
"""
|
||||
Test deletion of a tenant with no related providers.
|
||||
"""
|
||||
tenant = tenants_fixture[1] # Assume this tenant has no providers
|
||||
providers = Provider.objects.filter(tenant_id=tenant.id)
|
||||
with patch(
|
||||
"tasks.jobs.deletion.get_provider_graph_database_names"
|
||||
) as mock_get_provider_graph_database_names, patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_database"
|
||||
) as mock_drop_database:
|
||||
tenant = tenants_fixture[1] # Assume this tenant has no providers
|
||||
providers = Provider.objects.filter(tenant_id=tenant.id)
|
||||
|
||||
# Ensure the tenant exists but has no related providers
|
||||
assert Tenant.objects.filter(id=tenant.id).exists()
|
||||
assert not providers.exists()
|
||||
# Ensure the tenant exists but has no related providers
|
||||
assert Tenant.objects.filter(id=tenant.id).exists()
|
||||
assert not providers.exists()
|
||||
|
||||
# Call the function and validate the result
|
||||
deletion_summary = delete_tenant(tenant.id)
|
||||
# Call the function and validate the result
|
||||
deletion_summary = delete_tenant(tenant.id)
|
||||
|
||||
assert deletion_summary == {} # No providers, so empty summary
|
||||
assert not Tenant.objects.filter(id=tenant.id).exists()
|
||||
assert deletion_summary == {} # No providers, so empty summary
|
||||
assert not Tenant.objects.filter(id=tenant.id).exists()
|
||||
|
||||
mock_get_provider_graph_database_names.assert_not_called()
|
||||
mock_drop_database.assert_not_called()
|
||||
|
||||
@@ -417,9 +417,8 @@ class TestProwlerIntegrationConnectionTest:
|
||||
raise_on_exception=False,
|
||||
)
|
||||
|
||||
@patch("api.utils.AwsProvider")
|
||||
@patch("api.utils.S3")
|
||||
def test_s3_integration_connection_failure(self, mock_s3_class, mock_aws_provider):
|
||||
def test_s3_integration_connection_failure(self, mock_s3_class):
|
||||
"""Test S3 integration connection failure."""
|
||||
integration = MagicMock()
|
||||
integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
|
||||
@@ -429,9 +428,6 @@ class TestProwlerIntegrationConnectionTest:
|
||||
}
|
||||
integration.configuration = {"bucket_name": "test-bucket"}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_aws_provider.return_value.session.current_session = mock_session
|
||||
|
||||
mock_connection = Connection(
|
||||
is_connected=False, error=Exception("Bucket not found")
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,410 @@
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import matplotlib
|
||||
import pytest
|
||||
from reportlab.lib import colors
|
||||
from tasks.jobs.report import generate_compliance_reports, generate_threatscore_report
|
||||
from tasks.jobs.reports import (
|
||||
CHART_COLOR_GREEN_1,
|
||||
CHART_COLOR_GREEN_2,
|
||||
CHART_COLOR_ORANGE,
|
||||
CHART_COLOR_RED,
|
||||
CHART_COLOR_YELLOW,
|
||||
COLOR_BLUE,
|
||||
COLOR_ENS_ALTO,
|
||||
COLOR_HIGH_RISK,
|
||||
COLOR_LOW_RISK,
|
||||
COLOR_MEDIUM_RISK,
|
||||
COLOR_NIS2_PRIMARY,
|
||||
COLOR_SAFE,
|
||||
create_pdf_styles,
|
||||
get_chart_color_for_percentage,
|
||||
get_color_for_compliance,
|
||||
get_color_for_risk_level,
|
||||
get_color_for_weight,
|
||||
)
|
||||
from tasks.jobs.threatscore_utils import (
|
||||
_aggregate_requirement_statistics_from_database,
|
||||
_load_findings_for_requirement_checks,
|
||||
)
|
||||
|
||||
from api.models import Finding, StatusChoices
|
||||
from prowler.lib.check.models import Severity
|
||||
|
||||
matplotlib.use("Agg") # Use non-interactive backend for tests
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestAggregateRequirementStatistics:
|
||||
"""Test suite for _aggregate_requirement_statistics_from_database function."""
|
||||
|
||||
def test_aggregates_findings_correctly(self, tenants_fixture, scans_fixture):
|
||||
"""Verify correct pass/total counts per check are aggregated from database."""
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
|
||||
Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
scan=scan,
|
||||
uid="finding-1",
|
||||
check_id="check_1",
|
||||
status=StatusChoices.PASS,
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
check_metadata={},
|
||||
raw_result={},
|
||||
)
|
||||
Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
scan=scan,
|
||||
uid="finding-2",
|
||||
check_id="check_1",
|
||||
status=StatusChoices.FAIL,
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
check_metadata={},
|
||||
raw_result={},
|
||||
)
|
||||
Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
scan=scan,
|
||||
uid="finding-3",
|
||||
check_id="check_2",
|
||||
status=StatusChoices.PASS,
|
||||
severity=Severity.medium,
|
||||
impact=Severity.medium,
|
||||
check_metadata={},
|
||||
raw_result={},
|
||||
)
|
||||
|
||||
result = _aggregate_requirement_statistics_from_database(
|
||||
str(tenant.id), str(scan.id)
|
||||
)
|
||||
|
||||
assert "check_1" in result
|
||||
assert result["check_1"]["passed"] == 1
|
||||
assert result["check_1"]["total"] == 2
|
||||
|
||||
assert "check_2" in result
|
||||
assert result["check_2"]["passed"] == 1
|
||||
assert result["check_2"]["total"] == 1
|
||||
|
||||
def test_handles_empty_scan(self, tenants_fixture, scans_fixture):
|
||||
"""Verify empty result is returned for scan with no findings."""
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
|
||||
result = _aggregate_requirement_statistics_from_database(
|
||||
str(tenant.id), str(scan.id)
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_only_failed_findings(self, tenants_fixture, scans_fixture):
|
||||
"""Verify correct counts when all findings are FAIL."""
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
|
||||
Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
scan=scan,
|
||||
uid="finding-1",
|
||||
check_id="check_1",
|
||||
status=StatusChoices.FAIL,
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
check_metadata={},
|
||||
raw_result={},
|
||||
)
|
||||
Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
scan=scan,
|
||||
uid="finding-2",
|
||||
check_id="check_1",
|
||||
status=StatusChoices.FAIL,
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
check_metadata={},
|
||||
raw_result={},
|
||||
)
|
||||
|
||||
result = _aggregate_requirement_statistics_from_database(
|
||||
str(tenant.id), str(scan.id)
|
||||
)
|
||||
|
||||
assert result["check_1"]["passed"] == 0
|
||||
assert result["check_1"]["total"] == 2
|
||||
|
||||
def test_multiple_findings_same_check(self, tenants_fixture, scans_fixture):
|
||||
"""Verify multiple findings for same check are correctly aggregated."""
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
|
||||
for i in range(5):
|
||||
Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
scan=scan,
|
||||
uid=f"finding-{i}",
|
||||
check_id="check_1",
|
||||
status=StatusChoices.PASS if i % 2 == 0 else StatusChoices.FAIL,
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
check_metadata={},
|
||||
raw_result={},
|
||||
)
|
||||
|
||||
result = _aggregate_requirement_statistics_from_database(
|
||||
str(tenant.id), str(scan.id)
|
||||
)
|
||||
|
||||
assert result["check_1"]["passed"] == 3
|
||||
assert result["check_1"]["total"] == 5
|
||||
|
||||
def test_mixed_statuses(self, tenants_fixture, scans_fixture):
|
||||
"""Verify MANUAL status is counted in total but not passed."""
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
|
||||
Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
scan=scan,
|
||||
uid="finding-1",
|
||||
check_id="check_1",
|
||||
status=StatusChoices.PASS,
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
check_metadata={},
|
||||
raw_result={},
|
||||
)
|
||||
Finding.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
scan=scan,
|
||||
uid="finding-2",
|
||||
check_id="check_1",
|
||||
status=StatusChoices.MANUAL,
|
||||
severity=Severity.high,
|
||||
impact=Severity.high,
|
||||
check_metadata={},
|
||||
raw_result={},
|
||||
)
|
||||
|
||||
result = _aggregate_requirement_statistics_from_database(
|
||||
str(tenant.id), str(scan.id)
|
||||
)
|
||||
|
||||
# MANUAL findings are excluded from the aggregation query
|
||||
# since it only counts PASS and FAIL statuses
|
||||
assert result["check_1"]["passed"] == 1
|
||||
assert result["check_1"]["total"] == 1
|
||||
|
||||
|
||||
class TestColorHelperFunctions:
|
||||
"""Test suite for color helper functions."""
|
||||
|
||||
def test_get_color_for_risk_level_high(self):
|
||||
"""Test high risk level returns correct color."""
|
||||
result = get_color_for_risk_level(5)
|
||||
assert result == COLOR_HIGH_RISK
|
||||
|
||||
def test_get_color_for_risk_level_medium_high(self):
|
||||
"""Test risk level 4 returns high risk color."""
|
||||
result = get_color_for_risk_level(4)
|
||||
assert result == COLOR_HIGH_RISK # >= 4 is high risk
|
||||
|
||||
def test_get_color_for_risk_level_medium(self):
|
||||
"""Test risk level 3 returns medium risk color."""
|
||||
result = get_color_for_risk_level(3)
|
||||
assert result == COLOR_MEDIUM_RISK # >= 3 is medium risk
|
||||
|
||||
def test_get_color_for_risk_level_low(self):
|
||||
"""Test low risk level returns safe color."""
|
||||
result = get_color_for_risk_level(1)
|
||||
assert result == COLOR_SAFE # < 2 is safe
|
||||
|
||||
def test_get_color_for_weight_high(self):
|
||||
"""Test high weight returns correct color."""
|
||||
result = get_color_for_weight(150)
|
||||
assert result == COLOR_HIGH_RISK # > 100 is high risk
|
||||
|
||||
def test_get_color_for_weight_medium(self):
|
||||
"""Test medium weight returns low risk color."""
|
||||
result = get_color_for_weight(100)
|
||||
assert result == COLOR_LOW_RISK # 51-100 is low risk
|
||||
|
||||
def test_get_color_for_weight_low(self):
|
||||
"""Test low weight returns safe color."""
|
||||
result = get_color_for_weight(50)
|
||||
assert result == COLOR_SAFE # <= 50 is safe
|
||||
|
||||
def test_get_color_for_compliance_high(self):
|
||||
"""Test high compliance returns green color."""
|
||||
result = get_color_for_compliance(85)
|
||||
assert result == COLOR_SAFE
|
||||
|
||||
def test_get_color_for_compliance_medium(self):
|
||||
"""Test medium compliance returns yellow color."""
|
||||
result = get_color_for_compliance(70)
|
||||
assert result == COLOR_LOW_RISK
|
||||
|
||||
def test_get_color_for_compliance_low(self):
|
||||
"""Test low compliance returns red color."""
|
||||
result = get_color_for_compliance(50)
|
||||
assert result == COLOR_HIGH_RISK
|
||||
|
||||
def test_get_chart_color_for_percentage_excellent(self):
|
||||
"""Test excellent percentage returns correct chart color."""
|
||||
result = get_chart_color_for_percentage(90)
|
||||
assert result == CHART_COLOR_GREEN_1
|
||||
|
||||
def test_get_chart_color_for_percentage_good(self):
|
||||
"""Test good percentage returns correct chart color."""
|
||||
result = get_chart_color_for_percentage(70)
|
||||
assert result == CHART_COLOR_GREEN_2
|
||||
|
||||
def test_get_chart_color_for_percentage_fair(self):
|
||||
"""Test fair percentage returns correct chart color."""
|
||||
result = get_chart_color_for_percentage(50)
|
||||
assert result == CHART_COLOR_YELLOW
|
||||
|
||||
def test_get_chart_color_for_percentage_poor(self):
|
||||
"""Test poor percentage returns correct chart color."""
|
||||
result = get_chart_color_for_percentage(30)
|
||||
assert result == CHART_COLOR_ORANGE
|
||||
|
||||
def test_get_chart_color_for_percentage_critical(self):
|
||||
"""Test critical percentage returns correct chart color."""
|
||||
result = get_chart_color_for_percentage(10)
|
||||
assert result == CHART_COLOR_RED
|
||||
|
||||
|
||||
class TestPDFStylesCreation:
|
||||
"""Test suite for PDF styles creation."""
|
||||
|
||||
def test_create_pdf_styles_returns_dict(self):
|
||||
"""Test that create_pdf_styles returns a dictionary."""
|
||||
result = create_pdf_styles()
|
||||
assert isinstance(result, dict)
|
||||
|
||||
def test_create_pdf_styles_caches_result(self):
|
||||
"""Test that create_pdf_styles caches the result."""
|
||||
result1 = create_pdf_styles()
|
||||
result2 = create_pdf_styles()
|
||||
assert result1 is result2
|
||||
|
||||
def test_pdf_styles_have_correct_keys(self):
|
||||
"""Test that PDF styles dictionary has expected keys."""
|
||||
result = create_pdf_styles()
|
||||
expected_keys = ["title", "h1", "h2", "h3", "normal", "normal_center"]
|
||||
for key in expected_keys:
|
||||
assert key in result
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestLoadFindingsForChecks:
|
||||
"""Test suite for _load_findings_for_requirement_checks function."""
|
||||
|
||||
def test_empty_check_ids_returns_empty(self, tenants_fixture, providers_fixture):
|
||||
"""Test that empty check_ids list returns empty dict."""
|
||||
tenant = tenants_fixture[0]
|
||||
|
||||
mock_prowler_provider = Mock()
|
||||
mock_prowler_provider.identity.account = "test-account"
|
||||
|
||||
result = _load_findings_for_requirement_checks(
|
||||
str(tenant.id), str(uuid.uuid4()), [], mock_prowler_provider
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGenerateThreatscoreReportFunction:
|
||||
"""Test suite for generate_threatscore_report function."""
|
||||
|
||||
@patch("tasks.jobs.reports.base.initialize_prowler_provider")
|
||||
def test_generate_threatscore_report_exception_handling(
|
||||
self,
|
||||
mock_initialize_provider,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""Test that exceptions during report generation are properly handled."""
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
mock_initialize_provider.side_effect = Exception("Test exception")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
generate_threatscore_report(
|
||||
tenant_id=str(tenant.id),
|
||||
scan_id=str(scan.id),
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
output_path="/tmp/test_report.pdf",
|
||||
provider_id=str(provider.id),
|
||||
)
|
||||
|
||||
assert "Test exception" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGenerateComplianceReportsOptimized:
|
||||
"""Test suite for generate_compliance_reports function."""
|
||||
|
||||
@patch("tasks.jobs.report._upload_to_s3")
|
||||
@patch("tasks.jobs.report.generate_threatscore_report")
|
||||
@patch("tasks.jobs.report.generate_ens_report")
|
||||
@patch("tasks.jobs.report.generate_nis2_report")
|
||||
def test_no_findings_returns_early_for_both_reports(
|
||||
self,
|
||||
mock_nis2,
|
||||
mock_ens,
|
||||
mock_threatscore,
|
||||
mock_upload,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""Test that function returns early when scan has no findings."""
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
result = generate_compliance_reports(
|
||||
tenant_id=str(tenant.id),
|
||||
scan_id=str(scan.id),
|
||||
provider_id=str(provider.id),
|
||||
generate_threatscore=True,
|
||||
generate_ens=True,
|
||||
generate_nis2=True,
|
||||
)
|
||||
|
||||
assert result["threatscore"]["upload"] is False
|
||||
assert result["ens"]["upload"] is False
|
||||
assert result["nis2"]["upload"] is False
|
||||
|
||||
mock_threatscore.assert_not_called()
|
||||
mock_ens.assert_not_called()
|
||||
mock_nis2.assert_not_called()
|
||||
|
||||
|
||||
class TestOptimizationImprovements:
|
||||
"""Test suite for optimization-related functionality."""
|
||||
|
||||
def test_chart_color_constants_are_strings(self):
|
||||
"""Verify chart color constants are valid hex color strings."""
|
||||
assert CHART_COLOR_GREEN_1.startswith("#")
|
||||
assert CHART_COLOR_GREEN_2.startswith("#")
|
||||
assert CHART_COLOR_YELLOW.startswith("#")
|
||||
assert CHART_COLOR_ORANGE.startswith("#")
|
||||
assert CHART_COLOR_RED.startswith("#")
|
||||
|
||||
def test_color_constants_are_color_objects(self):
|
||||
"""Verify color constants are Color objects."""
|
||||
assert isinstance(COLOR_BLUE, colors.Color)
|
||||
assert isinstance(COLOR_HIGH_RISK, colors.Color)
|
||||
assert isinstance(COLOR_SAFE, colors.Color)
|
||||
assert isinstance(COLOR_ENS_ALTO, colors.Color)
|
||||
assert isinstance(COLOR_NIS2_PRIMARY, colors.Color)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1380,6 +1380,8 @@ class TestProcessFindingMicroBatch:
|
||||
scan_resource_cache: set[tuple[str, str, str, str]] = set()
|
||||
mute_rules_cache = {}
|
||||
scan_categories_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
scan_resource_groups_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
|
||||
with (
|
||||
patch("tasks.jobs.scan.rls_transaction", new=noop_rls_transaction),
|
||||
@@ -1398,6 +1400,8 @@ class TestProcessFindingMicroBatch:
|
||||
scan_resource_cache,
|
||||
mute_rules_cache,
|
||||
scan_categories_cache,
|
||||
scan_resource_groups_cache,
|
||||
group_resources_cache,
|
||||
)
|
||||
|
||||
created_finding = Finding.objects.get(uid=finding.uid)
|
||||
@@ -1491,6 +1495,8 @@ class TestProcessFindingMicroBatch:
|
||||
scan_resource_cache: set[tuple[str, str, str, str]] = set()
|
||||
mute_rules_cache = {finding.uid: "Muted via rule"}
|
||||
scan_categories_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
scan_resource_groups_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
|
||||
with (
|
||||
patch("tasks.jobs.scan.rls_transaction", new=noop_rls_transaction),
|
||||
@@ -1509,6 +1515,8 @@ class TestProcessFindingMicroBatch:
|
||||
scan_resource_cache,
|
||||
mute_rules_cache,
|
||||
scan_categories_cache,
|
||||
scan_resource_groups_cache,
|
||||
group_resources_cache,
|
||||
)
|
||||
|
||||
existing_resource.refresh_from_db()
|
||||
@@ -1617,6 +1625,8 @@ class TestProcessFindingMicroBatch:
|
||||
scan_resource_cache: set[tuple[str, str, str, str]] = set()
|
||||
mute_rules_cache = {}
|
||||
scan_categories_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
scan_resource_groups_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
|
||||
with (
|
||||
patch("tasks.jobs.scan.rls_transaction", new=noop_rls_transaction),
|
||||
@@ -1636,6 +1646,8 @@ class TestProcessFindingMicroBatch:
|
||||
scan_resource_cache,
|
||||
mute_rules_cache,
|
||||
scan_categories_cache,
|
||||
scan_resource_groups_cache,
|
||||
group_resources_cache,
|
||||
)
|
||||
|
||||
# Verify the long UID finding was NOT created
|
||||
@@ -1713,6 +1725,8 @@ class TestProcessFindingMicroBatch:
|
||||
scan_resource_cache: set[tuple[str, str, str, str]] = set()
|
||||
mute_rules_cache = {}
|
||||
scan_categories_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
scan_resource_groups_cache: dict[tuple[str, str], dict[str, int]] = {}
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
|
||||
with (
|
||||
patch("tasks.jobs.scan.rls_transaction", new=noop_rls_transaction),
|
||||
@@ -1731,6 +1745,8 @@ class TestProcessFindingMicroBatch:
|
||||
scan_resource_cache,
|
||||
mute_rules_cache,
|
||||
scan_categories_cache,
|
||||
scan_resource_groups_cache,
|
||||
group_resources_cache,
|
||||
)
|
||||
|
||||
# finding1: PASS, severity=low, categories=["gen-ai", "security"]
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from django_celery_beat.models import IntervalSchedule, PeriodicTask
|
||||
from django_celery_results.models import TaskResult
|
||||
from tasks.jobs.lighthouse_providers import (
|
||||
_create_bedrock_client,
|
||||
_extract_bedrock_credentials,
|
||||
@@ -15,6 +18,8 @@ from tasks.tasks import (
|
||||
check_integrations_task,
|
||||
check_lighthouse_provider_connection_task,
|
||||
generate_outputs_task,
|
||||
perform_attack_paths_scan_task,
|
||||
perform_scheduled_scan_task,
|
||||
refresh_lighthouse_provider_models_task,
|
||||
s3_integration_task,
|
||||
security_hub_integration_task,
|
||||
@@ -26,6 +31,7 @@ from api.models import (
|
||||
LighthouseProviderModels,
|
||||
Scan,
|
||||
StateChoices,
|
||||
Task,
|
||||
)
|
||||
|
||||
|
||||
@@ -737,8 +743,12 @@ class TestScanCompleteTasks:
|
||||
@patch("tasks.tasks.generate_outputs_task.si")
|
||||
@patch("tasks.tasks.generate_compliance_reports_task.si")
|
||||
@patch("tasks.tasks.check_integrations_task.si")
|
||||
@patch("tasks.tasks.perform_attack_paths_scan_task.apply_async")
|
||||
@patch("tasks.tasks.can_provider_run_attack_paths_scan", return_value=False)
|
||||
def test_scan_complete_tasks(
|
||||
self,
|
||||
mock_can_run_attack_paths,
|
||||
mock_attack_paths_task,
|
||||
mock_check_integrations_task,
|
||||
mock_compliance_reports_task,
|
||||
mock_outputs_task,
|
||||
@@ -793,6 +803,67 @@ class TestScanCompleteTasks:
|
||||
scan_id="scan-id",
|
||||
)
|
||||
|
||||
# Attack Paths task should be skipped when provider cannot run it
|
||||
mock_attack_paths_task.assert_not_called()
|
||||
|
||||
|
||||
class TestAttackPathsTasks:
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _override_task_request(task, **attrs):
|
||||
request = task.request
|
||||
sentinel = object()
|
||||
previous = {key: getattr(request, key, sentinel) for key in attrs}
|
||||
for key, value in attrs.items():
|
||||
setattr(request, key, value)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for key, prev in previous.items():
|
||||
if prev is sentinel:
|
||||
if hasattr(request, key):
|
||||
delattr(request, key)
|
||||
else:
|
||||
setattr(request, key, prev)
|
||||
|
||||
def test_perform_attack_paths_scan_task_calls_runner(self):
|
||||
with (
|
||||
patch("tasks.tasks.attack_paths_scan") as mock_attack_paths_scan,
|
||||
self._override_task_request(
|
||||
perform_attack_paths_scan_task, id="celery-task-id"
|
||||
),
|
||||
):
|
||||
mock_attack_paths_scan.return_value = {"status": "ok"}
|
||||
|
||||
result = perform_attack_paths_scan_task.run(
|
||||
tenant_id="tenant-id", scan_id="scan-id"
|
||||
)
|
||||
|
||||
mock_attack_paths_scan.assert_called_once_with(
|
||||
tenant_id="tenant-id", scan_id="scan-id", task_id="celery-task-id"
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
def test_perform_attack_paths_scan_task_propagates_exception(self):
|
||||
with (
|
||||
patch(
|
||||
"tasks.tasks.attack_paths_scan",
|
||||
side_effect=RuntimeError("Exception to propagate"),
|
||||
) as mock_attack_paths_scan,
|
||||
self._override_task_request(
|
||||
perform_attack_paths_scan_task, id="celery-task-error"
|
||||
),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="Exception to propagate"):
|
||||
perform_attack_paths_scan_task.run(
|
||||
tenant_id="tenant-id", scan_id="scan-id"
|
||||
)
|
||||
|
||||
mock_attack_paths_scan.assert_called_once_with(
|
||||
tenant_id="tenant-id", scan_id="scan-id", task_id="celery-task-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestCheckIntegrationsTask:
|
||||
@@ -2068,3 +2139,215 @@ class TestCleanupOrphanScheduledScans:
|
||||
assert not Scan.objects.filter(id=orphan_scan.id).exists()
|
||||
assert Scan.objects.filter(id=scheduled_scan.id).exists()
|
||||
assert Scan.objects.filter(id=available_scan_other_task.id).exists()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestPerformScheduledScanTask:
|
||||
"""Unit tests for perform_scheduled_scan_task."""
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _override_task_request(task, **attrs):
|
||||
request = task.request
|
||||
sentinel = object()
|
||||
previous = {key: getattr(request, key, sentinel) for key in attrs}
|
||||
for key, value in attrs.items():
|
||||
setattr(request, key, value)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for key, prev in previous.items():
|
||||
if prev is sentinel:
|
||||
if hasattr(request, key):
|
||||
delattr(request, key)
|
||||
else:
|
||||
setattr(request, key, prev)
|
||||
|
||||
def _create_periodic_task(self, provider_id, tenant_id, interval_hours=24):
|
||||
interval, _ = IntervalSchedule.objects.get_or_create(
|
||||
every=interval_hours, period="hours"
|
||||
)
|
||||
return PeriodicTask.objects.create(
|
||||
name=f"scan-perform-scheduled-{provider_id}",
|
||||
task="scan-perform-scheduled",
|
||||
interval=interval,
|
||||
kwargs=f'{{"tenant_id": "{tenant_id}", "provider_id": "{provider_id}"}}',
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
def _create_task_result(self, tenant_id, task_id):
|
||||
task_result = TaskResult.objects.create(
|
||||
task_id=task_id,
|
||||
task_name="scan-perform-scheduled",
|
||||
status="STARTED",
|
||||
date_created=datetime.now(timezone.utc),
|
||||
)
|
||||
Task.objects.create(
|
||||
id=task_id, task_runner_task=task_result, tenant_id=tenant_id
|
||||
)
|
||||
return task_result
|
||||
|
||||
def test_skip_when_scheduled_scan_executing(
|
||||
self, tenants_fixture, providers_fixture
|
||||
):
|
||||
"""Skip a scheduled run when another scheduled scan is already executing."""
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
periodic_task = self._create_periodic_task(provider.id, tenant.id)
|
||||
task_id = str(uuid.uuid4())
|
||||
self._create_task_result(tenant.id, task_id)
|
||||
|
||||
executing_scan = Scan.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
name="Daily scheduled scan",
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.EXECUTING,
|
||||
scheduler_task_id=periodic_task.id,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("tasks.tasks.perform_prowler_scan") as mock_scan,
|
||||
patch("tasks.tasks._perform_scan_complete_tasks") as mock_complete_tasks,
|
||||
self._override_task_request(perform_scheduled_scan_task, id=task_id),
|
||||
):
|
||||
result = perform_scheduled_scan_task.run(
|
||||
tenant_id=str(tenant.id), provider_id=str(provider.id)
|
||||
)
|
||||
|
||||
mock_scan.assert_not_called()
|
||||
mock_complete_tasks.assert_not_called()
|
||||
assert result["id"] == str(executing_scan.id)
|
||||
assert result["state"] == StateChoices.EXECUTING
|
||||
assert (
|
||||
Scan.objects.filter(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.SCHEDULED,
|
||||
).count()
|
||||
== 0
|
||||
)
|
||||
|
||||
def test_creates_next_scheduled_scan_after_completion(
|
||||
self, tenants_fixture, providers_fixture
|
||||
):
|
||||
"""Create a next scheduled scan after a successful run completes."""
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
self._create_periodic_task(provider.id, tenant.id)
|
||||
task_id = str(uuid.uuid4())
|
||||
self._create_task_result(tenant.id, task_id)
|
||||
|
||||
def _complete_scan(tenant_id, scan_id, provider_id):
|
||||
other_scheduled = Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.SCHEDULED,
|
||||
).exclude(id=scan_id)
|
||||
assert not other_scheduled.exists()
|
||||
scan_instance = Scan.objects.get(id=scan_id)
|
||||
scan_instance.state = StateChoices.COMPLETED
|
||||
scan_instance.save()
|
||||
return {"status": "ok"}
|
||||
|
||||
with (
|
||||
patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan),
|
||||
patch("tasks.tasks._perform_scan_complete_tasks"),
|
||||
self._override_task_request(perform_scheduled_scan_task, id=task_id),
|
||||
):
|
||||
perform_scheduled_scan_task.run(
|
||||
tenant_id=str(tenant.id), provider_id=str(provider.id)
|
||||
)
|
||||
|
||||
scheduled_scans = Scan.objects.filter(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.SCHEDULED,
|
||||
)
|
||||
assert scheduled_scans.count() == 1
|
||||
assert scheduled_scans.first().scheduled_at > datetime.now(timezone.utc)
|
||||
assert (
|
||||
Scan.objects.filter(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||
).count()
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
Scan.objects.filter(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.COMPLETED,
|
||||
).count()
|
||||
== 1
|
||||
)
|
||||
|
||||
def test_dedupes_multiple_scheduled_scans_before_run(
|
||||
self, tenants_fixture, providers_fixture
|
||||
):
|
||||
"""Ensure duplicated scheduled scans are removed before executing."""
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
periodic_task = self._create_periodic_task(provider.id, tenant.id)
|
||||
task_id = str(uuid.uuid4())
|
||||
self._create_task_result(tenant.id, task_id)
|
||||
|
||||
scheduled_scan = Scan.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
name="Daily scheduled scan",
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.SCHEDULED,
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
scheduler_task_id=periodic_task.id,
|
||||
)
|
||||
duplicate_scan = Scan.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
name="Daily scheduled scan",
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.AVAILABLE,
|
||||
scheduled_at=scheduled_scan.scheduled_at,
|
||||
scheduler_task_id=periodic_task.id,
|
||||
)
|
||||
|
||||
def _complete_scan(tenant_id, scan_id, provider_id):
|
||||
other_scheduled = Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||
).exclude(id=scan_id)
|
||||
assert not other_scheduled.exists()
|
||||
scan_instance = Scan.objects.get(id=scan_id)
|
||||
scan_instance.state = StateChoices.COMPLETED
|
||||
scan_instance.save()
|
||||
return {"status": "ok"}
|
||||
|
||||
with (
|
||||
patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan),
|
||||
patch("tasks.tasks._perform_scan_complete_tasks"),
|
||||
self._override_task_request(perform_scheduled_scan_task, id=task_id),
|
||||
):
|
||||
perform_scheduled_scan_task.run(
|
||||
tenant_id=str(tenant.id), provider_id=str(provider.id)
|
||||
)
|
||||
|
||||
assert not Scan.objects.filter(id=duplicate_scan.id).exists()
|
||||
assert Scan.objects.filter(id=scheduled_scan.id).exists()
|
||||
assert (
|
||||
Scan.objects.filter(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||
).count()
|
||||
== 1
|
||||
)
|
||||
|
||||
@@ -5,6 +5,10 @@ from enum import Enum
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
from django_celery_results.models import TaskResult
|
||||
|
||||
from api.models import Scan, StateChoices
|
||||
|
||||
SCHEDULED_SCAN_NAME = "Daily scheduled scan"
|
||||
|
||||
|
||||
class CustomEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
@@ -71,3 +75,58 @@ def batched(iterable, batch_size):
|
||||
batch = []
|
||||
|
||||
yield batch, True
|
||||
|
||||
|
||||
def _get_or_create_scheduled_scan(
|
||||
tenant_id: str,
|
||||
provider_id: str,
|
||||
scheduler_task_id: int,
|
||||
scheduled_at: datetime,
|
||||
update_state: bool = False,
|
||||
) -> Scan:
|
||||
"""
|
||||
Get or create a scheduled scan, cleaning up duplicates if found.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID.
|
||||
provider_id: The provider ID.
|
||||
scheduler_task_id: The PeriodicTask ID.
|
||||
scheduled_at: The scheduled datetime for the scan.
|
||||
update_state: If True, also reset state to SCHEDULED when updating.
|
||||
|
||||
Returns:
|
||||
The scan instance to use.
|
||||
"""
|
||||
scheduled_scans = list(
|
||||
Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||
scheduler_task_id=scheduler_task_id,
|
||||
).order_by("scheduled_at", "inserted_at")
|
||||
)
|
||||
|
||||
if scheduled_scans:
|
||||
scan_instance = scheduled_scans[0]
|
||||
if len(scheduled_scans) > 1:
|
||||
Scan.objects.filter(id__in=[s.id for s in scheduled_scans[1:]]).delete()
|
||||
needs_update = scan_instance.scheduled_at != scheduled_at
|
||||
if update_state and scan_instance.state != StateChoices.SCHEDULED:
|
||||
scan_instance.state = StateChoices.SCHEDULED
|
||||
scan_instance.name = SCHEDULED_SCAN_NAME
|
||||
needs_update = True
|
||||
if needs_update:
|
||||
scan_instance.scheduled_at = scheduled_at
|
||||
scan_instance.save()
|
||||
return scan_instance
|
||||
|
||||
return Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
name=SCHEDULED_SCAN_NAME,
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.SCHEDULED,
|
||||
scheduled_at=scheduled_at,
|
||||
scheduler_task_id=scheduler_task_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# prowler/contrib/aws/simulate_policy_client.py
|
||||
from typing import Optional
|
||||
|
||||
from prowler.contrib.aws.simulate_policy.simulate_policy_service import IamSimulator
|
||||
from prowler.providers.common.provider import Provider
|
||||
|
||||
_iam_simulator_client: Optional[IamSimulator] = None
|
||||
|
||||
|
||||
def get_iam_simulator_client() -> IamSimulator:
|
||||
global _iam_simulator_client
|
||||
if _iam_simulator_client is None:
|
||||
provider = Provider.get_global_provider()
|
||||
if provider is None:
|
||||
# Fail fast with a clear message if somehow called too early
|
||||
raise RuntimeError(
|
||||
"Global Provider is not initialized yet for IAM simulator."
|
||||
)
|
||||
_iam_simulator_client = IamSimulator(provider)
|
||||
return _iam_simulator_client
|
||||
@@ -0,0 +1,200 @@
|
||||
# prowler/contrib/aws/simulate_policy_service.py
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from prowler.providers.common.provider import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# PURPOSE
|
||||
# ----------------------------------------------------------------------
|
||||
# This module provides a precise way to test IAM actions programmatically.
|
||||
# It replicates the behaviour of the AWS CLI command:
|
||||
# aws iam simulate-principal-policy --policy-source-arn arn:aws:iam::<account>:role/<role> --action-names <action>
|
||||
#
|
||||
# Use this when you need to validate whether a specific IAM role allows or denies
|
||||
# certain actions against given resources.
|
||||
#
|
||||
# ======================================================================
|
||||
# CLI ANALOGUE
|
||||
# ----------------------------------------------------------------------
|
||||
# Example equivalent CLI command:
|
||||
# aws iam simulate-principal-policy \
|
||||
# --policy-source-arn arn:aws:iam::278419598935:role/your-role \
|
||||
# --action-names datazone:AcceptPredictions
|
||||
#
|
||||
# ======================================================================
|
||||
# DOCUMENTATION
|
||||
# ----------------------------------------------------------------------
|
||||
# AWS IAM Policy Simulator:
|
||||
# https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies_testing-policies.html
|
||||
#
|
||||
# IAM Condition Keys:
|
||||
# https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_condition-keys.html
|
||||
#
|
||||
# Related AWS SDK discussion:
|
||||
# https://github.com/aws/aws-sdk/issues/102
|
||||
#
|
||||
# ======================================================================
|
||||
# LIMITATIONS
|
||||
# ----------------------------------------------------------------------
|
||||
# - The IAM Policy Simulator does NOT evaluate Service Control Policies (SCPs)
|
||||
# that include conditions. This is a limitation of the API.
|
||||
# - In environments where SCPs contain conditions, use
|
||||
# `is_action_allowed_simulate_custom_policy` instead.
|
||||
# - In environments without SCP conditions, `is_action_allowed_simulate_principal_policy`
|
||||
# works as expected.
|
||||
#
|
||||
# ======================================================================
|
||||
# USAGE
|
||||
# ----------------------------------------------------------------------
|
||||
# In your custom check:
|
||||
#
|
||||
# from prowler.contrib.aws.simulate_policy.simulate_policy_client import get_iam_simulator_client
|
||||
#
|
||||
# iam_sim = get_iam_simulator_client()
|
||||
# policy_data = iam_sim.get_role_policy_data(role_name=role_name)
|
||||
# iam_sim.is_action_allowed_simulate_custom_policy(
|
||||
# policy_data=policy_data,
|
||||
# action_names=[action],
|
||||
# resource_arns=["*"]
|
||||
# )
|
||||
#
|
||||
#
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class IamSimulator:
|
||||
"""
|
||||
Helper for IAM Policy Simulator:
|
||||
- simulate_principal_policy
|
||||
- simulate_custom_policy
|
||||
- collect role inline/managed policies
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Provider) -> None:
|
||||
|
||||
boto3_session = provider.session.current_session
|
||||
|
||||
# IAM is a global service. Region is optional; we can use the provider's global region
|
||||
# to stay consistent across partitions.
|
||||
try:
|
||||
region_name = provider.get_global_region()
|
||||
except AttributeError:
|
||||
# Fallback if provider lacks the helper (older trees)
|
||||
region_name = boto3_session.region_name or "us-east-1"
|
||||
|
||||
self.iam = boto3_session.client("iam", region_name=region_name)
|
||||
|
||||
def is_action_allowed_simulate_principal_policy(
|
||||
self,
|
||||
principal_arn: str,
|
||||
action_names: List[str],
|
||||
resource_arns: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, Dict]:
|
||||
if resource_arns is None:
|
||||
resource_arns = ["*"]
|
||||
try:
|
||||
resp = self.iam.simulate_principal_policy(
|
||||
PolicySourceArn=principal_arn,
|
||||
ActionNames=action_names,
|
||||
ResourceArns=resource_arns,
|
||||
)
|
||||
allowed = any(
|
||||
r.get("EvalDecision") == "allowed"
|
||||
for r in resp.get("EvaluationResults", [])
|
||||
)
|
||||
return allowed, resp
|
||||
except ClientError as e:
|
||||
logger.error("simulate_principal_policy failed: %s", e, exc_info=True)
|
||||
return False, {"error": str(e)}
|
||||
|
||||
def get_role_policy_data(self, role_name: str) -> Dict[str, List]:
|
||||
inline_names: List[str] = []
|
||||
inline_docs: List[Dict] = []
|
||||
managed_names: List[str] = []
|
||||
managed_docs: List[Dict] = []
|
||||
|
||||
# Inline policies
|
||||
inline_resp = self.iam.list_role_policies(RoleName=role_name)
|
||||
inline_names = inline_resp.get("PolicyNames", [])
|
||||
for pname in inline_names:
|
||||
pol_resp = self.iam.get_role_policy(RoleName=role_name, PolicyName=pname)
|
||||
inline_docs.append(pol_resp["PolicyDocument"]) # dict
|
||||
|
||||
# Managed policies
|
||||
managed_resp = self.iam.list_attached_role_policies(RoleName=role_name)
|
||||
for attached in managed_resp.get("AttachedPolicies", []):
|
||||
managed_names.append(attached["PolicyName"])
|
||||
pol_meta = self.iam.get_policy(PolicyArn=attached["PolicyArn"])["Policy"]
|
||||
pol_ver = self.iam.get_policy_version(
|
||||
PolicyArn=attached["PolicyArn"], VersionId=pol_meta["DefaultVersionId"]
|
||||
)
|
||||
managed_docs.append(pol_ver["PolicyVersion"]["Document"]) # dict
|
||||
|
||||
return {
|
||||
"inline_policy_names": inline_names,
|
||||
"inline_policy_data": inline_docs,
|
||||
"managed_policy_names": managed_names,
|
||||
"managed_policy_data": managed_docs,
|
||||
}
|
||||
|
||||
def is_action_allowed_simulate_custom_policy(
|
||||
self,
|
||||
policy_data: Dict[str, List],
|
||||
action_names: List[str],
|
||||
resource_arns: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, Dict]:
|
||||
names = policy_data.get("inline_policy_names", []) + policy_data.get(
|
||||
"managed_policy_names", []
|
||||
)
|
||||
docs = policy_data.get("inline_policy_data", []) + policy_data.get(
|
||||
"managed_policy_data", []
|
||||
)
|
||||
|
||||
results: Dict[str, List] = {"policies": []}
|
||||
any_allowed = False
|
||||
if resource_arns is None:
|
||||
resource_arns = ["*"]
|
||||
|
||||
for idx, doc in enumerate(docs):
|
||||
name = names[idx] if idx < len(names) else f"policy_{idx}"
|
||||
try:
|
||||
sim_resp = self.iam.simulate_custom_policy(
|
||||
PolicyInputList=[json.dumps(doc)],
|
||||
ActionNames=action_names,
|
||||
ResourceArns=resource_arns,
|
||||
)
|
||||
except ClientError as e:
|
||||
logger.error(
|
||||
"simulate_custom_policy failed for %s: %s", name, e, exc_info=True
|
||||
)
|
||||
results["policies"].append({"policy_name": name, "error": str(e)})
|
||||
continue
|
||||
|
||||
per_action = []
|
||||
for ev in sim_resp.get("EvaluationResults", []):
|
||||
decision = ev.get(
|
||||
"EvalDecision"
|
||||
) # allowed | explicitDeny | implicitDeny
|
||||
per_action.append(
|
||||
{
|
||||
"action": ev.get("EvalActionName"),
|
||||
"decision": decision,
|
||||
"matching_statements": ev.get("MatchedStatements", []),
|
||||
"missing_context_values": ev.get("MissingContextValues", []),
|
||||
}
|
||||
)
|
||||
if decision == "allowed":
|
||||
any_allowed = True
|
||||
|
||||
results["policies"].append({"policy_name": name, "evaluations": per_action})
|
||||
|
||||
return any_allowed, results
|
||||
@@ -0,0 +1,24 @@
|
||||
import warnings
|
||||
|
||||
from dashboard.common_methods import get_section_containers_cis
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def get_table(data):
|
||||
aux = data[
|
||||
[
|
||||
"REQUIREMENTS_ID",
|
||||
"REQUIREMENTS_DESCRIPTION",
|
||||
"REQUIREMENTS_ATTRIBUTES_SECTION",
|
||||
"CHECKID",
|
||||
"STATUS",
|
||||
"REGION",
|
||||
"ACCOUNTID",
|
||||
"RESOURCEID",
|
||||
]
|
||||
].copy()
|
||||
|
||||
return get_section_containers_cis(
|
||||
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
import warnings
|
||||
|
||||
from dashboard.common_methods import get_section_containers_cis
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def get_table(data):
|
||||
|
||||
aux = data[
|
||||
[
|
||||
"REQUIREMENTS_ID",
|
||||
"REQUIREMENTS_DESCRIPTION",
|
||||
"REQUIREMENTS_ATTRIBUTES_SECTION",
|
||||
"CHECKID",
|
||||
"STATUS",
|
||||
"REGION",
|
||||
"ACCOUNTID",
|
||||
"RESOURCEID",
|
||||
]
|
||||
].copy()
|
||||
|
||||
return get_section_containers_cis(
|
||||
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
|
||||
)
|
||||
@@ -0,0 +1,24 @@
|
||||
import warnings
|
||||
|
||||
from dashboard.common_methods import get_section_containers_cis
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def get_table(data):
|
||||
aux = data[
|
||||
[
|
||||
"REQUIREMENTS_ID",
|
||||
"REQUIREMENTS_DESCRIPTION",
|
||||
"REQUIREMENTS_ATTRIBUTES_SECTION",
|
||||
"CHECKID",
|
||||
"STATUS",
|
||||
"REGION",
|
||||
"ACCOUNTID",
|
||||
"RESOURCEID",
|
||||
]
|
||||
].copy()
|
||||
|
||||
return get_section_containers_cis(
|
||||
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
|
||||
)
|
||||
+46
-1
@@ -1,6 +1,7 @@
|
||||
services:
|
||||
api-dev:
|
||||
hostname: "prowler-api"
|
||||
image: prowler-api-dev
|
||||
build:
|
||||
context: ./api
|
||||
dockerfile: Dockerfile
|
||||
@@ -24,6 +25,8 @@ services:
|
||||
condition: service_healthy
|
||||
valkey:
|
||||
condition: service_healthy
|
||||
neo4j:
|
||||
condition: service_healthy
|
||||
entrypoint:
|
||||
- "/home/prowler/docker-entrypoint.sh"
|
||||
- "dev"
|
||||
@@ -85,7 +88,41 @@ services:
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
neo4j:
|
||||
image: graphstack/dozerdb:5.26.3.0
|
||||
hostname: "neo4j"
|
||||
volumes:
|
||||
- ./_data/neo4j:/data
|
||||
environment:
|
||||
# We can't add our .env file because some of our current variables are not compatible with Neo4j env vars
|
||||
# Auth
|
||||
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
|
||||
# Memory limits
|
||||
- NEO4J_dbms_max__databases=${NEO4J_DBMS_MAX__DATABASES:-1000}
|
||||
- NEO4J_server_memory_pagecache_size=${NEO4J_SERVER_MEMORY_PAGECACHE_SIZE:-1G}
|
||||
- NEO4J_server_memory_heap_initial__size=${NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE:-1G}
|
||||
- NEO4J_server_memory_heap_max__size=${NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE:-1G}
|
||||
# APOC
|
||||
- apoc.export.file.enabled=${NEO4J_POC_EXPORT_FILE_ENABLED:-true}
|
||||
- apoc.import.file.enabled=${NEO4J_APOC_IMPORT_FILE_ENABLED:-true}
|
||||
- apoc.import.file.use_neo4j_config=${NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG:-true}
|
||||
- "NEO4J_PLUGINS=${NEO4J_PLUGINS:-[\"apoc\"]}"
|
||||
- "NEO4J_dbms_security_procedures_allowlist=${NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST:-apoc.*}"
|
||||
- "NEO4J_dbms_security_procedures_unrestricted=${NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED:-apoc.*}"
|
||||
# Networking
|
||||
- "dbms.connector.bolt.listen_address=${NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS:-0.0.0.0:7687}"
|
||||
# 7474 is the UI port
|
||||
ports:
|
||||
- 7474:7474
|
||||
- ${NEO4J_PORT:-7687}:7687
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--no-verbose", "http://localhost:7474"]
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
|
||||
worker-dev:
|
||||
image: prowler-api-dev
|
||||
build:
|
||||
context: ./api
|
||||
dockerfile: Dockerfile
|
||||
@@ -96,17 +133,23 @@ services:
|
||||
- path: .env
|
||||
required: false
|
||||
volumes:
|
||||
- "outputs:/tmp/prowler_api_output"
|
||||
- ./api/src/backend:/home/prowler/backend
|
||||
- ./api/pyproject.toml:/home/prowler/pyproject.toml
|
||||
- ./api/docker-entrypoint.sh:/home/prowler/docker-entrypoint.sh
|
||||
- outputs:/tmp/prowler_api_output
|
||||
depends_on:
|
||||
valkey:
|
||||
condition: service_healthy
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
neo4j:
|
||||
condition: service_healthy
|
||||
entrypoint:
|
||||
- "/home/prowler/docker-entrypoint.sh"
|
||||
- "worker"
|
||||
|
||||
worker-beat:
|
||||
image: prowler-api-dev
|
||||
build:
|
||||
context: ./api
|
||||
dockerfile: Dockerfile
|
||||
@@ -121,6 +164,8 @@ services:
|
||||
condition: service_healthy
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
neo4j:
|
||||
condition: service_healthy
|
||||
entrypoint:
|
||||
- "../docker-entrypoint.sh"
|
||||
- "beat"
|
||||
|
||||
@@ -21,6 +21,8 @@ services:
|
||||
condition: service_healthy
|
||||
valkey:
|
||||
condition: service_healthy
|
||||
neo4j:
|
||||
condition: service_healthy
|
||||
entrypoint:
|
||||
- "/home/prowler/docker-entrypoint.sh"
|
||||
- "prod"
|
||||
@@ -72,6 +74,37 @@ services:
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
neo4j:
|
||||
image: graphstack/dozerdb:5.26.3.0
|
||||
hostname: "neo4j"
|
||||
volumes:
|
||||
- ./_data/neo4j:/data
|
||||
environment:
|
||||
# We can't add our .env file because some of our current variables are not compatible with Neo4j env vars
|
||||
# Auth
|
||||
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
|
||||
# Memory limits
|
||||
- NEO4J_dbms_max__databases=${NEO4J_DBMS_MAX__DATABASES:-1000}
|
||||
- NEO4J_server_memory_pagecache_size=${NEO4J_SERVER_MEMORY_PAGECACHE_SIZE:-1G}
|
||||
- NEO4J_server_memory_heap_initial__size=${NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE:-1G}
|
||||
- NEO4J_server_memory_heap_max__size=${NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE:-1G}
|
||||
# APOC
|
||||
- apoc.export.file.enabled=${NEO4J_POC_EXPORT_FILE_ENABLED:-true}
|
||||
- apoc.import.file.enabled=${NEO4J_APOC_IMPORT_FILE_ENABLED:-true}
|
||||
- apoc.import.file.use_neo4j_config=${NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG:-true}
|
||||
- "NEO4J_PLUGINS=${NEO4J_PLUGINS:-[\"apoc\"]}"
|
||||
- "NEO4J_dbms_security_procedures_allowlist=${NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST:-apoc.*}"
|
||||
- "NEO4J_dbms_security_procedures_unrestricted=${NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED:-apoc.*}"
|
||||
# Networking
|
||||
- "dbms.connector.bolt.listen_address=${NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS:-0.0.0.0:7687}"
|
||||
ports:
|
||||
- ${NEO4J_PORT:-7687}:7687
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--no-verbose", "http://localhost:7474"]
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
|
||||
worker:
|
||||
image: prowlercloud/prowler-api:${PROWLER_API_VERSION:-stable}
|
||||
env_file:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user