Compare commits

..

10 Commits

Author SHA1 Message Date
César Arroba ef39a500b4 chore: rollback last change 2025-07-31 15:44:15 +02:00
César Arroba 9b399409d5 chore: build code with preview image tag 2025-07-31 15:36:39 +02:00
Pablo Lara ccb1ca53ff chore: some SSR tweaks 2025-07-09 14:23:58 +02:00
Chandrapal Badshah 2225cc48a5 feat: add lighthouse caching 2025-07-08 12:45:09 +05:30
Chandrapal Badshah e3c63c01bb feat: add caching recommendations in valkey 2025-07-04 11:36:25 +05:30
Chandrapal Badshah 5fa44572d5 feat: get summary of all scans in last 24 hours 2025-07-03 14:37:40 +05:30
Chandrapal Badshah 0da9bd80fb feat: add tool to detect new failed findings in a scan 2025-07-02 14:36:00 +05:30
Chandrapal Badshah b03a807d33 feat: add lighthouse summary generator 2025-07-02 14:33:25 +05:30
Chandrapal Badshah c3e55b383b fix: update schema to get check details 2025-07-02 11:24:22 +05:30
Chandrapal Badshah 5daed282d5 feat: add getLatestFindings tool 2025-06-27 15:10:13 +05:30
728 changed files with 18349 additions and 51907 deletions
+3 -5
View File
@@ -6,12 +6,9 @@
PROWLER_UI_VERSION="stable"
AUTH_URL=http://localhost:3000
API_BASE_URL=http://prowler-api:8080/api/v1
NEXT_PUBLIC_API_BASE_URL=${API_BASE_URL}
NEXT_PUBLIC_API_DOCS_URL=http://prowler-api:8080/api/v1/docs
AUTH_TRUST_HOST=true
UI_PORT=3000
# Temp URL for feeds need to use actual
RSS_FEED_URL=https://prowler.com/blog/rss
# openssl rand -base64 32
AUTH_SECRET="N/c6mnaS5+SWq81+819OrzQZlmx1Vxtp/orjttJSmw8="
# Google Tag Manager ID
@@ -133,7 +130,7 @@ SENTRY_ENVIRONMENT=local
SENTRY_RELEASE=local
#### Prowler release version ####
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.10.0
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.6.0
# Social login credentials
SOCIAL_GOOGLE_OAUTH_CALLBACK_URL="${AUTH_URL}/api/auth/callback/google"
@@ -145,7 +142,8 @@ SOCIAL_GITHUB_OAUTH_CLIENT_ID=""
SOCIAL_GITHUB_OAUTH_CLIENT_SECRET=""
# Single Sign-On (SSO)
SAML_SSO_CALLBACK_URL="${AUTH_URL}/api/auth/callback/saml"
SAML_PUBLIC_CERT=""
SAML_PRIVATE_KEY=""
# Lighthouse tracing
LANGSMITH_TRACING=false
-5
View File
@@ -22,11 +22,6 @@ provider/kubernetes:
- any-glob-to-any-file: "prowler/providers/kubernetes/**"
- any-glob-to-any-file: "tests/providers/kubernetes/**"
provider/m365:
- changed-files:
- any-glob-to-any-file: "prowler/providers/m365/**"
- any-glob-to-any-file: "tests/providers/m365/**"
provider/github:
- changed-files:
- any-glob-to-any-file: "prowler/providers/github/**"
@@ -6,7 +6,6 @@ on:
- "master"
paths:
- "api/**"
- "prowler/**"
- ".github/workflows/api-build-lint-push-containers.yml"
# Uncomment the code below to test this action on PRs
@@ -77,7 +76,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build and push container image (latest)
# Comment the following line for testing
+2 -2
View File
@@ -48,12 +48,12 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/init@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/api-codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/analyze@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
category: "/language:${{matrix.language}}"
+9 -4
View File
@@ -136,6 +136,12 @@ jobs:
run: |
poetry check --lock
- name: Prevents known compatibility error between lxml and libxml2/libxmlsec versions - https://github.com/xmlsec/python-xmlsec/issues/320
working-directory: ./api
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
run: |
poetry run pip install --force-reinstall --no-binary lxml lxml
- name: Lint with ruff
working-directory: ./api
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
@@ -163,10 +169,9 @@ jobs:
- name: Safety
working-directory: ./api
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
# 76352, 76353, 77323 come from SDK, but they cannot upgrade it yet. It does not affect API
# TODO: Botocore needs urllib3 1.X so we need to ignore these vulnerabilities 77744,77745. Remove this once we upgrade to urllib3 2.X
# 76352 and 76353 come from SDK, but they cannot upgrade it yet. It does not affect API
run: |
poetry run safety check --ignore 70612,66963,74429,76352,76353,77323,77744,77745
poetry run safety check --ignore 70612,66963,74429,76352,76353
- name: Vulture
working-directory: ./api
@@ -206,7 +211,7 @@ jobs:
files_ignore: ${{ env.IGNORE_FILES }}
- name: Set up Docker Buildx
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build Container
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
with:
fetch-depth: 0
- name: TruffleHog OSS
uses: trufflesecurity/trufflehog@a05cf0859455b5b16317ee22d809887a4043cdf0 # v3.90.2
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
with:
path: ./
base: ${{ github.event.repository.default_branch }}
@@ -1,266 +0,0 @@
name: Prowler Release Preparation
run-name: Prowler Release Preparation for ${{ inputs.prowler_version }}
on:
workflow_dispatch:
inputs:
prowler_version:
description: 'Prowler version to release (e.g., 5.9.0)'
required: true
type: string
env:
PROWLER_VERSION: ${{ github.event.inputs.prowler_version }}
jobs:
prepare-release:
if: github.repository == 'prowler-cloud/prowler'
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.12'
- name: Install Poetry
run: |
python3 -m pip install --user poetry
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Configure Git
run: |
git config --global user.name "prowler-bot"
git config --global user.email "179230569+prowler-bot@users.noreply.github.com"
- name: Parse version and determine branch
run: |
# Validate version format (reusing pattern from sdk-bump-version.yml)
if [[ $PROWLER_VERSION =~ ^([0-9]+)\.([0-9]+)\.([0-9]+)$ ]]; then
MAJOR_VERSION=${BASH_REMATCH[1]}
MINOR_VERSION=${BASH_REMATCH[2]}
PATCH_VERSION=${BASH_REMATCH[3]}
# Export version components to environment
echo "MAJOR_VERSION=${MAJOR_VERSION}" >> "${GITHUB_ENV}"
echo "MINOR_VERSION=${MINOR_VERSION}" >> "${GITHUB_ENV}"
echo "PATCH_VERSION=${PATCH_VERSION}" >> "${GITHUB_ENV}"
# Determine branch name (format: v5.9)
BRANCH_NAME="v${MAJOR_VERSION}.${MINOR_VERSION}"
echo "BRANCH_NAME=${BRANCH_NAME}" >> "${GITHUB_ENV}"
# Calculate UI version (1.X.X format - matches Prowler minor version)
UI_VERSION="1.${MINOR_VERSION}.${PATCH_VERSION}"
echo "UI_VERSION=${UI_VERSION}" >> "${GITHUB_ENV}"
# Calculate API version (1.X.X format - one minor version ahead)
API_MINOR_VERSION=$((MINOR_VERSION + 1))
API_VERSION="1.${API_MINOR_VERSION}.${PATCH_VERSION}"
echo "API_VERSION=${API_VERSION}" >> "${GITHUB_ENV}"
echo "Prowler version: $PROWLER_VERSION"
echo "Branch name: $BRANCH_NAME"
echo "UI version: $UI_VERSION"
echo "API version: $API_VERSION"
echo "Is minor release: $([ $PATCH_VERSION -eq 0 ] && echo 'true' || echo 'false')"
else
echo "Invalid version syntax: '$PROWLER_VERSION' (must be N.N.N)" >&2
exit 1
fi
- name: Checkout existing branch for patch release
if: ${{ env.PATCH_VERSION != '0' }}
run: |
echo "Patch release detected, checking out existing branch $BRANCH_NAME..."
if git show-ref --verify --quiet "refs/heads/$BRANCH_NAME"; then
echo "Branch $BRANCH_NAME exists locally, checking out..."
git checkout "$BRANCH_NAME"
elif git show-ref --verify --quiet "refs/remotes/origin/$BRANCH_NAME"; then
echo "Branch $BRANCH_NAME exists remotely, checking out..."
git checkout -b "$BRANCH_NAME" "origin/$BRANCH_NAME"
else
echo "ERROR: Branch $BRANCH_NAME should exist for patch release $PROWLER_VERSION"
exit 1
fi
- name: Verify version in pyproject.toml
run: |
CURRENT_VERSION=$(grep '^version = ' pyproject.toml | sed -E 's/version = "([^"]+)"/\1/' | tr -d '[:space:]')
PROWLER_VERSION_TRIMMED=$(echo "$PROWLER_VERSION" | tr -d '[:space:]')
if [ "$CURRENT_VERSION" != "$PROWLER_VERSION_TRIMMED" ]; then
echo "ERROR: Version mismatch in pyproject.toml (expected: '$PROWLER_VERSION_TRIMMED', found: '$CURRENT_VERSION')"
exit 1
fi
echo "✓ pyproject.toml version: $CURRENT_VERSION"
- name: Verify version in prowler/config/config.py
run: |
CURRENT_VERSION=$(grep '^prowler_version = ' prowler/config/config.py | sed -E 's/prowler_version = "([^"]+)"/\1/' | tr -d '[:space:]')
PROWLER_VERSION_TRIMMED=$(echo "$PROWLER_VERSION" | tr -d '[:space:]')
if [ "$CURRENT_VERSION" != "$PROWLER_VERSION_TRIMMED" ]; then
echo "ERROR: Version mismatch in prowler/config/config.py (expected: '$PROWLER_VERSION_TRIMMED', found: '$CURRENT_VERSION')"
exit 1
fi
echo "✓ prowler/config/config.py version: $CURRENT_VERSION"
- name: Verify version in api/pyproject.toml
run: |
CURRENT_API_VERSION=$(grep '^version = ' api/pyproject.toml | sed -E 's/version = "([^"]+)"/\1/' | tr -d '[:space:]')
API_VERSION_TRIMMED=$(echo "$API_VERSION" | tr -d '[:space:]')
if [ "$CURRENT_API_VERSION" != "$API_VERSION_TRIMMED" ]; then
echo "ERROR: API version mismatch in api/pyproject.toml (expected: '$API_VERSION_TRIMMED', found: '$CURRENT_API_VERSION')"
exit 1
fi
echo "✓ api/pyproject.toml version: $CURRENT_API_VERSION"
- name: Verify prowler dependency in api/pyproject.toml
if: ${{ env.PATCH_VERSION != '0' }}
run: |
CURRENT_PROWLER_REF=$(grep 'prowler @ git+https://github.com/prowler-cloud/prowler.git@' api/pyproject.toml | sed -E 's/.*@([^"]+)".*/\1/' | tr -d '[:space:]')
BRANCH_NAME_TRIMMED=$(echo "$BRANCH_NAME" | tr -d '[:space:]')
if [ "$CURRENT_PROWLER_REF" != "$BRANCH_NAME_TRIMMED" ]; then
echo "ERROR: Prowler dependency mismatch in api/pyproject.toml (expected: '$BRANCH_NAME_TRIMMED', found: '$CURRENT_PROWLER_REF')"
exit 1
fi
echo "✓ api/pyproject.toml prowler dependency: $CURRENT_PROWLER_REF"
- name: Verify version in api/src/backend/api/v1/views.py
run: |
CURRENT_API_VERSION=$(grep 'spectacular_settings.VERSION = ' api/src/backend/api/v1/views.py | sed -E 's/.*spectacular_settings.VERSION = "([^"]+)".*/\1/' | tr -d '[:space:]')
API_VERSION_TRIMMED=$(echo "$API_VERSION" | tr -d '[:space:]')
if [ "$CURRENT_API_VERSION" != "$API_VERSION_TRIMMED" ]; then
echo "ERROR: API version mismatch in views.py (expected: '$API_VERSION_TRIMMED', found: '$CURRENT_API_VERSION')"
exit 1
fi
echo "✓ api/src/backend/api/v1/views.py version: $CURRENT_API_VERSION"
- name: Create release branch for minor release
if: ${{ env.PATCH_VERSION == '0' }}
run: |
echo "Minor release detected (patch = 0), creating new branch $BRANCH_NAME..."
if git show-ref --verify --quiet "refs/heads/$BRANCH_NAME" || git show-ref --verify --quiet "refs/remotes/origin/$BRANCH_NAME"; then
echo "ERROR: Branch $BRANCH_NAME already exists for minor release $PROWLER_VERSION"
exit 1
fi
git checkout -b "$BRANCH_NAME"
# Push the new branch first so it exists remotely
git push origin "$BRANCH_NAME"
- name: Update prowler dependency in api/pyproject.toml
if: ${{ env.PATCH_VERSION == '0' }}
run: |
CURRENT_PROWLER_REF=$(grep 'prowler @ git+https://github.com/prowler-cloud/prowler.git@' api/pyproject.toml | sed -E 's/.*@([^"]+)".*/\1/' | tr -d '[:space:]')
BRANCH_NAME_TRIMMED=$(echo "$BRANCH_NAME" | tr -d '[:space:]')
# Minor release: update the dependency to use the new branch
echo "Minor release detected - updating prowler dependency from '$CURRENT_PROWLER_REF' to '$BRANCH_NAME_TRIMMED'"
sed -i "s|prowler @ git+https://github.com/prowler-cloud/prowler.git@[^\"]*\"|prowler @ git+https://github.com/prowler-cloud/prowler.git@$BRANCH_NAME_TRIMMED\"|" api/pyproject.toml
# Verify the change was made
UPDATED_PROWLER_REF=$(grep 'prowler @ git+https://github.com/prowler-cloud/prowler.git@' api/pyproject.toml | sed -E 's/.*@([^"]+)".*/\1/' | tr -d '[:space:]')
if [ "$UPDATED_PROWLER_REF" != "$BRANCH_NAME_TRIMMED" ]; then
echo "ERROR: Failed to update prowler dependency in api/pyproject.toml"
exit 1
fi
# Update poetry lock file
echo "Updating poetry.lock file..."
cd api
poetry lock
cd ..
# Commit and push the changes
git add api/pyproject.toml api/poetry.lock
git commit -m "chore(api): update prowler dependency to $BRANCH_NAME_TRIMMED for release $PROWLER_VERSION"
git push origin "$BRANCH_NAME"
echo "✓ api/pyproject.toml prowler dependency updated to: $UPDATED_PROWLER_REF"
- name: Extract changelog entries
run: |
set -e
# Function to extract changelog for a specific version
extract_changelog() {
local file="$1"
local version="$2"
local output_file="$3"
if [ ! -f "$file" ]; then
echo "Warning: $file not found, skipping..."
touch "$output_file"
return
fi
# Extract changelog section for this version
awk -v version="$version" '
/^## \[v?'"$version"'\]/ { found=1; next }
found && /^## \[v?[0-9]+\.[0-9]+\.[0-9]+\]/ { found=0 }
found && !/^## \[v?'"$version"'\]/ { print }
' "$file" > "$output_file"
# Remove --- separators
sed -i '/^---$/d' "$output_file"
# Remove trailing empty lines
sed -i '/^$/d' "$output_file"
}
# Extract changelogs
echo "Extracting changelog entries..."
extract_changelog "prowler/CHANGELOG.md" "$PROWLER_VERSION" "prowler_changelog.md"
extract_changelog "api/CHANGELOG.md" "$API_VERSION" "api_changelog.md"
extract_changelog "ui/CHANGELOG.md" "$UI_VERSION" "ui_changelog.md"
# Combine changelogs in order: UI, API, SDK
> combined_changelog.md
if [ -s "ui_changelog.md" ]; then
echo "## UI" >> combined_changelog.md
echo "" >> combined_changelog.md
cat ui_changelog.md >> combined_changelog.md
echo "" >> combined_changelog.md
fi
if [ -s "api_changelog.md" ]; then
echo "## API" >> combined_changelog.md
echo "" >> combined_changelog.md
cat api_changelog.md >> combined_changelog.md
echo "" >> combined_changelog.md
fi
if [ -s "prowler_changelog.md" ]; then
echo "## SDK" >> combined_changelog.md
echo "" >> combined_changelog.md
cat prowler_changelog.md >> combined_changelog.md
echo "" >> combined_changelog.md
fi
echo "Combined changelog preview:"
cat combined_changelog.md
- name: Create draft release
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
with:
tag_name: ${{ env.PROWLER_VERSION }}
name: Prowler ${{ env.PROWLER_VERSION }}
body_path: combined_changelog.md
draft: true
target_commitish: ${{ env.PATCH_VERSION == '0' && 'master' || env.BRANCH_NAME }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Clean up temporary files
run: |
rm -f prowler_changelog.md api_changelog.md ui_changelog.md combined_changelog.md
@@ -55,20 +55,29 @@ jobs:
comment-author: 'github-actions[bot]'
body-includes: '<!-- changelog-check -->'
- name: Update PR comment with changelog status
if: github.event.pull_request.head.repo.full_name == github.repository
- name: Comment on PR if changelog is missing
if: github.event.pull_request.head.repo.full_name == github.repository && steps.check_folders.outputs.missing_changelogs != ''
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
issue-number: ${{ github.event.pull_request.number }}
comment-id: ${{ steps.find_comment.outputs.comment-id }}
edit-mode: replace
body: |
<!-- changelog-check -->
${{ steps.check_folders.outputs.missing_changelogs != '' && format('⚠️ **Changes detected in the following folders without a corresponding update to the `CHANGELOG.md`:**
⚠️ **Changes detected in the following folders without a corresponding update to the `CHANGELOG.md`:**
{0}
${{ steps.check_folders.outputs.missing_changelogs }}
Please add an entry to the corresponding `CHANGELOG.md` file to maintain a clear history of changes.', steps.check_folders.outputs.missing_changelogs) || '✅ All necessary `CHANGELOG.md` files have been updated. Great job! 🎉' }}
Please add an entry to the corresponding `CHANGELOG.md` file to maintain a clear history of changes.
- name: Comment on PR if all changelogs are present
if: github.event.pull_request.head.repo.full_name == github.repository && steps.check_folders.outputs.missing_changelogs == ''
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
issue-number: ${{ github.event.pull_request.number }}
comment-id: ${{ steps.find_comment.outputs.comment-id }}
body: |
<!-- changelog-check -->
✅ All necessary `CHANGELOG.md` files have been updated. Great job! 🎉
- name: Fail if changelog is missing
if: steps.check_folders.outputs.missing_changelogs != ''
+8 -9
View File
@@ -27,12 +27,11 @@ jobs:
token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
repository: ${{ secrets.CLOUD_DISPATCH }}
event-type: prowler-pull-request-merged
client-payload: |
{
"PROWLER_COMMIT_SHA": "${{ github.event.pull_request.merge_commit_sha }}",
"PROWLER_COMMIT_SHORT_SHA": "${{ env.SHORT_SHA }}",
"PROWLER_PR_TITLE": ${{ toJson(github.event.pull_request.title) }},
"PROWLER_PR_LABELS": ${{ toJson(github.event.pull_request.labels.*.name) }},
"PROWLER_PR_BODY": ${{ toJson(github.event.pull_request.body) }},
"PROWLER_PR_URL": ${{ toJson(github.event.pull_request.html_url) }}
}
client-payload: '{
"PROWLER_COMMIT_SHA": "${{ github.event.pull_request.merge_commit_sha }}",
"PROWLER_COMMIT_SHORT_SHA": "${{ env.SHORT_SHA }}",
"PROWLER_PR_TITLE": "${{ github.event.pull_request.title }}",
"PROWLER_PR_LABELS": ${{ toJson(github.event.pull_request.labels.*.name) }},
"PROWLER_PR_BODY": ${{ toJson(github.event.pull_request.body) }},
"PROWLER_PR_URL":${{ toJson(github.event.pull_request.html_url) }}
}'
@@ -123,7 +123,7 @@ jobs:
AWS_REGION: ${{ env.AWS_REGION }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build and push container image (latest)
if: github.event_name == 'push'
+1 -2
View File
@@ -12,6 +12,7 @@ env:
jobs:
bump-version:
name: Bump Version
if: github.repository == 'prowler-cloud/prowler'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -96,7 +97,6 @@ jobs:
commit-message: "chore(release): Bump version to v${{ env.BUMP_VERSION_TO }}"
branch: "version-bump-to-v${{ env.BUMP_VERSION_TO }}"
title: "chore(release): Bump version to v${{ env.BUMP_VERSION_TO }}"
labels: no-changelog
body: |
### Description
@@ -135,7 +135,6 @@ jobs:
commit-message: "chore(release): Bump version to v${{ env.PATCH_VERSION_TO }}"
branch: "version-bump-to-v${{ env.PATCH_VERSION_TO }}"
title: "chore(release): Bump version to v${{ env.PATCH_VERSION_TO }}"
labels: no-changelog
body: |
### Description
+2 -2
View File
@@ -56,12 +56,12 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/init@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/sdk-codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/analyze@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
category: "/language:${{matrix.language}}"
+1 -8
View File
@@ -102,15 +102,8 @@ jobs:
run: |
poetry run vulture --exclude "contrib,api,ui" --min-confidence 100 .
- name: Dockerfile - Check if Dockerfile has changed
id: dockerfile-changed-files
uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46.0.5
with:
files: |
Dockerfile
- name: Hadolint
if: steps.dockerfile-changed-files.outputs.any_changed == 'true'
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
run: |
/tmp/hadolint Dockerfile --ignore=DL3013
@@ -30,7 +30,6 @@ env:
# Container Registries
PROWLERCLOUD_DOCKERHUB_REPOSITORY: prowlercloud
PROWLERCLOUD_DOCKERHUB_IMAGE: prowler-ui
NEXT_PUBLIC_API_BASE_URL: http://prowler-api:8080/api/v1
jobs:
repository-check:
@@ -77,7 +76,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build and push container image (latest)
# Comment the following line for testing
@@ -87,7 +86,6 @@ jobs:
context: ${{ env.WORKING_DIRECTORY }}
build-args: |
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=${{ env.SHORT_SHA }}
NEXT_PUBLIC_API_BASE_URL=${{ env.NEXT_PUBLIC_API_BASE_URL }}
# Set push: false for testing
push: true
tags: |
@@ -103,7 +101,6 @@ jobs:
context: ${{ env.WORKING_DIRECTORY }}
build-args: |
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v${{ env.RELEASE_TAG }}
NEXT_PUBLIC_API_BASE_URL=${{ env.NEXT_PUBLIC_API_BASE_URL }}
push: true
tags: |
${{ env.PROWLERCLOUD_DOCKERHUB_REPOSITORY }}/${{ env.PROWLERCLOUD_DOCKERHUB_IMAGE }}:${{ env.RELEASE_TAG }}
+2 -2
View File
@@ -48,12 +48,12 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/init@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/ui-codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/analyze@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
category: "/language:${{matrix.language}}"
-98
View File
@@ -1,98 +0,0 @@
name: UI - E2E Tests
on:
pull_request:
branches:
- master
- "v5.*"
paths:
- '.github/workflows/ui-e2e-tests.yml'
- 'ui/**'
jobs:
e2e-tests:
if: github.repository == 'prowler-cloud/prowler'
runs-on: ubuntu-latest
env:
AUTH_SECRET: 'fallback-ci-secret-for-testing'
AUTH_TRUST_HOST: true
NEXTAUTH_URL: 'http://localhost:3000'
NEXT_PUBLIC_API_BASE_URL: 'http://localhost:8080/api/v1'
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Start API services
run: |
# Override docker-compose image tag to use latest instead of stable
# This overrides any PROWLER_API_VERSION set in .env file
export PROWLER_API_VERSION=latest
echo "Using PROWLER_API_VERSION=${PROWLER_API_VERSION}"
docker compose up -d api worker worker-beat
- name: Wait for API to be ready
run: |
echo "Waiting for prowler-api..."
timeout=150 # 5 minutes max
elapsed=0
while [ $elapsed -lt $timeout ]; do
if curl -s ${NEXT_PUBLIC_API_BASE_URL}/docs >/dev/null 2>&1; then
echo "Prowler API is ready!"
exit 0
fi
echo "Waiting for prowler-api... (${elapsed}s elapsed)"
sleep 5
elapsed=$((elapsed + 5))
done
echo "Timeout waiting for prowler-api to start"
exit 1
- name: Load database fixtures for E2E tests
run: |
docker compose exec -T api sh -c '
echo "Loading all fixtures from api/fixtures/dev/..."
for fixture in api/fixtures/dev/*.json; do
if [ -f "$fixture" ]; then
echo "Loading $fixture"
poetry run python manage.py loaddata "$fixture" --database admin
fi
done
echo "All database fixtures loaded successfully!"
'
- name: Setup Node.js environment
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: '20.x'
cache: 'npm'
cache-dependency-path: './ui/package-lock.json'
- name: Install UI dependencies
working-directory: ./ui
run: npm ci
- name: Build UI application
working-directory: ./ui
run: npm run build
- name: Cache Playwright browsers
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('ui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-
- name: Install Playwright browsers
working-directory: ./ui
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: npm run test:e2e:install
- name: Run E2E tests
working-directory: ./ui
run: npm run test:e2e
- name: Upload test reports
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
if: failure()
with:
name: playwright-report
path: ui/playwright-report/
retention-days: 30
- name: Cleanup services
if: always()
run: |
echo "Shutting down services..."
docker compose down -v || true
echo "Cleanup completed"
+2 -5
View File
@@ -34,24 +34,21 @@ jobs:
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'
cache-dependency-path: './ui/package-lock.json'
- name: Install dependencies
working-directory: ./ui
run: npm ci
run: npm install
- name: Run Healthcheck
working-directory: ./ui
run: npm run healthcheck
- name: Build the application
working-directory: ./ui
run: npm run build
test-container-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build Container
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0
with:
-10
View File
@@ -44,16 +44,6 @@ junit-reports/
# Cursor files
.cursorignore
.cursor/
# RooCode files
.roo/
.rooignore
.roomodes
# Cline files
.cline/
.clineignore
# Terraform
.terraform*
+1 -2
View File
@@ -115,8 +115,7 @@ repos:
- id: safety
name: safety
description: "Safety is a tool that checks your installed dependencies for known security vulnerabilities"
# TODO: Botocore needs urllib3 1.X so we need to ignore these vulnerabilities 77744,77745. Remove this once we upgrade to urllib3 2.X
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353,77744,77745'
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353'
language: system
- id: vulture
+6 -3
View File
@@ -1,4 +1,4 @@
FROM python:3.12.11-slim-bookworm AS build
FROM python:3.12.10-slim-bookworm AS build
LABEL maintainer="https://github.com/prowler-cloud/prowler"
LABEL org.opencontainers.image.source="https://github.com/prowler-cloud/prowler"
@@ -6,8 +6,7 @@ LABEL org.opencontainers.image.source="https://github.com/prowler-cloud/prowler"
ARG POWERSHELL_VERSION=7.5.0
# hadolint ignore=DL3008
RUN apt-get update && apt-get install -y --no-install-recommends \
wget libicu72 libunwind8 libssl3 libcurl4 ca-certificates apt-transport-https gnupg \
RUN apt-get update && apt-get install -y --no-install-recommends wget libicu72 \
&& rm -rf /var/lib/apt/lists/*
# Install PowerShell
@@ -47,6 +46,10 @@ ENV PATH="${HOME}/.local/bin:${PATH}"
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir poetry
# By default poetry does not compile Python source files to bytecode during installation.
# This speeds up the installation process, but the first execution may take a little more
# time because Python then compiles source files to bytecode automatically. If you want to
# compile source files to bytecode during installation, you can use the --compile option
RUN poetry install --compile && \
rm -rf ~/.cache/pip
+1 -9
View File
@@ -88,7 +88,7 @@ prowler dashboard
|---|---|---|---|---|
| AWS | 567 | 82 | 36 | 10 |
| GCP | 79 | 13 | 10 | 3 |
| Azure | 142 | 18 | 11 | 3 |
| Azure | 142 | 18 | 10 | 3 |
| Kubernetes | 83 | 7 | 5 | 7 |
| GitHub | 16 | 2 | 1 | 0 |
| M365 | 69 | 7 | 3 | 2 |
@@ -136,14 +136,6 @@ If your workstation's architecture is incompatible, you can resolve this by:
> Once configured, access the Prowler App at http://localhost:3000. Sign up using your email and password to get started.
### Common Issues with Docker Pull Installation
> [!Note]
If you want to use AWS role assumption (e.g., with the "Connect assuming IAM Role" option), you may need to mount your local `.aws` directory into the container as a volume (e.g., `- "${HOME}/.aws:/home/prowler/.aws:ro"`). There are several ways to configure credentials for Docker containers. See the [Troubleshooting](./docs/troubleshooting.md) section for more details and examples.
You can find more information in the [Troubleshooting](./docs/troubleshooting.md) section.
### From GitHub
**Requirements**
+2 -68
View File
@@ -2,82 +2,16 @@
All notable changes to the **Prowler API** are documented in this file.
## [1.11.0] (Prowler 5.10.0)
### Added
- Github provider support [(#8271)](https://github.com/prowler-cloud/prowler/pull/8271)
- Integration with Amazon S3, enabling storage and retrieval of scan data via S3 buckets [(#8056)](https://github.com/prowler-cloud/prowler/pull/8056)
### Fixed
- Avoid sending errors to Sentry in M365 provider when user authentication fails [(#8420)](https://github.com/prowler-cloud/prowler/pull/8420)
---
## [1.10.2] (Prowler v5.9.2)
### Changed
- Optimized queries for resources views [(#8336)](https://github.com/prowler-cloud/prowler/pull/8336)
---
## [v1.10.1] (Prowler v5.9.1)
### Fixed
- Calculate failed findings during scans to prevent heavy database queries [(#8322)](https://github.com/prowler-cloud/prowler/pull/8322)
---
## [v1.10.0] (Prowler v5.9.0)
### Added
- SSO with SAML support [(#8175)](https://github.com/prowler-cloud/prowler/pull/8175)
- `GET /resources/metadata`, `GET /resources/metadata/latest` and `GET /resources/latest` to expose resource metadata and latest scan results [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
### Changed
- `/processors` endpoints to post-process findings. Currently, only the Mutelist processor is supported to allow to mute findings.
- Optimized the underlying queries for resources endpoints [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
- Optimized include parameters for resources view [(#8229)](https://github.com/prowler-cloud/prowler/pull/8229)
- Optimized overview background tasks [(#8300)](https://github.com/prowler-cloud/prowler/pull/8300)
### Fixed
- Search filter for findings and resources [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
- RBAC is now applied to `GET /overviews/providers` [(#8277)](https://github.com/prowler-cloud/prowler/pull/8277)
### Changed
- `POST /schedules/daily` returns a `409 CONFLICT` if already created [(#8258)](https://github.com/prowler-cloud/prowler/pull/8258)
### Security
- Enhanced password validation to enforce 12+ character passwords with special characters, uppercase, lowercase, and numbers [(#8225)](https://github.com/prowler-cloud/prowler/pull/8225)
---
## [v1.9.1] (Prowler v5.8.1)
### Added
- Custom exception for provider connection errors during scans [(#8234)](https://github.com/prowler-cloud/prowler/pull/8234)
### Changed
- Summary and overview tasks now use a dedicated queue and no longer propagate errors to compliance tasks [(#8214)](https://github.com/prowler-cloud/prowler/pull/8214)
### Fixed
- Scan with no resources will not trigger legacy code for findings metadata [(#8183)](https://github.com/prowler-cloud/prowler/pull/8183)
- Invitation email comparison case-insensitive [(#8206)](https://github.com/prowler-cloud/prowler/pull/8206)
### Removed
- Validation of the provider's secret type during updates [(#8197)](https://github.com/prowler-cloud/prowler/pull/8197)
---
## [v1.9.0] (Prowler v5.8.0)
## [v1.9.0] (Prowler UNRELEASED)
### Added
- SSO with SAML support [(#7822)](https://github.com/prowler-cloud/prowler/pull/7822)
- Support GCP Service Account key [(#7824)](https://github.com/prowler-cloud/prowler/pull/7824)
- `GET /compliance-overviews` endpoints to retrieve compliance metadata and specific requirements statuses [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877)
- Lighthouse configuration support [(#7848)](https://github.com/prowler-cloud/prowler/pull/7848)
### Changed
- Reworked `GET /compliance-overviews` to return proper requirement metrics [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877)
- Optional `user` and `password` for M365 provider [(#7992)](https://github.com/prowler-cloud/prowler/pull/7992)
### Fixed
- Scheduled scans are no longer deleted when their daily schedule run is disabled [(#8082)](https://github.com/prowler-cloud/prowler/pull/8082)
+4 -3
View File
@@ -44,9 +44,6 @@ USER prowler
WORKDIR /home/prowler
# Ensure output directory exists
RUN mkdir -p /tmp/prowler_api_output
COPY pyproject.toml ./
RUN pip install --no-cache-dir --upgrade pip && \
@@ -60,6 +57,10 @@ RUN poetry install --no-root && \
RUN poetry run python "$(poetry env info --path)/src/prowler/prowler/providers/m365/lib/powershell/m365_powershell.py"
# Prevents known compatibility error between lxml and libxml2/libxmlsec versions.
# See: https://github.com/xmlsec/python-xmlsec/issues/320
RUN poetry run pip install --force-reinstall --no-binary lxml lxml
COPY src/backend/ ./backend/
COPY docker-entrypoint.sh ./docker-entrypoint.sh
+1 -1
View File
@@ -257,7 +257,7 @@ cd src/backend
python manage.py loaddata api/fixtures/0_dev_users.json --database admin
```
> The default credentials are `dev@prowler.com:Thisisapassword123@` or `dev2@prowler.com:Thisisapassword123@`
> The default credentials are `dev@prowler.com:thisisapassword123` or `dev2@prowler.com:thisisapassword123`
## Run tests
+1 -1
View File
@@ -32,7 +32,7 @@ start_prod_server() {
start_worker() {
echo "Starting the worker..."
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview,integrations -E --max-tasks-per-child 1
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill -E --max-tasks-per-child 1
}
start_worker_beat() {
+411 -712
View File
File diff suppressed because it is too large Load Diff
+3 -5
View File
@@ -23,14 +23,12 @@ dependencies = [
"drf-spectacular==0.27.2",
"drf-spectacular-jsonapi==0.5.1",
"gunicorn==23.0.0",
"lxml==5.3.2",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@v5.10",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@master",
"psycopg2-binary==2.9.9",
"pytest-celery[redis] (>=1.0.1,<2.0.0)",
"sentry-sdk[django] (>=2.20.0,<3.0.0)",
"uuid6==2024.7.10",
"openai (>=1.82.0,<2.0.0)",
"xmlsec==1.3.14"
"openai (>=1.82.0,<2.0.0)"
]
description = "Prowler's API (Django/DRF)"
license = "Apache-2.0"
@@ -38,7 +36,7 @@ name = "prowler-api"
package-mode = false
# Needed for the SDK compatibility
requires-python = ">=3.11,<3.13"
version = "1.11.2"
version = "1.9.0"
[project.scripts]
celery = "src.backend.config.settings.celery"
+72 -6
View File
@@ -3,7 +3,14 @@ from django.db import transaction
from api.db_router import MainRouter
from api.db_utils import rls_transaction
from api.models import Membership, Role, Tenant, User, UserRoleRelationship
from api.models import (
Membership,
Role,
SAMLConfiguration,
Tenant,
User,
UserRoleRelationship,
)
class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
@@ -17,7 +24,7 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
def pre_social_login(self, request, sociallogin):
# Link existing accounts with the same email address
email = sociallogin.account.extra_data.get("email")
if sociallogin.provider.id == "saml":
if sociallogin.account.provider == "saml":
email = sociallogin.user.email
if email:
existing_user = self.get_user_by_email(email)
@@ -31,10 +38,71 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
"""
with transaction.atomic(using=MainRouter.admin_db):
user = super().save_user(request, sociallogin, form)
provider = sociallogin.provider.id
provider = sociallogin.account.provider
extra = sociallogin.account.extra_data
if provider != "saml":
if provider == "saml":
# Handle SAML-specific logic
user.first_name = (
extra.get("firstName", [""])[0] if extra.get("firstName") else ""
)
user.last_name = (
extra.get("lastName", [""])[0] if extra.get("lastName") else ""
)
user.company_name = (
extra.get("organization", [""])[0]
if extra.get("organization")
else ""
)
user.name = f"{user.first_name} {user.last_name}".strip()
if user.name == "":
user.name = "N/A"
user.save(using=MainRouter.admin_db)
email_domain = user.email.split("@")[-1]
tenant = (
SAMLConfiguration.objects.using(MainRouter.admin_db)
.get(email_domain=email_domain)
.tenant
)
with rls_transaction(str(tenant.id)):
role_name = (
extra.get("userType", ["saml_default_role"])[0].strip()
if extra.get("userType")
else "saml_default_role"
)
try:
role = Role.objects.using(MainRouter.admin_db).get(
name=role_name, tenant_id=tenant.id
)
except Role.DoesNotExist:
role = Role.objects.using(MainRouter.admin_db).create(
name=role_name,
tenant_id=tenant.id,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=False,
manage_integrations=False,
manage_scans=False,
unlimited_visibility=False,
)
Membership.objects.using(MainRouter.admin_db).create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.MEMBER,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
else:
# Handle other providers (e.g., GitHub, Google)
user.save(using=MainRouter.admin_db)
social_account_name = extra.get("name")
@@ -65,7 +133,5 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
role=role,
tenant_id=tenant.id,
)
else:
request.session["saml_user_created"] = str(user.id)
return user
-35
View File
@@ -175,29 +175,6 @@ def create_objects_in_batches(
model.objects.bulk_create(chunk, batch_size)
def update_objects_in_batches(
tenant_id: str, model, objects: list, fields: list, batch_size: int = 500
):
"""
Bulk-update model instances in repeated, per-tenant RLS transactions.
All chunks execute in their own transaction, so no single transaction
grows too large.
Args:
tenant_id (str): UUID string of the tenant under which to set RLS.
model: Django model class whose `.objects.bulk_update()` will be called.
objects (list): List of model instances (saved) to bulk-update.
fields (list): List of field names to update.
batch_size (int): Maximum number of objects per bulk_update call.
"""
total = len(objects)
for start in range(0, total, batch_size):
chunk = objects[start : start + batch_size]
with rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
model.objects.bulk_update(chunk, fields, batch_size)
# Postgres Enums
@@ -552,15 +529,3 @@ class IntegrationTypeEnum(EnumType):
class IntegrationTypeEnumField(PostgresEnumField):
def __init__(self, *args, **kwargs):
super().__init__("integration_type", *args, **kwargs)
# Postgres enum definition for Processor type
class ProcessorTypeEnum(EnumType):
enum_type_name = "processor_type"
class ProcessorTypeEnumField(PostgresEnumField):
def __init__(self, *args, **kwargs):
super().__init__("processor_type", *args, **kwargs)
-23
View File
@@ -57,11 +57,6 @@ class TaskInProgressException(TaskManagementError):
super().__init__()
# Provider connection errors
class ProviderConnectionError(Exception):
"""Base exception for provider connection errors."""
def custom_exception_handler(exc, context):
if isinstance(exc, django_validation_error):
if hasattr(exc, "error_dict"):
@@ -78,21 +73,3 @@ def custom_exception_handler(exc, context):
message_item["message"] for message_item in exc.detail["messages"]
]
return exception_handler(exc, context)
class ConflictException(APIException):
status_code = status.HTTP_409_CONFLICT
default_detail = "A conflict occurred. The resource already exists."
default_code = "conflict"
def __init__(self, detail=None, code=None, pointer=None):
error_detail = {
"detail": detail or self.default_detail,
"status": self.status_code,
"code": self.default_code,
}
if pointer:
error_detail["source"] = {"pointer": pointer}
super().__init__(detail=[error_detail])
-89
View File
@@ -1,6 +1,5 @@
from datetime import date, datetime, timedelta, timezone
from dateutil.parser import parse
from django.conf import settings
from django.db.models import Q
from django_filters.rest_framework import (
@@ -29,7 +28,6 @@ from api.models import (
Invitation,
Membership,
PermissionChoices,
Processor,
Provider,
ProviderGroup,
ProviderSecret,
@@ -340,8 +338,6 @@ class ResourceFilter(ProviderRelationshipFilterSet):
tags = CharFilter(method="filter_tag")
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
updated_at = DateFilter(field_name="updated_at", lookup_expr="date")
scan = UUIDFilter(field_name="provider__scan", lookup_expr="exact")
scan__in = UUIDInFilter(field_name="provider__scan", lookup_expr="in")
class Meta:
model = Resource
@@ -356,82 +352,6 @@ class ResourceFilter(ProviderRelationshipFilterSet):
"updated_at": ["gte", "lte"],
}
def filter_queryset(self, queryset):
if not (self.data.get("scan") or self.data.get("scan__in")) and not (
self.data.get("updated_at")
or self.data.get("updated_at__date")
or self.data.get("updated_at__gte")
or self.data.get("updated_at__lte")
):
raise ValidationError(
[
{
"detail": "At least one date filter is required: filter[updated_at], filter[updated_at.gte], "
"or filter[updated_at.lte].",
"status": 400,
"source": {"pointer": "/data/attributes/updated_at"},
"code": "required",
}
]
)
gte_date = (
parse(self.data.get("updated_at__gte")).date()
if self.data.get("updated_at__gte")
else datetime.now(timezone.utc).date()
)
lte_date = (
parse(self.data.get("updated_at__lte")).date()
if self.data.get("updated_at__lte")
else datetime.now(timezone.utc).date()
)
if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
):
raise ValidationError(
[
{
"detail": f"The date range cannot exceed {settings.FINDINGS_MAX_DAYS_IN_RANGE} days.",
"status": 400,
"source": {"pointer": "/data/attributes/updated_at"},
"code": "invalid",
}
]
)
return super().filter_queryset(queryset)
def filter_tag_key(self, queryset, name, value):
return queryset.filter(Q(tags__key=value) | Q(tags__key__icontains=value))
def filter_tag_value(self, queryset, name, value):
return queryset.filter(Q(tags__value=value) | Q(tags__value__icontains=value))
def filter_tag(self, queryset, name, value):
# We won't know what the user wants to filter on just based on the value,
# and we don't want to build special filtering logic for every possible
# provider tag spec, so we'll just do a full text search
return queryset.filter(tags__text_search=value)
class LatestResourceFilter(ProviderRelationshipFilterSet):
tag_key = CharFilter(method="filter_tag_key")
tag_value = CharFilter(method="filter_tag_value")
tag = CharFilter(method="filter_tag")
tags = CharFilter(method="filter_tag")
class Meta:
model = Resource
fields = {
"provider": ["exact", "in"],
"uid": ["exact", "icontains"],
"name": ["exact", "icontains"],
"region": ["exact", "icontains", "in"],
"service": ["exact", "icontains", "in"],
"type": ["exact", "icontains", "in"],
}
def filter_tag_key(self, queryset, name, value):
return queryset.filter(Q(tags__key=value) | Q(tags__key__icontains=value))
@@ -784,12 +704,3 @@ class IntegrationFilter(FilterSet):
fields = {
"inserted_at": ["date", "gte", "lte"],
}
class ProcessorFilter(FilterSet):
processor_type = ChoiceFilter(choices=Processor.ProcessorChoices.choices)
processor_type__in = ChoiceInFilter(
choices=Processor.ProcessorChoices.choices,
field_name="processor_type",
lookup_expr="in",
)
@@ -3,7 +3,7 @@
"model": "api.user",
"pk": "8b38e2eb-6689-4f1e-a4ba-95b275130200",
"fields": {
"password": "pbkdf2_sha256$870000$Z63pGJ7nre48hfcGbk5S0O$rQpKczAmijs96xa+gPVJifpT3Fetb8DOusl5Eq6gxac=",
"password": "pbkdf2_sha256$720000$vA62S78kog2c2ytycVQdke$Fp35GVLLMyy5fUq3krSL9I02A+ocQ+RVa4S22LIAO5s=",
"last_login": null,
"name": "Devie Prowlerson",
"email": "dev@prowler.com",
@@ -16,7 +16,7 @@
"model": "api.user",
"pk": "b6493a3a-c997-489b-8b99-278bf74de9f6",
"fields": {
"password": "pbkdf2_sha256$870000$Z63pGJ7nre48hfcGbk5S0O$rQpKczAmijs96xa+gPVJifpT3Fetb8DOusl5Eq6gxac=",
"password": "pbkdf2_sha256$720000$vA62S78kog2c2ytycVQdke$Fp35GVLLMyy5fUq3krSL9I02A+ocQ+RVa4S22LIAO5s=",
"last_login": null,
"name": "Devietoo Prowlerson",
"email": "dev2@prowler.com",
@@ -24,18 +24,5 @@
"is_active": true,
"date_joined": "2024-09-18T09:04:20.850Z"
}
},
{
"model": "api.user",
"pk": "6d4f8a91-3c2e-4b5a-8f7d-1e9c5b2a4d6f",
"fields": {
"password": "pbkdf2_sha256$870000$Z63pGJ7nre48hfcGbk5S0O$rQpKczAmijs96xa+gPVJifpT3Fetb8DOusl5Eq6gxac=",
"last_login": null,
"name": "E2E Test User",
"email": "e2e@prowler.com",
"company_name": "Prowler E2E Tests",
"is_active": true,
"date_joined": "2024-01-01T00:00:00.850Z"
}
}
]
@@ -46,24 +46,5 @@
"role": "member",
"date_joined": "2024-09-19T11:03:59.712Z"
}
},
{
"model": "api.tenant",
"pk": "7c8f94a3-e2d1-4b3a-9f87-2c4d5e6f1a2b",
"fields": {
"inserted_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"name": "E2E Test Tenant"
}
},
{
"model": "api.membership",
"pk": "9b1a2c3d-4e5f-6789-abc1-23456789def0",
"fields": {
"user": "6d4f8a91-3c2e-4b5a-8f7d-1e9c5b2a4d6f",
"tenant": "7c8f94a3-e2d1-4b3a-9f87-2c4d5e6f1a2b",
"role": "owner",
"date_joined": "2024-01-01T00:00:00.000Z"
}
}
]
@@ -149,32 +149,5 @@
"user": "8b38e2eb-6689-4f1e-a4ba-95b275130200",
"inserted_at": "2024-11-20T15:36:14.302Z"
}
},
{
"model": "api.role",
"pk": "a5b6c7d8-9e0f-1234-5678-90abcdef1234",
"fields": {
"tenant": "7c8f94a3-e2d1-4b3a-9f87-2c4d5e6f1a2b",
"name": "e2e_admin",
"manage_users": true,
"manage_account": true,
"manage_billing": true,
"manage_providers": true,
"manage_integrations": true,
"manage_scans": true,
"unlimited_visibility": true,
"inserted_at": "2024-01-01T00:00:00.000Z",
"updated_at": "2024-01-01T00:00:00.000Z"
}
},
{
"model": "api.userrolerelationship",
"pk": "f1e2d3c4-b5a6-9876-5432-10fedcba9876",
"fields": {
"tenant": "7c8f94a3-e2d1-4b3a-9f87-2c4d5e6f1a2b",
"role": "a5b6c7d8-9e0f-1234-5678-90abcdef1234",
"user": "6d4f8a91-3c2e-4b5a-8f7d-1e9c5b2a4d6f",
"inserted_at": "2024-01-01T00:00:00.000Z"
}
}
]
@@ -1,61 +1,57 @@
# Generated by Django 5.1.10 on 2025-07-02 15:47
# Generated by Django 5.1.8 on 2025-05-15 09:54
import uuid
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
import api.db_utils
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0031_scan_disable_on_cascade_periodic_tasks"),
("api", "0029_findings_check_index_parent"),
]
operations = [
migrations.AlterField(
model_name="integration",
name="integration_type",
field=api.db_utils.IntegrationTypeEnumField(
choices=[
("amazon_s3", "Amazon S3"),
("aws_security_hub", "AWS Security Hub"),
("jira", "JIRA"),
("slack", "Slack"),
]
),
),
migrations.CreateModel(
name="SAMLToken",
name="SAMLDomainIndex",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("expires_at", models.DateTimeField(editable=False)),
("token", models.JSONField(unique=True)),
("email_domain", models.CharField(max_length=254, unique=True)),
(
"user",
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "saml_tokens",
"db_table": "saml_domain_index",
},
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=models.UniqueConstraint(
fields=("email_domain", "tenant"),
name="unique_resources_by_email_domain",
),
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=api.rls.BaseSecurityConstraint(
name="statements_on_samldomainindex",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.CreateModel(
name="SAMLConfiguration",
fields=[
@@ -109,42 +105,16 @@ class Migration(migrations.Migration):
fields=("tenant",), name="unique_samlconfig_per_tenant"
),
),
migrations.CreateModel(
name="SAMLDomainIndex",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("email_domain", models.CharField(max_length=254, unique=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "saml_domain_index",
},
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=models.UniqueConstraint(
fields=("email_domain", "tenant"),
name="unique_resources_by_email_domain",
),
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=api.rls.BaseSecurityConstraint(
name="statements_on_samldomainindex",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
migrations.AlterField(
model_name="integration",
name="integration_type",
field=api.db_utils.IntegrationTypeEnumField(
choices=[
("amazon_s3", "Amazon S3"),
("aws_security_hub", "AWS Security Hub"),
("jira", "JIRA"),
("slack", "Slack"),
]
),
),
]
@@ -11,7 +11,7 @@ import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0029_findings_check_index_parent"),
("api", "0030_samlconfigurations"),
]
operations = [
@@ -6,7 +6,7 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0030_lighthouseconfiguration"),
("api", "0031_lighthouseconfiguration"),
("django_celery_beat", "0019_alter_periodictasks_options"),
]
@@ -1,34 +0,0 @@
# Generated by Django 5.1.5 on 2025-03-03 15:46
from functools import partial
from django.db import migrations
from api.db_utils import PostgresEnumMigration, ProcessorTypeEnum, register_enum
from api.models import Processor
ProcessorTypeEnumMigration = PostgresEnumMigration(
enum_name="processor_type",
enum_values=tuple(
processor_type[0] for processor_type in Processor.ProcessorChoices.choices
),
)
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0032_saml"),
]
operations = [
migrations.RunPython(
ProcessorTypeEnumMigration.create_enum_type,
reverse_code=ProcessorTypeEnumMigration.drop_enum_type,
),
migrations.RunPython(
partial(register_enum, enum_class=ProcessorTypeEnum),
reverse_code=migrations.RunPython.noop,
),
]
@@ -1,88 +0,0 @@
# Generated by Django 5.1.5 on 2025-03-26 13:04
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
from api.rls import RowLevelSecurityConstraint
class Migration(migrations.Migration):
dependencies = [
("api", "0033_processors_enum"),
]
operations = [
migrations.CreateModel(
name="Processor",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"processor_type",
api.db_utils.ProcessorTypeEnumField(
choices=[("mutelist", "Mutelist")]
),
),
("configuration", models.JSONField(default=dict)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "processors",
"abstract": False,
"indexes": [
models.Index(
fields=["tenant_id", "id"], name="processor_tenant_id_idx"
),
models.Index(
fields=["tenant_id", "processor_type"],
name="processor_tenant_type_idx",
),
],
},
),
migrations.AddConstraint(
model_name="processor",
constraint=models.UniqueConstraint(
fields=("tenant_id", "processor_type"),
name="unique_processor_types_tenant",
),
),
migrations.AddConstraint(
model_name="processor",
constraint=RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_processor",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddField(
model_name="scan",
name="processor",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="scans",
related_query_name="scan",
to="api.processor",
),
),
]
@@ -1,22 +0,0 @@
import django.core.validators
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0034_processors"),
]
operations = [
migrations.AddField(
model_name="finding",
name="muted_reason",
field=models.TextField(
blank=True,
max_length=500,
null=True,
validators=[django.core.validators.MinLengthValidator(3)],
),
),
]
@@ -1,30 +0,0 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0035_finding_muted_reason"),
]
operations = [
migrations.RunPython(
partial(
create_index_on_partitions,
parent_table="resource_finding_mappings",
index_name="rfm_tenant_finding_idx",
columns="tenant_id, finding_id",
method="BTREE",
),
reverse_code=partial(
drop_index_on_partitions,
parent_table="resource_finding_mappings",
index_name="rfm_tenant_finding_idx",
),
),
]
@@ -1,17 +0,0 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0036_rfm_tenant_finding_index_partitions"),
]
operations = [
migrations.AddIndex(
model_name="resourcefindingmapping",
index=models.Index(
fields=["tenant_id", "finding_id"],
name="rfm_tenant_finding_idx",
),
),
]
@@ -1,15 +0,0 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0037_rfm_tenant_finding_index_parent"),
]
operations = [
migrations.AddField(
model_name="resource",
name="failed_findings_count",
field=models.IntegerField(default=0),
)
]
@@ -1,20 +0,0 @@
from django.contrib.postgres.operations import AddIndexConcurrently
from django.db import migrations, models
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0038_resource_failed_findings_count"),
]
operations = [
AddIndexConcurrently(
model_name="resource",
index=models.Index(
fields=["tenant_id", "-failed_findings_count", "id"],
name="resources_failed_findings_idx",
),
),
]
@@ -1,30 +0,0 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0039_resource_resources_failed_findings_idx"),
]
operations = [
migrations.RunPython(
partial(
create_index_on_partitions,
parent_table="resource_finding_mappings",
index_name="rfm_tenant_resource_idx",
columns="tenant_id, resource_id",
method="BTREE",
),
reverse_code=partial(
drop_index_on_partitions,
parent_table="resource_finding_mappings",
index_name="rfm_tenant_resource_idx",
),
),
]
@@ -1,17 +0,0 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0040_rfm_tenant_resource_index_partitions"),
]
operations = [
migrations.AddIndex(
model_name="resourcefindingmapping",
index=models.Index(
fields=["tenant_id", "resource_id"],
name="rfm_tenant_resource_idx",
),
),
]
@@ -1,23 +0,0 @@
from django.contrib.postgres.operations import AddIndexConcurrently
from django.db import migrations, models
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0041_rfm_tenant_resource_parent_partitions"),
("django_celery_beat", "0019_alter_periodictasks_options"),
]
operations = [
AddIndexConcurrently(
model_name="scan",
index=models.Index(
condition=models.Q(("state", "completed")),
fields=["tenant_id", "provider_id", "-inserted_at"],
include=("id",),
name="scans_prov_ins_desc_idx",
),
),
]
@@ -1,33 +0,0 @@
# Generated by Django 5.1.7 on 2025-07-09 14:44
from django.db import migrations
import api.db_utils
class Migration(migrations.Migration):
dependencies = [
("api", "0042_scan_scans_prov_ins_desc_idx"),
]
operations = [
migrations.AlterField(
model_name="provider",
name="provider",
field=api.db_utils.ProviderEnumField(
choices=[
("aws", "AWS"),
("azure", "Azure"),
("gcp", "GCP"),
("kubernetes", "Kubernetes"),
("m365", "M365"),
("github", "GitHub"),
],
default="aws",
),
),
migrations.RunSQL(
"ALTER TYPE provider ADD VALUE IF NOT EXISTS 'github';",
reverse_sql=migrations.RunSQL.noop,
),
]
@@ -1,19 +0,0 @@
# Generated by Django 5.1.10 on 2025-07-17 11:52
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0043_github_provider"),
]
operations = [
migrations.AddConstraint(
model_name="integration",
constraint=models.UniqueConstraint(
fields=("configuration", "tenant"),
name="unique_configuration_per_tenant",
),
),
]
@@ -1,17 +0,0 @@
# Generated by Django 5.1.10 on 2025-07-21 16:08
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0044_integration_unique_configuration_per_tenant"),
]
operations = [
migrations.AlterField(
model_name="scan",
name="output_location",
field=models.CharField(blank=True, max_length=4096, null=True),
),
]
+26 -171
View File
@@ -2,7 +2,6 @@ import json
import logging
import re
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
from allauth.socialaccount.models import SocialApp
@@ -34,7 +33,6 @@ from api.db_utils import (
IntegrationTypeEnumField,
InvitationStateEnumField,
MemberRoleEnumField,
ProcessorTypeEnumField,
ProviderEnumField,
ProviderSecretTypeEnumField,
ScanTriggerEnumField,
@@ -205,7 +203,6 @@ class Provider(RowLevelSecurityProtectedModel):
GCP = "gcp", _("GCP")
KUBERNETES = "kubernetes", _("Kubernetes")
M365 = "m365", _("M365")
GITHUB = "github", _("GitHub")
@staticmethod
def validate_aws_uid(value):
@@ -266,16 +263,6 @@ class Provider(RowLevelSecurityProtectedModel):
pointer="/data/attributes/uid",
)
@staticmethod
def validate_github_uid(value):
if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9-]{0,38}$", value):
raise ModelValidationError(
detail="GitHub provider ID must be a valid GitHub username or organization name (1-39 characters, "
"starting with alphanumeric, containing only alphanumeric characters and hyphens).",
code="github-uid",
pointer="/data/attributes/uid",
)
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
@@ -421,6 +408,20 @@ class Scan(RowLevelSecurityProtectedModel):
name = models.CharField(
blank=True, null=True, max_length=100, validators=[MinLengthValidator(3)]
)
provider = models.ForeignKey(
Provider,
on_delete=models.CASCADE,
related_name="scans",
related_query_name="scan",
)
task = models.ForeignKey(
Task,
on_delete=models.CASCADE,
related_name="scans",
related_query_name="scan",
null=True,
blank=True,
)
trigger = ScanTriggerEnumField(
choices=TriggerChoices.choices,
)
@@ -438,29 +439,9 @@ class Scan(RowLevelSecurityProtectedModel):
scheduler_task = models.ForeignKey(
PeriodicTask, on_delete=models.SET_NULL, null=True, blank=True
)
output_location = models.CharField(blank=True, null=True, max_length=4096)
provider = models.ForeignKey(
Provider,
on_delete=models.CASCADE,
related_name="scans",
related_query_name="scan",
)
task = models.ForeignKey(
Task,
on_delete=models.CASCADE,
related_name="scans",
related_query_name="scan",
null=True,
blank=True,
)
processor = models.ForeignKey(
"Processor",
on_delete=models.SET_NULL,
related_name="scans",
related_query_name="scan",
null=True,
blank=True,
)
output_location = models.CharField(blank=True, null=True, max_length=200)
# TODO: mutelist foreign key
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "scans"
@@ -487,13 +468,6 @@ class Scan(RowLevelSecurityProtectedModel):
condition=Q(state=StateChoices.COMPLETED),
name="scans_prov_state_ins_desc_idx",
),
# TODO This might replace `scans_prov_state_ins_desc_idx` completely. Review usage
models.Index(
fields=["tenant_id", "provider_id", "-inserted_at"],
condition=Q(state=StateChoices.COMPLETED),
include=["id"],
name="scans_prov_ins_desc_idx",
),
]
class JSONAPIMeta:
@@ -579,8 +553,6 @@ class Resource(RowLevelSecurityProtectedModel):
details = models.TextField(blank=True, null=True)
partition = models.TextField(blank=True, null=True)
failed_findings_count = models.IntegerField(default=0)
# Relationships
tags = models.ManyToManyField(
ResourceTag,
@@ -627,10 +599,6 @@ class Resource(RowLevelSecurityProtectedModel):
fields=["tenant_id", "provider_id"],
name="resources_tenant_provider_idx",
),
models.Index(
fields=["tenant_id", "-failed_findings_count", "id"],
name="resources_failed_findings_idx",
),
]
constraints = [
@@ -729,9 +697,6 @@ class Finding(PostgresPartitionedModel, RowLevelSecurityProtectedModel):
check_id = models.CharField(max_length=100, blank=False, null=False)
check_metadata = models.JSONField(default=dict, null=False)
muted = models.BooleanField(default=False, null=False)
muted_reason = models.TextField(
blank=True, null=True, validators=[MinLengthValidator(3)], max_length=500
)
compliance = models.JSONField(default=dict, null=True, blank=True)
# Denormalize resource data for performance
@@ -873,16 +838,6 @@ class ResourceFindingMapping(PostgresPartitionedModel, RowLevelSecurityProtected
# - tenant_id
# - id
indexes = [
models.Index(
fields=["tenant_id", "finding_id"],
name="rfm_tenant_finding_idx",
),
models.Index(
fields=["tenant_id", "resource_id"],
name="rfm_tenant_resource_idx",
),
]
constraints = [
models.UniqueConstraint(
fields=("tenant_id", "resource_id", "finding_id"),
@@ -987,11 +942,6 @@ class Invitation(RowLevelSecurityProtectedModel):
null=True,
)
def save(self, *args, **kwargs):
if self.email:
self.email = self.email.strip().lower()
super().save(*args, **kwargs)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "invitations"
@@ -1346,7 +1296,7 @@ class ScanSummary(RowLevelSecurityProtectedModel):
class Integration(RowLevelSecurityProtectedModel):
class IntegrationChoices(models.TextChoices):
AMAZON_S3 = "amazon_s3", _("Amazon S3")
S3 = "amazon_s3", _("Amazon S3")
AWS_SECURITY_HUB = "aws_security_hub", _("AWS Security Hub")
JIRA = "jira", _("JIRA")
SLACK = "slack", _("Slack")
@@ -1372,10 +1322,6 @@ class Integration(RowLevelSecurityProtectedModel):
db_table = "integrations"
constraints = [
models.UniqueConstraint(
fields=("configuration", "tenant"),
name="unique_configuration_per_tenant",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
@@ -1424,26 +1370,6 @@ class IntegrationProviderRelationship(RowLevelSecurityProtectedModel):
]
class SAMLToken(models.Model):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
expires_at = models.DateTimeField(editable=False)
token = models.JSONField(unique=True)
user = models.ForeignKey(User, on_delete=models.CASCADE)
class Meta:
db_table = "saml_tokens"
def save(self, *args, **kwargs):
if not self.expires_at:
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=15)
super().save(*args, **kwargs)
def is_expired(self) -> bool:
return datetime.now(timezone.utc) >= self.expires_at
class SAMLDomainIndex(models.Model):
"""
Public index of SAML domains. No RLS. Used for fast lookup in SAML login flow.
@@ -1521,7 +1447,7 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
),
]
def clean(self, old_email_domain=None, is_create=False):
def clean(self, old_email_domain=None):
# Domain must not contain @
if "@" in self.email_domain:
raise ValidationError({"email_domain": "Domain must not contain @"})
@@ -1545,25 +1471,6 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
{"tenant": "There is a problem with your email domain."}
)
# The entityID must be unique in the system
idp_settings = self._parsed_metadata
entity_id = idp_settings.get("entity_id")
if entity_id:
# Find any SocialApp with this entityID
q = SocialApp.objects.filter(provider="saml", provider_id=entity_id)
# If updating, exclude our own SocialApp from the check
if not is_create:
q = q.exclude(client_id=old_email_domain)
else:
q = q.exclude(client_id=self.email_domain)
if q.exists():
raise ValidationError(
{"metadata_xml": "There is a problem with your metadata."}
)
def save(self, *args, **kwargs):
self.email_domain = self.email_domain.strip().lower()
is_create = not SAMLConfiguration.objects.filter(pk=self.pk).exists()
@@ -1576,8 +1483,7 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
old_email_domain = None
old_metadata_xml = None
self._parsed_metadata = self._parse_metadata()
self.clean(old_email_domain, is_create)
self.clean(old_email_domain)
super().save(*args, **kwargs)
if is_create or (
@@ -1595,12 +1501,6 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
email_domain=self.email_domain, defaults={"tenant": self.tenant}
)
def delete(self, *args, **kwargs):
super().delete(*args, **kwargs)
SocialApp.objects.filter(provider="saml", client_id=self.email_domain).delete()
SAMLDomainIndex.objects.filter(email_domain=self.email_domain).delete()
def _parse_metadata(self):
"""
Parse the raw IdP metadata XML and extract:
@@ -1620,8 +1520,6 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
# Entity ID
entity_id = root.attrib.get("entityID")
if not entity_id:
raise ValidationError({"metadata_xml": "Missing entityID in metadata."})
# SSO endpoint (must exist)
sso = root.find(".//md:IDPSSODescriptor/md:SingleSignOnService", ns)
@@ -1660,8 +1558,9 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
Create or update the corresponding SocialApp based on email_domain.
If the domain changed, update the matching SocialApp.
"""
idp_settings = self._parse_metadata()
settings_dict = SOCIALACCOUNT_PROVIDERS["saml"].copy()
settings_dict["idp"] = self._parsed_metadata
settings_dict["idp"] = idp_settings
current_site = Site.objects.get(id=settings.SITE_ID)
@@ -1669,24 +1568,19 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
provider="saml", client_id=previous_email_domain or self.email_domain
)
client_id = self.email_domain[:191]
name = f"SAML-{self.email_domain}"[:40]
if social_app_qs.exists():
social_app = social_app_qs.first()
social_app.client_id = client_id
social_app.name = name
social_app.client_id = self.email_domain
social_app.name = f"{self.tenant.name} SAML ({self.email_domain})"
social_app.settings = settings_dict
social_app.provider_id = self._parsed_metadata["entity_id"]
social_app.save()
social_app.sites.set([current_site])
else:
social_app = SocialApp.objects.create(
provider="saml",
client_id=client_id,
name=name,
client_id=self.email_domain,
name=f"{self.tenant.name} SAML ({self.email_domain})",
settings=settings_dict,
provider_id=self._parsed_metadata["entity_id"],
)
social_app.sites.set([current_site])
@@ -1865,42 +1759,3 @@ class LighthouseConfiguration(RowLevelSecurityProtectedModel):
class JSONAPIMeta:
resource_name = "lighthouse-configurations"
class Processor(RowLevelSecurityProtectedModel):
class ProcessorChoices(models.TextChoices):
MUTELIST = "mutelist", _("Mutelist")
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
processor_type = ProcessorTypeEnumField(choices=ProcessorChoices.choices)
configuration = models.JSONField(default=dict)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "processors"
constraints = [
models.UniqueConstraint(
fields=("tenant_id", "processor_type"),
name="unique_processor_types_tenant",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
indexes = [
models.Index(
fields=["tenant_id", "id"],
name="processor_tenant_id_idx",
),
models.Index(
fields=["tenant_id", "processor_type"],
name="processor_tenant_type_idx",
),
]
class JSONAPIMeta:
resource_name = "processors"
File diff suppressed because it is too large Load Diff
@@ -11,7 +11,7 @@ def test_basic_authentication():
client = APIClient()
test_user = "test_email@prowler.com"
test_password = "Test_password@1"
test_password = "test_password"
# Check that a 401 is returned when no basic authentication is provided
no_auth_response = client.get(reverse("provider-list"))
@@ -108,7 +108,7 @@ def test_user_me_when_inviting_users(create_test_user, tenants_fixture, roles_fi
user1_email = "user1@testing.com"
user2_email = "user2@testing.com"
password = "Thisisapassword123@"
password = "thisisapassword123"
user1_response = client.post(
reverse("user-list"),
@@ -187,7 +187,7 @@ class TestTokenSwitchTenant:
client = APIClient()
test_user = "test_email@prowler.com"
test_password = "Test_password1@"
test_password = "test_password"
# Check that we can create a new user without any kind of authentication
user_creation_response = client.post(
@@ -17,7 +17,7 @@ def test_delete_provider_without_executing_task(
client = APIClient()
test_user = "test_email@prowler.com"
test_password = "Test_password1@"
test_password = "test_password"
prowler_task = tasks_fixture[0]
task_mock = Mock()
+34 -21
View File
@@ -1,10 +1,12 @@
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from allauth.socialaccount.models import SocialLogin
from django.contrib.auth import get_user_model
from api.adapters import ProwlerSocialAccountAdapter
from api.db_router import MainRouter
from api.models import Membership, SAMLConfiguration, Tenant
User = get_user_model()
@@ -25,8 +27,7 @@ class TestProwlerSocialAccountAdapter:
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.account = MagicMock()
sociallogin.provider = MagicMock()
sociallogin.provider.id = "saml"
sociallogin.account.provider = "saml"
sociallogin.account.extra_data = {}
sociallogin.user = create_test_user
sociallogin.connect = MagicMock()
@@ -45,9 +46,7 @@ class TestProwlerSocialAccountAdapter:
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.account = MagicMock()
sociallogin.provider = MagicMock()
sociallogin.user = MagicMock()
sociallogin.provider.id = "saml"
sociallogin.account.provider = "github"
sociallogin.account.extra_data = {}
sociallogin.connect = MagicMock()
@@ -55,23 +54,37 @@ class TestProwlerSocialAccountAdapter:
sociallogin.connect.assert_not_called()
def test_save_user_saml_sets_session_flag(self, rf):
def test_save_user_saml_flow(
self,
rf,
saml_setup,
saml_sociallogin,
):
adapter = ProwlerSocialAccountAdapter()
request = rf.get("/")
request.session = {}
saml_sociallogin.user.email = saml_setup["email"]
saml_sociallogin.account.extra_data = {
"firstName": [],
"lastName": [],
"organization": [],
"userType": [],
}
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.provider = MagicMock()
sociallogin.provider.id = "saml"
sociallogin.account = MagicMock()
sociallogin.account.extra_data = {}
tenant = Tenant.objects.using(MainRouter.admin_db).get(
id=saml_setup["tenant_id"]
)
saml_config = SAMLConfiguration.objects.using(MainRouter.admin_db).get(
tenant=tenant
)
assert saml_config.email_domain == saml_setup["domain"]
mock_user = MagicMock()
mock_user.id = 123
user = adapter.save_user(request, saml_sociallogin)
with patch("api.adapters.super") as mock_super:
with patch("api.adapters.transaction"):
with patch("api.adapters.MainRouter"):
mock_super.return_value.save_user.return_value = mock_user
adapter.save_user(request, sociallogin)
assert request.session["saml_user_created"] == "123"
assert user.name == "N/A"
assert user.company_name == ""
assert user.email == saml_setup["email"]
assert (
Membership.objects.using(MainRouter.admin_db)
.filter(user=user, tenant=tenant)
.exists()
)
@@ -13,7 +13,6 @@ from api.db_utils import (
enum_to_choices,
generate_random_token,
one_week_from_now,
update_objects_in_batches,
)
from api.models import Provider
@@ -228,88 +227,3 @@ class TestCreateObjectsInBatches:
qs = Provider.objects.filter(tenant=tenant)
assert qs.count() == total
@pytest.mark.django_db
class TestUpdateObjectsInBatches:
@pytest.fixture
def tenant(self, tenants_fixture):
return tenants_fixture[0]
def make_provider_instances(self, tenant, count):
"""
Return a list of `count` unsaved Provider instances for the given tenant.
"""
base_uid = 2000
return [
Provider(
tenant=tenant,
uid=str(base_uid + i),
provider=Provider.ProviderChoices.AWS,
)
for i in range(count)
]
def test_exact_multiple_of_batch(self, tenant):
total = 6
batch_size = 3
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs, batch_size=batch_size)
# Fetch them back, mutate the `uid` field, then update in batches
providers = list(Provider.objects.filter(tenant=tenant))
for p in providers:
p.uid = f"{p.uid}_upd"
update_objects_in_batches(
tenant_id=str(tenant.id),
model=Provider,
objects=providers,
fields=["uid"],
batch_size=batch_size,
)
qs = Provider.objects.filter(tenant=tenant, uid__endswith="_upd")
assert qs.count() == total
def test_non_multiple_of_batch(self, tenant):
total = 7
batch_size = 3
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs, batch_size=batch_size)
providers = list(Provider.objects.filter(tenant=tenant))
for p in providers:
p.uid = f"{p.uid}_upd"
update_objects_in_batches(
tenant_id=str(tenant.id),
model=Provider,
objects=providers,
fields=["uid"],
batch_size=batch_size,
)
qs = Provider.objects.filter(tenant=tenant, uid__endswith="_upd")
assert qs.count() == total
def test_batch_size_default(self, tenant):
default_size = settings.DJANGO_DELETION_BATCH_SIZE
total = default_size + 2
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs)
providers = list(Provider.objects.filter(tenant=tenant))
for p in providers:
p.uid = f"{p.uid}_upd"
# Update without specifying batch_size (uses default)
update_objects_in_batches(
tenant_id=str(tenant.id),
model=Provider,
objects=providers,
fields=["uid"],
)
qs = Provider.objects.filter(tenant=tenant, uid__endswith="_upd")
assert qs.count() == total
+16 -71
View File
@@ -3,7 +3,7 @@ from allauth.socialaccount.models import SocialApp
from django.core.exceptions import ValidationError
from api.db_router import MainRouter
from api.models import Resource, ResourceTag, SAMLConfiguration, SAMLDomainIndex
from api.models import Resource, ResourceTag, SAMLConfiguration, Tenant
@pytest.mark.django_db
@@ -142,8 +142,8 @@ class TestSAMLConfigurationModel:
</md:EntityDescriptor>
"""
def test_creates_valid_configuration(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_creates_valid_configuration(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant A")
config = SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="ssoexample.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
@@ -153,8 +153,8 @@ class TestSAMLConfigurationModel:
assert config.email_domain == "ssoexample.com"
assert SocialApp.objects.filter(client_id="ssoexample.com").exists()
def test_email_domain_with_at_symbol_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_email_domain_with_at_symbol_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant B")
config = SAMLConfiguration(
email_domain="invalid@domain.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
@@ -168,8 +168,9 @@ class TestSAMLConfigurationModel:
assert "email_domain" in errors
assert "Domain must not contain @" in errors["email_domain"][0]
def test_duplicate_email_domain_fails(self, tenants_fixture):
tenant1, tenant2, *_ = tenants_fixture
def test_duplicate_email_domain_fails(self):
tenant1 = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant C1")
tenant2 = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant C2")
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="duplicate.com",
@@ -190,8 +191,8 @@ class TestSAMLConfigurationModel:
assert "tenant" in errors
assert "There is a problem with your email domain." in errors["tenant"][0]
def test_duplicate_tenant_config_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_duplicate_tenant_config_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant D")
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="unique1.com",
@@ -215,8 +216,8 @@ class TestSAMLConfigurationModel:
in errors["tenant"][0]
)
def test_invalid_metadata_xml_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_invalid_metadata_xml_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant E")
config = SAMLConfiguration(
email_domain="brokenxml.com",
metadata_xml="<bad<xml>",
@@ -231,8 +232,8 @@ class TestSAMLConfigurationModel:
assert "Invalid XML" in errors["metadata_xml"][0]
assert "not well-formed" in errors["metadata_xml"][0]
def test_metadata_missing_sso_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_metadata_missing_sso_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant F")
xml = """<md:EntityDescriptor entityID="x" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
<md:IDPSSODescriptor></md:IDPSSODescriptor>
</md:EntityDescriptor>"""
@@ -249,8 +250,8 @@ class TestSAMLConfigurationModel:
assert "metadata_xml" in errors
assert "Missing SingleSignOnService" in errors["metadata_xml"][0]
def test_metadata_missing_certificate_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_metadata_missing_certificate_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant G")
xml = """<md:EntityDescriptor entityID="x" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
<md:IDPSSODescriptor>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://example.com/sso"/>
@@ -268,59 +269,3 @@ class TestSAMLConfigurationModel:
errors = exc_info.value.message_dict
assert "metadata_xml" in errors
assert "X509Certificate" in errors["metadata_xml"][0]
def test_deletes_saml_configuration_and_related_objects(self, tenants_fixture):
tenant = tenants_fixture[0]
email_domain = "deleteme.com"
# Create the configuration
config = SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain=email_domain,
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
tenant=tenant,
)
# Verify that the SocialApp and SAMLDomainIndex exist
assert SocialApp.objects.filter(client_id=email_domain).exists()
assert (
SAMLDomainIndex.objects.using(MainRouter.admin_db)
.filter(email_domain=email_domain)
.exists()
)
# Delete the configuration
config.delete()
# Verify that the configuration and its related objects are deleted
assert (
not SAMLConfiguration.objects.using(MainRouter.admin_db)
.filter(pk=config.pk)
.exists()
)
assert not SocialApp.objects.filter(client_id=email_domain).exists()
assert (
not SAMLDomainIndex.objects.using(MainRouter.admin_db)
.filter(email_domain=email_domain)
.exists()
)
def test_duplicate_entity_id_fails_on_creation(self, tenants_fixture):
tenant1, tenant2, *_ = tenants_fixture
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="first.com",
metadata_xml=self.VALID_METADATA,
tenant=tenant1,
)
config = SAMLConfiguration(
email_domain="second.com",
metadata_xml=self.VALID_METADATA,
tenant=tenant2,
)
with pytest.raises(ValidationError) as exc_info:
config.save()
errors = exc_info.value.message_dict
assert "metadata_xml" in errors
assert "There is a problem with your metadata." in errors["metadata_xml"][0]
+3 -88
View File
@@ -1,7 +1,6 @@
from unittest.mock import ANY, Mock, patch
import pytest
from conftest import TODAY
from django.urls import reverse
from rest_framework import status
@@ -61,7 +60,7 @@ class TestUserViewSet:
def test_create_user_with_all_permissions(self, authenticated_client_rbac):
valid_user_payload = {
"name": "test",
"password": "Newpassword123@",
"password": "newpassword123",
"email": "new_user@test.com",
}
response = authenticated_client_rbac.post(
@@ -75,7 +74,7 @@ class TestUserViewSet:
):
valid_user_payload = {
"name": "test",
"password": "Newpassword123@",
"password": "newpassword123",
"email": "new_user@test.com",
}
response = authenticated_client_no_permissions_rbac.post(
@@ -322,7 +321,7 @@ class TestProviderViewSet:
@pytest.mark.django_db
class TestLimitedVisibility:
TEST_EMAIL = "rbac@rbac.com"
TEST_PASSWORD = "Thisisapassword123@"
TEST_PASSWORD = "thisisapassword123"
@pytest.fixture
def limited_admin_user(
@@ -410,87 +409,3 @@ class TestLimitedVisibility:
assert (
response.json()["data"]["relationships"]["providers"]["meta"]["count"] == 1
)
def test_overviews_providers(
self,
authenticated_client_rbac_limited,
scan_summaries_fixture,
providers_fixture,
):
# By default, the associated provider is the one which has the overview data
response = authenticated_client_rbac_limited.get(reverse("overview-providers"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) > 0
# Changing the provider visibility, no data should be returned
# Only the associated provider to that group is changed
new_provider = providers_fixture[1]
ProviderGroupMembership.objects.all().update(provider=new_provider)
response = authenticated_client_rbac_limited.get(reverse("overview-providers"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 0
@pytest.mark.parametrize(
"endpoint_name",
[
"findings",
"findings_severity",
],
)
def test_overviews_findings(
self,
endpoint_name,
authenticated_client_rbac_limited,
scan_summaries_fixture,
providers_fixture,
):
# By default, the associated provider is the one which has the overview data
response = authenticated_client_rbac_limited.get(
reverse(f"overview-{endpoint_name}")
)
assert response.status_code == status.HTTP_200_OK
values = response.json()["data"]["attributes"].values()
assert any(value > 0 for value in values)
# Changing the provider visibility, no data should be returned
# Only the associated provider to that group is changed
new_provider = providers_fixture[1]
ProviderGroupMembership.objects.all().update(provider=new_provider)
response = authenticated_client_rbac_limited.get(
reverse(f"overview-{endpoint_name}")
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]["attributes"].values()
assert all(value == 0 for value in data)
def test_overviews_services(
self,
authenticated_client_rbac_limited,
scan_summaries_fixture,
providers_fixture,
):
# By default, the associated provider is the one which has the overview data
response = authenticated_client_rbac_limited.get(
reverse("overview-services"), {"filter[inserted_at]": TODAY}
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) > 0
# Changing the provider visibility, no data should be returned
# Only the associated provider to that group is changed
new_provider = providers_fixture[1]
ProviderGroupMembership.objects.all().update(provider=new_provider)
response = authenticated_client_rbac_limited.get(
reverse("overview-services"), {"filter[inserted_at]": TODAY}
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 0
@@ -1,100 +0,0 @@
import pytest
from rest_framework.exceptions import ValidationError
from api.v1.serializer_utils.integrations import S3ConfigSerializer
class TestS3ConfigSerializer:
"""Test cases for S3ConfigSerializer validation."""
def test_validate_output_directory_valid_paths(self):
"""Test that valid output directory paths are accepted."""
serializer = S3ConfigSerializer()
# Test normal paths
assert serializer.validate_output_directory("test") == "test"
assert serializer.validate_output_directory("test/folder") == "test/folder"
assert serializer.validate_output_directory("my-folder_123") == "my-folder_123"
# Test paths with leading slashes (should be normalized)
assert serializer.validate_output_directory("/test") == "test"
assert serializer.validate_output_directory("/test/folder") == "test/folder"
# Test paths with excessive slashes (should be normalized)
assert serializer.validate_output_directory("///test") == "test"
assert serializer.validate_output_directory("///////test") == "test"
assert serializer.validate_output_directory("test//folder") == "test/folder"
assert serializer.validate_output_directory("test///folder") == "test/folder"
def test_validate_output_directory_empty_values(self):
"""Test that empty values raise validation errors."""
serializer = S3ConfigSerializer()
with pytest.raises(
ValidationError, match="Output directory cannot be empty or just"
):
serializer.validate_output_directory(".")
with pytest.raises(
ValidationError, match="Output directory cannot be empty or just"
):
serializer.validate_output_directory("/")
def test_validate_output_directory_invalid_characters(self):
"""Test that invalid characters are rejected."""
serializer = S3ConfigSerializer()
invalid_chars = ["<", ">", ":", '"', "|", "?", "*"]
for char in invalid_chars:
with pytest.raises(
ValidationError, match="Output directory contains invalid characters"
):
serializer.validate_output_directory(f"test{char}folder")
def test_validate_output_directory_too_long(self):
"""Test that paths that are too long are rejected."""
serializer = S3ConfigSerializer()
# Create a path longer than 900 characters
long_path = "a" * 901
with pytest.raises(ValidationError, match="Output directory path is too long"):
serializer.validate_output_directory(long_path)
def test_validate_output_directory_edge_cases(self):
"""Test edge cases for output directory validation."""
serializer = S3ConfigSerializer()
# Test path at the limit (900 characters)
path_at_limit = "a" * 900
assert serializer.validate_output_directory(path_at_limit) == path_at_limit
# Test complex normalization
assert serializer.validate_output_directory("//test/../folder//") == "folder"
assert serializer.validate_output_directory("/test/./folder/") == "test/folder"
def test_s3_config_serializer_full_validation(self):
"""Test the full S3ConfigSerializer with valid data."""
data = {
"bucket_name": "my-test-bucket",
"output_directory": "///////test", # This should be normalized
}
serializer = S3ConfigSerializer(data=data)
assert serializer.is_valid()
validated_data = serializer.validated_data
assert validated_data["bucket_name"] == "my-test-bucket"
assert validated_data["output_directory"] == "test" # Normalized
def test_s3_config_serializer_invalid_data(self):
"""Test the full S3ConfigSerializer with invalid data."""
data = {
"bucket_name": "my-test-bucket",
"output_directory": "test<invalid", # Contains invalid character
}
serializer = S3ConfigSerializer(data=data)
assert not serializer.is_valid()
assert "output_directory" in serializer.errors
+4 -60
View File
@@ -131,21 +131,6 @@ class TestInitializeProwlerProvider:
initialize_prowler_provider(provider)
mock_return_prowler_provider.return_value.assert_called_once_with(key="value")
@patch("api.utils.return_prowler_provider")
def test_initialize_prowler_provider_with_mutelist(
self, mock_return_prowler_provider
):
provider = MagicMock()
provider.secret.secret = {"key": "value"}
mutelist_processor = MagicMock()
mutelist_processor.configuration = {"Mutelist": {"key": "value"}}
mock_return_prowler_provider.return_value = MagicMock()
initialize_prowler_provider(provider, mutelist_processor)
mock_return_prowler_provider.return_value.assert_called_once_with(
key="value", mutelist_content={"key": "value"}
)
class TestProwlerProviderConnectionTest:
@patch("api.utils.return_prowler_provider")
@@ -215,25 +200,6 @@ class TestGetProwlerProviderKwargs:
expected_result = {**secret_dict, **expected_extra_kwargs}
assert result == expected_result
def test_get_prowler_provider_kwargs_with_mutelist(self):
provider_uid = "provider_uid"
secret_dict = {"key": "value"}
secret_mock = MagicMock()
secret_mock.secret = secret_dict
mutelist_processor = MagicMock()
mutelist_processor.configuration = {"Mutelist": {"key": "value"}}
provider = MagicMock()
provider.provider = Provider.ProviderChoices.AWS.value
provider.secret = secret_mock
provider.uid = provider_uid
result = get_prowler_provider_kwargs(provider, mutelist_processor)
expected_result = {**secret_dict, "mutelist_content": {"key": "value"}}
assert result == expected_result
def test_get_prowler_provider_kwargs_unsupported_provider(self):
# Setup
provider_uid = "provider_uid"
@@ -288,7 +254,7 @@ class TestValidateInvitation:
assert result == invitation
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="user@example.com"
token="VALID_TOKEN", email="user@example.com"
)
def test_invitation_not_found_raises_validation_error(self):
@@ -303,7 +269,7 @@ class TestValidateInvitation:
"invitation_token": "Invalid invitation code."
}
mock_db.get.assert_called_once_with(
token="INVALID_TOKEN", email__iexact="user@example.com"
token="INVALID_TOKEN", email="user@example.com"
)
def test_invitation_not_found_raises_not_found(self):
@@ -318,7 +284,7 @@ class TestValidateInvitation:
assert exc_info.value.detail == "Invitation is not valid."
mock_db.get.assert_called_once_with(
token="INVALID_TOKEN", email__iexact="user@example.com"
token="INVALID_TOKEN", email="user@example.com"
)
def test_invitation_expired(self, invitation):
@@ -366,27 +332,5 @@ class TestValidateInvitation:
"invitation_token": "Invalid invitation code."
}
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="different@example.com"
)
def test_valid_invitation_uppercase_email(self):
"""Test that validate_invitation works with case-insensitive email lookup."""
uppercase_email = "USER@example.com"
invitation = MagicMock(spec=Invitation)
invitation.token = "VALID_TOKEN"
invitation.email = uppercase_email
invitation.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
invitation.state = Invitation.State.PENDING
invitation.tenant = MagicMock()
with patch("api.utils.Invitation.objects.using") as mock_using:
mock_db = mock_using.return_value
mock_db.get.return_value = invitation
result = validate_invitation("VALID_TOKEN", "user@example.com")
assert result == invitation
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="user@example.com"
token="VALID_TOKEN", email="different@example.com"
)
File diff suppressed because it is too large Load Diff
+9 -69
View File
@@ -7,14 +7,12 @@ from rest_framework.exceptions import NotFound, ValidationError
from api.db_router import MainRouter
from api.exceptions import InvitationTokenExpiredException
from api.models import Integration, Invitation, Processor, Provider, Resource
from api.models import Invitation, Provider, Resource
from api.v1.serializers import FindingMetadataSerializer
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.lib.s3.s3 import S3
from prowler.providers.azure.azure_provider import AzureProvider
from prowler.providers.common.models import Connection
from prowler.providers.gcp.gcp_provider import GcpProvider
from prowler.providers.github.github_provider import GithubProvider
from prowler.providers.kubernetes.kubernetes_provider import KubernetesProvider
from prowler.providers.m365.m365_provider import M365Provider
@@ -57,21 +55,14 @@ def merge_dicts(default_dict: dict, replacement_dict: dict) -> dict:
def return_prowler_provider(
provider: Provider,
) -> [
AwsProvider
| AzureProvider
| GcpProvider
| GithubProvider
| KubernetesProvider
| M365Provider
]:
) -> [AwsProvider | AzureProvider | GcpProvider | KubernetesProvider | M365Provider]:
"""Return the Prowler provider class based on the given provider type.
Args:
provider (Provider): The provider object containing the provider type and associated secrets.
Returns:
AwsProvider | AzureProvider | GcpProvider | GithubProvider | KubernetesProvider | M365Provider: The corresponding provider class.
AwsProvider | AzureProvider | GcpProvider | KubernetesProvider | M365Provider: The corresponding provider class.
Raises:
ValueError: If the provider type specified in `provider.provider` is not supported.
@@ -87,21 +78,16 @@ def return_prowler_provider(
prowler_provider = KubernetesProvider
case Provider.ProviderChoices.M365.value:
prowler_provider = M365Provider
case Provider.ProviderChoices.GITHUB.value:
prowler_provider = GithubProvider
case _:
raise ValueError(f"Provider type {provider.provider} not supported")
return prowler_provider
def get_prowler_provider_kwargs(
provider: Provider, mutelist_processor: Processor | None = None
) -> dict:
def get_prowler_provider_kwargs(provider: Provider) -> dict:
"""Get the Prowler provider kwargs based on the given provider type.
Args:
provider (Provider): The provider object containing the provider type and associated secret.
mutelist_processor (Processor): The mutelist processor object containing the mutelist configuration.
Returns:
dict: The provider kwargs for the corresponding provider class.
@@ -119,39 +105,24 @@ def get_prowler_provider_kwargs(
}
elif provider.provider == Provider.ProviderChoices.KUBERNETES.value:
prowler_provider_kwargs = {**prowler_provider_kwargs, "context": provider.uid}
if mutelist_processor:
mutelist_content = mutelist_processor.configuration.get("Mutelist", {})
if mutelist_content:
prowler_provider_kwargs["mutelist_content"] = mutelist_content
return prowler_provider_kwargs
def initialize_prowler_provider(
provider: Provider,
mutelist_processor: Processor | None = None,
) -> (
AwsProvider
| AzureProvider
| GcpProvider
| GithubProvider
| KubernetesProvider
| M365Provider
):
) -> AwsProvider | AzureProvider | GcpProvider | KubernetesProvider | M365Provider:
"""Initialize a Prowler provider instance based on the given provider type.
Args:
provider (Provider): The provider object containing the provider type and associated secrets.
mutelist_processor (Processor): The mutelist processor object containing the mutelist configuration.
Returns:
AwsProvider | AzureProvider | GcpProvider | GithubProvider | KubernetesProvider | M365Provider: An instance of the corresponding provider class
(`AwsProvider`, `AzureProvider`, `GcpProvider`, `GithubProvider`, `KubernetesProvider` or `M365Provider`) initialized with the
AwsProvider | AzureProvider | GcpProvider | KubernetesProvider | M365Provider: An instance of the corresponding provider class
(`AwsProvider`, `AzureProvider`, `GcpProvider`, `KubernetesProvider` or `M365Provider`) initialized with the
provider's secrets.
"""
prowler_provider = return_prowler_provider(provider)
prowler_provider_kwargs = get_prowler_provider_kwargs(provider, mutelist_processor)
prowler_provider_kwargs = get_prowler_provider_kwargs(provider)
return prowler_provider(**prowler_provider_kwargs)
@@ -176,37 +147,6 @@ def prowler_provider_connection_test(provider: Provider) -> Connection:
)
def prowler_integration_connection_test(integration: Integration) -> Connection:
"""Test the connection to a Prowler integration based on the given integration type.
Args:
integration (Integration): The integration object containing the integration type and associated credentials.
Returns:
Connection: A connection object representing the result of the connection test for the specified integration.
"""
if integration.integration_type == Integration.IntegrationChoices.AMAZON_S3:
return S3.test_connection(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
raise_on_exception=False,
)
# TODO: It is possible that we can unify the connection test for all integrations, but need refactoring
# to avoid code duplication. Actually the AWS integrations are similar, so SecurityHub and S3 can be unified making some changes in the SDK.
elif (
integration.integration_type == Integration.IntegrationChoices.AWS_SECURITY_HUB
):
pass
elif integration.integration_type == Integration.IntegrationChoices.JIRA:
pass
elif integration.integration_type == Integration.IntegrationChoices.SLACK:
pass
else:
raise ValueError(
f"Integration type {integration.integration_type} not supported"
)
def validate_invitation(
invitation_token: str, email: str, raise_not_found=False
) -> Invitation:
@@ -247,7 +187,7 @@ def validate_invitation(
# Admin DB connector is used to bypass RLS protection since the invitation belongs to a tenant the user
# is not a member of yet
invitation = Invitation.objects.using(MainRouter.admin_db).get(
token=invitation_token, email__iexact=email
token=invitation_token, email=email
)
except Invitation.DoesNotExist:
if raise_not_found:
+2 -14
View File
@@ -24,32 +24,20 @@ class PaginateByPkMixin:
request, # noqa: F841
base_queryset,
manager,
select_related: list | None = None,
prefetch_related: list | None = None,
select_related: list[str] | None = None,
prefetch_related: list[str] | None = None,
) -> Response:
"""
Paginate a queryset by primary key.
This method is useful when you want to paginate a queryset that has been
filtered or annotated in a way that would be lost if you used the default
pagination method.
"""
pk_list = base_queryset.values_list("id", flat=True)
page = self.paginate_queryset(pk_list)
if page is None:
return Response(self.get_serializer(base_queryset, many=True).data)
queryset = manager.filter(id__in=page)
if select_related:
queryset = queryset.select_related(*select_related)
if prefetch_related:
queryset = queryset.prefetch_related(*prefetch_related)
# Optimize tags loading, if applicable
if hasattr(self, "_optimize_tags_loading"):
queryset = self._optimize_tags_loading(queryset)
queryset = sorted(queryset, key=lambda obj: page.index(obj.id))
serialized = self.get_serializer(queryset, many=True).data
@@ -1,23 +0,0 @@
import yaml
from rest_framework_json_api import serializers
from rest_framework_json_api.serializers import ValidationError
class BaseValidateSerializer(serializers.Serializer):
def validate(self, data):
if hasattr(self, "initial_data"):
initial_data = set(self.initial_data.keys()) - {"id", "type"}
unknown_keys = initial_data - set(self.fields.keys())
if unknown_keys:
raise ValidationError(f"Invalid fields: {unknown_keys}")
return data
class YamlOrJsonField(serializers.JSONField):
def to_internal_value(self, data):
if isinstance(data, str):
try:
data = yaml.safe_load(data)
except yaml.YAMLError as exc:
raise serializers.ValidationError("Invalid YAML format") from exc
return super().to_internal_value(data)
@@ -1,52 +1,24 @@
import os
import re
from drf_spectacular.utils import extend_schema_field
from rest_framework_json_api import serializers
from rest_framework_json_api.serializers import ValidationError
from api.v1.serializer_utils.base import BaseValidateSerializer
class BaseValidateSerializer(serializers.Serializer):
def validate(self, data):
if hasattr(self, "initial_data"):
initial_data = set(self.initial_data.keys()) - {"id", "type"}
unknown_keys = initial_data - set(self.fields.keys())
if unknown_keys:
raise ValidationError(f"Invalid fields: {unknown_keys}")
return data
# Integrations
class S3ConfigSerializer(BaseValidateSerializer):
bucket_name = serializers.CharField()
output_directory = serializers.CharField(allow_blank=True)
def validate_output_directory(self, value):
"""
Validate the output_directory field to ensure it's a properly formatted path.
Prevents paths with excessive slashes like "///////test".
If empty, sets a default value.
"""
# If empty or None, set default value
if not value:
return "output"
# Normalize the path to remove excessive slashes
normalized_path = os.path.normpath(value)
# Remove leading slashes for S3 paths
if normalized_path.startswith("/"):
normalized_path = normalized_path.lstrip("/")
# Check for invalid characters or patterns
if re.search(r'[<>:"|?*]', normalized_path):
raise serializers.ValidationError(
'Output directory contains invalid characters. Avoid: < > : " | ? *'
)
# Check for empty path after normalization
if not normalized_path or normalized_path == ".":
raise serializers.ValidationError(
"Output directory cannot be empty or just '.' or '/'."
)
# Check for paths that are too long (S3 key limit is 1024 characters, leave some room for filename)
if len(normalized_path) > 900:
raise serializers.ValidationError(
"Output directory path is too long (max 900 characters)."
)
return normalized_path
output_directory = serializers.CharField()
class Meta:
resource_name = "integrations"
@@ -138,13 +110,10 @@ class IntegrationCredentialField(serializers.JSONField):
},
"output_directory": {
"type": "string",
"description": 'The directory path within the bucket where files will be saved. Optional - defaults to "output" if not provided. Path will be normalized to remove excessive slashes and invalid characters are not allowed (< > : " | ? *). Maximum length is 900 characters.',
"maxLength": 900,
"pattern": '^[^<>:"|?*]+$',
"default": "output",
"description": "The directory path within the bucket where files will be saved.",
},
},
"required": ["bucket_name"],
"required": ["bucket_name", "output_directory"],
},
]
}
@@ -1,21 +0,0 @@
from drf_spectacular.utils import extend_schema_field
from api.v1.serializer_utils.base import YamlOrJsonField
from prowler.lib.mutelist.mutelist import mutelist_schema
@extend_schema_field(
{
"oneOf": [
{
"type": "object",
"title": "Mutelist",
"properties": {"Mutelist": mutelist_schema},
"additionalProperties": False,
},
]
}
)
class ProcessorConfigField(YamlOrJsonField):
pass
@@ -176,43 +176,6 @@ from rest_framework_json_api import serializers
},
"required": ["kubeconfig_content"],
},
{
"type": "object",
"title": "GitHub Personal Access Token",
"properties": {
"personal_access_token": {
"type": "string",
"description": "GitHub personal access token for authentication.",
}
},
"required": ["personal_access_token"],
},
{
"type": "object",
"title": "GitHub OAuth App Token",
"properties": {
"oauth_app_token": {
"type": "string",
"description": "GitHub OAuth App token for authentication.",
}
},
"required": ["oauth_app_token"],
},
{
"type": "object",
"title": "GitHub App Credentials",
"properties": {
"github_app_id": {
"type": "integer",
"description": "GitHub App ID for authentication.",
},
"github_app_key": {
"type": "string",
"description": "Path to the GitHub App private key file.",
},
},
"required": ["github_app_id", "github_app_key"],
},
]
}
)
+12 -227
View File
@@ -7,9 +7,7 @@ from django.contrib.auth.models import update_last_login
from django.contrib.auth.password_validation import validate_password
from drf_spectacular.utils import extend_schema_field
from jwt.exceptions import InvalidKeyError
from rest_framework.validators import UniqueTogetherValidator
from rest_framework_json_api import serializers
from rest_framework_json_api.relations import SerializerMethodResourceRelatedField
from rest_framework_json_api.serializers import ValidationError
from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
@@ -23,7 +21,6 @@ from api.models import (
InvitationRoleRelationship,
LighthouseConfiguration,
Membership,
Processor,
Provider,
ProviderGroup,
ProviderGroupMembership,
@@ -47,9 +44,7 @@ from api.v1.serializer_utils.integrations import (
IntegrationCredentialField,
S3ConfigSerializer,
)
from api.v1.serializer_utils.processors import ProcessorConfigField
from api.v1.serializer_utils.providers import ProviderSecretField
from prowler.lib.mutelist.mutelist import Mutelist
# Tokens
@@ -135,12 +130,6 @@ class TokenSerializer(BaseTokenSerializer):
class TokenSocialLoginSerializer(BaseTokenSerializer):
email = serializers.EmailField(write_only=True)
tenant_id = serializers.UUIDField(
write_only=True,
required=False,
help_text="If not provided, the tenant ID of the first membership that was added"
" to the user will be used.",
)
# Output tokens
refresh = serializers.CharField(read_only=True)
@@ -862,7 +851,6 @@ class ScanSerializer(RLSSerializer):
"completed_at",
"scheduled_at",
"next_scan_at",
"processor",
"url",
]
@@ -1000,12 +988,8 @@ class ResourceSerializer(RLSSerializer):
tags = serializers.SerializerMethodField()
type_ = serializers.CharField(read_only=True)
failed_findings_count = serializers.IntegerField(read_only=True)
findings = SerializerMethodResourceRelatedField(
many=True,
read_only=True,
)
findings = serializers.ResourceRelatedField(many=True, read_only=True)
class Meta:
model = Resource
@@ -1021,7 +1005,6 @@ class ResourceSerializer(RLSSerializer):
"tags",
"provider",
"findings",
"failed_findings_count",
"url",
]
extra_kwargs = {
@@ -1031,8 +1014,8 @@ class ResourceSerializer(RLSSerializer):
}
included_serializers = {
"findings": "api.v1.serializers.FindingIncludeSerializer",
"provider": "api.v1.serializers.ProviderIncludeSerializer",
"findings": "api.v1.serializers.FindingSerializer",
"provider": "api.v1.serializers.ProviderSerializer",
}
@extend_schema_field(
@@ -1043,10 +1026,6 @@ class ResourceSerializer(RLSSerializer):
}
)
def get_tags(self, obj):
# Use prefetched tags if available to avoid N+1 queries
if hasattr(obj, "prefetched_tags"):
return {tag.key: tag.value for tag in obj.prefetched_tags}
# Fallback to the original method if prefetch is not available
return obj.get_tags(self.context.get("tenant_id"))
def get_fields(self):
@@ -1056,17 +1035,10 @@ class ResourceSerializer(RLSSerializer):
fields["type"] = type_
return fields
def get_findings(self, obj):
return (
obj.latest_findings
if hasattr(obj, "latest_findings")
else obj.findings.all()
)
class ResourceIncludeSerializer(RLSSerializer):
"""
Serializer for the included Resource model.
Serializer for the Resource model.
"""
tags = serializers.SerializerMethodField()
@@ -1099,10 +1071,6 @@ class ResourceIncludeSerializer(RLSSerializer):
}
)
def get_tags(self, obj):
# Use prefetched tags if available to avoid N+1 queries
if hasattr(obj, "prefetched_tags"):
return {tag.key: tag.value for tag in obj.prefetched_tags}
# Fallback to the original method if prefetch is not available
return obj.get_tags(self.context.get("tenant_id"))
def get_fields(self):
@@ -1113,17 +1081,6 @@ class ResourceIncludeSerializer(RLSSerializer):
return fields
class ResourceMetadataSerializer(serializers.Serializer):
services = serializers.ListField(child=serializers.CharField(), allow_empty=True)
regions = serializers.ListField(child=serializers.CharField(), allow_empty=True)
types = serializers.ListField(child=serializers.CharField(), allow_empty=True)
# Temporarily disabled until we implement tag filtering in the UI
# tags = serializers.JSONField(help_text="Tags are described as key-value pairs.")
class Meta:
resource_name = "resources-metadata"
class FindingSerializer(RLSSerializer):
"""
Serializer for the Finding model.
@@ -1147,7 +1104,6 @@ class FindingSerializer(RLSSerializer):
"updated_at",
"first_seen_at",
"muted",
"muted_reason",
"url",
# Relationships
"scan",
@@ -1160,28 +1116,6 @@ class FindingSerializer(RLSSerializer):
}
class FindingIncludeSerializer(RLSSerializer):
"""
Serializer for the include Finding model.
"""
class Meta:
model = Finding
fields = [
"id",
"uid",
"status",
"severity",
"check_id",
"check_metadata",
"inserted_at",
"updated_at",
"first_seen_at",
"muted",
"muted_reason",
]
# To be removed when the related endpoint is removed as well
class FindingDynamicFilterSerializer(serializers.Serializer):
services = serializers.ListField(child=serializers.CharField(), allow_empty=True)
@@ -1217,8 +1151,6 @@ class BaseWriteProviderSecretSerializer(BaseWriteSerializer):
serializer = AzureProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.GCP.value:
serializer = GCPProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.GITHUB.value:
serializer = GithubProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.KUBERNETES.value:
serializer = KubernetesProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.M365.value:
@@ -1268,8 +1200,8 @@ class M365ProviderSecret(serializers.Serializer):
client_id = serializers.CharField()
client_secret = serializers.CharField()
tenant_id = serializers.CharField()
user = serializers.EmailField(required=False)
password = serializers.CharField(required=False)
user = serializers.EmailField()
password = serializers.CharField()
class Meta:
resource_name = "provider-secrets"
@@ -1298,16 +1230,6 @@ class KubernetesProviderSecret(serializers.Serializer):
resource_name = "provider-secrets"
class GithubProviderSecret(serializers.Serializer):
personal_access_token = serializers.CharField(required=False)
oauth_app_token = serializers.CharField(required=False)
github_app_id = serializers.IntegerField(required=False)
github_app_key_content = serializers.CharField(required=False)
class Meta:
resource_name = "provider-secrets"
class AWSRoleAssumptionProviderSecret(serializers.Serializer):
role_arn = serializers.CharField()
external_id = serializers.CharField()
@@ -1387,13 +1309,12 @@ class ProviderSecretUpdateSerializer(BaseWriteProviderSecretSerializer):
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
"provider": {"read_only": True},
"secret_type": {"required": False},
"secret_type": {"read_only": True},
}
def validate(self, attrs):
provider = self.instance.provider
# To allow updating a secret with the same type without making the `secret_type` mandatory
secret_type = attrs.get("secret_type") or self.instance.secret_type
secret_type = self.instance.secret_type
secret = attrs.get("secret")
validated_attrs = super().validate(attrs)
@@ -1950,16 +1871,6 @@ class ScheduleDailyCreateSerializer(serializers.Serializer):
class BaseWriteIntegrationSerializer(BaseWriteSerializer):
def validate(self, attrs):
if Integration.objects.filter(
configuration=attrs.get("configuration")
).exists():
raise serializers.ValidationError(
{"name": "This integration already exists."}
)
return super().validate(attrs)
@staticmethod
def validate_integration_data(
integration_type: str,
@@ -1967,7 +1878,7 @@ class BaseWriteIntegrationSerializer(BaseWriteSerializer):
configuration: dict,
credentials: dict,
):
if integration_type == Integration.IntegrationChoices.AMAZON_S3:
if integration_type == Integration.IntegrationChoices.S3:
config_serializer = S3ConfigSerializer
credentials_serializers = [AWSCredentialSerializer]
# TODO: This will be required for AWS Security Hub
@@ -1985,11 +1896,7 @@ class BaseWriteIntegrationSerializer(BaseWriteSerializer):
}
)
serializer_instance = config_serializer(data=configuration)
serializer_instance.is_valid(raise_exception=True)
# Apply the validated (and potentially transformed) data back to configuration
configuration.update(serializer_instance.validated_data)
config_serializer(data=configuration).is_valid(raise_exception=True)
for cred_serializer in credentials_serializers:
try:
@@ -2068,6 +1975,7 @@ class IntegrationCreateSerializer(BaseWriteIntegrationSerializer):
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
"connected": {"read_only": True},
"enabled": {"read_only": True},
"connection_last_checked_at": {"read_only": True},
}
@@ -2077,10 +1985,10 @@ class IntegrationCreateSerializer(BaseWriteIntegrationSerializer):
configuration = attrs.get("configuration")
credentials = attrs.get("credentials")
validated_attrs = super().validate(attrs)
self.validate_integration_data(
integration_type, providers, configuration, credentials
)
validated_attrs = super().validate(attrs)
return validated_attrs
def create(self, validated_data):
@@ -2131,7 +2039,6 @@ class IntegrationUpdateSerializer(BaseWriteIntegrationSerializer):
}
def validate(self, attrs):
super().validate(attrs)
integration_type = self.instance.integration_type
providers = attrs.get("providers")
configuration = attrs.get("configuration") or self.instance.configuration
@@ -2158,128 +2065,6 @@ class IntegrationUpdateSerializer(BaseWriteIntegrationSerializer):
return super().update(instance, validated_data)
# Processors
class ProcessorSerializer(RLSSerializer):
"""
Serializer for the Processor model.
"""
configuration = ProcessorConfigField()
class Meta:
model = Processor
fields = [
"id",
"inserted_at",
"updated_at",
"processor_type",
"configuration",
"url",
]
class ProcessorCreateSerializer(RLSSerializer, BaseWriteSerializer):
configuration = ProcessorConfigField(required=True)
class Meta:
model = Processor
fields = [
"inserted_at",
"updated_at",
"processor_type",
"configuration",
]
extra_kwargs = {
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
}
validators = [
UniqueTogetherValidator(
queryset=Processor.objects.all(),
fields=["processor_type"],
message="A processor with the same type already exists.",
)
]
def validate(self, attrs):
validated_attrs = super().validate(attrs)
self.validate_processor_data(attrs)
return validated_attrs
def validate_processor_data(self, attrs):
processor_type = attrs.get("processor_type")
configuration = attrs.get("configuration")
if processor_type == "mutelist":
self.validate_mutelist_configuration(configuration)
def validate_mutelist_configuration(self, configuration):
if not isinstance(configuration, dict):
raise serializers.ValidationError("Invalid Mutelist configuration.")
mutelist_configuration = configuration.get("Mutelist", {})
if not mutelist_configuration:
raise serializers.ValidationError(
"Invalid Mutelist configuration: 'Mutelist' is a required property."
)
try:
Mutelist.validate_mutelist(mutelist_configuration, raise_on_exception=True)
return
except Exception as error:
raise serializers.ValidationError(
f"Invalid Mutelist configuration: {error}"
)
class ProcessorUpdateSerializer(BaseWriteSerializer):
configuration = ProcessorConfigField(required=True)
class Meta:
model = Processor
fields = [
"inserted_at",
"updated_at",
"configuration",
]
extra_kwargs = {
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
}
def validate(self, attrs):
validated_attrs = super().validate(attrs)
self.validate_processor_data(attrs)
return validated_attrs
def validate_processor_data(self, attrs):
processor_type = self.instance.processor_type
configuration = attrs.get("configuration")
if processor_type == "mutelist":
self.validate_mutelist_configuration(configuration)
def validate_mutelist_configuration(self, configuration):
if not isinstance(configuration, dict):
raise serializers.ValidationError("Invalid Mutelist configuration.")
mutelist_configuration = configuration.get("Mutelist", {})
if not mutelist_configuration:
raise serializers.ValidationError(
"Invalid Mutelist configuration: 'Mutelist' is a required property."
)
try:
Mutelist.validate_mutelist(mutelist_configuration, raise_on_exception=True)
return
except Exception as error:
raise serializers.ValidationError(
f"Invalid Mutelist configuration: {error}"
)
# SSO
+3 -27
View File
@@ -1,11 +1,9 @@
from allauth.socialaccount.providers.saml.views import ACSView, MetadataView, SLSView
from django.urls import include, path
from drf_spectacular.views import SpectacularRedocView
from rest_framework_nested import routers
from api.v1.views import (
ComplianceOverviewViewSet,
CustomSAMLLoginView,
CustomTokenObtainView,
CustomTokenRefreshView,
CustomTokenSwitchTenantView,
@@ -18,7 +16,6 @@ from api.v1.views import (
LighthouseConfigViewSet,
MembershipViewSet,
OverviewViewSet,
ProcessorViewSet,
ProviderGroupProvidersRelationshipView,
ProviderGroupViewSet,
ProviderSecretViewSet,
@@ -28,7 +25,6 @@ from api.v1.views import (
RoleViewSet,
SAMLConfigurationViewSet,
SAMLInitiateAPIView,
SAMLTokenValidateView,
ScanViewSet,
ScheduleViewSet,
SchemaView,
@@ -57,7 +53,6 @@ router.register(
router.register(r"overviews", OverviewViewSet, basename="overview")
router.register(r"schedules", ScheduleViewSet, basename="schedule")
router.register(r"integrations", IntegrationViewSet, basename="integration")
router.register(r"processors", ProcessorViewSet, basename="processor")
router.register(r"saml-config", SAMLConfigurationViewSet, basename="saml-config")
router.register(
r"lighthouse-configurations",
@@ -131,32 +126,13 @@ urlpatterns = [
path(
"auth/saml/initiate/", SAMLInitiateAPIView.as_view(), name="api_saml_initiate"
),
# Allauth SAML endpoints for tenants
path("accounts/", include("allauth.urls")),
path(
"accounts/saml/<organization_slug>/login/",
CustomSAMLLoginView.as_view(),
name="saml_login",
),
path(
"accounts/saml/<organization_slug>/acs/",
ACSView.as_view(),
name="saml_acs",
),
path(
"accounts/saml/<organization_slug>/acs/finish/",
"api/v1/accounts/saml/<organization_slug>/acs/finish/",
TenantFinishACSView.as_view(),
name="saml_finish_acs",
),
path(
"accounts/saml/<organization_slug>/sls/",
SLSView.as_view(),
name="saml_sls",
),
path(
"accounts/saml/<organization_slug>/metadata/",
MetadataView.as_view(),
name="saml_metadata",
),
path("tokens/saml", SAMLTokenValidateView.as_view(), name="token-saml"),
path("tokens/google", GoogleSocialLoginView.as_view(), name="token-google"),
path("tokens/github", GithubSocialLoginView.as_view(), name="token-github"),
path("", include(router.urls)),
+89 -589
View File
@@ -1,17 +1,14 @@
import glob
import logging
import os
from datetime import datetime, timedelta, timezone
from urllib.parse import urljoin
import sentry_sdk
from allauth.socialaccount.models import SocialAccount, SocialApp
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
from allauth.socialaccount.providers.saml.views import FinishACSView, LoginView
from allauth.socialaccount.providers.saml.views import FinishACSView
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
from celery.result import AsyncResult
from config.custom_logging import BackendLogger
from config.env import env
from config.settings.social_login import (
GITHUB_OAUTH_CALLBACK_URL,
@@ -22,9 +19,9 @@ from django.conf import settings as django_settings
from django.contrib.postgres.aggregates import ArrayAgg
from django.contrib.postgres.search import SearchQuery
from django.db import transaction
from django.db.models import Count, F, Prefetch, Q, Subquery, Sum
from django.db.models import Count, Exists, F, OuterRef, Prefetch, Q, Sum
from django.db.models.functions import Coalesce
from django.http import HttpResponse
from django.http import HttpResponse, JsonResponse
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.dateparse import parse_date
@@ -57,7 +54,6 @@ from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
backfill_scan_resource_summaries_task,
check_integration_connection_task,
check_lighthouse_connection_task,
check_provider_connection_task,
delete_provider_task,
@@ -79,9 +75,7 @@ from api.filters import (
IntegrationFilter,
InvitationFilter,
LatestFindingFilter,
LatestResourceFilter,
MembershipFilter,
ProcessorFilter,
ProviderFilter,
ProviderGroupFilter,
ProviderSecretFilter,
@@ -95,13 +89,13 @@ from api.filters import (
UserFilter,
)
from api.models import (
ComplianceOverview,
ComplianceRequirementOverview,
Finding,
Integration,
Invitation,
LighthouseConfiguration,
Membership,
Processor,
Provider,
ProviderGroup,
ProviderGroupMembership,
@@ -109,12 +103,10 @@ from api.models import (
Resource,
ResourceFindingMapping,
ResourceScanSummary,
ResourceTag,
Role,
RoleProviderGroupRelationship,
SAMLConfiguration,
SAMLDomainIndex,
SAMLToken,
Scan,
ScanSummary,
SeverityChoices,
@@ -156,9 +148,6 @@ from api.v1.serializers import (
OverviewProviderSerializer,
OverviewServiceSerializer,
OverviewSeveritySerializer,
ProcessorCreateSerializer,
ProcessorSerializer,
ProcessorUpdateSerializer,
ProviderCreateSerializer,
ProviderGroupCreateSerializer,
ProviderGroupMembershipSerializer,
@@ -169,7 +158,6 @@ from api.v1.serializers import (
ProviderSecretUpdateSerializer,
ProviderSerializer,
ProviderUpdateSerializer,
ResourceMetadataSerializer,
ResourceSerializer,
RoleCreateSerializer,
RoleProviderGroupRelationshipSerializer,
@@ -195,8 +183,6 @@ from api.v1.serializers import (
UserUpdateSerializer,
)
logger = logging.getLogger(BackendLogger.API)
CACHE_DECORATOR = cache_control(
max_age=django_settings.CACHE_MAX_AGE,
stale_while_revalidate=django_settings.CACHE_STALE_WHILE_REVALIDATE,
@@ -293,7 +279,7 @@ class SchemaView(SpectacularAPIView):
def get(self, request, *args, **kwargs):
spectacular_settings.TITLE = "Prowler API"
spectacular_settings.VERSION = "1.11.2"
spectacular_settings.VERSION = "1.9.0"
spectacular_settings.DESCRIPTION = (
"Prowler API specification.\n\nThis file is auto-generated."
)
@@ -359,11 +345,6 @@ class SchemaView(SpectacularAPIView):
"description": "Endpoints for managing Lighthouse configurations, including creation, retrieval, "
"updating, and deletion of configurations such as OpenAI keys, models, and business context.",
},
{
"name": "Processor",
"description": "Endpoints for managing post-processors used to process Prowler findings, including "
"registration, configuration, and deletion of post-processing actions.",
},
]
return super().get(request, *args, **kwargs)
@@ -420,68 +401,17 @@ class GithubSocialLoginView(SocialLoginView):
return original_response
@extend_schema(exclude=True)
class SAMLTokenValidateView(GenericAPIView):
resource_name = "tokens"
http_method_names = ["post"]
def post(self, request):
token_id = request.query_params.get("id", "invalid")
try:
saml_token = SAMLToken.objects.using(MainRouter.admin_db).get(id=token_id)
except SAMLToken.DoesNotExist:
return Response({"detail": "Invalid token ID."}, status=404)
if saml_token.is_expired():
return Response({"detail": "Token expired."}, status=400)
token_data = saml_token.token
# Currently we don't store the tokens in the database, so we delete the token after use
saml_token.delete()
return Response(token_data, status=200)
@extend_schema(exclude=True)
class CustomSAMLLoginView(LoginView):
def dispatch(self, request, *args, **kwargs):
"""
Convert GET requests to POST to bypass allauth's confirmation screen.
Why this is necessary:
- django-allauth requires POST for social logins to prevent open redirect attacks
- SAML login links typically use GET requests (e.g., <a href="...">)
- This conversion allows seamless login without user-facing confirmation
Security considerations:
1. Preserves CSRF protection: Original POST handling remains intact
2. Avoids global SOCIALACCOUNT_LOGIN_ON_GET=True which would:
- Enable GET logins for ALL providers (security risk)
- Potentially expose open redirect vulnerabilities
3. SAML payloads remain signed/encrypted regardless of HTTP method
4. No sensitive parameters are exposed in URLs (copied to POST body)
This approach maintains security while providing better UX.
"""
if request.method == "GET":
# Convert GET to POST while preserving parameters
request.method = "POST"
return super().dispatch(request, *args, **kwargs)
@extend_schema(exclude=True)
class SAMLInitiateAPIView(GenericAPIView):
serializer_class = SamlInitiateSerializer
permission_classes = []
def post(self, request, *args, **kwargs):
# Validate the input payload and extract the domain
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
email = serializer.validated_data["email_domain"]
domain = email.split("@", 1)[-1].lower()
# Retrieve the SAML configuration for the given email domain
try:
check = SAMLDomainIndex.objects.get(email_domain=domain)
with rls_transaction(str(check.tenant_id)):
@@ -491,24 +421,20 @@ class SAMLInitiateAPIView(GenericAPIView):
{"detail": "Unauthorized domain."}, status=status.HTTP_403_FORBIDDEN
)
# Check certificates are not empty (TODO: Validate certificates)
# saml_public_cert = os.getenv("SAML_PUBLIC_CERT", "").strip()
# saml_private_key = os.getenv("SAML_PRIVATE_KEY", "").strip()
# Check certificates are not empty
saml_public_cert = os.getenv("SAML_PUBLIC_CERT", "").strip()
saml_private_key = os.getenv("SAML_PRIVATE_KEY", "").strip()
# if not saml_public_cert or not saml_private_key:
# return Response(
# {"detail": "SAML configuration is invalid: missing certificates."},
# status=status.HTTP_403_FORBIDDEN,
# )
if not saml_public_cert or not saml_private_key:
return Response(
{"detail": "SAML configuration is invalid: missing certificates."},
status=status.HTTP_403_FORBIDDEN,
)
# Build the SAML login URL using the configured API host
api_host = os.getenv("API_BASE_URL")
login_path = reverse(
saml_login_url = reverse(
"saml_login", kwargs={"organization_slug": config.email_domain}
)
login_url = urljoin(api_host, login_path)
return redirect(login_url)
return redirect(f"{saml_login_url}?email={email}")
@extend_schema_view(
@@ -566,52 +492,21 @@ class SAMLConfigurationViewSet(BaseRLSViewSet):
class TenantFinishACSView(FinishACSView):
def _rollback_saml_user(self, request):
"""Helper function to rollback SAML user if it was just created and validation fails"""
saml_user_id = request.session.get("saml_user_created")
if saml_user_id:
User.objects.using(MainRouter.admin_db).filter(id=saml_user_id).delete()
request.session.pop("saml_user_created", None)
def dispatch(self, request, organization_slug):
try:
super().dispatch(request, organization_slug)
except Exception as e:
logger.error(f"SAML dispatch failed: {e}")
self._rollback_saml_user(request)
callback_url = env.str("AUTH_URL")
return redirect(f"{callback_url}?sso_saml_failed=true")
response = super().dispatch(request, organization_slug)
user = getattr(request, "user", None)
if not user or not user.is_authenticated:
self._rollback_saml_user(request)
callback_url = env.str("AUTH_URL")
return redirect(f"{callback_url}?sso_saml_failed=true")
return response
# Defensive check to avoid edge case failures due to inconsistent or incomplete data in the database
# This handles scenarios like partially deleted or missing related objects
try:
check = SAMLDomainIndex.objects.get(email_domain=organization_slug)
with rls_transaction(str(check.tenant_id)):
SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
social_app = SocialApp.objects.get(
provider="saml", client_id=organization_slug
)
user_id = User.objects.get(email=str(user)).id
social_account = SocialAccount.objects.get(
user=str(user_id), provider=social_app.provider_id
user=user, provider=social_app.provider
)
except (
SAMLDomainIndex.DoesNotExist,
SAMLConfiguration.DoesNotExist,
SocialApp.DoesNotExist,
SocialAccount.DoesNotExist,
User.DoesNotExist,
) as e:
logger.error(f"SAML user is not authenticated: {e}")
self._rollback_saml_user(request)
callback_url = env.str("AUTH_URL")
return redirect(f"{callback_url}?sso_saml_failed=true")
except (SocialApp.DoesNotExist, SocialAccount.DoesNotExist):
return response
extra = social_account.extra_data
user.first_name = (
@@ -633,9 +528,9 @@ class TenantFinishACSView(FinishACSView):
.tenant
)
role_name = (
extra.get("userType", ["no_permissions"])[0].strip()
extra.get("userType", ["saml_default_role"])[0].strip()
if extra.get("userType")
else "no_permissions"
else "saml_default_role"
)
try:
role = Role.objects.using(MainRouter.admin_db).get(
@@ -662,30 +557,15 @@ class TenantFinishACSView(FinishACSView):
role=role,
tenant_id=tenant.id,
)
membership, _ = Membership.objects.using(MainRouter.admin_db).get_or_create(
user=user,
tenant=tenant,
defaults={
"user": user,
"tenant": tenant,
"role": Membership.RoleChoices.MEMBER,
},
)
serializer = TokenSocialLoginSerializer(
data={"email": user.email, "tenant_id": str(tenant.id)}
)
serializer = TokenSocialLoginSerializer(data={"email": user.email})
serializer.is_valid(raise_exception=True)
token_data = serializer.validated_data
saml_token = SAMLToken.objects.using(MainRouter.admin_db).create(
token=token_data, user=user
return JsonResponse(
{
"type": "saml-social-tokens",
"attributes": serializer.validated_data,
}
)
callback_url = env.str("SAML_SSO_CALLBACK_URL")
redirect_url = f"{callback_url}?id={saml_token.id}"
request.session.pop("saml_user_created", None)
return redirect(redirect_url)
@extend_schema_view(
@@ -1886,14 +1766,6 @@ class TaskViewSet(BaseRLSViewSet):
summary="List all resources",
description="Retrieve a list of all resources with options for filtering by various criteria. Resources are "
"objects that are discovered by Prowler. They can be anything from a single host to a whole VPC.",
parameters=[
OpenApiParameter(
name="filter[updated_at]",
description="At least one of the variations of the `filter[updated_at]` filter must be provided.",
required=True,
type=OpenApiTypes.DATE,
)
],
),
retrieve=extend_schema(
tags=["Resource"],
@@ -1901,43 +1773,15 @@ class TaskViewSet(BaseRLSViewSet):
description="Fetch detailed information about a specific resource by their ID. A Resource is an object that "
"is discovered by Prowler. It can be anything from a single host to a whole VPC.",
),
metadata=extend_schema(
tags=["Resource"],
summary="Retrieve metadata values from resources",
description="Fetch unique metadata values from a set of resources. This is useful for dynamic filtering.",
parameters=[
OpenApiParameter(
name="filter[updated_at]",
description="At least one of the variations of the `filter[updated_at]` filter must be provided.",
required=True,
type=OpenApiTypes.DATE,
)
],
filters=True,
),
latest=extend_schema(
tags=["Resource"],
summary="List the latest resources",
description="Retrieve a list of the latest resources from the latest scans for each provider with options for "
"filtering by various criteria.",
filters=True,
),
metadata_latest=extend_schema(
tags=["Resource"],
summary="Retrieve metadata values from the latest resources",
description="Fetch unique metadata values from a set of resources from the latest scans for each provider. "
"This is useful for dynamic filtering.",
filters=True,
),
)
@method_decorator(CACHE_DECORATOR, name="list")
@method_decorator(CACHE_DECORATOR, name="retrieve")
class ResourceViewSet(PaginateByPkMixin, BaseRLSViewSet):
queryset = Resource.all_objects.all()
class ResourceViewSet(BaseRLSViewSet):
queryset = Resource.objects.all()
serializer_class = ResourceSerializer
http_method_names = ["get"]
filterset_class = ResourceFilter
ordering = ["-failed_findings_count", "-updated_at"]
ordering = ["-inserted_at"]
ordering_fields = [
"provider_uid",
"uid",
@@ -1948,14 +1792,6 @@ class ResourceViewSet(PaginateByPkMixin, BaseRLSViewSet):
"inserted_at",
"updated_at",
]
prefetch_for_includes = {
"__all__": [],
"provider": [
Prefetch(
"provider", queryset=Provider.all_objects.select_related("resources")
)
],
}
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
# the provider through the provider group)
required_permissions = []
@@ -1964,284 +1800,41 @@ class ResourceViewSet(PaginateByPkMixin, BaseRLSViewSet):
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all scans
queryset = Resource.all_objects.filter(tenant_id=self.request.tenant_id)
queryset = Resource.objects.filter(tenant_id=self.request.tenant_id)
else:
# User lacks permission, filter providers based on provider groups associated with the role
queryset = Resource.all_objects.filter(
queryset = Resource.objects.filter(
tenant_id=self.request.tenant_id, provider__in=get_providers(user_roles)
)
search_value = self.request.query_params.get("filter[search]", None)
if search_value:
# Django's ORM will build a LEFT JOIN and OUTER JOIN on the "through" table, resulting in duplicates
# The duplicates then require a `distinct` query
search_query = SearchQuery(
search_value, config="simple", search_type="plain"
)
queryset = queryset.filter(
Q(text_search=search_query) | Q(tags__text_search=search_query)
Q(tags__key=search_value)
| Q(tags__value=search_value)
| Q(tags__text_search=search_query)
| Q(tags__key__contains=search_value)
| Q(tags__value__contains=search_value)
| Q(uid=search_value)
| Q(name=search_value)
| Q(region=search_value)
| Q(service=search_value)
| Q(type=search_value)
| Q(text_search=search_query)
| Q(uid__contains=search_value)
| Q(name__contains=search_value)
| Q(region__contains=search_value)
| Q(service__contains=search_value)
| Q(type__contains=search_value)
).distinct()
return queryset
def _optimize_tags_loading(self, queryset):
"""Optimize tags loading with prefetch_related to avoid N+1 queries"""
# Use prefetch_related to load all tags in a single query
return queryset.prefetch_related(
Prefetch(
"tags",
queryset=ResourceTag.objects.filter(
tenant_id=self.request.tenant_id
).select_related(),
to_attr="prefetched_tags",
)
)
def _should_prefetch_findings(self) -> bool:
fields_param = self.request.query_params.get("fields[resources]", "")
include_param = self.request.query_params.get("include", "")
return (
fields_param == ""
or "findings" in fields_param.split(",")
or "findings" in include_param.split(",")
)
def _get_findings_prefetch(self):
findings_queryset = Finding.all_objects.defer("scan", "resources").filter(
tenant_id=self.request.tenant_id
)
return [Prefetch("findings", queryset=findings_queryset)]
def get_serializer_class(self):
if self.action in ["metadata", "metadata_latest"]:
return ResourceMetadataSerializer
return super().get_serializer_class()
def get_filterset_class(self):
if self.action in ["latest", "metadata_latest"]:
return LatestResourceFilter
return ResourceFilter
def filter_queryset(self, queryset):
# Do not apply filters when retrieving specific resource
if self.action == "retrieve":
return queryset
return super().filter_queryset(queryset)
def list(self, request, *args, **kwargs):
filtered_queryset = self.filter_queryset(self.get_queryset())
return self.paginate_by_pk(
request,
filtered_queryset,
manager=Resource.all_objects,
select_related=["provider"],
prefetch_related=(
self._get_findings_prefetch()
if self._should_prefetch_findings()
else []
),
)
def retrieve(self, request, *args, **kwargs):
queryset = self._optimize_tags_loading(self.get_queryset())
instance = get_object_or_404(queryset, pk=kwargs.get("pk"))
mapping_ids = list(
ResourceFindingMapping.objects.filter(
resource=instance, tenant_id=request.tenant_id
).values_list("finding_id", flat=True)
)
latest_findings = (
Finding.all_objects.filter(id__in=mapping_ids, tenant_id=request.tenant_id)
.order_by("uid", "-inserted_at")
.distinct("uid")
)
setattr(instance, "latest_findings", latest_findings)
serializer = self.get_serializer(instance)
return Response(serializer.data, status=status.HTTP_200_OK)
@action(detail=False, methods=["get"], url_name="latest")
def latest(self, request):
tenant_id = request.tenant_id
filtered_queryset = self.filter_queryset(self.get_queryset())
latest_scans = (
Scan.all_objects.filter(
tenant_id=tenant_id,
state=StateChoices.COMPLETED,
)
.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
.values("provider_id")
)
filtered_queryset = filtered_queryset.filter(
provider_id__in=Subquery(latest_scans)
)
return self.paginate_by_pk(
request,
filtered_queryset,
manager=Resource.all_objects,
select_related=["provider"],
prefetch_related=(
self._get_findings_prefetch()
if self._should_prefetch_findings()
else []
),
)
@action(detail=False, methods=["get"], url_name="metadata")
def metadata(self, request):
# Force filter validation
self.filter_queryset(self.get_queryset())
tenant_id = request.tenant_id
query_params = request.query_params
queryset = ResourceScanSummary.objects.filter(tenant_id=tenant_id)
if scans := query_params.get("filter[scan__in]") or query_params.get(
"filter[scan]"
):
queryset = queryset.filter(scan_id__in=scans.split(","))
else:
exact = query_params.get("filter[inserted_at]")
gte = query_params.get("filter[inserted_at__gte]")
lte = query_params.get("filter[inserted_at__lte]")
date_filters = {}
if exact:
date = parse_date(exact)
datetime_start = datetime.combine(
date, datetime.min.time(), tzinfo=timezone.utc
)
datetime_end = datetime_start + timedelta(days=1)
date_filters["scan_id__gte"] = uuid7_start(
datetime_to_uuid7(datetime_start)
)
date_filters["scan_id__lt"] = uuid7_start(
datetime_to_uuid7(datetime_end)
)
else:
if gte:
date_start = parse_date(gte)
datetime_start = datetime.combine(
date_start, datetime.min.time(), tzinfo=timezone.utc
)
date_filters["scan_id__gte"] = uuid7_start(
datetime_to_uuid7(datetime_start)
)
if lte:
date_end = parse_date(lte)
datetime_end = datetime.combine(
date_end + timedelta(days=1),
datetime.min.time(),
tzinfo=timezone.utc,
)
date_filters["scan_id__lt"] = uuid7_start(
datetime_to_uuid7(datetime_end)
)
if date_filters:
queryset = queryset.filter(**date_filters)
if service_filter := query_params.get("filter[service]") or query_params.get(
"filter[service__in]"
):
queryset = queryset.filter(service__in=service_filter.split(","))
if region_filter := query_params.get("filter[region]") or query_params.get(
"filter[region__in]"
):
queryset = queryset.filter(region__in=region_filter.split(","))
if resource_type_filter := query_params.get("filter[type]") or query_params.get(
"filter[type__in]"
):
queryset = queryset.filter(
resource_type__in=resource_type_filter.split(",")
)
services = list(
queryset.values_list("service", flat=True).distinct().order_by("service")
)
regions = list(
queryset.values_list("region", flat=True).distinct().order_by("region")
)
resource_types = list(
queryset.values_list("resource_type", flat=True)
.exclude(resource_type__isnull=True)
.exclude(resource_type__exact="")
.distinct()
.order_by("resource_type")
)
result = {
"services": services,
"regions": regions,
"types": resource_types,
}
serializer = self.get_serializer(data=result)
serializer.is_valid(raise_exception=True)
return Response(serializer.data)
@action(
detail=False,
methods=["get"],
url_name="metadata_latest",
url_path="metadata/latest",
)
def metadata_latest(self, request):
tenant_id = request.tenant_id
query_params = request.query_params
latest_scans_queryset = (
Scan.all_objects.filter(tenant_id=tenant_id, state=StateChoices.COMPLETED)
.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
)
queryset = ResourceScanSummary.objects.filter(
tenant_id=tenant_id,
scan_id__in=latest_scans_queryset.values_list("id", flat=True),
)
if service_filter := query_params.get("filter[service]") or query_params.get(
"filter[service__in]"
):
queryset = queryset.filter(service__in=service_filter.split(","))
if region_filter := query_params.get("filter[region]") or query_params.get(
"filter[region__in]"
):
queryset = queryset.filter(region__in=region_filter.split(","))
if resource_type_filter := query_params.get("filter[type]") or query_params.get(
"filter[type__in]"
):
queryset = queryset.filter(
resource_type__in=resource_type_filter.split(",")
)
services = list(
queryset.values_list("service", flat=True).distinct().order_by("service")
)
regions = list(
queryset.values_list("region", flat=True).distinct().order_by("region")
)
resource_types = list(
queryset.values_list("resource_type", flat=True)
.exclude(resource_type__isnull=True)
.exclude(resource_type__exact="")
.distinct()
.order_by("resource_type")
)
result = {
"services": services,
"regions": regions,
"types": resource_types,
}
serializer = self.get_serializer(data=result)
serializer.is_valid(raise_exception=True)
return Response(serializer.data)
@extend_schema_view(
list=extend_schema(
@@ -2360,7 +1953,17 @@ class FindingViewSet(PaginateByPkMixin, BaseRLSViewSet):
search_value, config="simple", search_type="plain"
)
queryset = queryset.filter(text_search=search_query)
resource_match = Resource.all_objects.filter(
text_search=search_query,
id__in=ResourceFindingMapping.objects.filter(
resource_id=OuterRef("pk"),
tenant_id=tenant_id,
).values("resource_id"),
)
queryset = queryset.filter(
Q(text_search=search_query) | Q(Exists(resource_match))
)
return queryset
@@ -2463,12 +2066,9 @@ class FindingViewSet(PaginateByPkMixin, BaseRLSViewSet):
# ToRemove: Temporary fallback mechanism
if not queryset.exists():
raw_scans_ids = Scan.objects.filter(
scan_ids = Scan.objects.filter(
tenant_id=tenant_id, **scan_based_filters
).values_list("id", "unique_resource_count")
scan_ids = [
scan_id for scan_id, count in raw_scans_ids if count and count > 0
]
).values_list("id", flat=True)
for scan_id in scan_ids:
backfill_scan_resource_summaries_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
@@ -2554,12 +2154,7 @@ class FindingViewSet(PaginateByPkMixin, BaseRLSViewSet):
.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
)
raw_latest_scans_ids = list(
latest_scans_queryset.values_list("id", "unique_resource_count")
)
latest_scans_ids = [
scan_id for scan_id, count in raw_latest_scans_ids if count and count > 0
]
latest_scans_ids = list(latest_scans_queryset.values_list("id", flat=True))
queryset = ResourceScanSummary.objects.filter(
tenant_id=tenant_id,
@@ -3455,7 +3050,7 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
@extend_schema(tags=["Overview"])
@extend_schema_view(
providers=extend_schema(
list=extend_schema(
summary="Get aggregated provider data",
description=(
"Retrieve an aggregated overview of findings and resources grouped by providers. "
@@ -3496,7 +3091,7 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
)
@method_decorator(CACHE_DECORATOR, name="list")
class OverviewViewSet(BaseRLSViewSet):
queryset = ScanSummary.objects.all()
queryset = ComplianceOverview.objects.all()
http_method_names = ["get"]
ordering = ["-inserted_at"]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
@@ -3507,10 +3102,19 @@ class OverviewViewSet(BaseRLSViewSet):
role = get_role(self.request.user)
providers = get_providers(role)
if not role.unlimited_visibility:
self.allowed_providers = providers
def _get_filtered_queryset(model):
if role.unlimited_visibility:
return model.all_objects.filter(tenant_id=self.request.tenant_id)
return model.all_objects.filter(
tenant_id=self.request.tenant_id, scan__provider__in=providers
)
return ScanSummary.all_objects.filter(tenant_id=self.request.tenant_id)
if self.action == "providers":
return _get_filtered_queryset(Finding)
elif self.action in ("findings", "findings_severity", "services"):
return _get_filtered_queryset(ScanSummary)
else:
return super().get_queryset()
def get_serializer_class(self):
if self.action == "providers":
@@ -3543,24 +3147,18 @@ class OverviewViewSet(BaseRLSViewSet):
@action(detail=False, methods=["get"], url_name="providers")
def providers(self, request):
tenant_id = self.request.tenant_id
queryset = self.get_queryset()
provider_filter = (
{"provider__in": self.allowed_providers}
if hasattr(self, "allowed_providers")
else {}
)
latest_scan_ids = (
Scan.all_objects.filter(
tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter
)
Scan.all_objects.filter(tenant_id=tenant_id, state=StateChoices.COMPLETED)
.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
.values_list("id", flat=True)
)
findings_aggregated = (
queryset.filter(scan_id__in=latest_scan_ids)
ScanSummary.all_objects.filter(
tenant_id=tenant_id, scan_id__in=latest_scan_ids
)
.values(
"scan__provider_id",
provider=F("scan__provider__provider"),
@@ -3596,7 +3194,7 @@ class OverviewViewSet(BaseRLSViewSet):
)
return Response(
self.get_serializer(overview, many=True).data,
OverviewProviderSerializer(overview, many=True).data,
status=status.HTTP_200_OK,
)
@@ -3605,16 +3203,9 @@ class OverviewViewSet(BaseRLSViewSet):
tenant_id = self.request.tenant_id
queryset = self.get_queryset()
filtered_queryset = self.filter_queryset(queryset)
provider_filter = (
{"provider__in": self.allowed_providers}
if hasattr(self, "allowed_providers")
else {}
)
latest_scan_ids = (
Scan.all_objects.filter(
tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter
)
Scan.all_objects.filter(tenant_id=tenant_id, state=StateChoices.COMPLETED)
.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
.values_list("id", flat=True)
@@ -3651,16 +3242,9 @@ class OverviewViewSet(BaseRLSViewSet):
tenant_id = self.request.tenant_id
queryset = self.get_queryset()
filtered_queryset = self.filter_queryset(queryset)
provider_filter = (
{"provider__in": self.allowed_providers}
if hasattr(self, "allowed_providers")
else {}
)
latest_scan_ids = (
Scan.all_objects.filter(
tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter
)
Scan.all_objects.filter(tenant_id=tenant_id, state=StateChoices.COMPLETED)
.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
.values_list("id", flat=True)
@@ -3680,7 +3264,7 @@ class OverviewViewSet(BaseRLSViewSet):
for item in severity_counts:
severity_data[item["severity"]] = item["count"]
serializer = self.get_serializer(severity_data)
serializer = OverviewSeveritySerializer(severity_data)
return Response(serializer.data, status=status.HTTP_200_OK)
@action(detail=False, methods=["get"], url_name="services")
@@ -3688,16 +3272,9 @@ class OverviewViewSet(BaseRLSViewSet):
tenant_id = self.request.tenant_id
queryset = self.get_queryset()
filtered_queryset = self.filter_queryset(queryset)
provider_filter = (
{"provider__in": self.allowed_providers}
if hasattr(self, "allowed_providers")
else {}
)
latest_scan_ids = (
Scan.all_objects.filter(
tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter
)
Scan.all_objects.filter(tenant_id=tenant_id, state=StateChoices.COMPLETED)
.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
.values_list("id", flat=True)
@@ -3715,7 +3292,7 @@ class OverviewViewSet(BaseRLSViewSet):
.order_by("service")
)
serializer = self.get_serializer(services_data, many=True)
serializer = OverviewServiceSerializer(services_data, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@@ -3839,32 +3416,6 @@ class IntegrationViewSet(BaseRLSViewSet):
context["allowed_providers"] = self.allowed_providers
return context
@extend_schema(
tags=["Integration"],
summary="Check integration connection",
description="Try to verify integration connection",
request=None,
responses={202: OpenApiResponse(response=TaskSerializer)},
)
@action(detail=True, methods=["post"], url_name="connection")
def connection(self, request, pk=None):
get_object_or_404(Integration, pk=pk)
with transaction.atomic():
task = check_integration_connection_task.delay(
integration_id=pk, tenant_id=self.request.tenant_id
)
prowler_task = Task.objects.get(id=task.id)
serializer = TaskSerializer(prowler_task)
return Response(
data=serializer.data,
status=status.HTTP_202_ACCEPTED,
headers={
"Content-Location": reverse(
"task-detail", kwargs={"pk": prowler_task.id}
)
},
)
@extend_schema_view(
list=extend_schema(
@@ -3941,54 +3492,3 @@ class LighthouseConfigViewSet(BaseRLSViewSet):
)
},
)
@extend_schema_view(
list=extend_schema(
tags=["Processor"],
summary="List all processors",
description="Retrieve a list of all configured processors with options for filtering by various criteria.",
),
retrieve=extend_schema(
tags=["Processor"],
summary="Retrieve processor details",
description="Fetch detailed information about a specific processor by its ID.",
),
create=extend_schema(
tags=["Processor"],
summary="Create a new processor",
description="Register a new processor with the system, providing necessary configuration details. There can "
"only be one processor of each type per tenant.",
),
partial_update=extend_schema(
tags=["Processor"],
summary="Partially update a processor",
description="Modify certain fields of an existing processor without affecting other settings.",
),
destroy=extend_schema(
tags=["Processor"],
summary="Delete a processor",
description="Remove a processor from the system by its ID.",
),
)
@method_decorator(CACHE_DECORATOR, name="list")
@method_decorator(CACHE_DECORATOR, name="retrieve")
class ProcessorViewSet(BaseRLSViewSet):
queryset = Processor.objects.all()
serializer_class = ProcessorSerializer
http_method_names = ["get", "post", "patch", "delete"]
filterset_class = ProcessorFilter
ordering = ["processor_type", "-inserted_at"]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
queryset = Processor.objects.filter(tenant_id=self.request.tenant_id)
return queryset
def get_serializer_class(self):
if self.action == "create":
return ProcessorCreateSerializer
elif self.action == "partial_update":
return ProcessorUpdateSerializer
return super().get_serializer_class()
-88
View File
@@ -1,5 +1,3 @@
import string
from django.core.exceptions import ValidationError
from django.utils.translation import gettext as _
@@ -22,89 +20,3 @@ class MaximumLengthValidator:
return _(
f"Your password must contain no more than {self.max_length} characters."
)
class SpecialCharactersValidator:
def __init__(self, special_characters=None, min_special_characters=1):
# Use string.punctuation if no custom characters provided
self.special_characters = special_characters or string.punctuation
self.min_special_characters = min_special_characters
def validate(self, password, user=None):
if (
sum(1 for char in password if char in self.special_characters)
< self.min_special_characters
):
raise ValidationError(
_("This password must contain at least one special character."),
code="password_no_special_characters",
params={
"special_characters": self.special_characters,
"min_special_characters": self.min_special_characters,
},
)
def get_help_text(self):
return _(
f"Your password must contain at least one special character from: {self.special_characters}"
)
class UppercaseValidator:
def __init__(self, min_uppercase=1):
self.min_uppercase = min_uppercase
def validate(self, password, user=None):
if sum(1 for char in password if char.isupper()) < self.min_uppercase:
raise ValidationError(
_(
"This password must contain at least %(min_uppercase)d uppercase letter."
),
code="password_no_uppercase_letters",
params={"min_uppercase": self.min_uppercase},
)
def get_help_text(self):
return _(
f"Your password must contain at least {self.min_uppercase} uppercase letter."
)
class LowercaseValidator:
def __init__(self, min_lowercase=1):
self.min_lowercase = min_lowercase
def validate(self, password, user=None):
if sum(1 for char in password if char.islower()) < self.min_lowercase:
raise ValidationError(
_(
"This password must contain at least %(min_lowercase)d lowercase letter."
),
code="password_no_lowercase_letters",
params={"min_lowercase": self.min_lowercase},
)
def get_help_text(self):
return _(
f"Your password must contain at least {self.min_lowercase} lowercase letter."
)
class NumericValidator:
def __init__(self, min_numeric=1):
self.min_numeric = min_numeric
def validate(self, password, user=None):
if sum(1 for char in password if char.isdigit()) < self.min_numeric:
raise ValidationError(
_(
"This password must contain at least %(min_numeric)d numeric character."
),
code="password_no_numeric_characters",
params={"min_numeric": self.min_numeric},
)
def get_help_text(self):
return _(
f"Your password must contain at least {self.min_numeric} numeric character."
)
-29
View File
@@ -11,7 +11,6 @@ SECRET_KEY = env("SECRET_KEY", default="secret")
DEBUG = env.bool("DJANGO_DEBUG", default=False)
ALLOWED_HOSTS = ["localhost", "127.0.0.1"]
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https")
USE_X_FORWARDED_HOST = True
# Application definition
@@ -159,30 +158,6 @@ AUTH_PASSWORD_VALIDATORS = [
{
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
},
{
"NAME": "api.validators.SpecialCharactersValidator",
"OPTIONS": {
"min_special_characters": 1,
},
},
{
"NAME": "api.validators.UppercaseValidator",
"OPTIONS": {
"min_uppercase": 1,
},
},
{
"NAME": "api.validators.LowercaseValidator",
"OPTIONS": {
"min_lowercase": 1,
},
},
{
"NAME": "api.validators.NumericValidator",
"OPTIONS": {
"min_numeric": 1,
},
},
]
SIMPLE_JWT = {
@@ -273,7 +248,3 @@ X_FRAME_OPTIONS = "DENY"
SECURE_REFERRER_POLICY = "strict-origin-when-cross-origin"
DJANGO_DELETION_BATCH_SIZE = env.int("DJANGO_DELETION_BATCH_SIZE", 5000)
# SAML requirement
CSRF_COOKIE_SECURE = True
SESSION_COOKIE_SECURE = True
+3 -9
View File
@@ -4,7 +4,6 @@ from config.env import env
IGNORED_EXCEPTIONS = [
# Provider is not connected due to credentials errors
"is not connected",
"ProviderConnectionError",
# Authentication Errors from AWS
"InvalidToken",
"AccessDeniedException",
@@ -17,7 +16,7 @@ IGNORED_EXCEPTIONS = [
"InternalServerErrorException",
"AccessDenied",
"No Shodan API Key", # Shodan Check
"RequestLimitExceeded", # For now, we don't want to log the RequestLimitExceeded errors
"RequestLimitExceeded", # For now we don't want to log the RequestLimitExceeded errors
"ThrottlingException",
"Rate exceeded",
"SubscriptionRequiredException",
@@ -43,9 +42,7 @@ IGNORED_EXCEPTIONS = [
"AWSAccessKeyIDInvalidError",
"AWSSessionTokenExpiredError",
"EndpointConnectionError", # AWS Service is not available in a region
# The following comes from urllib3: eu-west-1 -- HTTPClientError[126]: An HTTP Client raised an
# unhandled exception: AWSHTTPSConnectionPool(host='hostname.s3.eu-west-1.amazonaws.com', port=443): Pool is closed.
"Pool is closed",
"Pool is closed", # The following comes from urllib3: eu-west-1 -- HTTPClientError[126]: An HTTP Client raised an unhandled exception: AWSHTTPSConnectionPool(host='hostname.s3.eu-west-1.amazonaws.com', port=443): Pool is closed.
# Authentication Errors from GCP
"ClientAuthenticationError",
"AuthorizationFailed",
@@ -69,15 +66,12 @@ IGNORED_EXCEPTIONS = [
"AzureClientIdAndClientSecretNotBelongingToTenantIdError",
"AzureHTTPResponseError",
"Error with credentials provided",
# PowerShell Errors in User Authentication
"Microsoft Teams User Auth connection failed: Please check your permissions and try again.",
"Exchange Online User Auth connection failed: Please check your permissions and try again.",
]
def before_send(event, hint):
"""
before_send handles the Sentry events in order to send them or not
before_send handles the Sentry events in order to sent them or not
"""
# Ignore logs with the ignored_exceptions
# https://docs.python.org/3/library/logging.html#logrecord-objects
@@ -25,18 +25,9 @@ SOCIALACCOUNT_EMAIL_AUTHENTICATION = True
SOCIALACCOUNT_EMAIL_AUTHENTICATION_AUTO_CONNECT = True
SOCIALACCOUNT_ADAPTER = "api.adapters.ProwlerSocialAccountAdapter"
# def inline(pem: str) -> str:
# return "".join(
# line.strip()
# for line in pem.splitlines()
# if "CERTIFICATE" not in line and "KEY" not in line
# )
# # SAML keys (TODO: Validate certificates)
# SAML_PUBLIC_CERT = inline(env("SAML_PUBLIC_CERT", default=""))
# SAML_PRIVATE_KEY = inline(env("SAML_PRIVATE_KEY", default=""))
# SAML keys
SAML_PUBLIC_CERT = env("SAML_PUBLIC_CERT", default="")
SAML_PRIVATE_KEY = env("SAML_PRIVATE_KEY", default="")
SOCIALACCOUNT_PROVIDERS = {
"google": {
@@ -69,14 +60,12 @@ SOCIALACCOUNT_PROVIDERS = {
"entity_id": "urn:prowler.com:sp",
},
"advanced": {
# TODO: Validate certificates
# "x509cert": SAML_PUBLIC_CERT,
# "private_key": SAML_PRIVATE_KEY,
# "authn_request_signed": True,
# "want_message_signed": True,
# "want_assertion_signed": True,
"reject_idp_initiated_sso": False,
"x509cert": SAML_PUBLIC_CERT,
"private_key": SAML_PRIVATE_KEY,
"name_id_format": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
"authn_request_signed": True,
"want_assertion_signed": True,
"want_message_signed": True,
},
},
}
+3 -96
View File
@@ -23,13 +23,11 @@ from api.models import (
Invitation,
LighthouseConfiguration,
Membership,
Processor,
Provider,
ProviderGroup,
ProviderSecret,
Resource,
ResourceTag,
ResourceTagMapping,
Role,
SAMLConfiguration,
SAMLDomainIndex,
@@ -46,19 +44,12 @@ from api.v1.serializers import TokenSerializer
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
TODAY = str(datetime.today().date())
API_JSON_CONTENT_TYPE = "application/vnd.api+json"
NO_TENANT_HTTP_STATUS = status.HTTP_401_UNAUTHORIZED
TEST_USER = "dev@prowler.com"
TEST_PASSWORD = "testing_psswd"
def today_after_n_days(n_days: int) -> str:
return datetime.strftime(
datetime.today().date() + timedelta(days=n_days), "%Y-%m-%d"
)
@pytest.fixture(scope="module")
def enforce_test_user_db_connection(django_db_setup, django_db_blocker):
"""Ensure tests use the test user for database connections."""
@@ -390,27 +381,8 @@ def providers_fixture(tenants_fixture):
tenant_id=tenant.id,
scanner_args={"key1": "value1", "key2": {"key21": "value21"}},
)
provider6 = Provider.objects.create(
provider="m365",
uid="m365.test.com",
alias="m365_testing",
tenant_id=tenant.id,
)
return provider1, provider2, provider3, provider4, provider5, provider6
@pytest.fixture
def processor_fixture(tenants_fixture):
tenant, *_ = tenants_fixture
processor = Processor.objects.create(
tenant_id=tenant.id,
processor_type="mutelist",
configuration="Mutelist:\n Accounts:\n *:\n Checks:\n iam_user_hardware_mfa_enabled:\n "
" Regions:\n - *\n Resources:\n - *",
)
return processor
return provider1, provider2, provider3, provider4, provider5
@pytest.fixture
@@ -662,7 +634,6 @@ def findings_fixture(scans_fixture, resources_fixture):
check_metadata={
"CheckId": "test_check_id",
"Description": "test description apple sauce",
"servicename": "ec2",
},
first_seen_at="2024-01-02T00:00:00Z",
)
@@ -689,7 +660,6 @@ def findings_fixture(scans_fixture, resources_fixture):
check_metadata={
"CheckId": "test_check_id",
"Description": "test description orange juice",
"servicename": "s3",
},
first_seen_at="2024-01-02T00:00:00Z",
muted=True,
@@ -1065,7 +1035,7 @@ def integrations_fixture(providers_fixture):
enabled=True,
connected=True,
integration_type="amazon_s3",
configuration={"key": "value1"},
configuration={"key": "value"},
credentials={"psswd": "1234"},
)
IntegrationProviderRelationship.objects.create(
@@ -1145,73 +1115,10 @@ def latest_scan_finding(authenticated_client, providers_fixture, resources_fixtu
return finding
@pytest.fixture(scope="function")
def latest_scan_resource(authenticated_client, providers_fixture):
provider = providers_fixture[0]
tenant_id = str(providers_fixture[0].tenant_id)
scan = Scan.objects.create(
name="latest completed scan for resource",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant_id,
)
resource = Resource.objects.create(
tenant_id=tenant_id,
provider=provider,
uid="latest_resource_uid",
name="Latest Resource",
region="us-east-1",
service="ec2",
type="instance",
metadata='{"test": "metadata"}',
details='{"test": "details"}',
)
resource_tag = ResourceTag.objects.create(
tenant_id=tenant_id,
key="environment",
value="test",
)
ResourceTagMapping.objects.create(
tenant_id=tenant_id,
resource=resource,
tag=resource_tag,
)
finding = Finding.objects.create(
tenant_id=tenant_id,
uid="test_finding_uid_latest",
scan=scan,
delta="new",
status=Status.FAIL,
status_extended="test status extended ",
impact=Severity.critical,
impact_extended="test impact extended",
severity=Severity.critical,
raw_result={
"status": Status.FAIL,
"impact": Severity.critical,
"severity": Severity.critical,
},
tags={"test": "latest"},
check_id="test_check_id_latest",
check_metadata={
"CheckId": "test_check_id_latest",
"Description": "test description latest",
},
first_seen_at="2024-01-02T00:00:00Z",
)
finding.add_resources([resource])
backfill_resource_scan_summaries(tenant_id, str(scan.id))
return resource
@pytest.fixture
def saml_setup(tenants_fixture):
tenant_id = tenants_fixture[0].id
domain = "prowler.com"
domain = "example.com"
SAMLDomainIndex.objects.create(email_domain=domain, tenant_id=tenant_id)
+10 -4
View File
@@ -2,10 +2,10 @@ import json
from datetime import datetime, timedelta, timezone
from django_celery_beat.models import IntervalSchedule, PeriodicTask
from rest_framework_json_api.serializers import ValidationError
from tasks.tasks import perform_scheduled_scan_task
from api.db_utils import rls_transaction
from api.exceptions import ConflictException
from api.models import Provider, Scan, StateChoices
@@ -24,9 +24,15 @@ def schedule_provider_scan(provider_instance: Provider):
if PeriodicTask.objects.filter(
interval=schedule, name=task_name, task="scan-perform-scheduled"
).exists():
raise ConflictException(
detail="There is already a scheduled scan for this provider.",
pointer="/data/attributes/provider_id",
raise ValidationError(
[
{
"detail": "There is already a scheduled scan for this provider.",
"status": 400,
"source": {"pointer": "/data/attributes/provider_id"},
"code": "invalid",
}
]
)
with rls_transaction(tenant_id):
+2 -37
View File
@@ -3,11 +3,8 @@ from datetime import datetime, timezone
import openai
from celery.utils.log import get_task_logger
from api.models import Integration, LighthouseConfiguration, Provider
from api.utils import (
prowler_integration_connection_test,
prowler_provider_connection_test,
)
from api.models import LighthouseConfiguration, Provider
from api.utils import prowler_provider_connection_test
logger = get_task_logger(__name__)
@@ -86,35 +83,3 @@ def check_lighthouse_connection(lighthouse_config_id: str):
lighthouse_config.is_active = False
lighthouse_config.save()
return {"connected": False, "error": str(e), "available_models": []}
def check_integration_connection(integration_id: str):
"""
Business logic to check the connection status of an integration.
Args:
integration_id (str): The primary key of the Integration instance to check.
"""
integration = Integration.objects.filter(pk=integration_id, enabled=True).first()
if not integration:
logger.info(f"Integration {integration_id} is not enabled")
return {"connected": False, "error": "Integration is not enabled"}
try:
result = prowler_integration_connection_test(integration)
except Exception as e:
logger.warning(
f"Unexpected exception checking {integration.integration_type} integration connection: {str(e)}"
)
raise e
# Update integration connection status
integration.connected = result.is_connected
integration.connection_last_checked_at = datetime.now(tz=timezone.utc)
integration.save()
return {
"connected": result.is_connected,
"error": str(result.error) if result.error else None,
}
+5 -16
View File
@@ -8,12 +8,11 @@ from botocore.exceptions import ClientError, NoCredentialsError, ParamValidation
from celery.utils.log import get_task_logger
from django.conf import settings
from api.db_utils import rls_transaction
from api.models import Scan
from prowler.config.config import (
csv_file_suffix,
html_file_suffix,
json_ocsf_file_suffix,
output_file_timestamp,
)
from prowler.lib.outputs.compliance.aws_well_architected.aws_well_architected import (
AWSWellArchitected,
@@ -21,7 +20,6 @@ from prowler.lib.outputs.compliance.aws_well_architected.aws_well_architected im
from prowler.lib.outputs.compliance.cis.cis_aws import AWSCIS
from prowler.lib.outputs.compliance.cis.cis_azure import AzureCIS
from prowler.lib.outputs.compliance.cis.cis_gcp import GCPCIS
from prowler.lib.outputs.compliance.cis.cis_github import GithubCIS
from prowler.lib.outputs.compliance.cis.cis_kubernetes import KubernetesCIS
from prowler.lib.outputs.compliance.cis.cis_m365 import M365CIS
from prowler.lib.outputs.compliance.ens.ens_aws import AWSENS
@@ -33,7 +31,6 @@ from prowler.lib.outputs.compliance.iso27001.iso27001_gcp import GCPISO27001
from prowler.lib.outputs.compliance.iso27001.iso27001_kubernetes import (
KubernetesISO27001,
)
from prowler.lib.outputs.compliance.iso27001.iso27001_m365 import M365ISO27001
from prowler.lib.outputs.compliance.kisa_ismsp.kisa_ismsp_aws import AWSKISAISMSP
from prowler.lib.outputs.compliance.mitre_attack.mitre_attack_aws import AWSMitreAttack
from prowler.lib.outputs.compliance.mitre_attack.mitre_attack_azure import (
@@ -93,10 +90,6 @@ COMPLIANCE_CLASS_MAP = {
"m365": [
(lambda name: name.startswith("cis_"), M365CIS),
(lambda name: name == "prowler_threatscore_m365", ProwlerThreatScoreM365),
(lambda name: name.startswith("iso27001_"), M365ISO27001),
],
"github": [
(lambda name: name.startswith("cis_"), GithubCIS),
],
}
@@ -172,7 +165,7 @@ def get_s3_client():
return s3_client
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str | None:
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
"""
Upload the specified ZIP file to an S3 bucket.
If the S3 bucket environment variables are not configured,
@@ -189,7 +182,7 @@ def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str | None:
"""
bucket = base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET
if not bucket:
return
return None
try:
s3 = get_s3_client()
@@ -249,19 +242,15 @@ def _generate_output_directory(
# Sanitize the prowler provider name to ensure it is a valid directory name
prowler_provider_sanitized = re.sub(r"[^\w\-]", "-", prowler_provider)
with rls_transaction(tenant_id):
started_at = Scan.objects.get(id=scan_id).started_at
timestamp = started_at.strftime("%Y%m%d%H%M%S")
path = (
f"{output_directory}/{tenant_id}/{scan_id}/prowler-output-"
f"{prowler_provider_sanitized}-{timestamp}"
f"{prowler_provider_sanitized}-{output_file_timestamp}"
)
os.makedirs("/".join(path.split("/")[:-1]), exist_ok=True)
compliance_path = (
f"{output_directory}/{tenant_id}/{scan_id}/compliance/prowler-output-"
f"{prowler_provider_sanitized}-{timestamp}"
f"{prowler_provider_sanitized}-{output_file_timestamp}"
)
os.makedirs("/".join(compliance_path.split("/")[:-1]), exist_ok=True)
-156
View File
@@ -1,156 +0,0 @@
import os
from glob import glob
from celery.utils.log import get_task_logger
from api.db_utils import rls_transaction
from api.models import Integration
from prowler.lib.outputs.asff.asff import ASFF
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.csv.csv import CSV
from prowler.lib.outputs.html.html import HTML
from prowler.lib.outputs.ocsf.ocsf import OCSF
from prowler.providers.aws.lib.s3.s3 import S3
from prowler.providers.common.models import Connection
logger = get_task_logger(__name__)
def get_s3_client_from_integration(
integration: Integration,
) -> tuple[bool, S3 | Connection]:
"""
Create and return a boto3 S3 client using AWS credentials from an integration.
Args:
integration (Integration): The integration to get the S3 client from.
Returns:
tuple[bool, S3 | Connection]: A tuple containing a boolean indicating if the connection was successful and the S3 client or connection object.
"""
s3 = S3(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
output_directory=integration.configuration["output_directory"],
)
connection = s3.test_connection(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
)
if connection.is_connected:
return True, s3
return False, connection
def upload_s3_integration(
tenant_id: str, provider_id: str, output_directory: str
) -> bool:
"""
Upload the specified output files to an S3 bucket from an integration.
Reconstructs output objects from files in the output directory instead of using serialized data.
Args:
tenant_id (str): The tenant identifier, used as part of the S3 key prefix.
provider_id (str): The provider identifier, used as part of the S3 key prefix.
output_directory (str): Path to the directory containing output files.
Returns:
bool: True if all integrations were executed, False otherwise.
Raises:
botocore.exceptions.ClientError: If the upload attempt to S3 fails for any reason.
"""
logger.info(f"Processing S3 integrations for provider {provider_id}")
try:
with rls_transaction(tenant_id):
integrations = list(
Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
)
if not integrations:
logger.error(f"No S3 integrations found for provider {provider_id}")
return False
integration_executions = 0
for integration in integrations:
try:
connected, s3 = get_s3_client_from_integration(integration)
except Exception as e:
logger.info(
f"S3 connection failed for integration {integration.id}: {e}"
)
integration.connected = False
integration.save()
continue
if connected:
try:
# Reconstruct generated_outputs from files in output directory
# This approach scans the output directory for files and creates the appropriate
# output objects based on file extensions and naming patterns.
generated_outputs = {"regular": [], "compliance": []}
# Find and recreate regular outputs (CSV, HTML, OCSF)
output_file_patterns = {
".csv": CSV,
".html": HTML,
".ocsf.json": OCSF,
".asff.json": ASFF,
}
base_dir = os.path.dirname(output_directory)
for extension, output_class in output_file_patterns.items():
pattern = f"{output_directory}*{extension}"
for file_path in glob(pattern):
if os.path.exists(file_path):
output = output_class(findings=[], file_path=file_path)
output.create_file_descriptor(file_path)
generated_outputs["regular"].append(output)
# Find and recreate compliance outputs
compliance_pattern = os.path.join(base_dir, "compliance", "*.csv")
for file_path in glob(compliance_pattern):
if os.path.exists(file_path):
output = GenericCompliance(
findings=[],
compliance=None,
file_path=file_path,
file_extension=".csv",
)
output.create_file_descriptor(file_path)
generated_outputs["compliance"].append(output)
# Use send_to_bucket with recreated generated_outputs objects
s3.send_to_bucket(generated_outputs)
except Exception as e:
logger.error(
f"S3 upload failed for integration {integration.id}: {e}"
)
continue
integration_executions += 1
else:
integration.connected = False
integration.save()
logger.error(
f"S3 upload failed, connection failed for integration {integration.id}: {s3.error}"
)
result = integration_executions == len(integrations)
if result:
logger.info(
f"All the S3 integrations completed successfully for provider {provider_id}"
)
else:
logger.info(f"Some S3 integrations failed for provider {provider_id}")
return result
except Exception as e:
logger.error(f"S3 integrations failed for provider {provider_id}: {str(e)}")
return False
+16 -84
View File
@@ -1,29 +1,22 @@
import json
import time
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timezone
from celery.utils.log import get_task_logger
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
from django.db import IntegrityError, OperationalError
from django.db.models import Case, Count, IntegerField, Prefetch, Sum, When
from django.db.models import Case, Count, IntegerField, Sum, When
from tasks.utils import CustomEncoder
from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
generate_scan_compliance,
)
from api.db_utils import (
create_objects_in_batches,
rls_transaction,
update_objects_in_batches,
)
from api.exceptions import ProviderConnectionError
from api.db_utils import create_objects_in_batches, rls_transaction
from api.models import (
ComplianceRequirementOverview,
Finding,
Processor,
Provider,
Resource,
ResourceScanSummary,
@@ -33,7 +26,7 @@ from api.models import (
StateChoices,
)
from api.models import StatusChoices as FindingStatus
from api.utils import initialize_prowler_provider, return_prowler_provider
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.outputs.finding import Finding as ProwlerFinding
from prowler.lib.scan.scan import Scan as ProwlerScan
@@ -108,10 +101,7 @@ def _store_resources(
def perform_prowler_scan(
tenant_id: str,
scan_id: str,
provider_id: str,
checks_to_execute: list[str] | None = None,
tenant_id: str, scan_id: str, provider_id: str, checks_to_execute: list[str] = None
):
"""
Perform a scan using Prowler and store the findings and resources in the database.
@@ -142,28 +132,14 @@ def perform_prowler_scan(
scan_instance.started_at = datetime.now(tz=timezone.utc)
scan_instance.save()
# Find the mutelist processor if it exists
with rls_transaction(tenant_id):
try:
mutelist_processor = Processor.objects.get(
tenant_id=tenant_id, processor_type=Processor.ProcessorChoices.MUTELIST
)
except Processor.DoesNotExist:
mutelist_processor = None
except Exception as e:
logger.error(f"Error processing mutelist rules: {e}")
mutelist_processor = None
try:
with rls_transaction(tenant_id):
try:
prowler_provider = initialize_prowler_provider(
provider_instance, mutelist_processor
)
prowler_provider = initialize_prowler_provider(provider_instance)
provider_instance.connected = True
except Exception as e:
provider_instance.connected = False
exc = ProviderConnectionError(
exc = ValueError(
f"Provider {provider_instance.provider} is not connected: {e}"
)
finally:
@@ -173,8 +149,7 @@ def perform_prowler_scan(
provider_instance.save()
# If the provider is not connected, raise an exception outside the transaction.
# If raised within the transaction, the transaction will be rolled back and the provider will not be marked
# as not connected.
# If raised within the transaction, the transaction will be rolled back and the provider will not be marked as not connected.
if exc:
raise exc
@@ -183,7 +158,6 @@ def perform_prowler_scan(
resource_cache = {}
tag_cache = {}
last_status_cache = {}
resource_failed_findings_cache = defaultdict(int)
for progress, findings in prowler_scan.scan():
for finding in findings:
@@ -209,9 +183,6 @@ def perform_prowler_scan(
},
)
resource_cache[resource_uid] = resource_instance
# Initialize all processed resources in the cache
resource_failed_findings_cache[resource_uid] = 0
else:
resource_instance = resource_cache[resource_uid]
@@ -302,9 +273,6 @@ def perform_prowler_scan(
if not last_first_seen_at:
last_first_seen_at = datetime.now(tz=timezone.utc)
# If the finding is muted at this time the reason must be the configured Mutelist
muted_reason = "Muted by mutelist" if finding.muted else None
# Create the finding
finding_instance = Finding.objects.create(
tenant_id=tenant_id,
@@ -320,16 +288,10 @@ def perform_prowler_scan(
scan=scan_instance,
first_seen_at=last_first_seen_at,
muted=finding.muted,
muted_reason=muted_reason,
compliance=finding.compliance,
)
finding_instance.add_resources([resource_instance])
# Increment failed_findings_count cache if the finding status is FAIL and not muted
if status == FindingStatus.FAIL and not finding.muted:
resource_uid = finding.resource_uid
resource_failed_findings_cache[resource_uid] += 1
# Update scan resource summaries
scan_resource_cache.add(
(
@@ -347,24 +309,6 @@ def perform_prowler_scan(
scan_instance.state = StateChoices.COMPLETED
# Update failed_findings_count for all resources in batches if scan completed successfully
if resource_failed_findings_cache:
resources_to_update = []
for resource_uid, failed_count in resource_failed_findings_cache.items():
if resource_uid in resource_cache:
resource_instance = resource_cache[resource_uid]
resource_instance.failed_findings_count = failed_count
resources_to_update.append(resource_instance)
if resources_to_update:
update_objects_in_batches(
tenant_id=tenant_id,
model=Resource,
objects=resources_to_update,
fields=["failed_findings_count"],
batch_size=1000,
)
except Exception as e:
logger.error(f"Error performing scan {scan_id}: {e}")
exception = e
@@ -417,9 +361,6 @@ def aggregate_findings(tenant_id: str, scan_id: str):
changed, unchanged). The results are grouped by `check_id`, `service`, `severity`, and `region`.
These aggregated metrics are then stored in the `ScanSummary` table.
Additionally, it updates the failed_findings_count field for each resource based on the most
recent findings for each finding.uid.
Args:
tenant_id (str): The ID of the tenant to which the scan belongs.
scan_id (str): The ID of the scan for which findings need to be aggregated.
@@ -585,30 +526,21 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
with rls_transaction(tenant_id):
scan_instance = Scan.objects.get(pk=scan_id)
provider_instance = scan_instance.provider
prowler_provider = return_prowler_provider(provider_instance)
prowler_provider = initialize_prowler_provider(provider_instance)
# Get check status data by region from findings
findings = (
Finding.all_objects.filter(scan_id=scan_id, muted=False)
.only("id", "check_id", "status")
.prefetch_related(
Prefetch(
"resources",
queryset=Resource.objects.only("id", "region"),
to_attr="small_resources",
)
)
.iterator(chunk_size=1000)
)
check_status_by_region = {}
with rls_transaction(tenant_id):
findings = Finding.objects.filter(scan_id=scan_id, muted=False)
for finding in findings:
for resource in finding.small_resources:
# Get region from resources
for resource in finding.resources.all():
region = resource.region
current_status = check_status_by_region.setdefault(region, {})
if current_status.get(finding.check_id) != "FAIL":
current_status[finding.check_id] = finding.status
region_dict = check_status_by_region.setdefault(region, {})
current_status = region_dict.get(finding.check_id)
if current_status == "FAIL":
continue
region_dict[finding.check_id] = finding.status
try:
# Try to get regions from provider
+23 -150
View File
@@ -2,17 +2,13 @@ from datetime import datetime, timedelta, timezone
from pathlib import Path
from shutil import rmtree
from celery import chain, group, shared_task
from celery import chain, shared_task
from celery.utils.log import get_task_logger
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django_celery_beat.models import PeriodicTask
from tasks.jobs.backfill import backfill_resource_scan_summaries
from tasks.jobs.connection import (
check_integration_connection,
check_lighthouse_connection,
check_provider_connection,
)
from tasks.jobs.connection import check_lighthouse_connection, check_provider_connection
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.export import (
COMPLIANCE_CLASS_MAP,
@@ -21,7 +17,6 @@ from tasks.jobs.export import (
_generate_output_directory,
_upload_to_s3,
)
from tasks.jobs.integrations import upload_s3_integration
from tasks.jobs.scan import (
aggregate_findings,
create_compliance_requirements,
@@ -32,7 +27,7 @@ from tasks.utils import batched, get_next_execution_datetime
from api.compliance import get_compliance_frameworks
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.models import Finding, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.check.compliance_models import Compliance
@@ -42,30 +37,6 @@ from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str):
"""
Helper function to perform tasks after a scan is completed.
Args:
tenant_id (str): The tenant ID under which the scan was performed.
scan_id (str): The ID of the scan that was performed.
provider_id (str): The primary key of the Provider instance that was scanned.
"""
create_compliance_requirements_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
)
chain(
perform_scan_summary_task.si(tenant_id=tenant_id, scan_id=scan_id),
generate_outputs_task.si(
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
check_integrations_task.si(
tenant_id=tenant_id,
provider_id=provider_id,
),
).apply_async()
@shared_task(base=RLSTask, name="provider-connection-check")
@set_tenant
def check_provider_connection_task(provider_id: str):
@@ -83,18 +54,6 @@ def check_provider_connection_task(provider_id: str):
return check_provider_connection(provider_id=provider_id)
@shared_task(base=RLSTask, name="integration-connection-check")
@set_tenant
def check_integration_connection_task(integration_id: str):
"""
Task to check the connection status of an integration.
Args:
integration_id (str): The primary key of the Integration instance to check.
"""
return check_integration_connection(integration_id=integration_id)
@shared_task(
base=RLSTask, name="provider-deletion", queue="deletion", autoretry_for=(Exception,)
)
@@ -144,7 +103,13 @@ def perform_scan_task(
checks_to_execute=checks_to_execute,
)
_perform_scan_complete_tasks(tenant_id, scan_id, provider_id)
chain(
perform_scan_summary_task.si(tenant_id, scan_id),
create_compliance_requirements_task.si(tenant_id=tenant_id, scan_id=scan_id),
generate_outputs.si(
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
).apply_async()
return result
@@ -249,12 +214,20 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scheduler_task_id=periodic_task_instance.id,
)
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)
chain(
perform_scan_summary_task.si(tenant_id, scan_instance.id),
create_compliance_requirements_task.si(
tenant_id=tenant_id, scan_id=str(scan_instance.id)
),
generate_outputs.si(
scan_id=str(scan_instance.id), provider_id=provider_id, tenant_id=tenant_id
),
).apply_async()
return result
@shared_task(name="scan-summary", queue="overview")
@shared_task(name="scan-summary")
def perform_scan_summary_task(tenant_id: str, scan_id: str):
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
@@ -270,7 +243,7 @@ def delete_tenant_task(tenant_id: str):
queue="scan-reports",
)
@set_tenant(keep_tenant=True)
def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
def generate_outputs(scan_id: str, provider_id: str, tenant_id: str):
"""
Process findings in batches and generate output files in multiple formats.
@@ -382,34 +355,7 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
compressed = _compress_output_files(out_dir)
upload_uri = _upload_to_s3(tenant_id, compressed, scan_id)
# S3 integrations (need output_directory)
with rls_transaction(tenant_id):
s3_integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
if s3_integrations:
# Pass the output directory path to S3 integration task to reconstruct objects from files
s3_integration_task.apply_async(
kwargs={
"tenant_id": tenant_id,
"provider_id": provider_id,
"output_directory": out_dir,
}
).get(
disable_sync_subtasks=False
) # TODO: This synchronous execution is NOT recommended
# We're forced to do this because we need the files to exist before deletion occurs.
# Once we have the periodic file cleanup task implemented, we should:
# 1. Remove this .get() call and make it fully async
# 2. For Cloud deployments, develop a secondary approach where outputs are stored
# directly in S3 and read from there, eliminating local file dependencies
if upload_uri:
# TODO: We need to create a new periodic task to delete the output files
# This task shouldn't be responsible for deleting the output files
try:
rmtree(Path(compressed).parent, ignore_errors=True)
except Exception as e:
@@ -420,10 +366,7 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
Scan.all_objects.filter(id=scan_id).update(output_location=final_location)
logger.info(f"Scan outputs at {final_location}")
return {
"upload": did_upload,
}
return {"upload": did_upload}
@shared_task(name="backfill-scan-resource-summaries", queue="backfill")
@@ -438,7 +381,7 @@ def backfill_scan_resource_summaries_task(tenant_id: str, scan_id: str):
return backfill_resource_scan_summaries(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(base=RLSTask, name="scan-compliance-overviews", queue="overview")
@shared_task(base=RLSTask, name="scan-compliance-overviews")
def create_compliance_requirements_task(tenant_id: str, scan_id: str):
"""
Creates detailed compliance requirement records for a scan.
@@ -471,73 +414,3 @@ def check_lighthouse_connection_task(lighthouse_config_id: str, tenant_id: str =
- 'available_models' (list): List of available models if connection is successful.
"""
return check_lighthouse_connection(lighthouse_config_id=lighthouse_config_id)
@shared_task(name="integration-check")
def check_integrations_task(tenant_id: str, provider_id: str):
"""
Check and execute all configured integrations for a provider.
Args:
tenant_id (str): The tenant identifier
provider_id (str): The provider identifier
"""
logger.info(f"Checking integrations for provider {provider_id}")
try:
with rls_transaction(tenant_id):
integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
enabled=True,
)
if not integrations.exists():
logger.info(f"No integrations configured for provider {provider_id}")
return {"integrations_processed": 0}
integration_tasks = []
# TODO: Add other integration types here
# slack_integrations = integrations.filter(
# integration_type=Integration.IntegrationChoices.SLACK
# )
# if slack_integrations.exists():
# integration_tasks.append(
# slack_integration_task.s(
# tenant_id=tenant_id,
# provider_id=provider_id,
# )
# )
except Exception as e:
logger.error(f"Integration check failed for provider {provider_id}: {str(e)}")
return {"integrations_processed": 0, "error": str(e)}
# Execute all integration tasks in parallel if any were found
if integration_tasks:
job = group(integration_tasks)
job.apply_async()
logger.info(f"Launched {len(integration_tasks)} integration task(s)")
return {"integrations_processed": len(integration_tasks)}
@shared_task(
base=RLSTask,
name="integration-s3",
queue="integrations",
)
def s3_integration_task(
tenant_id: str,
provider_id: str,
output_directory: str,
):
"""
Process S3 integrations for a provider.
Args:
tenant_id (str): The tenant identifier
provider_id (str): The provider identifier
output_directory (str): Path to the directory containing output files
"""
return upload_s3_integration(tenant_id, provider_id, output_directory)
+3 -3
View File
@@ -3,9 +3,9 @@ from unittest.mock import patch
import pytest
from django_celery_beat.models import IntervalSchedule, PeriodicTask
from rest_framework_json_api.serializers import ValidationError
from tasks.beat import schedule_provider_scan
from api.exceptions import ConflictException
from api.models import Scan
@@ -48,8 +48,8 @@ class TestScheduleProviderScan:
with patch("tasks.tasks.perform_scheduled_scan_task.apply_async"):
schedule_provider_scan(provider_instance)
# Now, try scheduling again, should raise ConflictException
with pytest.raises(ConflictException) as exc_info:
# Now, try scheduling again, should raise ValidationError
with pytest.raises(ValidationError) as exc_info:
schedule_provider_scan(provider_instance)
assert "There is already a scheduled scan for this provider." in str(
+2 -132
View File
@@ -1,15 +1,10 @@
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from tasks.jobs.connection import (
check_integration_connection,
check_lighthouse_connection,
check_provider_connection,
)
from tasks.jobs.connection import check_lighthouse_connection, check_provider_connection
from api.models import Integration, LighthouseConfiguration, Provider
from api.models import LighthouseConfiguration, Provider
@pytest.mark.parametrize(
@@ -132,128 +127,3 @@ def test_check_lighthouse_connection_missing_api_key(mock_lighthouse_get):
assert result["available_models"] == []
assert mock_lighthouse_instance.is_active is False
mock_lighthouse_instance.save.assert_called_once()
@pytest.mark.django_db
class TestCheckIntegrationConnection:
def setup_method(self):
self.integration_id = str(uuid.uuid4())
@patch("tasks.jobs.connection.Integration.objects.filter")
@patch("tasks.jobs.connection.prowler_integration_connection_test")
def test_check_integration_connection_success(
self, mock_prowler_test, mock_integration_filter
):
"""Test successful integration connection check with enabled=True filter."""
mock_integration = MagicMock()
mock_integration.id = self.integration_id
mock_integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
mock_queryset = MagicMock()
mock_queryset.first.return_value = mock_integration
mock_integration_filter.return_value = mock_queryset
mock_connection_result = MagicMock()
mock_connection_result.is_connected = True
mock_connection_result.error = None
mock_prowler_test.return_value = mock_connection_result
result = check_integration_connection(integration_id=self.integration_id)
# Verify that Integration.objects.filter was called with enabled=True filter
mock_integration_filter.assert_called_once_with(
pk=self.integration_id, enabled=True
)
mock_queryset.first.assert_called_once()
mock_prowler_test.assert_called_once_with(mock_integration)
# Verify the integration properties were updated
assert mock_integration.connected is True
assert mock_integration.connection_last_checked_at is not None
mock_integration.save.assert_called_once()
# Verify the return value
assert result["connected"] is True
assert result["error"] is None
@patch("tasks.jobs.connection.Integration.objects.filter")
@patch("tasks.jobs.connection.prowler_integration_connection_test")
def test_check_integration_connection_failure(
self, mock_prowler_test, mock_integration_filter
):
"""Test failed integration connection check."""
mock_integration = MagicMock()
mock_integration.id = self.integration_id
mock_queryset = MagicMock()
mock_queryset.first.return_value = mock_integration
mock_integration_filter.return_value = mock_queryset
test_error = Exception("Connection failed")
mock_connection_result = MagicMock()
mock_connection_result.is_connected = False
mock_connection_result.error = test_error
mock_prowler_test.return_value = mock_connection_result
result = check_integration_connection(integration_id=self.integration_id)
# Verify that Integration.objects.filter was called with enabled=True filter
mock_integration_filter.assert_called_once_with(
pk=self.integration_id, enabled=True
)
mock_queryset.first.assert_called_once()
# Verify the integration properties were updated
assert mock_integration.connected is False
assert mock_integration.connection_last_checked_at is not None
mock_integration.save.assert_called_once()
# Verify the return value
assert result["connected"] is False
assert result["error"] == str(test_error)
@patch("tasks.jobs.connection.Integration.objects.filter")
def test_check_integration_connection_not_enabled(self, mock_integration_filter):
"""Test that disabled integrations return proper error response."""
# Mock that no enabled integration is found
mock_queryset = MagicMock()
mock_queryset.first.return_value = None
mock_integration_filter.return_value = mock_queryset
result = check_integration_connection(integration_id=self.integration_id)
# Verify the filter was called with enabled=True
mock_integration_filter.assert_called_once_with(
pk=self.integration_id, enabled=True
)
mock_queryset.first.assert_called_once()
# Verify the return value matches the expected error response
assert result["connected"] is False
assert result["error"] == "Integration is not enabled"
@patch("tasks.jobs.connection.Integration.objects.filter")
@patch("tasks.jobs.connection.prowler_integration_connection_test")
def test_check_integration_connection_exception(
self, mock_prowler_test, mock_integration_filter
):
"""Test integration connection check when prowler test raises exception."""
mock_integration = MagicMock()
mock_integration.id = self.integration_id
mock_queryset = MagicMock()
mock_queryset.first.return_value = mock_integration
mock_integration_filter.return_value = mock_queryset
test_exception = Exception("Unexpected error during connection test")
mock_prowler_test.side_effect = test_exception
with pytest.raises(Exception, match="Unexpected error during connection test"):
check_integration_connection(integration_id=self.integration_id)
# Verify that Integration.objects.filter was called with enabled=True filter
mock_integration_filter.assert_called_once_with(
pk=self.integration_id, enabled=True
)
mock_queryset.first.assert_called_once()
mock_prowler_test.assert_called_once_with(mock_integration)
+12 -38
View File
@@ -1,7 +1,5 @@
import os
import uuid
import zipfile
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch
@@ -129,26 +127,14 @@ class TestOutputs:
_upload_to_s3("tenant", str(zip_path), "scan")
mock_logger.assert_called()
@patch("tasks.jobs.export.rls_transaction")
@patch("tasks.jobs.export.Scan")
def test_generate_output_directory_creates_paths(
self, mock_scan, mock_rls_transaction, tmpdir
):
# Mock the scan object with a started_at timestamp
mock_scan_instance = MagicMock()
mock_scan_instance.started_at = datetime(2023, 6, 15, 10, 30, 45)
mock_scan.objects.get.return_value = mock_scan_instance
# Mock rls_transaction as a context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock(return_value=False)
def test_generate_output_directory_creates_paths(self, tmpdir):
from prowler.config.config import output_file_timestamp
base_tmp = Path(str(tmpdir.mkdir("generate_output")))
base_dir = str(base_tmp)
tenant_id = str(uuid.uuid4())
scan_id = str(uuid.uuid4())
tenant_id = "t1"
scan_id = "s1"
provider = "aws"
expected_timestamp = "20230615103045"
path, compliance = _generate_output_directory(
base_dir, provider, tenant_id, scan_id
@@ -157,29 +143,17 @@ class TestOutputs:
assert os.path.isdir(os.path.dirname(path))
assert os.path.isdir(os.path.dirname(compliance))
assert path.endswith(f"{provider}-{expected_timestamp}")
assert compliance.endswith(f"{provider}-{expected_timestamp}")
assert path.endswith(f"{provider}-{output_file_timestamp}")
assert compliance.endswith(f"{provider}-{output_file_timestamp}")
@patch("tasks.jobs.export.rls_transaction")
@patch("tasks.jobs.export.Scan")
def test_generate_output_directory_invalid_character(
self, mock_scan, mock_rls_transaction, tmpdir
):
# Mock the scan object with a started_at timestamp
mock_scan_instance = MagicMock()
mock_scan_instance.started_at = datetime(2023, 6, 15, 10, 30, 45)
mock_scan.objects.get.return_value = mock_scan_instance
# Mock rls_transaction as a context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock(return_value=False)
def test_generate_output_directory_invalid_character(self, tmpdir):
from prowler.config.config import output_file_timestamp
base_tmp = Path(str(tmpdir.mkdir("generate_output")))
base_dir = str(base_tmp)
tenant_id = str(uuid.uuid4())
scan_id = str(uuid.uuid4())
tenant_id = "t1"
scan_id = "s1"
provider = "aws/test@check"
expected_timestamp = "20230615103045"
path, compliance = _generate_output_directory(
base_dir, provider, tenant_id, scan_id
@@ -188,5 +162,5 @@ class TestOutputs:
assert os.path.isdir(os.path.dirname(path))
assert os.path.isdir(os.path.dirname(compliance))
assert path.endswith(f"aws-test-check-{expected_timestamp}")
assert compliance.endswith(f"aws-test-check-{expected_timestamp}")
assert path.endswith(f"aws-test-check-{output_file_timestamp}")
assert compliance.endswith(f"aws-test-check-{output_file_timestamp}")
@@ -1,475 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from tasks.jobs.integrations import (
get_s3_client_from_integration,
upload_s3_integration,
)
from api.models import Integration
from api.utils import prowler_integration_connection_test
from prowler.providers.common.models import Connection
@pytest.mark.django_db
class TestS3IntegrationUploads:
@patch("tasks.jobs.integrations.S3")
def test_get_s3_client_from_integration_success(self, mock_s3_class):
mock_integration = MagicMock()
mock_integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
mock_integration.configuration = {
"bucket_name": "test-bucket",
"output_directory": "test-prefix",
}
mock_s3 = MagicMock()
mock_connection = MagicMock()
mock_connection.is_connected = True
mock_s3.test_connection.return_value = mock_connection
mock_s3_class.return_value = mock_s3
connected, s3 = get_s3_client_from_integration(mock_integration)
assert connected is True
assert s3 == mock_s3
mock_s3_class.assert_called_once_with(
**mock_integration.credentials,
bucket_name="test-bucket",
output_directory="test-prefix",
)
mock_s3.test_connection.assert_called_once_with(
**mock_integration.credentials,
bucket_name="test-bucket",
)
@patch("tasks.jobs.integrations.S3")
def test_get_s3_client_from_integration_failure(self, mock_s3_class):
mock_integration = MagicMock()
mock_integration.credentials = {}
mock_integration.configuration = {
"bucket_name": "test-bucket",
"output_directory": "test-prefix",
}
from prowler.providers.common.models import Connection
mock_connection = Connection()
mock_connection.is_connected = False
mock_connection.error = Exception("test error")
mock_s3 = MagicMock()
mock_s3.test_connection.return_value = mock_connection
mock_s3_class.return_value = mock_s3
connected, connection = get_s3_client_from_integration(mock_integration)
assert connected is False
assert isinstance(connection, Connection)
assert str(connection.error) == "test error"
@patch("tasks.jobs.integrations.GenericCompliance")
@patch("tasks.jobs.integrations.ASFF")
@patch("tasks.jobs.integrations.OCSF")
@patch("tasks.jobs.integrations.HTML")
@patch("tasks.jobs.integrations.CSV")
@patch("tasks.jobs.integrations.glob")
@patch("tasks.jobs.integrations.get_s3_client_from_integration")
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
def test_upload_s3_integration_uploads_serialized_outputs(
self,
mock_integration_model,
mock_rls,
mock_get_s3,
mock_glob,
mock_csv,
mock_html,
mock_ocsf,
mock_asff,
mock_compliance,
):
tenant_id = "tenant-id"
provider_id = "provider-id"
integration = MagicMock()
integration.id = "i-1"
integration.configuration = {
"bucket_name": "bucket",
"output_directory": "prefix",
}
mock_integration_model.objects.filter.return_value = [integration]
mock_s3 = MagicMock()
mock_get_s3.return_value = (True, mock_s3)
# Mock the output classes to return mock instances
mock_csv_instance = MagicMock()
mock_html_instance = MagicMock()
mock_ocsf_instance = MagicMock()
mock_asff_instance = MagicMock()
mock_compliance_instance = MagicMock()
mock_csv.return_value = mock_csv_instance
mock_html.return_value = mock_html_instance
mock_ocsf.return_value = mock_ocsf_instance
mock_asff.return_value = mock_asff_instance
mock_compliance.return_value = mock_compliance_instance
# Mock glob to return test files
output_directory = "/tmp/prowler_output/scan123"
mock_glob.side_effect = [
["/tmp/prowler_output/scan123.csv"],
["/tmp/prowler_output/scan123.html"],
["/tmp/prowler_output/scan123.ocsf.json"],
["/tmp/prowler_output/scan123.asff.json"],
["/tmp/prowler_output/compliance/compliance.csv"],
]
with patch("os.path.exists", return_value=True):
with patch("os.getenv", return_value="/tmp/prowler_api_output"):
result = upload_s3_integration(tenant_id, provider_id, output_directory)
assert result is True
mock_s3.send_to_bucket.assert_called_once()
@patch("tasks.jobs.integrations.get_s3_client_from_integration")
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.logger")
def test_upload_s3_integration_fails_connection_logs_error(
self, mock_logger, mock_integration_model, mock_rls, mock_get_s3
):
tenant_id = "tenant-id"
provider_id = "provider-id"
integration = MagicMock()
integration.id = "i-1"
integration.connected = True
mock_s3_client = MagicMock()
mock_s3_client.error = "Connection failed"
mock_integration_model.objects.filter.return_value = [integration]
mock_get_s3.return_value = (False, mock_s3_client)
output_directory = "/tmp/prowler_output/scan123"
result = upload_s3_integration(tenant_id, provider_id, output_directory)
assert result is False
integration.save.assert_called_once()
assert integration.connected is False
mock_logger.error.assert_any_call(
"S3 upload failed, connection failed for integration i-1: Connection failed"
)
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.logger")
def test_upload_s3_integration_logs_if_no_integrations(
self, mock_logger, mock_integration_model, mock_rls
):
mock_integration_model.objects.filter.return_value = []
output_directory = "/tmp/prowler_output/scan123"
result = upload_s3_integration("tenant", "provider", output_directory)
assert result is False
mock_logger.error.assert_called_once_with(
"No S3 integrations found for provider provider"
)
@patch(
"tasks.jobs.integrations.get_s3_client_from_integration",
side_effect=Exception("failed"),
)
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.logger")
def test_upload_s3_integration_logs_connection_exception_and_continues(
self, mock_logger, mock_integration_model, mock_rls, mock_get_s3
):
tenant_id = "tenant-id"
provider_id = "provider-id"
integration = MagicMock()
integration.id = "i-1"
integration.configuration = {
"bucket_name": "bucket",
"output_directory": "prefix",
}
mock_integration_model.objects.filter.return_value = [integration]
output_directory = "/tmp/prowler_output/scan123"
result = upload_s3_integration(tenant_id, provider_id, output_directory)
assert result is False
mock_logger.info.assert_any_call(
"S3 connection failed for integration i-1: failed"
)
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration.objects.filter")
def test_upload_s3_integration_filters_enabled_only(
self, mock_integration_filter, mock_rls
):
"""Test that upload_s3_integration only processes enabled integrations."""
tenant_id = "tenant-id"
provider_id = "provider-id"
output_directory = "/tmp/prowler_output/scan123"
# Mock that no enabled integrations are found
mock_integration_filter.return_value = []
mock_rls.return_value.__enter__.return_value = None
result = upload_s3_integration(tenant_id, provider_id, output_directory)
assert result is False
# Verify the filter includes the correct parameters including enabled=True
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
def test_s3_integration_validates_and_normalizes_output_directory(self):
"""Test that S3 integration validation normalizes output_directory paths."""
from api.models import Integration
from api.v1.serializers import BaseWriteIntegrationSerializer
integration_type = Integration.IntegrationChoices.AMAZON_S3
providers = []
configuration = {
"bucket_name": "test-bucket",
"output_directory": "///////test", # This should be normalized
}
credentials = {
"aws_access_key_id": "AKIATEST",
"aws_secret_access_key": "secret123",
}
# Should not raise an exception and should normalize the path
BaseWriteIntegrationSerializer.validate_integration_data(
integration_type, providers, configuration, credentials
)
# Verify that the path was normalized
assert configuration["output_directory"] == "test"
def test_s3_integration_rejects_invalid_output_directory_characters(self):
"""Test that S3 integration validation rejects invalid characters."""
from rest_framework.exceptions import ValidationError
from api.models import Integration
from api.v1.serializers import BaseWriteIntegrationSerializer
integration_type = Integration.IntegrationChoices.AMAZON_S3
providers = []
configuration = {
"bucket_name": "test-bucket",
"output_directory": "test<invalid", # Contains invalid character
}
credentials = {
"aws_access_key_id": "AKIATEST",
"aws_secret_access_key": "secret123",
}
with pytest.raises(ValidationError) as exc_info:
BaseWriteIntegrationSerializer.validate_integration_data(
integration_type, providers, configuration, credentials
)
# Should contain validation error about invalid characters
assert "Output directory contains invalid characters" in str(exc_info.value)
def test_s3_integration_rejects_empty_output_directory(self):
"""Test that S3 integration validation rejects empty directories."""
from rest_framework.exceptions import ValidationError
from api.models import Integration
from api.v1.serializers import BaseWriteIntegrationSerializer
integration_type = Integration.IntegrationChoices.AMAZON_S3
providers = []
configuration = {
"bucket_name": "test-bucket",
"output_directory": "/////", # This becomes empty after normalization
}
credentials = {
"aws_access_key_id": "AKIATEST",
"aws_secret_access_key": "secret123",
}
with pytest.raises(ValidationError) as exc_info:
BaseWriteIntegrationSerializer.validate_integration_data(
integration_type, providers, configuration, credentials
)
# Should contain validation error about empty directory
assert "Output directory cannot be empty" in str(exc_info.value)
def test_s3_integration_normalizes_complex_paths(self):
"""Test that S3 integration validation handles complex path normalization."""
from api.models import Integration
from api.v1.serializers import BaseWriteIntegrationSerializer
integration_type = Integration.IntegrationChoices.AMAZON_S3
providers = []
configuration = {
"bucket_name": "test-bucket",
"output_directory": "//test//folder///subfolder//",
}
credentials = {
"aws_access_key_id": "AKIATEST",
"aws_secret_access_key": "secret123",
}
BaseWriteIntegrationSerializer.validate_integration_data(
integration_type, providers, configuration, credentials
)
# Verify complex path normalization
assert configuration["output_directory"] == "test/folder/subfolder"
@patch("tasks.jobs.integrations.S3")
def test_s3_client_uses_output_directory_in_object_paths(self, mock_s3_class):
"""Test that S3 client uses output_directory correctly when generating object paths."""
mock_integration = MagicMock()
mock_integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
mock_integration.configuration = {
"bucket_name": "test-bucket",
"output_directory": "my-custom-prefix/scan-results",
}
mock_s3_instance = MagicMock()
mock_connection = MagicMock()
mock_connection.is_connected = True
mock_s3_instance.test_connection.return_value = mock_connection
mock_s3_class.return_value = mock_s3_instance
connected, s3 = get_s3_client_from_integration(mock_integration)
assert connected is True
# Verify S3 was initialized with the correct output_directory
mock_s3_class.assert_called_once_with(
**mock_integration.credentials,
bucket_name="test-bucket",
output_directory="my-custom-prefix/scan-results",
)
@pytest.mark.django_db
class TestProwlerIntegrationConnectionTest:
@patch("api.utils.S3")
def test_s3_integration_connection_success(self, mock_s3_class):
"""Test successful S3 integration connection."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
integration.configuration = {"bucket_name": "test-bucket"}
mock_connection = Connection(is_connected=True)
mock_s3_class.test_connection.return_value = mock_connection
result = prowler_integration_connection_test(integration)
assert result.is_connected is True
mock_s3_class.test_connection.assert_called_once_with(
**integration.credentials,
bucket_name="test-bucket",
raise_on_exception=False,
)
@patch("api.utils.S3")
def test_aws_provider_exception_handling(self, mock_s3_class):
"""Test S3 connection exception is properly caught and returned."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
integration.credentials = {
"aws_access_key_id": "invalid",
"aws_secret_access_key": "credentials",
}
integration.configuration = {"bucket_name": "test-bucket"}
test_exception = Exception("Invalid credentials")
mock_connection = Connection(is_connected=False, error=test_exception)
mock_s3_class.test_connection.return_value = mock_connection
result = prowler_integration_connection_test(integration)
assert result.is_connected is False
assert result.error == test_exception
mock_s3_class.test_connection.assert_called_once_with(
aws_access_key_id="invalid",
aws_secret_access_key="credentials",
bucket_name="test-bucket",
raise_on_exception=False,
)
@patch("api.utils.AwsProvider")
@patch("api.utils.S3")
def test_s3_integration_connection_failure(self, mock_s3_class, mock_aws_provider):
"""Test S3 integration connection failure."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
integration.configuration = {"bucket_name": "test-bucket"}
mock_session = MagicMock()
mock_aws_provider.return_value.session.current_session = mock_session
mock_connection = Connection(
is_connected=False, error=Exception("Bucket not found")
)
mock_s3_class.test_connection.return_value = mock_connection
result = prowler_integration_connection_test(integration)
assert result.is_connected is False
assert str(result.error) == "Bucket not found"
@patch("api.utils.AwsProvider")
@patch("api.utils.S3")
def test_aws_security_hub_integration_connection(
self, mock_s3_class, mock_aws_provider
):
"""Test AWS Security Hub integration only validates AWS session."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AWS_SECURITY_HUB
integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
integration.configuration = {"region": "us-east-1"}
mock_session = MagicMock()
mock_aws_provider.return_value.session.current_session = mock_session
# For AWS Security Hub, the function should return early after AWS session validation
result = prowler_integration_connection_test(integration)
# The function should not reach S3 test_connection for AWS_SECURITY_HUB
mock_s3_class.test_connection.assert_not_called()
# Since no exception was raised during AWS session creation, return None (success)
assert result is None
def test_unsupported_integration_type(self):
"""Test unsupported integration type raises ValueError."""
integration = MagicMock()
integration.integration_type = "UNSUPPORTED_TYPE"
integration.credentials = {}
integration.configuration = {}
with pytest.raises(
ValueError, match="Integration type UNSUPPORTED_TYPE not supported"
):
prowler_integration_connection_test(integration)
File diff suppressed because it is too large Load Diff
+17 -202
View File
@@ -1,18 +1,11 @@
import uuid
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from tasks.tasks import (
_perform_scan_complete_tasks,
check_integrations_task,
generate_outputs_task,
s3_integration_task,
)
from api.models import Integration
from tasks.tasks import generate_outputs
# TODO Move this to outputs/reports jobs
@pytest.mark.django_db
class TestGenerateOutputs:
def setup_method(self):
@@ -24,7 +17,7 @@ class TestGenerateOutputs:
with patch("tasks.tasks.ScanSummary.objects.filter") as mock_filter:
mock_filter.return_value.exists.return_value = False
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -33,6 +26,7 @@ class TestGenerateOutputs:
assert result == {"upload": False}
mock_filter.assert_called_once_with(scan_id=self.scan_id)
@patch("tasks.tasks.rmtree")
@patch("tasks.tasks._upload_to_s3")
@patch("tasks.tasks._compress_output_files")
@patch("tasks.tasks.get_compliance_frameworks")
@@ -51,6 +45,7 @@ class TestGenerateOutputs:
mock_get_available_frameworks,
mock_compress,
mock_upload,
mock_rmtree,
):
mock_scan_summary_filter.return_value.exists.return_value = True
@@ -100,12 +95,11 @@ class TestGenerateOutputs:
return_value=("out-dir", "comp-dir"),
),
patch("tasks.tasks.Scan.all_objects.filter") as mock_scan_update,
patch("tasks.tasks.rmtree"),
):
mock_compress.return_value = "/tmp/zipped.zip"
mock_upload.return_value = "s3://bucket/zipped.zip"
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -115,6 +109,9 @@ class TestGenerateOutputs:
mock_scan_update.return_value.update.assert_called_once_with(
output_location="s3://bucket/zipped.zip"
)
mock_rmtree.assert_called_once_with(
Path("/tmp/zipped.zip").parent, ignore_errors=True
)
def test_generate_outputs_fails_upload(self):
with (
@@ -146,7 +143,6 @@ class TestGenerateOutputs:
patch("tasks.tasks._compress_output_files", return_value="/tmp/compressed"),
patch("tasks.tasks._upload_to_s3", return_value=None),
patch("tasks.tasks.Scan.all_objects.filter") as mock_scan_update,
patch("tasks.tasks.rmtree"),
):
mock_filter.return_value.exists.return_value = True
mock_findings.return_value.order_by.return_value.iterator.return_value = [
@@ -154,9 +150,9 @@ class TestGenerateOutputs:
True,
]
result = generate_outputs_task(
result = generate_outputs(
scan_id="scan",
provider_id=self.provider_id,
provider_id="provider",
tenant_id=self.tenant_id,
)
@@ -188,7 +184,6 @@ class TestGenerateOutputs:
patch("tasks.tasks._compress_output_files", return_value="/tmp/compressed"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/f.zip"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch("tasks.tasks.rmtree"),
):
mock_filter.return_value.exists.return_value = True
mock_findings.return_value.order_by.return_value.iterator.return_value = [
@@ -213,7 +208,7 @@ class TestGenerateOutputs:
{"aws": [(lambda x: True, MagicMock())]},
),
):
generate_outputs_task(
generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -259,8 +254,8 @@ class TestGenerateOutputs:
),
patch("tasks.tasks._compress_output_files", return_value="outdir.zip"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/outdir.zip"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch("tasks.tasks.rmtree"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch(
"tasks.tasks.batched",
return_value=[
@@ -281,7 +276,7 @@ class TestGenerateOutputs:
}
},
):
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -337,13 +332,13 @@ class TestGenerateOutputs:
),
patch("tasks.tasks._compress_output_files", return_value="outdir.zip"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/outdir.zip"),
patch("tasks.tasks.rmtree"),
patch(
"tasks.tasks.Scan.all_objects.filter",
return_value=MagicMock(update=lambda **kw: None),
),
patch("tasks.tasks.batched", return_value=two_batches),
patch("tasks.tasks.OUTPUT_FORMATS_MAPPING", {}),
patch("tasks.tasks.rmtree"),
patch(
"tasks.tasks.COMPLIANCE_CLASS_MAP",
{"aws": [(lambda name: True, TrackingComplianceWriter)]},
@@ -351,7 +346,7 @@ class TestGenerateOutputs:
):
mock_summary.return_value.exists.return_value = True
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -362,7 +357,6 @@ class TestGenerateOutputs:
assert writer.transform_calls == [([raw2], compliance_obj, "cis")]
assert result == {"upload": True}
# TODO: We need to add a periodic task to delete old output files
def test_generate_outputs_logs_rmtree_exception(self, caplog):
mock_finding_output = MagicMock()
mock_finding_output.compliance = {"cis": ["requirement-1", "requirement-2"]}
@@ -413,188 +407,9 @@ class TestGenerateOutputs:
),
):
with caplog.at_level("ERROR"):
generate_outputs_task(
generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
assert "Error deleting output files" in caplog.text
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_generate_outputs_filters_enabled_s3_integrations(
self, mock_integration_filter, mock_rls
):
"""Test that generate_outputs_task only processes enabled S3 integrations."""
with (
patch("tasks.tasks.ScanSummary.objects.filter") as mock_summary,
patch("tasks.tasks.Provider.objects.get"),
patch("tasks.tasks.initialize_prowler_provider"),
patch("tasks.tasks.Compliance.get_bulk"),
patch("tasks.tasks.get_compliance_frameworks", return_value=[]),
patch("tasks.tasks.Finding.all_objects.filter") as mock_findings,
patch(
"tasks.tasks._generate_output_directory", return_value=("out", "comp")
),
patch("tasks.tasks.FindingOutput._transform_findings_stats"),
patch("tasks.tasks.FindingOutput.transform_api_finding"),
patch("tasks.tasks._compress_output_files", return_value="/tmp/compressed"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/file.zip"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch("tasks.tasks.rmtree"),
patch("tasks.tasks.s3_integration_task.apply_async") as mock_s3_task,
):
mock_summary.return_value.exists.return_value = True
mock_findings.return_value.order_by.return_value.iterator.return_value = [
[MagicMock()],
True,
]
mock_integration_filter.return_value = [MagicMock()]
mock_rls.return_value.__enter__.return_value = None
with (
patch("tasks.tasks.OUTPUT_FORMATS_MAPPING", {}),
patch("tasks.tasks.COMPLIANCE_CLASS_MAP", {"aws": []}),
):
generate_outputs_task(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
# Verify the S3 integrations filters
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
mock_s3_task.assert_called_once()
class TestScanCompleteTasks:
@patch("tasks.tasks.create_compliance_requirements_task.apply_async")
@patch("tasks.tasks.perform_scan_summary_task.si")
@patch("tasks.tasks.generate_outputs_task.si")
def test_scan_complete_tasks(
self, mock_outputs_task, mock_scan_summary_task, mock_compliance_tasks
):
_perform_scan_complete_tasks("tenant-id", "scan-id", "provider-id")
mock_compliance_tasks.assert_called_once_with(
kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"},
)
mock_scan_summary_task.assert_called_once_with(
scan_id="scan-id",
tenant_id="tenant-id",
)
mock_outputs_task.assert_called_once_with(
scan_id="scan-id",
provider_id="provider-id",
tenant_id="tenant-id",
)
@pytest.mark.django_db
class TestCheckIntegrationsTask:
def setup_method(self):
self.scan_id = str(uuid.uuid4())
self.provider_id = str(uuid.uuid4())
self.tenant_id = str(uuid.uuid4())
self.output_directory = "/tmp/some-output-dir"
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_check_integrations_no_integrations(
self, mock_integration_filter, mock_rls
):
mock_integration_filter.return_value.exists.return_value = False
# Ensure rls_transaction is mocked
mock_rls.return_value.__enter__.return_value = None
result = check_integrations_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
)
assert result == {"integrations_processed": 0}
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id,
enabled=True,
)
@patch("tasks.tasks.group")
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_check_integrations_s3_success(
self, mock_integration_filter, mock_rls, mock_group
):
# Mock that we have some integrations
mock_integration_filter.return_value.exists.return_value = True
# Ensure rls_transaction is mocked
mock_rls.return_value.__enter__.return_value = None
# Since the current implementation doesn't actually create tasks yet (TODO comment),
# we test that no tasks are created but the function returns the correct count
result = check_integrations_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
)
assert result == {"integrations_processed": 0}
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id,
enabled=True,
)
# group should not be called since no integration tasks are created yet
mock_group.assert_not_called()
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_check_integrations_disabled_integrations_ignored(
self, mock_integration_filter, mock_rls
):
"""Test that disabled integrations are not processed."""
mock_integration_filter.return_value.exists.return_value = False
mock_rls.return_value.__enter__.return_value = None
result = check_integrations_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
)
assert result == {"integrations_processed": 0}
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id,
enabled=True,
)
@patch("tasks.tasks.upload_s3_integration")
def test_s3_integration_task_success(self, mock_upload):
mock_upload.return_value = True
output_directory = "/tmp/prowler_api_output/test"
result = s3_integration_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
output_directory=output_directory,
)
assert result is True
mock_upload.assert_called_once_with(
self.tenant_id, self.provider_id, output_directory
)
@patch("tasks.tasks.upload_s3_integration")
def test_s3_integration_task_failure(self, mock_upload):
mock_upload.return_value = False
output_directory = "/tmp/prowler_api_output/test"
result = s3_integration_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
output_directory=output_directory,
)
assert result is False
mock_upload.assert_called_once_with(
self.tenant_id, self.provider_id, output_directory
)
@@ -1,234 +0,0 @@
from locust import events, task
from utils.config import (
L_PROVIDER_NAME,
M_PROVIDER_NAME,
RESOURCES_UI_FIELDS,
S_PROVIDER_NAME,
TARGET_INSERTED_AT,
)
from utils.helpers import (
APIUserBase,
get_api_token,
get_auth_headers,
get_dynamic_filters_pairs,
get_next_resource_filter,
get_scan_id_from_provider_name,
)
GLOBAL = {
"token": None,
"scan_ids": {},
"resource_filters": None,
"large_resource_filters": None,
}
@events.test_start.add_listener
def on_test_start(environment, **kwargs):
GLOBAL["token"] = get_api_token(environment.host)
GLOBAL["scan_ids"]["small"] = get_scan_id_from_provider_name(
environment.host, GLOBAL["token"], S_PROVIDER_NAME
)
GLOBAL["scan_ids"]["medium"] = get_scan_id_from_provider_name(
environment.host, GLOBAL["token"], M_PROVIDER_NAME
)
GLOBAL["scan_ids"]["large"] = get_scan_id_from_provider_name(
environment.host, GLOBAL["token"], L_PROVIDER_NAME
)
GLOBAL["resource_filters"] = get_dynamic_filters_pairs(
environment.host, GLOBAL["token"], "resources"
)
GLOBAL["large_resource_filters"] = get_dynamic_filters_pairs(
environment.host, GLOBAL["token"], "resources", GLOBAL["scan_ids"]["large"]
)
class APIUser(APIUserBase):
def on_start(self):
self.token = GLOBAL["token"]
self.s_scan_id = GLOBAL["scan_ids"]["small"]
self.m_scan_id = GLOBAL["scan_ids"]["medium"]
self.l_scan_id = GLOBAL["scan_ids"]["large"]
self.available_resource_filters = GLOBAL["resource_filters"]
self.available_resource_filters_large_scan = GLOBAL["large_resource_filters"]
@task
def resources_default(self):
name = "/resources"
page_number = self._next_page(name)
endpoint = (
f"/resources?page[number]={page_number}"
f"&filter[updated_at]={TARGET_INSERTED_AT}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(3)
def resources_default_ui_fields(self):
name = "/resources?fields"
page_number = self._next_page(name)
endpoint = (
f"/resources?page[number]={page_number}"
f"&fields[resources]={','.join(RESOURCES_UI_FIELDS)}"
f"&filter[updated_at]={TARGET_INSERTED_AT}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(3)
def resources_default_include(self):
name = "/resources?include"
page = self._next_page(name)
endpoint = (
f"/resources?page[number]={page}"
f"&filter[updated_at]={TARGET_INSERTED_AT}"
f"&include=provider"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(3)
def resources_metadata(self):
name = "/resources/metadata"
endpoint = f"/resources/metadata?filter[updated_at]={TARGET_INSERTED_AT}"
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task
def resources_scan_small(self):
name = "/resources?filter[scan_id] - 50k"
page_number = self._next_page(name)
endpoint = (
f"/resources?page[number]={page_number}" f"&filter[scan]={self.s_scan_id}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task
def resources_metadata_scan_small(self):
name = "/resources/metadata?filter[scan_id] - 50k"
endpoint = f"/resources/metadata?&filter[scan]={self.s_scan_id}"
self.client.get(
endpoint,
headers=get_auth_headers(self.token),
name=name,
)
@task(2)
def resources_scan_medium(self):
name = "/resources?filter[scan_id] - 250k"
page_number = self._next_page(name)
endpoint = (
f"/resources?page[number]={page_number}" f"&filter[scan]={self.m_scan_id}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task
def resources_metadata_scan_medium(self):
name = "/resources/metadata?filter[scan_id] - 250k"
endpoint = f"/resources/metadata?&filter[scan]={self.m_scan_id}"
self.client.get(
endpoint,
headers=get_auth_headers(self.token),
name=name,
)
@task
def resources_scan_large(self):
name = "/resources?filter[scan_id] - 500k"
page_number = self._next_page(name)
endpoint = (
f"/resources?page[number]={page_number}" f"&filter[scan]={self.l_scan_id}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task
def resources_scan_large_include(self):
name = "/resources?filter[scan_id]&include - 500k"
page_number = self._next_page(name)
endpoint = (
f"/resources?page[number]={page_number}"
f"&filter[scan]={self.l_scan_id}"
f"&include=provider"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task
def resources_metadata_scan_large(self):
endpoint = f"/resources/metadata?&filter[scan]={self.l_scan_id}"
self.client.get(
endpoint,
headers=get_auth_headers(self.token),
name="/resources/metadata?filter[scan_id] - 500k",
)
@task(2)
def resources_filters(self):
name = "/resources?filter[resource_filter]&include"
filter_name, filter_value = get_next_resource_filter(
self.available_resource_filters
)
endpoint = (
f"/resources?filter[{filter_name}]={filter_value}"
f"&filter[updated_at]={TARGET_INSERTED_AT}"
f"&include=provider"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(3)
def resources_metadata_filters(self):
name = "/resources/metadata?filter[resource_filter]"
filter_name, filter_value = get_next_resource_filter(
self.available_resource_filters
)
endpoint = (
f"/resources/metadata?filter[{filter_name}]={filter_value}"
f"&filter[updated_at]={TARGET_INSERTED_AT}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(3)
def resources_metadata_filters_scan_large(self):
name = "/resources/metadata?filter[resource_filter]&filter[scan_id] - 500k"
filter_name, filter_value = get_next_resource_filter(
self.available_resource_filters
)
endpoint = (
f"/resources/metadata?filter[{filter_name}]={filter_value}"
f"&filter[scan]={self.l_scan_id}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(2)
def resourcess_filter_large_scan_include(self):
name = "/resources?filter[resource_filter][scan]&include - 500k"
filter_name, filter_value = get_next_resource_filter(
self.available_resource_filters
)
endpoint = (
f"/resources?filter[{filter_name}]={filter_value}"
f"&filter[scan]={self.l_scan_id}"
f"&include=provider"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(3)
def resources_latest_default_ui_fields(self):
name = "/resources/latest?fields"
page_number = self._next_page(name)
endpoint = (
f"/resources/latest?page[number]={page_number}"
f"&fields[resources]={','.join(RESOURCES_UI_FIELDS)}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(3)
def resources_latest_metadata_filters(self):
name = "/resources/metadata/latest?filter[resource_filter]"
filter_name, filter_value = get_next_resource_filter(
self.available_resource_filters
)
endpoint = f"/resources/metadata/latest?filter[{filter_name}]={filter_value}"
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
-17
View File
@@ -13,23 +13,6 @@ FINDINGS_RESOURCE_METADATA = {
"resource_types": "resource_type",
"services": "service",
}
RESOURCE_METADATA = {
"regions": "region",
"types": "type",
"services": "service",
}
RESOURCES_UI_FIELDS = [
"name",
"failed_findings_count",
"region",
"service",
"type",
"provider",
"inserted_at",
"updated_at",
"uid",
]
S_PROVIDER_NAME = "provider-50k"
M_PROVIDER_NAME = "provider-250k"
+6 -16
View File
@@ -7,7 +7,6 @@ from locust import HttpUser, between
from utils.config import (
BASE_HEADERS,
FINDINGS_RESOURCE_METADATA,
RESOURCE_METADATA,
TARGET_INSERTED_AT,
USER_EMAIL,
USER_PASSWORD,
@@ -122,16 +121,13 @@ def get_scan_id_from_provider_name(host: str, token: str, provider_name: str) ->
return response.json()["data"][0]["id"]
def get_dynamic_filters_pairs(
host: str, token: str, endpoint: str, scan_id: str = ""
) -> dict:
def get_resource_filters_pairs(host: str, token: str, scan_id: str = "") -> dict:
"""
Retrieves and maps metadata filter values from a given endpoint.
Retrieves and maps resource metadata filter values from the findings endpoint.
Args:
host (str): The host URL of the API.
token (str): Bearer token for authentication.
endpoint (str): The API endpoint to query for metadata.
scan_id (str, optional): Optional scan ID to filter metadata. Defaults to using inserted_at timestamp.
Returns:
@@ -140,28 +136,22 @@ def get_dynamic_filters_pairs(
Raises:
AssertionError: If the request fails or does not return a 200 status code.
"""
metadata_mapping = (
FINDINGS_RESOURCE_METADATA if endpoint == "findings" else RESOURCE_METADATA
)
date_filter = "inserted_at" if endpoint == "findings" else "updated_at"
metadata_filters = (
f"filter[scan]={scan_id}"
if scan_id
else f"filter[{date_filter}]={TARGET_INSERTED_AT}"
else f"filter[inserted_at]={TARGET_INSERTED_AT}"
)
response = requests.get(
f"{host}/{endpoint}/metadata?{metadata_filters}",
headers=get_auth_headers(token),
f"{host}/findings/metadata?{metadata_filters}", headers=get_auth_headers(token)
)
assert (
response.status_code == 200
), f"Failed to get resource filters values: {response.text}"
attributes = response.json()["data"]["attributes"]
return {
metadata_mapping[key]: values
FINDINGS_RESOURCE_METADATA[key]: values
for key, values in attributes.items()
if key in metadata_mapping.keys()
if key in FINDINGS_RESOURCE_METADATA.keys()
}
+4 -5
View File
@@ -23,7 +23,6 @@ import argparse
import json
import os
import re
import shlex
import signal
import socket
import subprocess
@@ -146,11 +145,11 @@ def _get_script_arguments():
def _run_prowler(prowler_args):
_debug("Running prowler with args: {0}".format(prowler_args), 1)
_prowler_command = shlex.split(
"{prowler}/prowler {args}".format(prowler=PATH_TO_PROWLER, args=prowler_args)
_prowler_command = "{prowler}/prowler {args}".format(
prowler=PATH_TO_PROWLER, args=prowler_args
)
_debug("Running command: {0}".format(" ".join(_prowler_command)), 2)
_process = subprocess.Popen(_prowler_command, stdout=subprocess.PIPE)
_debug("Running command: {0}".format(_prowler_command), 2)
_process = subprocess.Popen(_prowler_command, stdout=subprocess.PIPE, shell=True)
_output, _error = _process.communicate()
_debug("Raw prowler output: {0}".format(_output), 3)
_debug("Raw prowler error: {0}".format(_error), 3)
-25
View File
@@ -1,25 +0,0 @@
import warnings
from dashboard.common_methods import get_section_containers_cis
warnings.filterwarnings("ignore")
def get_table(data):
aux = data[
[
"REQUIREMENTS_ID",
"REQUIREMENTS_DESCRIPTION",
"REQUIREMENTS_ATTRIBUTES_SECTION",
"CHECKID",
"STATUS",
"REGION",
"ACCOUNTID",
"RESOURCEID",
]
].copy()
return get_section_containers_cis(
aux, "REQUIREMENTS_ID", "REQUIREMENTS_ATTRIBUTES_SECTION"
)
+1 -4
View File
@@ -4,10 +4,7 @@ from dash import html
def create_provider_card(
provider: str,
provider_logo: str,
account_type: str,
filtered_data,
provider: str, provider_logo: str, account_type: str, filtered_data
) -> List[html.Div]:
"""
Card to display the provider's name and icon.
-25
View File
@@ -245,31 +245,6 @@ def create_service_dropdown(services: list) -> html.Div:
)
def create_provider_dropdown(providers: list) -> html.Div:
"""
Dropdown to select the provider.
Args:
providers (list): List of providers.
Returns:
html.Div: Dropdown to select the provider.
"""
return html.Div(
[
html.Label(
"Provider:", className="text-prowler-stone-900 font-bold text-sm"
),
dcc.Dropdown(
id="provider-filter",
options=[{"label": i, "value": i} for i in providers],
value=["All"],
clearable=False,
multi=True,
style={"color": "#000000"},
),
],
)
def create_status_dropdown(status: list) -> html.Div:
"""
Dropdown to select the status.
+2 -5
View File
@@ -9,11 +9,9 @@ def create_layout_overview(
download_button_xlsx: html.Button,
severity_dropdown: html.Div,
service_dropdown: html.Div,
provider_dropdown: html.Div,
table_row_dropdown: html.Div,
status_dropdown: html.Div,
table_div_header: html.Div,
amount_providers: int,
) -> html.Div:
"""
Create the layout of the dashboard.
@@ -49,10 +47,9 @@ def create_layout_overview(
[
html.Div([severity_dropdown], className=""),
html.Div([service_dropdown], className=""),
html.Div([provider_dropdown], className=""),
html.Div([status_dropdown], className=""),
],
className="grid gap-x-4 mb-[30px] sm:grid-cols-2 lg:grid-cols-4",
className="grid gap-x-4 mb-[30px] sm:grid-cols-2 lg:grid-cols-3",
),
html.Div(
[
@@ -62,7 +59,7 @@ def create_layout_overview(
html.Div(className="flex", id="k8s_card", n_clicks=0),
html.Div(className="flex", id="m365_card", n_clicks=0),
],
className=f"grid gap-x-4 mb-[30px] sm:grid-cols-2 lg:grid-cols-{amount_providers}",
className="grid gap-x-4 mb-[30px] sm:grid-cols-2 lg:grid-cols-5",
),
html.H4(
"Count of Findings by severity",
+20 -13
View File
@@ -346,27 +346,34 @@ def display_data(
if item == "nan" or item.__class__.__name__ != "str":
region_filter_options.remove(item)
# Convert ASSESSMENTDATE to datetime
data["ASSESSMENTDATE"] = pd.to_datetime(data["ASSESSMENTDATE"], errors="coerce")
data["ASSESSMENTDAY"] = data["ASSESSMENTDATE"].dt.date
data["ASSESSMENTDATE"] = data["ASSESSMENTDATE"].dt.strftime("%Y-%m-%d %H:%M:%S")
# Find the latest timestamp per account per day
latest_per_account_day = data.groupby(["ACCOUNTID", "ASSESSMENTDAY"])[
"ASSESSMENTDATE"
].transform("max")
# Choosing the date that is the most recent
data_values = data["ASSESSMENTDATE"].unique()
data_values.sort()
data_values = data_values[::-1]
aux = []
# Keep only rows with the latest timestamp for each account and day
data = data[data["ASSESSMENTDATE"] == latest_per_account_day]
data_values = [str(i) for i in data_values]
for value in data_values:
if value.split(" ")[0] not in [aux[i].split(" ")[0] for i in range(len(aux))]:
aux.append(value)
data_values = [str(i) for i in aux]
# Prepare the date filter options (unique days, as strings)
options_date = sorted(data["ASSESSMENTDAY"].astype(str).unique(), reverse=True)
data = data[data["ASSESSMENTDATE"].isin(data_values)]
data["ASSESSMENTDATE"] = data["ASSESSMENTDATE"].apply(lambda x: x.split(" ")[0])
# Filter by selected date (as string)
options_date = data["ASSESSMENTDATE"].unique()
options_date.sort()
options_date = options_date[::-1]
# Filter DATE
if date_filter_analytics in options_date:
data = data[data["ASSESSMENTDAY"].astype(str) == date_filter_analytics]
data = data[data["ASSESSMENTDATE"] == date_filter_analytics]
else:
date_filter_analytics = options_date[0]
data = data[data["ASSESSMENTDAY"].astype(str) == date_filter_analytics]
data = data[data["ASSESSMENTDATE"] == date_filter_analytics]
if data.empty:
fig = px.pie()
+39 -88
View File
@@ -1,4 +1,5 @@
# Standard library imports
import csv
import glob
import json
import os
@@ -19,6 +20,7 @@ from dash.dependencies import Input, Output
# Config import
from dashboard.config import (
critical_color,
encoding_format,
fail_color,
folder_path_overview,
high_color,
@@ -36,7 +38,6 @@ from dashboard.lib.cards import create_provider_card
from dashboard.lib.dropdowns import (
create_account_dropdown,
create_date_dropdown,
create_provider_dropdown,
create_region_dropdown,
create_service_dropdown,
create_severity_dropdown,
@@ -44,7 +45,6 @@ from dashboard.lib.dropdowns import (
create_table_row_dropdown,
)
from dashboard.lib.layouts import create_layout_overview
from prowler.lib.logger import logger
# Suppress warnings
warnings.filterwarnings("ignore")
@@ -54,13 +54,11 @@ warnings.filterwarnings("ignore")
csv_files = []
for file in glob.glob(os.path.join(folder_path_overview, "*.csv")):
try:
df = pd.read_csv(file, sep=";")
num_rows = len(df)
with open(file, "r", newline="", encoding=encoding_format) as csvfile:
reader = csv.reader(csvfile)
num_rows = sum(1 for row in reader)
if num_rows > 1:
csv_files.append(file)
except Exception:
logger.error(f"Error reading file {file}")
# Import logos providers
@@ -192,13 +190,7 @@ else:
data.rename(columns={"RESOURCE_ID": "RESOURCE_UID"}, inplace=True)
# Remove dupplicates on the finding_uid colummn but keep the last one taking into account the timestamp
data["DATE"] = data["TIMESTAMP"].dt.date
data = (
data.sort_values("TIMESTAMP")
.groupby(["DATE", "FINDING_UID"], as_index=False)
.last()
)
data["TIMESTAMP"] = pd.to_datetime(data["TIMESTAMP"])
data = data.sort_values("TIMESTAMP").drop_duplicates("FINDING_UID", keep="last")
data["ASSESSMENT_TIME"] = data["TIMESTAMP"].dt.strftime("%Y-%m-%d")
data_valid = pd.DataFrame()
@@ -306,13 +298,6 @@ else:
service_dropdown = create_service_dropdown(services)
# Provider Dropdown
providers = ["All"] + list(data["PROVIDER"].unique())
providers = [
x for x in providers if str(x) != "nan" and x.__class__.__name__ == "str"
]
provider_dropdown = create_provider_dropdown(providers)
# Create the download button
download_button_csv = html.Button(
"Download this table as CSV",
@@ -494,11 +479,9 @@ else:
download_button_xlsx,
severity_dropdown,
service_dropdown,
provider_dropdown,
table_row_dropdown,
status_dropdown,
table_div_header,
len(data["PROVIDER"].unique()),
)
@@ -525,8 +508,6 @@ else:
Output("severity-filter", "value"),
Output("severity-filter", "options"),
Output("service-filter", "value"),
Output("provider-filter", "value"),
Output("provider-filter", "options"),
Output("service-filter", "options"),
Output("table-rows", "value"),
Output("table-rows", "options"),
@@ -545,7 +526,6 @@ else:
Input("download_link_xlsx", "n_clicks"),
Input("severity-filter", "value"),
Input("service-filter", "value"),
Input("provider-filter", "value"),
Input("table-rows", "value"),
Input("status-filter", "value"),
Input("search-input", "value"),
@@ -569,7 +549,6 @@ def filter_data(
n_clicks_xlsx,
severity_values,
service_values,
provider_values,
table_row_values,
status_values,
search_value,
@@ -895,25 +874,6 @@ def filter_data(
filtered_data["SERVICE_NAME"].isin(updated_service_values)
]
provider_filter_options = ["All"] + list(filtered_data["PROVIDER"].unique())
# Filter Provider
if provider_values == ["All"]:
updated_provider_values = filtered_data["PROVIDER"].unique()
elif "All" in provider_values and len(provider_values) > 1:
# Remove 'All' from the list
provider_values.remove("All")
updated_provider_values = provider_values
elif len(provider_values) == 0:
updated_provider_values = filtered_data["PROVIDER"].unique()
provider_values = ["All"]
else:
updated_provider_values = provider_values
filtered_data = filtered_data[
filtered_data["PROVIDER"].isin(updated_provider_values)
]
# Filter Status
if status_values == ["All"]:
updated_status_values = filtered_data["STATUS"].unique()
@@ -1134,17 +1094,25 @@ def filter_data(
table_row_options = []
# Calculate table row options as percentages
percentages = [0.05, 0.10, 0.25, 0.50, 0.75, 1.0]
total_rows = len(filtered_data)
for pct in percentages:
value = max(1, int(total_rows * pct))
label = f"{int(pct * 100)}%"
table_row_options.append({"label": label, "value": value})
# Default to 25% if not set
# Take the values from the table_row_values
if table_row_values is None or table_row_values == -1:
table_row_values = table_row_options[0]["value"]
if len(filtered_data) < 25:
table_row_values = len(filtered_data)
else:
table_row_values = 25
if len(filtered_data) < 25:
table_row_values = len(filtered_data)
if len(filtered_data) >= 25:
table_row_options.append(25)
if len(filtered_data) >= 50:
table_row_options.append(50)
if len(filtered_data) >= 75:
table_row_options.append(75)
if len(filtered_data) >= 100:
table_row_options.append(100)
table_row_options.append(len(filtered_data))
# For the values that are nan or none, replace them with ""
filtered_data = filtered_data.replace({np.nan: ""})
@@ -1379,36 +1347,21 @@ def filter_data(
]
# Create Provider Cards
if "aws" in list(data["PROVIDER"].unique()):
aws_card = create_provider_card(
"aws", aws_provider_logo, "Accounts", full_filtered_data
)
else:
aws_card = None
if "azure" in list(data["PROVIDER"].unique()):
azure_card = create_provider_card(
"azure", azure_provider_logo, "Subscriptions", full_filtered_data
)
else:
azure_card = None
if "gcp" in list(data["PROVIDER"].unique()):
gcp_card = create_provider_card(
"gcp", gcp_provider_logo, "Projects", full_filtered_data
)
else:
gcp_card = None
if "kubernetes" in list(data["PROVIDER"].unique()):
k8s_card = create_provider_card(
"kubernetes", ks8_provider_logo, "Clusters", full_filtered_data
)
else:
k8s_card = None
if "m365" in list(data["PROVIDER"].unique()):
m365_card = create_provider_card(
"m365", m365_provider_logo, "Accounts", full_filtered_data
)
else:
m365_card = None
aws_card = create_provider_card(
"aws", aws_provider_logo, "Accounts", full_filtered_data
)
azure_card = create_provider_card(
"azure", azure_provider_logo, "Subscriptions", full_filtered_data
)
gcp_card = create_provider_card(
"gcp", gcp_provider_logo, "Projects", full_filtered_data
)
k8s_card = create_provider_card(
"kubernetes", ks8_provider_logo, "Clusters", full_filtered_data
)
m365_card = create_provider_card(
"m365", m365_provider_logo, "Accounts", full_filtered_data
)
# Subscribe to Prowler Cloud card
subscribe_card = [
@@ -1492,8 +1445,6 @@ def filter_data(
severity_values,
severity_filter_options,
service_values,
provider_values,
provider_filter_options,
service_filter_options,
table_row_values,
table_row_options,
+2 -6
View File
@@ -16,7 +16,7 @@ services:
volumes:
- "./api/src/backend:/home/prowler/backend"
- "./api/pyproject.toml:/home/prowler/pyproject.toml"
- "outputs:/tmp/prowler_api_output"
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -87,7 +87,7 @@ services:
- path: .env
required: false
volumes:
- "outputs:/tmp/prowler_api_output"
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy
@@ -115,7 +115,3 @@ services:
entrypoint:
- "../docker-entrypoint.sh"
- "beat"
volumes:
outputs:
driver: local
+2 -6
View File
@@ -8,7 +8,7 @@ services:
ports:
- "${DJANGO_PORT:-8080}:${DJANGO_PORT:-8080}"
volumes:
- "output:/tmp/prowler_api_output"
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -68,7 +68,7 @@ services:
- path: .env
required: false
volumes:
- "output:/tmp/prowler_api_output"
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy
@@ -91,7 +91,3 @@ services:
entrypoint:
- "../docker-entrypoint.sh"
- "beat"
volumes:
output:
driver: local
+69 -49
View File
@@ -8,20 +8,15 @@ Checks are the core component of Prowler. A check is a piece of code designed to
### Creating a Check
The most common high level steps to create a new check are:
To create a new check:
1. Prerequisites:
- Verify the check does not already exist by searching [Prowler Hub](https://hub.prowler.com) or checking `prowler/providers/<provider>/services/<service>/<check_name_want_to_implement>/`.
- Ensure required provider and service exist. If not, follow the [Provider](./provider.md) and [Service](./services.md) documentation to create them.
- Confirm the service has implemented all required methods and attributes for the check (in most cases, you will need to add or modify some methods in the service to get the data you need for the check).
2. Navigate to the service directory. The path should be as follows: `prowler/providers/<provider>/services/<service>`.
3. Create a check-specific folder. The path should follow this pattern: `prowler/providers/<provider>/services/<service>/<check_name_want_to_implement>`. Adhere to the [Naming Format for Checks](#naming-format-for-checks).
4. Populate the folder with files as specified in [File Creation](#file-creation).
5. Run the check locally to ensure it works as expected. For checking you can use the CLI in the next way:
- To ensure the check has been detected by Prowler: `poetry run python prowler-cli.py <provider> --list-checks | grep <check_name>`.
- To run the check, to find possible issues: `poetry run python prowler-cli.py <provider> --log-level ERROR --verbose --check <check_name>`.
6. Create comprehensive tests for the check that cover multiple scenarios including both PASS (compliant) and FAIL (non-compliant) cases. For detailed information about test structure and implementation guidelines, refer to the [Testing](./unit-testing.md) documentation.
7. If the check and its corresponding tests are working as expected, you can submit a PR to Prowler.
- Prerequisites: A Prowler provider and service must exist. Verify support and check for pre-existing checks via [Prowler Hub](https://hub.prowler.com). If the provider or service is not present, please refer to the [Provider](./provider.md) and [Service](./services.md) documentation for creation instructions.
- Navigate to the service directory. The path should be as follows: `prowler/providers/<provider>/services/<service>`.
- Create a check-specific folder. The path should follow this pattern: `prowler/providers/<provider>/services/<service>/<check_name>`. Adhere to the [Naming Format for Checks](#naming-format-for-checks).
- Populate the folder with files as specified in [File Creation](#file-creation).
### Naming Format for Checks
@@ -64,19 +59,13 @@ from prowler.providers.<provider>.services.<service>.<service>_client import <se
# Each check must be implemented as a Python class with the same name as its corresponding file.
# The class must inherit from the Check base class.
class <check_name>(Check):
"""
Ensure that <resource> meets <security_requirement>.
This check evaluates whether <specific_condition> to ensure <security_benefit>.
- PASS: <description_of_compliant_state(s)>.
- FAIL: <description_of_non_compliant_state(s)>.
"""
"""Short description of what is being checked"""
def execute(self):
"""Execute the check logic.
"""Execute <check short description>
Returns:
A list of reports containing the result of the check.
List[CheckReport<Provider>]: A list of reports containing the result of the check.
"""
findings = []
# Iterate over the target resources using the provider service client
@@ -158,10 +147,12 @@ else:
Each check **must** populate the report with an unique identifier for the audited resource. This identifier or identifiers are going to depend on the provider and the resource that is being audited. Here are the criteria for each provider:
- AWS
- Amazon Resource ID — `report.resource_id`.
- The resource identifier. This is the name of the resource, the ID of the resource, or a resource path. Some resource identifiers include a parent resource (sub-resource-type/parent-resource/sub-resource) or a qualifier such as a version (resource-type:resource-name:qualifier).
- If the resource ID cannot be retrieved directly from the audited resource, it can be extracted from the ARN. It is the last part of the ARN after the last slash (`/`) or colon (`:`).
- If no actual resource to audit exists, this format can be used: `<resource_type>/unknown`
- Amazon Resource Name — `report.resource_arn`.
- The [Amazon Resource Name (ARN)](https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) of the audited entity.
- If the ARN cannot be retrieved directly from the audited resource, construct a valid ARN using the `resource_id` component as the audited entity. Examples:
@@ -172,24 +163,32 @@ Each check **must** populate the report with an unique identifier for the audite
- AWS Security Hub — `arn:<partition>:security-hub:<region>:<account-id>:hub/unknown`.
- Access Analyzer — `arn:<partition>:access-analyzer:<region>:<account-id>:analyzer/unknown`.
- GuardDuty — `arn:<partition>:guardduty:<region>:<account-id>:detector/unknown`.
- GCP
- Resource ID — `report.resource_id`.
- Resource ID represents the full, [unambiguous path to a resource](https://google.aip.dev/122#full-resource-names), known as the full resource name. Typically, it follows the format: `//{api_service/resource_path}`.
- If the resource ID cannot be retrieved directly from the audited resource, by default the resource name is used.
- Resource Name — `report.resource_name`.
- Resource Name usually refers to the name of a resource within its service.
- Azure
- Resource ID — `report.resource_id`.
- Resource ID represents the full Azure Resource Manager path to a resource, which follows the format: `/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/{resourceProviderNamespace}/{resourceType}/{resourceName}`.
- Resource Name — `report.resource_name`.
- Resource Name usually refers to the name of a resource within its service.
- If the [resource name](https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules) cannot be retrieved directly from the audited resource, the last part of the resource ID can be used.
- Kubernetes
- Resource ID — `report.resource_id`.
- The UID of the Kubernetes object. This is a system-generated string that uniquely identifies the object within the cluster for its entire lifetime. See [Kubernetes Object Names and IDs - UIDs](https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#uids).
- Resource Name — `report.resource_name`.
- The name of the Kubernetes object. This is a client-provided string that must be unique for the resource type within a namespace (for namespaced resources) or cluster (for cluster-scoped resources). Names typically follow DNS subdomain or label conventions. See [Kubernetes Object Names and IDs - Names](https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names).
- M365
- Resource ID — `report.resource_id`.
- If the audited resource has a globally unique identifier such as a `guid`, use it as the `resource_id`.
- If no `guid` exists, use another unique and relevant identifier for the resource, such as the tenant domain, the internal policy ID, or a representative string following the format `<resource_type>/<name_or_id>`.
@@ -205,15 +204,42 @@ Each check **must** populate the report with an unique identifier for the audite
- For global configurations:
- `resource_id`: Tenant domain or representative string (e.g., "userSettings")
- `resource_name`: Description of the configuration (e.g., "SharePoint Settings")
- GitHub
- Resource ID — `report.resource_id`.
- The ID of the Github resource. This is a system-generated integer that uniquely identifies the resource within the Github platform.
- Resource Name — `report.resource_name`.
- The name of the Github resource. In the case of a repository, this is just the repository name. For full repository names use the resource `full_name`.
### Configurable Checks in Prowler
### Using the Audit Configuration
See [Configurable Checks](./configurable-checks.md) for detailed information on making checks configurable using the `audit_config` object and configuration file.
Prowler has a [configuration file](../tutorials/configuration_file.md) which is used to pass certain configuration values to the checks. For example:
```python title="ec2_securitygroup_with_many_ingress_egress_rules.py"
class ec2_securitygroup_with_many_ingress_egress_rules(Check):
def execute(self):
findings = []
max_security_group_rules = ec2_client.audit_config.get(
"max_security_group_rules", 50
)
for security_group_arn, security_group in ec2_client.security_groups.items():
```
We use the `audit_config` object to retrieve the value of `max_security_group_rules`, which is the default value of 50 if the configuration value is not present.
The configuration file is located at [`prowler/config/config.yaml`](https://github.com/prowler-cloud/prowler/blob/master/prowler/config/config.yaml) and is used to pass certain configuration values to the checks. For example:
```yaml title="config.yaml"
aws:
max_security_group_rules: 50
```
This `audit_config` object is a Python dictionary that stores values read from the configuration file. It can be accessed by the check using the `audit_config` attribute of the service client.
???+ note
Always use the `dictionary.get(value, default)` syntax to ensure a default value is set when the configuration value is not present.
## Metadata Structure for Prowler Checks
@@ -259,25 +285,44 @@ Below is a generic example of a check metadata file. **Do not include comments i
### Metadata Fields and Their Purpose
- **Provider** — The Prowler provider related to the check. The name **must** be lowercase and match the provider folder name. For supported providers refer to [Prowler Hub](https://hub.prowler.com/check) or directly to [Prowler Code](https://github.com/prowler-cloud/prowler/tree/master/prowler/providers).
- **CheckID** — The unique identifier for the check inside the provider, this field **must** match the check's folder and python file and json metadata file name. For more information about the naming refer to the [Naming Format for Checks](#naming-format-for-checks) section.
- **CheckTitle** — A concise, descriptive title for the check.
- **CheckType***For now this field is only standardized for the AWS provider*.
- For AWS this field must follow the [AWS Security Hub Types](https://docs.aws.amazon.com/securityhub/latest/userguide/asff-required-attributes.html#Types) format. So the common pattern to follow is `namespace/category/classifier`, refer to the attached documentation for the valid values for this fields.
- **ServiceName** — The name of the provider service being audited. This field **must** be in lowercase and match with the service folder name. For supported services refer to [Prowler Hub](https://hub.prowler.com/check) or directly to [Prowler Code](https://github.com/prowler-cloud/prowler/tree/master/prowler/providers).
- **SubServiceName** — The subservice or resource within the service, if applicable. For more information refer to the [Naming Format for Checks](#naming-format-for-checks) section.
- **ResourceIdTemplate** — A template for the unique resource identifier. For more information refer to the [Prowler's Resource Identification](#prowlers-resource-identification) section.
- **Severity** — The severity of the finding if the check fails. Must be one of: `critical`, `high`, `medium`, `low`, or `informational`, this field **must** be in lowercase. To get more information about the severity levels refer to the [Prowler's Check Severity Levels](#prowlers-check-severity-levels) section.
- **ResourceType** — The type of resource being audited. *For now this field is only standardized for the AWS provider*.
- For AWS use the [Security Hub resource types](https://docs.aws.amazon.com/securityhub/latest/userguide/asff-resources.html) or, if not available, the PascalCase version of the [CloudFormation type](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-template-resource-type-ref.html) (e.g., `AwsEc2Instance`). Use "Other" if no match exists.
- **Description** — A short description of what the check does.
- **Risk** — The risk or impact if the check fails, explaining why the finding matters.
- **RelatedUrl** — A URL to official documentation or further reading about the check's purpose. If no official documentation is available, use the risk and recommendation text from trusted third-party sources.
- **Remediation** — Guidance for fixing a failed check, including:
- **Code** — Remediation commands or code snippets for CLI, Terraform, native IaC, or other tools like the Web Console.
- **Recommendation** — A textual human readable recommendation. Here it is not necessary to include actual steps, but rather a general recommendation about what to do to fix the check.
- **Categories** — One or more categories for grouping checks in execution (e.g., `internet-exposed`). For the current list of categories, refer to the [Prowler Hub](https://hub.prowler.com/check).
- **DependsOn** — Currently not used.
- **RelatedTo** — Currently not used.
- **Notes** — Any additional information not covered by other fields.
### Remediation Code Guidelines
@@ -292,28 +337,3 @@ When providing remediation steps, reference the following sources:
### Python Model Reference
The metadata structure is enforced in code using a Pydantic model. For reference, see the [`CheckMetadata`](https://github.com/prowler-cloud/prowler/blob/master/prowler/lib/check/models.py).
## Generic Check Patterns and Best Practices
### Common Patterns
- Every check is implemented as a class inheriting from `Check` (from `prowler.lib.check.models`).
- The main logic is implemented in the `execute()` method (**only method that must be implemented**), which always returns a list of provider-specific report objects (e.g., `CheckReport<Provider>`)—one per finding/resource. If there are no findings/resources, return an empty list.
- **Never** use the provider's client directly; instead, use the service client (e.g., `<service>_client`) and iterate over its resources.
- For each resource, create a provider-specific report object, populate it with metadata, resource details, status (`PASS`, `FAIL`, etc.), and a human-readable `status_extended` message.
- Use the `metadata()` method to attach check metadata to each report.
- Checks are designed to be idempotent and stateless: they do not modify resources, only report on their state.
### Best Practices
- Use clear, actionable, and user-friendly language in `status_extended` to explain the result. Always provide information to identify the resource.
- Use helper functions/utilities for repeated logic to avoid code duplication. Save them in the `lib` folder of the service.
- Handle exceptions gracefully: catch errors per resource, log them, and continue processing other resources.
- Document the check with a class and function level docstring explaining what it does, what it checks, and any caveats or provider-specific behaviors.
- Use type hints for the `execute()` method (e.g., `-> list[CheckReport<Provider>]`) for clarity and static analysis.
- Ensure checks are efficient; avoid excessive nested loops. If the complexity is high, consider refactoring the check.
- Keep the check logic focused: one check = one control/requirement. Avoid combining unrelated logic in a single check.
## Specific Check Patterns
Details for specific providers can be found in documentation pages named using the pattern `<provider_name>-details`.

Some files were not shown because too many files have changed in this diff Show More