Compare commits

..

13 Commits

Author SHA1 Message Date
sumit_chaturvedi 95cb36e09b refactor(e2e): reuse login helper and update test cases 2025-06-27 10:26:20 +05:30
sumit_chaturvedi 1880b97687 chore(UI): fix the env variable 2025-06-25 22:44:12 +05:30
sumit_chaturvedi ae1219dac8 refactor(e2e): increase timeout to avoid test failure 2025-06-25 17:56:43 +05:30
sumit_chaturvedi dc9d5b0bcd test(e2e): implement login flow tests with valid and invalid credentials 2025-06-25 15:24:26 +05:30
sumit_chaturvedi 498a38634c refactor(e2e): removed sign-up redirection 2025-06-25 14:50:36 +05:30
sumit_chaturvedi 5e8385607a chore(e2e): updated page load timing 2025-06-25 14:27:29 +05:30
sumit_chaturvedi 60b090284a refactor(e2e): remove global-setup file — Docker lifecycle handled in CI 2025-06-25 12:39:59 +05:30
sumit_chaturvedi 6c3ceda58a chore(e2e): update execSync command to use 'docker compose' format 2025-06-25 11:04:08 +05:30
sumit_chaturvedi 6080343eaf chore(e2e): add temporary AUTH_SECRET for Playwright E2E test runs 2025-06-24 19:36:05 +05:30
sumit_chaturvedi b081027f5e chore(e2e): add GitHub Actions workflow for Playwright UI E2E tests 2025-06-24 16:41:08 +05:30
sumit_chaturvedi a5c7cfc752 docs: changelog update 2025-06-23 15:37:20 +05:30
sumit_chaturvedi d68a798d25 feat: add basic Playwright tests for login and findings page 2025-06-23 15:17:04 +05:30
sumit_chaturvedi bd0749daa8 feat(ui): add Playwright setup with basic configuration for E2E testing in Next.js 2025-06-23 10:56:57 +05:30
1075 changed files with 22022 additions and 88168 deletions
+4 -7
View File
@@ -6,12 +6,9 @@
PROWLER_UI_VERSION="stable"
AUTH_URL=http://localhost:3000
API_BASE_URL=http://prowler-api:8080/api/v1
NEXT_PUBLIC_API_BASE_URL=${API_BASE_URL}
NEXT_PUBLIC_API_DOCS_URL=http://prowler-api:8080/api/v1/docs
AUTH_TRUST_HOST=true
UI_PORT=3000
# Temp URL for feeds need to use actual
RSS_FEED_URL=https://prowler.com/blog/rss
# openssl rand -base64 32
AUTH_SECRET="N/c6mnaS5+SWq81+819OrzQZlmx1Vxtp/orjttJSmw8="
# Google Tag Manager ID
@@ -74,7 +71,7 @@ DJANGO_SETTINGS_MODULE=config.django.production
DJANGO_LOGGING_FORMATTER=human_readable
# Select one of [DEBUG|INFO|WARNING|ERROR|CRITICAL]
# Applies to both Django and Celery Workers
DJANGO_LOGGING_LEVEL=DEBUG
DJANGO_LOGGING_LEVEL=INFO
# Defaults to the maximum available based on CPU cores if not set.
DJANGO_WORKERS=4
# Token lifetime is in minutes
@@ -127,14 +124,13 @@ jQIDAQAB
DJANGO_SECRETS_ENCRYPTION_KEY="oE/ltOhp/n1TdbHjVmzcjDPLcLA41CVI/4Rk+UB5ESc="
DJANGO_BROKER_VISIBILITY_TIMEOUT=86400
DJANGO_SENTRY_DSN=
DJANGO_THROTTLE_TOKEN_OBTAIN=50/minute
# Sentry settings
SENTRY_ENVIRONMENT=local
SENTRY_RELEASE=local
#### Prowler release version ####
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.10.0
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.6.0
# Social login credentials
SOCIAL_GOOGLE_OAUTH_CALLBACK_URL="${AUTH_URL}/api/auth/callback/google"
@@ -146,7 +142,8 @@ SOCIAL_GITHUB_OAUTH_CLIENT_ID=""
SOCIAL_GITHUB_OAUTH_CLIENT_SECRET=""
# Single Sign-On (SSO)
SAML_SSO_CALLBACK_URL="${AUTH_URL}/api/auth/callback/saml"
SAML_PUBLIC_CERT=""
SAML_PRIVATE_KEY=""
# Lighthouse tracing
LANGSMITH_TRACING=false
-16
View File
@@ -22,11 +22,6 @@ provider/kubernetes:
- any-glob-to-any-file: "prowler/providers/kubernetes/**"
- any-glob-to-any-file: "tests/providers/kubernetes/**"
provider/m365:
- changed-files:
- any-glob-to-any-file: "prowler/providers/m365/**"
- any-glob-to-any-file: "tests/providers/m365/**"
provider/github:
- changed-files:
- any-glob-to-any-file: "prowler/providers/github/**"
@@ -37,11 +32,6 @@ provider/iac:
- any-glob-to-any-file: "prowler/providers/iac/**"
- any-glob-to-any-file: "tests/providers/iac/**"
provider/mongodbatlas:
- changed-files:
- any-glob-to-any-file: "prowler/providers/mongodbatlas/**"
- any-glob-to-any-file: "tests/providers/mongodbatlas/**"
github_actions:
- changed-files:
- any-glob-to-any-file: ".github/workflows/*"
@@ -57,13 +47,11 @@ mutelist:
- any-glob-to-any-file: "prowler/providers/azure/lib/mutelist/**"
- any-glob-to-any-file: "prowler/providers/gcp/lib/mutelist/**"
- any-glob-to-any-file: "prowler/providers/kubernetes/lib/mutelist/**"
- any-glob-to-any-file: "prowler/providers/mongodbatlas/lib/mutelist/**"
- any-glob-to-any-file: "tests/lib/mutelist/**"
- any-glob-to-any-file: "tests/providers/aws/lib/mutelist/**"
- any-glob-to-any-file: "tests/providers/azure/lib/mutelist/**"
- any-glob-to-any-file: "tests/providers/gcp/lib/mutelist/**"
- any-glob-to-any-file: "tests/providers/kubernetes/lib/mutelist/**"
- any-glob-to-any-file: "tests/providers/mongodbatlas/lib/mutelist/**"
integration/s3:
- changed-files:
@@ -119,7 +107,3 @@ compliance:
review-django-migrations:
- changed-files:
- any-glob-to-any-file: "api/src/backend/api/migrations/**"
metadata-review:
- changed-files:
- any-glob-to-any-file: "**/*.metadata.json"
-4
View File
@@ -8,10 +8,6 @@ If fixes an issue please add it with `Fix #XXXX`
Please include a summary of the change and which issue is fixed. List any dependencies that are required for this change.
### Steps to review
Please add a detailed description of how to review this PR.
### Checklist
- Are there new checks included in this PR? Yes / No
@@ -6,7 +6,6 @@ on:
- "master"
paths:
- "api/**"
- "prowler/**"
- ".github/workflows/api-build-lint-push-containers.yml"
# Uncomment the code below to test this action on PRs
@@ -77,7 +76,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build and push container image (latest)
# Comment the following line for testing
+2 -2
View File
@@ -48,12 +48,12 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/init@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/api-codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/analyze@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
category: "/language:${{matrix.language}}"
+10 -25
View File
@@ -13,7 +13,6 @@ on:
- "master"
- "v5.*"
paths:
- ".github/workflows/api-pull-request.yml"
- "api/**"
env:
@@ -82,9 +81,7 @@ jobs:
id: are-non-ignored-files-changed
uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46.0.5
with:
files: |
api/**
.github/workflows/api-pull-request.yml
files: api/**
files_ignore: ${{ env.IGNORE_FILES }}
- name: Replace @master with current branch in pyproject.toml
@@ -108,23 +105,6 @@ jobs:
run: |
poetry lock
- name: Update SDK's poetry.lock resolved_reference to latest commit - Only for push events to `master`
working-directory: ./api
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true' && github.event_name == 'push'
run: |
# Get the latest commit hash from the prowler-cloud/prowler repository
LATEST_COMMIT=$(curl -s "https://api.github.com/repos/prowler-cloud/prowler/commits/master" | jq -r '.sha')
echo "Latest commit hash: $LATEST_COMMIT"
# Update the resolved_reference specifically for prowler-cloud/prowler repository
sed -i '/url = "https:\/\/github\.com\/prowler-cloud\/prowler\.git"/,/resolved_reference = / {
s/resolved_reference = "[a-f0-9]\{40\}"/resolved_reference = "'"$LATEST_COMMIT"'"/
}' poetry.lock
# Verify the change was made
echo "Updated resolved_reference:"
grep -A2 -B2 "resolved_reference" poetry.lock
- name: Set up Python ${{ matrix.python-version }}
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
@@ -156,6 +136,12 @@ jobs:
run: |
poetry check --lock
- name: Prevents known compatibility error between lxml and libxml2/libxmlsec versions - https://github.com/xmlsec/python-xmlsec/issues/320
working-directory: ./api
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
run: |
poetry run pip install --force-reinstall --no-binary lxml lxml
- name: Lint with ruff
working-directory: ./api
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
@@ -183,10 +169,9 @@ jobs:
- name: Safety
working-directory: ./api
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
# 76352, 76353, 77323 come from SDK, but they cannot upgrade it yet. It does not affect API
# TODO: Botocore needs urllib3 1.X so we need to ignore these vulnerabilities 77744,77745. Remove this once we upgrade to urllib3 2.X
# 76352 and 76353 come from SDK, but they cannot upgrade it yet. It does not affect API
run: |
poetry run safety check --ignore 70612,66963,74429,76352,76353,77323,77744,77745
poetry run safety check --ignore 70612,66963,74429,76352,76353
- name: Vulture
working-directory: ./api
@@ -226,7 +211,7 @@ jobs:
files_ignore: ${{ env.IGNORE_FILES }}
- name: Set up Docker Buildx
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build Container
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0
@@ -7,7 +7,6 @@ on:
- 'v3'
paths:
- 'docs/**'
- '.github/workflows/build-documentation-on-pr.yml'
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
@@ -17,20 +16,9 @@ jobs:
name: Documentation Link
runs-on: ubuntu-latest
steps:
- name: Find existing documentation comment
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
id: find-comment
with:
issue-number: ${{ env.PR_NUMBER }}
comment-author: 'github-actions[bot]'
body-includes: '<!-- prowler-docs-link -->'
- name: Create or update PR comment with the Prowler Documentation URI
- name: Leave PR comment with the Prowler Documentation URI
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
comment-id: ${{ steps.find-comment.outputs.comment-id }}
issue-number: ${{ env.PR_NUMBER }}
body: |
<!-- prowler-docs-link -->
You can check the documentation for this PR here -> [Prowler Documentation](https://prowler-prowler-docs--${{ env.PR_NUMBER }}.com.readthedocs.build/projects/prowler-open-source/en/${{ env.PR_NUMBER }}/)
edit-mode: replace
+1 -1
View File
@@ -1,4 +1,4 @@
name: Prowler - Create Backport Label
name: Create Backport Label
on:
release:
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
with:
fetch-depth: 0
- name: TruffleHog OSS
uses: trufflesecurity/trufflehog@a05cf0859455b5b16317ee22d809887a4043cdf0 # v3.90.2
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
with:
path: ./
base: ${{ github.event.repository.default_branch }}
-167
View File
@@ -1,167 +0,0 @@
name: Prowler - PR Conflict Checker
on:
pull_request:
types:
- opened
- synchronize
- reopened
branches:
- "master"
- "v5.*"
pull_request_target:
types:
- opened
- synchronize
- reopened
branches:
- "master"
- "v5.*"
jobs:
conflict-checker:
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46.0.5
with:
files: |
**
- name: Check for conflict markers
id: conflict-check
run: |
echo "Checking for conflict markers in changed files..."
CONFLICT_FILES=""
HAS_CONFLICTS=false
# Check each changed file for conflict markers
for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
if [ -f "$file" ]; then
echo "Checking file: $file"
# Look for conflict markers
if grep -l "^<<<<<<<\|^=======\|^>>>>>>>" "$file" 2>/dev/null; then
echo "Conflict markers found in: $file"
CONFLICT_FILES="$CONFLICT_FILES$file "
HAS_CONFLICTS=true
fi
fi
done
if [ "$HAS_CONFLICTS" = true ]; then
echo "has_conflicts=true" >> $GITHUB_OUTPUT
echo "conflict_files=$CONFLICT_FILES" >> $GITHUB_OUTPUT
echo "Conflict markers detected in files: $CONFLICT_FILES"
else
echo "has_conflicts=false" >> $GITHUB_OUTPUT
echo "No conflict markers found in changed files"
fi
- name: Add conflict label
if: steps.conflict-check.outputs.has_conflicts == 'true'
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
github-token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
script: |
const { data: labels } = await github.rest.issues.listLabelsOnIssue({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
});
const hasConflictLabel = labels.some(label => label.name === 'has-conflicts');
if (!hasConflictLabel) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
labels: ['has-conflicts']
});
console.log('Added has-conflicts label');
} else {
console.log('has-conflicts label already exists');
}
- name: Remove conflict label
if: steps.conflict-check.outputs.has_conflicts == 'false'
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
github-token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
script: |
try {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
name: 'has-conflicts'
});
console.log('Removed has-conflicts label');
} catch (error) {
if (error.status === 404) {
console.log('has-conflicts label was not present');
} else {
throw error;
}
}
- name: Find existing conflict comment
if: steps.conflict-check.outputs.has_conflicts == 'true'
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
id: find-comment
with:
issue-number: ${{ github.event.pull_request.number }}
comment-author: 'github-actions[bot]'
body-regex: '(⚠️ \*\*Conflict Markers Detected\*\*|✅ \*\*Conflict Markers Resolved\*\*)'
- name: Create or update conflict comment
if: steps.conflict-check.outputs.has_conflicts == 'true'
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
comment-id: ${{ steps.find-comment.outputs.comment-id }}
issue-number: ${{ github.event.pull_request.number }}
edit-mode: replace
body: |
⚠️ **Conflict Markers Detected**
This pull request contains unresolved conflict markers in the following files:
```
${{ steps.conflict-check.outputs.conflict_files }}
```
Please resolve these conflicts by:
1. Locating the conflict markers: `<<<<<<<`, `=======`, and `>>>>>>>`
2. Manually editing the files to resolve the conflicts
3. Removing all conflict markers
4. Committing and pushing the changes
- name: Find existing conflict comment when resolved
if: steps.conflict-check.outputs.has_conflicts == 'false'
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
id: find-resolved-comment
with:
issue-number: ${{ github.event.pull_request.number }}
comment-author: 'github-actions[bot]'
body-regex: '(⚠️ \*\*Conflict Markers Detected\*\*|✅ \*\*Conflict Markers Resolved\*\*)'
- name: Update comment when conflicts resolved
if: steps.conflict-check.outputs.has_conflicts == 'false' && steps.find-resolved-comment.outputs.comment-id != ''
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
comment-id: ${{ steps.find-resolved-comment.outputs.comment-id }}
issue-number: ${{ github.event.pull_request.number }}
edit-mode: replace
body: |
✅ **Conflict Markers Resolved**
All conflict markers have been successfully resolved in this pull request.
@@ -1,300 +0,0 @@
name: Prowler - Release Preparation
run-name: Prowler Release Preparation for ${{ inputs.prowler_version }}
on:
workflow_dispatch:
inputs:
prowler_version:
description: 'Prowler version to release (e.g., 5.9.0)'
required: true
type: string
env:
PROWLER_VERSION: ${{ github.event.inputs.prowler_version }}
jobs:
prepare-release:
if: github.repository == 'prowler-cloud/prowler'
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.12'
- name: Install Poetry
run: |
python3 -m pip install --user poetry
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Configure Git
run: |
git config --global user.name "prowler-bot"
git config --global user.email "179230569+prowler-bot@users.noreply.github.com"
- name: Parse version and determine branch
run: |
# Validate version format (reusing pattern from sdk-bump-version.yml)
if [[ $PROWLER_VERSION =~ ^([0-9]+)\.([0-9]+)\.([0-9]+)$ ]]; then
MAJOR_VERSION=${BASH_REMATCH[1]}
MINOR_VERSION=${BASH_REMATCH[2]}
PATCH_VERSION=${BASH_REMATCH[3]}
# Export version components to environment
echo "MAJOR_VERSION=${MAJOR_VERSION}" >> "${GITHUB_ENV}"
echo "MINOR_VERSION=${MINOR_VERSION}" >> "${GITHUB_ENV}"
echo "PATCH_VERSION=${PATCH_VERSION}" >> "${GITHUB_ENV}"
# Determine branch name (format: v5.9)
BRANCH_NAME="v${MAJOR_VERSION}.${MINOR_VERSION}"
echo "BRANCH_NAME=${BRANCH_NAME}" >> "${GITHUB_ENV}"
# Calculate UI version (1.X.X format - matches Prowler minor version)
UI_VERSION="1.${MINOR_VERSION}.${PATCH_VERSION}"
echo "UI_VERSION=${UI_VERSION}" >> "${GITHUB_ENV}"
# Calculate API version (1.X.X format - one minor version ahead)
API_MINOR_VERSION=$((MINOR_VERSION + 1))
API_VERSION="1.${API_MINOR_VERSION}.${PATCH_VERSION}"
echo "API_VERSION=${API_VERSION}" >> "${GITHUB_ENV}"
echo "Prowler version: $PROWLER_VERSION"
echo "Branch name: $BRANCH_NAME"
echo "UI version: $UI_VERSION"
echo "API version: $API_VERSION"
echo "Is minor release: $([ $PATCH_VERSION -eq 0 ] && echo 'true' || echo 'false')"
else
echo "Invalid version syntax: '$PROWLER_VERSION' (must be N.N.N)" >&2
exit 1
fi
- name: Checkout existing branch for patch release
if: ${{ env.PATCH_VERSION != '0' }}
run: |
echo "Patch release detected, checking out existing branch $BRANCH_NAME..."
if git show-ref --verify --quiet "refs/heads/$BRANCH_NAME"; then
echo "Branch $BRANCH_NAME exists locally, checking out..."
git checkout "$BRANCH_NAME"
elif git show-ref --verify --quiet "refs/remotes/origin/$BRANCH_NAME"; then
echo "Branch $BRANCH_NAME exists remotely, checking out..."
git checkout -b "$BRANCH_NAME" "origin/$BRANCH_NAME"
else
echo "ERROR: Branch $BRANCH_NAME should exist for patch release $PROWLER_VERSION"
exit 1
fi
- name: Verify version in pyproject.toml
run: |
CURRENT_VERSION=$(grep '^version = ' pyproject.toml | sed -E 's/version = "([^"]+)"/\1/' | tr -d '[:space:]')
PROWLER_VERSION_TRIMMED=$(echo "$PROWLER_VERSION" | tr -d '[:space:]')
if [ "$CURRENT_VERSION" != "$PROWLER_VERSION_TRIMMED" ]; then
echo "ERROR: Version mismatch in pyproject.toml (expected: '$PROWLER_VERSION_TRIMMED', found: '$CURRENT_VERSION')"
exit 1
fi
echo "✓ pyproject.toml version: $CURRENT_VERSION"
- name: Verify version in prowler/config/config.py
run: |
CURRENT_VERSION=$(grep '^prowler_version = ' prowler/config/config.py | sed -E 's/prowler_version = "([^"]+)"/\1/' | tr -d '[:space:]')
PROWLER_VERSION_TRIMMED=$(echo "$PROWLER_VERSION" | tr -d '[:space:]')
if [ "$CURRENT_VERSION" != "$PROWLER_VERSION_TRIMMED" ]; then
echo "ERROR: Version mismatch in prowler/config/config.py (expected: '$PROWLER_VERSION_TRIMMED', found: '$CURRENT_VERSION')"
exit 1
fi
echo "✓ prowler/config/config.py version: $CURRENT_VERSION"
- name: Verify version in api/pyproject.toml
run: |
CURRENT_API_VERSION=$(grep '^version = ' api/pyproject.toml | sed -E 's/version = "([^"]+)"/\1/' | tr -d '[:space:]')
API_VERSION_TRIMMED=$(echo "$API_VERSION" | tr -d '[:space:]')
if [ "$CURRENT_API_VERSION" != "$API_VERSION_TRIMMED" ]; then
echo "ERROR: API version mismatch in api/pyproject.toml (expected: '$API_VERSION_TRIMMED', found: '$CURRENT_API_VERSION')"
exit 1
fi
echo "✓ api/pyproject.toml version: $CURRENT_API_VERSION"
- name: Verify prowler dependency in api/pyproject.toml
if: ${{ env.PATCH_VERSION != '0' }}
run: |
CURRENT_PROWLER_REF=$(grep 'prowler @ git+https://github.com/prowler-cloud/prowler.git@' api/pyproject.toml | sed -E 's/.*@([^"]+)".*/\1/' | tr -d '[:space:]')
BRANCH_NAME_TRIMMED=$(echo "$BRANCH_NAME" | tr -d '[:space:]')
if [ "$CURRENT_PROWLER_REF" != "$BRANCH_NAME_TRIMMED" ]; then
echo "ERROR: Prowler dependency mismatch in api/pyproject.toml (expected: '$BRANCH_NAME_TRIMMED', found: '$CURRENT_PROWLER_REF')"
exit 1
fi
echo "✓ api/pyproject.toml prowler dependency: $CURRENT_PROWLER_REF"
- name: Verify version in api/src/backend/api/v1/views.py
run: |
CURRENT_API_VERSION=$(grep 'spectacular_settings.VERSION = ' api/src/backend/api/v1/views.py | sed -E 's/.*spectacular_settings.VERSION = "([^"]+)".*/\1/' | tr -d '[:space:]')
API_VERSION_TRIMMED=$(echo "$API_VERSION" | tr -d '[:space:]')
if [ "$CURRENT_API_VERSION" != "$API_VERSION_TRIMMED" ]; then
echo "ERROR: API version mismatch in views.py (expected: '$API_VERSION_TRIMMED', found: '$CURRENT_API_VERSION')"
exit 1
fi
echo "✓ api/src/backend/api/v1/views.py version: $CURRENT_API_VERSION"
- name: Checkout existing release branch for minor release
if: ${{ env.PATCH_VERSION == '0' }}
run: |
echo "Minor release detected (patch = 0), checking out existing branch $BRANCH_NAME..."
if git show-ref --verify --quiet "refs/remotes/origin/$BRANCH_NAME"; then
echo "Branch $BRANCH_NAME exists remotely, checking out..."
git checkout -b "$BRANCH_NAME" "origin/$BRANCH_NAME"
else
echo "ERROR: Branch $BRANCH_NAME should exist for minor release $PROWLER_VERSION. Please create it manually first."
exit 1
fi
- name: Prepare prowler dependency update for minor release
if: ${{ env.PATCH_VERSION == '0' }}
run: |
CURRENT_PROWLER_REF=$(grep 'prowler @ git+https://github.com/prowler-cloud/prowler.git@' api/pyproject.toml | sed -E 's/.*@([^"]+)".*/\1/' | tr -d '[:space:]')
BRANCH_NAME_TRIMMED=$(echo "$BRANCH_NAME" | tr -d '[:space:]')
# Create a temporary branch for the PR
TEMP_BRANCH="update-api-dependency-$BRANCH_NAME_TRIMMED-$(date +%s)"
echo "TEMP_BRANCH=$TEMP_BRANCH" >> $GITHUB_ENV
# Switch back to master and create temp branch
git checkout master
git checkout -b "$TEMP_BRANCH"
# Minor release: update the dependency to use the release branch
echo "Updating prowler dependency from '$CURRENT_PROWLER_REF' to '$BRANCH_NAME_TRIMMED'"
sed -i "s|prowler @ git+https://github.com/prowler-cloud/prowler.git@[^\"]*\"|prowler @ git+https://github.com/prowler-cloud/prowler.git@$BRANCH_NAME_TRIMMED\"|" api/pyproject.toml
# Verify the change was made
UPDATED_PROWLER_REF=$(grep 'prowler @ git+https://github.com/prowler-cloud/prowler.git@' api/pyproject.toml | sed -E 's/.*@([^"]+)".*/\1/' | tr -d '[:space:]')
if [ "$UPDATED_PROWLER_REF" != "$BRANCH_NAME_TRIMMED" ]; then
echo "ERROR: Failed to update prowler dependency in api/pyproject.toml"
exit 1
fi
# Update poetry lock file
echo "Updating poetry.lock file..."
cd api
poetry lock
cd ..
# Commit and push the temporary branch
git add api/pyproject.toml api/poetry.lock
git commit -m "chore(api): update prowler dependency to $BRANCH_NAME_TRIMMED for release $PROWLER_VERSION"
git push origin "$TEMP_BRANCH"
echo "✓ Prepared prowler dependency update to: $UPDATED_PROWLER_REF"
- name: Create Pull Request against release branch
if: ${{ env.PATCH_VERSION == '0' }}
uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8
with:
token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
branch: ${{ env.TEMP_BRANCH }}
base: ${{ env.BRANCH_NAME }}
title: "chore(api): Update prowler dependency to ${{ env.BRANCH_NAME }} for release ${{ env.PROWLER_VERSION }}"
body: |
### Description
Updates the API prowler dependency for release ${{ env.PROWLER_VERSION }}.
**Changes:**
- Updates `api/pyproject.toml` prowler dependency from `@master` to `@${{ env.BRANCH_NAME }}`
- Updates `api/poetry.lock` file with resolved dependencies
This PR should be merged into the `${{ env.BRANCH_NAME }}` release branch.
### License
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
author: prowler-bot <179230569+prowler-bot@users.noreply.github.com>
labels: |
component/api
no-changelog
- name: Extract changelog entries
run: |
set -e
# Function to extract changelog for a specific version
extract_changelog() {
local file="$1"
local version="$2"
local output_file="$3"
if [ ! -f "$file" ]; then
echo "Warning: $file not found, skipping..."
touch "$output_file"
return
fi
# Extract changelog section for this version
awk -v version="$version" '
/^## \[v?'"$version"'\]/ { found=1; next }
found && /^## \[v?[0-9]+\.[0-9]+\.[0-9]+\]/ { found=0 }
found && !/^## \[v?'"$version"'\]/ { print }
' "$file" > "$output_file"
# Remove --- separators
sed -i '/^---$/d' "$output_file"
# Remove trailing empty lines
sed -i '/^$/d' "$output_file"
}
# Extract changelogs
echo "Extracting changelog entries..."
extract_changelog "prowler/CHANGELOG.md" "$PROWLER_VERSION" "prowler_changelog.md"
extract_changelog "api/CHANGELOG.md" "$API_VERSION" "api_changelog.md"
extract_changelog "ui/CHANGELOG.md" "$UI_VERSION" "ui_changelog.md"
# Combine changelogs in order: UI, API, SDK
> combined_changelog.md
if [ -s "ui_changelog.md" ]; then
echo "## UI" >> combined_changelog.md
echo "" >> combined_changelog.md
cat ui_changelog.md >> combined_changelog.md
echo "" >> combined_changelog.md
fi
if [ -s "api_changelog.md" ]; then
echo "## API" >> combined_changelog.md
echo "" >> combined_changelog.md
cat api_changelog.md >> combined_changelog.md
echo "" >> combined_changelog.md
fi
if [ -s "prowler_changelog.md" ]; then
echo "## SDK" >> combined_changelog.md
echo "" >> combined_changelog.md
cat prowler_changelog.md >> combined_changelog.md
echo "" >> combined_changelog.md
fi
echo "Combined changelog preview:"
cat combined_changelog.md
- name: Create draft release
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
with:
tag_name: ${{ env.PROWLER_VERSION }}
name: Prowler ${{ env.PROWLER_VERSION }}
body_path: combined_changelog.md
draft: true
target_commitish: ${{ env.PATCH_VERSION == '0' && 'master' || env.BRANCH_NAME }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Clean up temporary files
run: |
rm -f prowler_changelog.md api_changelog.md ui_changelog.md combined_changelog.md
@@ -1,4 +1,4 @@
name: Prowler - Check Changelog
name: Check Changelog
on:
pull_request:
@@ -9,11 +9,9 @@ jobs:
if: contains(github.event.pull_request.labels.*.name, 'no-changelog') == false
runs-on: ubuntu-latest
permissions:
id-token: write
contents: read
pull-requests: write
env:
MONITORED_FOLDERS: "api ui prowler dashboard"
MONITORED_FOLDERS: "api ui prowler"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -47,7 +45,6 @@ jobs:
echo "EOF" >> $GITHUB_OUTPUT
- name: Find existing changelog comment
if: github.event.pull_request.head.repo.full_name == github.repository
id: find_comment
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e #v3.1.0
with:
@@ -55,20 +52,29 @@ jobs:
comment-author: 'github-actions[bot]'
body-includes: '<!-- changelog-check -->'
- name: Update PR comment with changelog status
if: github.event.pull_request.head.repo.full_name == github.repository
- name: Comment on PR if changelog is missing
if: steps.check_folders.outputs.missing_changelogs != ''
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
issue-number: ${{ github.event.pull_request.number }}
comment-id: ${{ steps.find_comment.outputs.comment-id }}
edit-mode: replace
body: |
<!-- changelog-check -->
${{ steps.check_folders.outputs.missing_changelogs != '' && format('⚠️ **Changes detected in the following folders without a corresponding update to the `CHANGELOG.md`:**
⚠️ **Changes detected in the following folders without a corresponding update to the `CHANGELOG.md`:**
{0}
${{ steps.check_folders.outputs.missing_changelogs }}
Please add an entry to the corresponding `CHANGELOG.md` file to maintain a clear history of changes.', steps.check_folders.outputs.missing_changelogs) || '✅ All necessary `CHANGELOG.md` files have been updated. Great job! 🎉' }}
Please add an entry to the corresponding `CHANGELOG.md` file to maintain a clear history of changes.
- name: Comment on PR if all changelogs are present
if: steps.check_folders.outputs.missing_changelogs == ''
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
issue-number: ${{ github.event.pull_request.number }}
comment-id: ${{ steps.find_comment.outputs.comment-id }}
body: |
<!-- changelog-check -->
✅ All necessary `CHANGELOG.md` files have been updated. Great job! 🎉
- name: Fail if changelog is missing
if: steps.check_folders.outputs.missing_changelogs != ''
+8 -9
View File
@@ -27,12 +27,11 @@ jobs:
token: ${{ secrets.PROWLER_BOT_ACCESS_TOKEN }}
repository: ${{ secrets.CLOUD_DISPATCH }}
event-type: prowler-pull-request-merged
client-payload: |
{
"PROWLER_COMMIT_SHA": "${{ github.event.pull_request.merge_commit_sha }}",
"PROWLER_COMMIT_SHORT_SHA": "${{ env.SHORT_SHA }}",
"PROWLER_PR_TITLE": ${{ toJson(github.event.pull_request.title) }},
"PROWLER_PR_LABELS": ${{ toJson(github.event.pull_request.labels.*.name) }},
"PROWLER_PR_BODY": ${{ toJson(github.event.pull_request.body) }},
"PROWLER_PR_URL": ${{ toJson(github.event.pull_request.html_url) }}
}
client-payload: '{
"PROWLER_COMMIT_SHA": "${{ github.event.pull_request.merge_commit_sha }}",
"PROWLER_COMMIT_SHORT_SHA": "${{ env.SHORT_SHA }}",
"PROWLER_PR_TITLE": "${{ github.event.pull_request.title }}",
"PROWLER_PR_LABELS": ${{ toJson(github.event.pull_request.labels.*.name) }},
"PROWLER_PR_BODY": ${{ toJson(github.event.pull_request.body) }},
"PROWLER_PR_URL":${{ toJson(github.event.pull_request.html_url) }}
}'
@@ -123,7 +123,7 @@ jobs:
AWS_REGION: ${{ env.AWS_REGION }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build and push container image (latest)
if: github.event_name == 'push'
@@ -157,22 +157,6 @@ jobs:
cache-from: type=gha
cache-to: type=gha,mode=max
# - name: Push README to Docker Hub (toniblyx)
# uses: peter-evans/dockerhub-description@432a30c9e07499fd01da9f8a49f0faf9e0ca5b77 # v4.0.2
# with:
# username: ${{ secrets.DOCKERHUB_USERNAME }}
# password: ${{ secrets.DOCKERHUB_TOKEN }}
# repository: ${{ env.DOCKER_HUB_REPOSITORY }}/${{ env.IMAGE_NAME }}
# readme-filepath: ./README.md
#
# - name: Push README to Docker Hub (prowlercloud)
# uses: peter-evans/dockerhub-description@432a30c9e07499fd01da9f8a49f0faf9e0ca5b77 # v4.0.2
# with:
# username: ${{ secrets.DOCKERHUB_USERNAME }}
# password: ${{ secrets.DOCKERHUB_TOKEN }}
# repository: ${{ env.PROWLERCLOUD_DOCKERHUB_REPOSITORY }}/${{ env.PROWLERCLOUD_DOCKERHUB_IMAGE }}
# readme-filepath: ./README.md
dispatch-action:
needs: container-build-push
runs-on: ubuntu-latest
+1 -2
View File
@@ -12,6 +12,7 @@ env:
jobs:
bump-version:
name: Bump Version
if: github.repository == 'prowler-cloud/prowler'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -96,7 +97,6 @@ jobs:
commit-message: "chore(release): Bump version to v${{ env.BUMP_VERSION_TO }}"
branch: "version-bump-to-v${{ env.BUMP_VERSION_TO }}"
title: "chore(release): Bump version to v${{ env.BUMP_VERSION_TO }}"
labels: no-changelog
body: |
### Description
@@ -135,7 +135,6 @@ jobs:
commit-message: "chore(release): Bump version to v${{ env.PATCH_VERSION_TO }}"
branch: "version-bump-to-v${{ env.PATCH_VERSION_TO }}"
title: "chore(release): Bump version to v${{ env.PATCH_VERSION_TO }}"
labels: no-changelog
body: |
### Description
+2 -2
View File
@@ -56,12 +56,12 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/init@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/sdk-codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/analyze@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
category: "/language:${{matrix.language}}"
+1 -23
View File
@@ -102,15 +102,8 @@ jobs:
run: |
poetry run vulture --exclude "contrib,api,ui" --min-confidence 100 .
- name: Dockerfile - Check if Dockerfile has changed
id: dockerfile-changed-files
uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46.0.5
with:
files: |
Dockerfile
- name: Hadolint
if: steps.dockerfile-changed-files.outputs.any_changed == 'true'
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
run: |
/tmp/hadolint Dockerfile --ignore=DL3013
@@ -234,21 +227,6 @@ jobs:
run: |
poetry run pytest -n auto --cov=./prowler/providers/iac --cov-report=xml:iac_coverage.xml tests/providers/iac
# Test MongoDB Atlas
- name: MongoDB Atlas - Check if any file has changed
id: mongodb-atlas-changed-files
uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46.0.5
with:
files: |
./prowler/providers/mongodbatlas/**
./tests/providers/mongodbatlas/**
.poetry.lock
- name: MongoDB Atlas - Test
if: steps.mongodb-atlas-changed-files.outputs.any_changed == 'true'
run: |
poetry run pytest -n auto --cov=./prowler/providers/mongodbatlas --cov-report=xml:mongodb_atlas_coverage.xml tests/providers/mongodbatlas
# Common Tests
- name: Lib - Test
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
@@ -30,7 +30,6 @@ env:
# Container Registries
PROWLERCLOUD_DOCKERHUB_REPOSITORY: prowlercloud
PROWLERCLOUD_DOCKERHUB_IMAGE: prowler-ui
NEXT_PUBLIC_API_BASE_URL: http://prowler-api:8080/api/v1
jobs:
repository-check:
@@ -77,7 +76,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build and push container image (latest)
# Comment the following line for testing
@@ -87,7 +86,6 @@ jobs:
context: ${{ env.WORKING_DIRECTORY }}
build-args: |
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=${{ env.SHORT_SHA }}
NEXT_PUBLIC_API_BASE_URL=${{ env.NEXT_PUBLIC_API_BASE_URL }}
# Set push: false for testing
push: true
tags: |
@@ -103,7 +101,6 @@ jobs:
context: ${{ env.WORKING_DIRECTORY }}
build-args: |
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v${{ env.RELEASE_TAG }}
NEXT_PUBLIC_API_BASE_URL=${{ env.NEXT_PUBLIC_API_BASE_URL }}
push: true
tags: |
${{ env.PROWLERCLOUD_DOCKERHUB_REPOSITORY }}/${{ env.PROWLERCLOUD_DOCKERHUB_IMAGE }}:${{ env.RELEASE_TAG }}
+2 -2
View File
@@ -48,12 +48,12 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/init@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/ui-codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3.29.5
uses: github/codeql-action/analyze@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
with:
category: "/language:${{matrix.language}}"
-98
View File
@@ -1,98 +0,0 @@
name: UI - E2E Tests
on:
pull_request:
branches:
- master
- "v5.*"
paths:
- '.github/workflows/ui-e2e-tests.yml'
- 'ui/**'
jobs:
e2e-tests:
if: github.repository == 'prowler-cloud/prowler'
runs-on: ubuntu-latest
env:
AUTH_SECRET: 'fallback-ci-secret-for-testing'
AUTH_TRUST_HOST: true
NEXTAUTH_URL: 'http://localhost:3000'
NEXT_PUBLIC_API_BASE_URL: 'http://localhost:8080/api/v1'
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Start API services
run: |
# Override docker-compose image tag to use latest instead of stable
# This overrides any PROWLER_API_VERSION set in .env file
export PROWLER_API_VERSION=latest
echo "Using PROWLER_API_VERSION=${PROWLER_API_VERSION}"
docker compose up -d api worker worker-beat
- name: Wait for API to be ready
run: |
echo "Waiting for prowler-api..."
timeout=150 # 5 minutes max
elapsed=0
while [ $elapsed -lt $timeout ]; do
if curl -s ${NEXT_PUBLIC_API_BASE_URL}/docs >/dev/null 2>&1; then
echo "Prowler API is ready!"
exit 0
fi
echo "Waiting for prowler-api... (${elapsed}s elapsed)"
sleep 5
elapsed=$((elapsed + 5))
done
echo "Timeout waiting for prowler-api to start"
exit 1
- name: Load database fixtures for E2E tests
run: |
docker compose exec -T api sh -c '
echo "Loading all fixtures from api/fixtures/dev/..."
for fixture in api/fixtures/dev/*.json; do
if [ -f "$fixture" ]; then
echo "Loading $fixture"
poetry run python manage.py loaddata "$fixture" --database admin
fi
done
echo "All database fixtures loaded successfully!"
'
- name: Setup Node.js environment
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: '20.x'
cache: 'npm'
cache-dependency-path: './ui/package-lock.json'
- name: Install UI dependencies
working-directory: ./ui
run: npm ci
- name: Build UI application
working-directory: ./ui
run: npm run build
- name: Cache Playwright browsers
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('ui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-
- name: Install Playwright browsers
working-directory: ./ui
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: npm run test:e2e:install
- name: Run E2E tests
working-directory: ./ui
run: npm run test:e2e
- name: Upload test reports
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
if: failure()
with:
name: playwright-report
path: ui/playwright-report/
retention-days: 30
- name: Cleanup services
if: always()
run: |
echo "Shutting down services..."
docker compose down -v || true
echo "Cleanup completed"
+82
View File
@@ -0,0 +1,82 @@
name: UI - E2E Tests
on:
pull_request:
branches:
- master
- "v5.*"
paths:
- 'ui/**'
env:
# Temporary secret for CI test runs only replace with GitHub Secret later
AUTH_SECRET: "N/c6mnaS5+SWq81+819OrzQZlmx1Vxtp/orjttJSmw8="
API_BASE_URL: "http://localhost:8080/api/v1"
SERVICES_TO_START: "api-dev postgres valkey worker-beat worker-dev"
DOCKER_COMPOSE_FILE: "docker-compose-dev.yml"
jobs:
e2e:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: 20
cache: 'npm'
cache-dependency-path: './ui/package-lock.json'
# - name: Cache Playwright Browsers
# uses: actions/cache@v4
# with:
# path: ~/.cache/ms-playwright
# key: playwright-${{ runner.os }}-${{ hashFiles('**/package-lock.json') }}
# restore-keys: |
# playwright-${{ runner.os }}-
- name: Install dependencies
run: npm ci
working-directory: ./ui
- name: Install Playwright Browsers
run: npx playwright install --with-deps
working-directory: ./ui
- name: Set up Docker Compose
uses: docker/setup-compose-action@364cc21a5de5b1ee4a7f5f9d3fa374ce0ccde746 #v1.2.0
- name: Start Docker Compose
run: docker compose -f ${DOCKER_COMPOSE_FILE} up -d ${SERVICES_TO_START}
- name: Wait for API to be ready
run: |
for i in {1..30}; do
if curl -s http://localhost:8000/api/v1; then
echo "API is up!"
break
fi
echo "Waiting for API..."
sleep 5
done
- name: Run Playwright tests
run: npx playwright test
working-directory: ./ui
- name: Upload Playwright report
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: ./ui/playwright-report
- name: Upload Playwright videos
uses: actions/upload-artifact@v4
with:
name: test-videos
path: ./ui/test-results/**/*.webm
- name: Docker Compose Down
if: always()
run: docker compose -f ${DOCKER_COMPOSE_FILE} down
+2 -5
View File
@@ -34,24 +34,21 @@ jobs:
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'
cache-dependency-path: './ui/package-lock.json'
- name: Install dependencies
working-directory: ./ui
run: npm ci
run: npm install
- name: Run Healthcheck
working-directory: ./ui
run: npm run healthcheck
- name: Build the application
working-directory: ./ui
run: npm run build
test-container-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- name: Build Container
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0
with:
-13
View File
@@ -44,16 +44,6 @@ junit-reports/
# Cursor files
.cursorignore
.cursor/
# RooCode files
.roo/
.rooignore
.roomodes
# Cline files
.cline/
.clineignore
# Terraform
.terraform*
@@ -75,6 +65,3 @@ node_modules
# Persistent data
_data/
# Claude
CLAUDE.md
+1 -2
View File
@@ -115,8 +115,7 @@ repos:
- id: safety
name: safety
description: "Safety is a tool that checks your installed dependencies for known security vulnerabilities"
# TODO: Botocore needs urllib3 1.X so we need to ignore these vulnerabilities 77744,77745. Remove this once we upgrade to urllib3 2.X
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353,77744,77745'
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353'
language: system
- id: vulture
+6 -3
View File
@@ -1,4 +1,4 @@
FROM python:3.12.11-slim-bookworm AS build
FROM python:3.12.10-slim-bookworm AS build
LABEL maintainer="https://github.com/prowler-cloud/prowler"
LABEL org.opencontainers.image.source="https://github.com/prowler-cloud/prowler"
@@ -6,8 +6,7 @@ LABEL org.opencontainers.image.source="https://github.com/prowler-cloud/prowler"
ARG POWERSHELL_VERSION=7.5.0
# hadolint ignore=DL3008
RUN apt-get update && apt-get install -y --no-install-recommends \
wget libicu72 libunwind8 libssl3 libcurl4 ca-certificates apt-transport-https gnupg \
RUN apt-get update && apt-get install -y --no-install-recommends wget libicu72 \
&& rm -rf /var/lib/apt/lists/*
# Install PowerShell
@@ -47,6 +46,10 @@ ENV PATH="${HOME}/.local/bin:${PATH}"
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir poetry
# By default poetry does not compile Python source files to bytecode during installation.
# This speeds up the installation process, but the first execution may take a little more
# time because Python then compiles source files to bytecode automatically. If you want to
# compile source files to bytecode during installation, you can use the --compile option
RUN poetry install --compile && \
rm -rf ~/.cache/pip
+26 -35
View File
@@ -19,16 +19,19 @@
<a href="https://goto.prowler.com/slack"><img alt="Slack Shield" src="https://img.shields.io/badge/slack-prowler-brightgreen.svg?logo=slack"></a>
<a href="https://pypi.org/project/prowler/"><img alt="Python Version" src="https://img.shields.io/pypi/v/prowler.svg"></a>
<a href="https://pypi.python.org/pypi/prowler/"><img alt="Python Version" src="https://img.shields.io/pypi/pyversions/prowler.svg"></a>
<a href="https://pypistats.org/packages/prowler"><img alt="PyPI Downloads" src="https://img.shields.io/pypi/dw/prowler.svg?label=downloads"></a>
<a href="https://pypistats.org/packages/prowler"><img alt="PyPI Prowler Downloads" src="https://img.shields.io/pypi/dw/prowler.svg?label=prowler%20downloads"></a>
<a href="https://hub.docker.com/r/toniblyx/prowler"><img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/toniblyx/prowler"></a>
<a href="https://hub.docker.com/r/toniblyx/prowler"><img alt="Docker" src="https://img.shields.io/docker/cloud/build/toniblyx/prowler"></a>
<a href="https://hub.docker.com/r/toniblyx/prowler"><img alt="Docker" src="https://img.shields.io/docker/image-size/toniblyx/prowler"></a>
<a href="https://gallery.ecr.aws/prowler-cloud/prowler"><img width="120" height=19" alt="AWS ECR Gallery" src="https://user-images.githubusercontent.com/3985464/151531396-b6535a68-c907-44eb-95a1-a09508178616.png"></a>
<a href="https://codecov.io/gh/prowler-cloud/prowler"><img src="https://codecov.io/gh/prowler-cloud/prowler/graph/badge.svg?token=OflBGsdpDl"/></a>
</p>
<p align="center">
<a href="https://github.com/prowler-cloud/prowler/releases"><img alt="Version" src="https://img.shields.io/github/v/release/prowler-cloud/prowler"></a>
<a href="https://github.com/prowler-cloud/prowler"><img alt="Repo size" src="https://img.shields.io/github/repo-size/prowler-cloud/prowler"></a>
<a href="https://github.com/prowler-cloud/prowler/issues"><img alt="Issues" src="https://img.shields.io/github/issues/prowler-cloud/prowler"></a>
<a href="https://github.com/prowler-cloud/prowler/releases"><img alt="Version" src="https://img.shields.io/github/v/release/prowler-cloud/prowler"></a>
<a href="https://github.com/prowler-cloud/prowler/releases"><img alt="Version" src="https://img.shields.io/github/release-date/prowler-cloud/prowler"></a>
<a href="https://github.com/prowler-cloud/prowler"><img alt="Contributors" src="https://img.shields.io/github/contributors-anon/prowler-cloud/prowler"></a>
<a href="https://github.com/prowler-cloud/prowler/issues"><img alt="Issues" src="https://img.shields.io/github/issues/prowler-cloud/prowler"></a>
<a href="https://github.com/prowler-cloud/prowler"><img alt="License" src="https://img.shields.io/github/license/prowler-cloud/prowler"></a>
<a href="https://twitter.com/ToniBlyx"><img alt="Twitter" src="https://img.shields.io/twitter/follow/toniblyx?style=social"></a>
<a href="https://twitter.com/prowlercloud"><img alt="Twitter" src="https://img.shields.io/twitter/follow/prowlercloud?style=social"></a>
@@ -52,11 +55,15 @@ Prowler includes hundreds of built-in controls to ensure compliance with standar
- **National Security Standards:** ENS (Spanish National Security Scheme)
- **Custom Security Frameworks:** Tailored to your needs
## Prowler CLI and Prowler Cloud
Prowler offers a Command Line Interface (CLI), known as Prowler Open Source, and an additional service built on top of it, called <a href="https://prowler.com">Prowler Cloud</a>.
## Prowler App
Prowler App is a web-based application that simplifies running Prowler across your cloud provider accounts. It provides a user-friendly interface to visualize the results and streamline your security assessments.
![Prowler App](docs/products/img/overview.png)
![Prowler App](docs/img/overview.png)
>For more details, refer to the [Prowler App Documentation](https://docs.prowler.com/projects/prowler-open-source/en/latest/#prowler-app-installation)
@@ -73,36 +80,28 @@ prowler <provider>
```console
prowler dashboard
```
![Prowler Dashboard](docs/products/img/dashboard.png)
![Prowler Dashboard](docs/img/dashboard.png)
# Prowler at a Glance
> [!Tip]
> For the most accurate and up-to-date information about checks, services, frameworks, and categories, visit [**Prowler Hub**](https://hub.prowler.com).
| Provider | Checks | Services | [Compliance Frameworks](https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/compliance/) | [Categories](https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/misc/#categories) | Support | Stage | Interface |
|---|---|---|---|---|---|---|---|
| AWS | 576 | 82 | 36 | 10 | Official | Stable | UI, API, CLI |
| GCP | 79 | 13 | 10 | 3 | Official | Stable | UI, API, CLI |
| Azure | 162 | 19 | 11 | 4 | Official | Stable | UI, API, CLI |
| Kubernetes | 83 | 7 | 5 | 7 | Official | Stable | UI, API, CLI |
| GitHub | 17 | 2 | 1 | 0 | Official | Stable | UI, API, CLI |
| M365 | 70 | 7 | 3 | 2 | Official | Stable | UI, API, CLI |
| IaC | [See `trivy` docs.](https://trivy.dev/latest/docs/coverage/iac/) | N/A | N/A | N/A | Official | Beta | CLI |
| MongoDB Atlas | 10 | 3 | 0 | 0 | Official | Beta | CLI |
| NHN | 6 | 2 | 1 | 0 | Unofficial | Beta | CLI |
| Provider | Checks | Services | [Compliance Frameworks](https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/compliance/) | [Categories](https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/misc/#categories) |
|---|---|---|---|---|
| AWS | 567 | 82 | 36 | 10 |
| GCP | 79 | 13 | 10 | 3 |
| Azure | 142 | 18 | 10 | 3 |
| Kubernetes | 83 | 7 | 5 | 7 |
| GitHub | 16 | 2 | 1 | 0 |
| M365 | 69 | 7 | 3 | 2 |
| NHN (Unofficial) | 6 | 2 | 1 | 0 |
> [!Note]
> The numbers in the table are updated periodically.
> [!Tip]
> For the most accurate and up-to-date information about checks, services, frameworks, and categories, visit [**Prowler Hub**](https://hub.prowler.com).
> [!Note]
> Use the following commands to list Prowler's available checks, services, compliance frameworks, and categories:
> - `prowler <provider> --list-checks`
> - `prowler <provider> --list-services`
> - `prowler <provider> --list-compliance`
> - `prowler <provider> --list-categories`
> Use the following commands to list Prowler's available checks, services, compliance frameworks, and categories: `prowler <provider> --list-checks`, `prowler <provider> --list-services`, `prowler <provider> --list-compliance` and `prowler <provider> --list-categories`.
# 💻 Installation
@@ -137,14 +136,6 @@ If your workstation's architecture is incompatible, you can resolve this by:
> Once configured, access the Prowler App at http://localhost:3000. Sign up using your email and password to get started.
### Common Issues with Docker Pull Installation
> [!Note]
If you want to use AWS role assumption (e.g., with the "Connect assuming IAM Role" option), you may need to mount your local `.aws` directory into the container as a volume (e.g., `- "${HOME}/.aws:/home/prowler/.aws:ro"`). There are several ways to configure credentials for Docker containers. See the [Troubleshooting](./docs/troubleshooting.md) section for more details and examples.
You can find more information in the [Troubleshooting](./docs/troubleshooting.md) section.
### From GitHub
**Requirements**
@@ -240,7 +231,7 @@ The following versions of Prowler CLI are available, depending on your requireme
The container images are available here:
- Prowler CLI:
- [DockerHub](https://hub.docker.com/r/prowlercloud/prowler/tags)
- [DockerHub](https://hub.docker.com/r/toniblyx/prowler/tags)
- [AWS Public ECR](https://gallery.ecr.aws/prowler-cloud/prowler)
- Prowler App:
- [DockerHub - Prowler UI](https://hub.docker.com/r/prowlercloud/prowler-ui/tags)
@@ -275,7 +266,7 @@ python prowler-cli.py -v
- **Prowler API**: A backend service, developed with Django REST Framework, responsible for running Prowler scans and storing the generated results.
- **Prowler SDK**: A Python SDK designed to extend the functionality of the Prowler CLI for advanced capabilities.
![Prowler App Architecture](docs/products/img/prowler-app-architecture.png)
![Prowler App Architecture](docs/img/prowler-app-architecture.png)
## Prowler CLI
+15 -57
View File
@@ -1,65 +1,23 @@
# Security
# Security Policy
## Reporting Vulnerabilities
## Software Security
As an **AWS Partner** and we have passed the [AWS Foundation Technical Review (FTR)](https://aws.amazon.com/partners/foundational-technical-review/) and we use the following tools and automation to make sure our code is secure and dependencies up-to-dated:
At Prowler, we consider the security of our open source software and systems a top priority. But no matter how much effort we put into system security, there can still be vulnerabilities present.
- `bandit` for code security review.
- `safety` and `dependabot` for dependencies.
- `hadolint` and `dockle` for our containers security.
- `snyk` in Docker Hub.
- `clair` in Amazon ECR.
- `vulture`, `flake8`, `black` and `pylint` for formatting and best practices.
If you discover a vulnerability, we would like to know about it so we can take steps to address it as quickly as possible. We would like to ask you to help us better protect our users, our clients and our systems.
## Reporting a Vulnerability
When reporting vulnerabilities, please consider (1) attack scenario / exploitability, and (2) the security impact of the bug. The following issues are considered out of scope:
If you would like to report a vulnerability or have a security concern regarding Prowler Open Source or ProwlerPro service, please submit the information by contacting to https://support.prowler.com.
- Social engineering support or attacks requiring social engineering.
- Clickjacking on pages with no sensitive actions.
- Cross-Site Request Forgery (CSRF) on unauthenticated forms or forms with no sensitive actions.
- Attacks requiring Man-In-The-Middle (MITM) or physical access to a user's device.
- Previously known vulnerable libraries without a working Proof of Concept (PoC).
- Comma Separated Values (CSV) injection without demonstrating a vulnerability.
- Missing best practices in SSL/TLS configuration.
- Any activity that could lead to the disruption of service (DoS).
- Rate limiting or brute force issues on non-authentication endpoints.
- Missing best practices in Content Security Policy (CSP).
- Missing HttpOnly or Secure flags on cookies.
- Configuration of or missing security headers.
- Missing email best practices, such as invalid, incomplete, or missing SPF/DKIM/DMARC records.
- Vulnerabilities only affecting users of outdated or unpatched browsers (less than two stable versions behind).
- Software version disclosure, banner identification issues, or descriptive error messages.
- Tabnabbing.
- Issues that require unlikely user interaction.
- Improper logout functionality and improper session timeout.
- CORS misconfiguration without an exploitation scenario.
- Broken link hijacking.
- Automated scanning results (e.g., sqlmap, Burp active scanner) that have not been manually verified.
- Content spoofing and text injection issues without a clear attack vector.
- Email spoofing without exploiting security flaws.
- Dead links or broken links.
- User enumeration.
The information you share with ProwlerPro as part of this process is kept confidential within ProwlerPro. We will only share this information with a third party if the vulnerability you report is found to affect a third-party product, in which case we will share this information with the third-party product's author or manufacturer. Otherwise, we will only share this information as permitted by you.
Testing guidelines:
- Do not run automated scanners on other customer projects. Running automated scanners can run up costs for our users. Aggressively configured scanners might inadvertently disrupt services, exploit vulnerabilities, lead to system instability or breaches and violate Terms of Service from our upstream providers. Our own security systems won't be able to distinguish hostile reconnaissance from whitehat research. If you wish to run an automated scanner, notify us at support@prowler.com and only run it on your own Prowler app project. Do NOT attack Prowler in usage of other customers.
- Do not take advantage of the vulnerability or problem you have discovered, for example by downloading more data than necessary to demonstrate the vulnerability or deleting or modifying other people's data.
We will review the submitted report, and assign it a tracking number. We will then respond to you, acknowledging receipt of the report, and outline the next steps in the process.
Reporting guidelines:
- File a report through our Support Desk at https://support.prowler.com
- If it is about a lack of a security functionality, please file a feature request instead at https://github.com/prowler-cloud/prowler/issues
- Do provide sufficient information to reproduce the problem, so we will be able to resolve it as quickly as possible.
- If you have further questions and want direct interaction with the Prowler team, please contact us at via our Community Slack at goto.prowler.com/slack.
You will receive a non-automated response to your initial contact within 24 hours, confirming receipt of your reported vulnerability.
Disclosure guidelines:
- In order to protect our users and customers, do not reveal the problem to others until we have researched, addressed and informed our affected customers.
- If you want to publicly share your research about Prowler at a conference, in a blog or any other public forum, you should share a draft with us for review and approval at least 30 days prior to the publication date. Please note that the following should not be included:
- Data regarding any Prowler user or customer projects.
- Prowler customers' data.
- Information about Prowler employees, contractors or partners.
What we promise:
- We will respond to your report within 5 business days with our evaluation of the report and an expected resolution date.
- If you have followed the instructions above, we will not take any legal action against you in regard to the report.
- We will handle your report with strict confidentiality, and not pass on your personal details to third parties without your permission.
- We will keep you informed of the progress towards resolving the problem.
- In the public information concerning the problem reported, we will give your name as the discoverer of the problem (unless you desire otherwise).
We strive to resolve all problems as quickly as possible, and we would like to play an active role in the ultimate publication on the problem after it is resolved.
---
For more information about our security policies, please refer to our [Security](https://docs.prowler.com/projects/prowler-open-source/en/latest/security/) section in our documentation.
We will coordinate public notification of any validated vulnerability with you. Where possible, we prefer that our respective public disclosures be posted simultaneously.
-2
View File
@@ -19,8 +19,6 @@ DJANGO_REFRESH_TOKEN_LIFETIME=1440
DJANGO_CACHE_MAX_AGE=3600
DJANGO_STALE_WHILE_REVALIDATE=60
DJANGO_SECRETS_ENCRYPTION_KEY=""
# Throttle, two options: Empty means no throttle; or if desired use one in DRF format: https://www.django-rest-framework.org/api-guide/throttling/#setting-the-throttling-policy
DJANGO_THROTTLE_TOKEN_OBTAIN=50/minute
# Decide whether to allow Django manage database table partitions
DJANGO_MANAGE_DB_PARTITIONS=[True|False]
DJANGO_CELERY_DEADLOCK_ATTEMPTS=5
+2 -97
View File
@@ -2,111 +2,16 @@
All notable changes to the **Prowler API** are documented in this file.
## [Unreleased]
### Added
- IaC (Infrastructure as Code) provider support for remote repositories [(#TBD)](https://github.com/prowler-cloud/prowler/pull/TBD)
---
## [1.13.0] (Prowler 5.12.0)
### Added
- Integration with JIRA, enabling sending findings to a JIRA project [(#8622)](https://github.com/prowler-cloud/prowler/pull/8622), [(#8637)](https://github.com/prowler-cloud/prowler/pull/8637)
- `GET /overviews/findings_severity` now supports `filter[status]` and `filter[status__in]` to aggregate by specific statuses (`FAIL`, `PASS`)[(#8186)](https://github.com/prowler-cloud/prowler/pull/8186)
- Throttling options for `/api/v1/tokens` using the `DJANGO_THROTTLE_TOKEN_OBTAIN` environment variable [(#8647)](https://github.com/prowler-cloud/prowler/pull/8647)
---
## [1.12.0] (Prowler 5.11.0)
### Added
- Lighthouse support for OpenAI GPT-5 [(#8527)](https://github.com/prowler-cloud/prowler/pull/8527)
- Integration with Amazon Security Hub, enabling sending findings to Security Hub [(#8365)](https://github.com/prowler-cloud/prowler/pull/8365)
- Generate ASFF output for AWS providers with SecurityHub integration enabled [(#8569)](https://github.com/prowler-cloud/prowler/pull/8569)
### Fixed
- GitHub provider always scans user instead of organization when using provider UID [(#8587)](https://github.com/prowler-cloud/prowler/pull/8587)
## [1.11.0] (Prowler 5.10.0)
### Added
- Github provider support [(#8271)](https://github.com/prowler-cloud/prowler/pull/8271)
- Integration with Amazon S3, enabling storage and retrieval of scan data via S3 buckets [(#8056)](https://github.com/prowler-cloud/prowler/pull/8056)
### Fixed
- Avoid sending errors to Sentry in M365 provider when user authentication fails [(#8420)](https://github.com/prowler-cloud/prowler/pull/8420)
---
## [1.10.2] (Prowler v5.9.2)
### Changed
- Optimized queries for resources views [(#8336)](https://github.com/prowler-cloud/prowler/pull/8336)
---
## [v1.10.1] (Prowler v5.9.1)
### Fixed
- Calculate failed findings during scans to prevent heavy database queries [(#8322)](https://github.com/prowler-cloud/prowler/pull/8322)
---
## [v1.10.0] (Prowler v5.9.0)
### Added
- SSO with SAML support [(#8175)](https://github.com/prowler-cloud/prowler/pull/8175)
- `GET /resources/metadata`, `GET /resources/metadata/latest` and `GET /resources/latest` to expose resource metadata and latest scan results [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
### Changed
- `/processors` endpoints to post-process findings. Currently, only the Mutelist processor is supported to allow to mute findings.
- Optimized the underlying queries for resources endpoints [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
- Optimized include parameters for resources view [(#8229)](https://github.com/prowler-cloud/prowler/pull/8229)
- Optimized overview background tasks [(#8300)](https://github.com/prowler-cloud/prowler/pull/8300)
### Fixed
- Search filter for findings and resources [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
- RBAC is now applied to `GET /overviews/providers` [(#8277)](https://github.com/prowler-cloud/prowler/pull/8277)
### Changed
- `POST /schedules/daily` returns a `409 CONFLICT` if already created [(#8258)](https://github.com/prowler-cloud/prowler/pull/8258)
### Security
- Enhanced password validation to enforce 12+ character passwords with special characters, uppercase, lowercase, and numbers [(#8225)](https://github.com/prowler-cloud/prowler/pull/8225)
---
## [v1.9.1] (Prowler v5.8.1)
### Added
- Custom exception for provider connection errors during scans [(#8234)](https://github.com/prowler-cloud/prowler/pull/8234)
### Changed
- Summary and overview tasks now use a dedicated queue and no longer propagate errors to compliance tasks [(#8214)](https://github.com/prowler-cloud/prowler/pull/8214)
### Fixed
- Scan with no resources will not trigger legacy code for findings metadata [(#8183)](https://github.com/prowler-cloud/prowler/pull/8183)
- Invitation email comparison case-insensitive [(#8206)](https://github.com/prowler-cloud/prowler/pull/8206)
### Removed
- Validation of the provider's secret type during updates [(#8197)](https://github.com/prowler-cloud/prowler/pull/8197)
---
## [v1.9.0] (Prowler v5.8.0)
## [v1.9.0] (Prowler UNRELEASED)
### Added
- SSO with SAML support [(#7822)](https://github.com/prowler-cloud/prowler/pull/7822)
- Support GCP Service Account key [(#7824)](https://github.com/prowler-cloud/prowler/pull/7824)
- `GET /compliance-overviews` endpoints to retrieve compliance metadata and specific requirements statuses [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877)
- Lighthouse configuration support [(#7848)](https://github.com/prowler-cloud/prowler/pull/7848)
### Changed
- Reworked `GET /compliance-overviews` to return proper requirement metrics [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877)
- Optional `user` and `password` for M365 provider [(#7992)](https://github.com/prowler-cloud/prowler/pull/7992)
### Fixed
- Scheduled scans are no longer deleted when their daily schedule run is disabled [(#8082)](https://github.com/prowler-cloud/prowler/pull/8082)
---
+4 -22
View File
@@ -36,25 +36,6 @@ RUN ARCH=$(uname -m) && \
ln -s /opt/microsoft/powershell/7/pwsh /usr/bin/pwsh && \
rm /tmp/powershell.tar.gz
# Install Trivy for IaC scanning
ARG TRIVY_VERSION=0.66.0
RUN ARCH=$(uname -m) && \
if [ "$ARCH" = "x86_64" ]; then \
TRIVY_ARCH="Linux-64bit" ; \
elif [ "$ARCH" = "aarch64" ]; then \
TRIVY_ARCH="Linux-ARM64" ; \
else \
echo "Unsupported architecture for Trivy: $ARCH" && exit 1 ; \
fi && \
wget --progress=dot:giga "https://github.com/aquasecurity/trivy/releases/download/v${TRIVY_VERSION}/trivy_${TRIVY_VERSION}_${TRIVY_ARCH}.tar.gz" -O /tmp/trivy.tar.gz && \
tar zxf /tmp/trivy.tar.gz -C /tmp && \
mv /tmp/trivy /usr/local/bin/trivy && \
chmod +x /usr/local/bin/trivy && \
rm /tmp/trivy.tar.gz && \
# Create trivy cache directory with proper permissions
mkdir -p /tmp/.cache/trivy && \
chmod 777 /tmp/.cache/trivy
# Add prowler user
RUN addgroup --gid 1000 prowler && \
adduser --uid 1000 --gid 1000 --disabled-password --gecos "" prowler
@@ -63,9 +44,6 @@ USER prowler
WORKDIR /home/prowler
# Ensure output directory exists
RUN mkdir -p /tmp/prowler_api_output
COPY pyproject.toml ./
RUN pip install --no-cache-dir --upgrade pip && \
@@ -79,6 +57,10 @@ RUN poetry install --no-root && \
RUN poetry run python "$(poetry env info --path)/src/prowler/prowler/providers/m365/lib/powershell/m365_powershell.py"
# Prevents known compatibility error between lxml and libxml2/libxmlsec versions.
# See: https://github.com/xmlsec/python-xmlsec/issues/320
RUN poetry run pip install --force-reinstall --no-binary lxml lxml
COPY src/backend/ ./backend/
COPY docker-entrypoint.sh ./docker-entrypoint.sh
+1 -1
View File
@@ -257,7 +257,7 @@ cd src/backend
python manage.py loaddata api/fixtures/0_dev_users.json --database admin
```
> The default credentials are `dev@prowler.com:Thisisapassword123@` or `dev2@prowler.com:Thisisapassword123@`
> The default credentials are `dev@prowler.com:thisisapassword123` or `dev2@prowler.com:thisisapassword123`
## Run tests
+1 -1
View File
@@ -32,7 +32,7 @@ start_prod_server() {
start_worker() {
echo "Starting the worker..."
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview,integrations -E --max-tasks-per-child 1
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill -E --max-tasks-per-child 1
}
start_worker_beat() {
+1399 -1738
View File
File diff suppressed because it is too large Load Diff
+3 -6
View File
@@ -23,15 +23,12 @@ dependencies = [
"drf-spectacular==0.27.2",
"drf-spectacular-jsonapi==0.5.1",
"gunicorn==23.0.0",
"lxml==5.3.2",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@iac-in-the-app",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@master",
"psycopg2-binary==2.9.9",
"pytest-celery[redis] (>=1.0.1,<2.0.0)",
"sentry-sdk[django] (>=2.20.0,<3.0.0)",
"uuid6==2024.7.10",
"openai (>=1.82.0,<2.0.0)",
"xmlsec==1.3.14",
"h2 (==4.3.0)"
"openai (>=1.82.0,<2.0.0)"
]
description = "Prowler's API (Django/DRF)"
license = "Apache-2.0"
@@ -39,7 +36,7 @@ name = "prowler-api"
package-mode = false
# Needed for the SDK compatibility
requires-python = ">=3.11,<3.13"
version = "1.13.0"
version = "1.9.0"
[project.scripts]
celery = "src.backend.config.settings.celery"
+58 -6
View File
@@ -3,7 +3,14 @@ from django.db import transaction
from api.db_router import MainRouter
from api.db_utils import rls_transaction
from api.models import Membership, Role, Tenant, User, UserRoleRelationship
from api.models import (
Membership,
Role,
SAMLConfiguration,
Tenant,
User,
UserRoleRelationship,
)
class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
@@ -17,7 +24,7 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
def pre_social_login(self, request, sociallogin):
# Link existing accounts with the same email address
email = sociallogin.account.extra_data.get("email")
if sociallogin.provider.id == "saml":
if sociallogin.account.provider == "saml":
email = sociallogin.user.email
if email:
existing_user = self.get_user_by_email(email)
@@ -31,10 +38,57 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
"""
with transaction.atomic(using=MainRouter.admin_db):
user = super().save_user(request, sociallogin, form)
provider = sociallogin.provider.id
provider = sociallogin.account.provider
extra = sociallogin.account.extra_data
if provider != "saml":
if provider == "saml":
# Handle SAML-specific logic
user.first_name = extra.get("firstName", [""])[0]
user.last_name = extra.get("lastName", [""])[0]
user.company_name = extra.get("organization", [""])[0]
user.name = f"{user.first_name} {user.last_name}".strip()
user.save(using=MainRouter.admin_db)
email_domain = user.email.split("@")[-1]
tenant = (
SAMLConfiguration.objects.using(MainRouter.admin_db)
.get(email_domain=email_domain)
.tenant
)
with rls_transaction(str(tenant.id)):
role_name = extra.get("userType", ["saml_default_role"])[0].strip()
try:
role = Role.objects.using(MainRouter.admin_db).get(
name=role_name, tenant_id=tenant.id
)
except Role.DoesNotExist:
role = Role.objects.using(MainRouter.admin_db).create(
name=role_name,
tenant_id=tenant.id,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=False,
manage_integrations=False,
manage_scans=False,
unlimited_visibility=False,
)
Membership.objects.using(MainRouter.admin_db).create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.MEMBER,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
else:
# Handle other providers (e.g., GitHub, Google)
user.save(using=MainRouter.admin_db)
social_account_name = extra.get("name")
@@ -65,7 +119,5 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
role=role,
tenant_id=tenant.id,
)
else:
request.session["saml_user_created"] = str(user.id)
return user
-35
View File
@@ -175,29 +175,6 @@ def create_objects_in_batches(
model.objects.bulk_create(chunk, batch_size)
def update_objects_in_batches(
tenant_id: str, model, objects: list, fields: list, batch_size: int = 500
):
"""
Bulk-update model instances in repeated, per-tenant RLS transactions.
All chunks execute in their own transaction, so no single transaction
grows too large.
Args:
tenant_id (str): UUID string of the tenant under which to set RLS.
model: Django model class whose `.objects.bulk_update()` will be called.
objects (list): List of model instances (saved) to bulk-update.
fields (list): List of field names to update.
batch_size (int): Maximum number of objects per bulk_update call.
"""
total = len(objects)
for start in range(0, total, batch_size):
chunk = objects[start : start + batch_size]
with rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
model.objects.bulk_update(chunk, fields, batch_size)
# Postgres Enums
@@ -552,15 +529,3 @@ class IntegrationTypeEnum(EnumType):
class IntegrationTypeEnumField(PostgresEnumField):
def __init__(self, *args, **kwargs):
super().__init__("integration_type", *args, **kwargs)
# Postgres enum definition for Processor type
class ProcessorTypeEnum(EnumType):
enum_type_name = "processor_type"
class ProcessorTypeEnumField(PostgresEnumField):
def __init__(self, *args, **kwargs):
super().__init__("processor_type", *args, **kwargs)
-23
View File
@@ -57,11 +57,6 @@ class TaskInProgressException(TaskManagementError):
super().__init__()
# Provider connection errors
class ProviderConnectionError(Exception):
"""Base exception for provider connection errors."""
def custom_exception_handler(exc, context):
if isinstance(exc, django_validation_error):
if hasattr(exc, "error_dict"):
@@ -78,21 +73,3 @@ def custom_exception_handler(exc, context):
message_item["message"] for message_item in exc.detail["messages"]
]
return exception_handler(exc, context)
class ConflictException(APIException):
status_code = status.HTTP_409_CONFLICT
default_detail = "A conflict occurred. The resource already exists."
default_code = "conflict"
def __init__(self, detail=None, code=None, pointer=None):
error_detail = {
"detail": detail or self.default_detail,
"status": self.status_code,
"code": self.default_code,
}
if pointer:
error_detail["source"] = {"pointer": pointer}
super().__init__(detail=[error_detail])
+1 -177
View File
@@ -1,8 +1,7 @@
from datetime import date, datetime, timedelta, timezone
from dateutil.parser import parse
from django.conf import settings
from django.db.models import F, Q
from django.db.models import Q
from django_filters.rest_framework import (
BaseInFilter,
BooleanFilter,
@@ -28,9 +27,7 @@ from api.models import (
Integration,
Invitation,
Membership,
OverviewStatusChoices,
PermissionChoices,
Processor,
Provider,
ProviderGroup,
ProviderSecret,
@@ -341,8 +338,6 @@ class ResourceFilter(ProviderRelationshipFilterSet):
tags = CharFilter(method="filter_tag")
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
updated_at = DateFilter(field_name="updated_at", lookup_expr="date")
scan = UUIDFilter(field_name="provider__scan", lookup_expr="exact")
scan__in = UUIDInFilter(field_name="provider__scan", lookup_expr="in")
class Meta:
model = Resource
@@ -357,82 +352,6 @@ class ResourceFilter(ProviderRelationshipFilterSet):
"updated_at": ["gte", "lte"],
}
def filter_queryset(self, queryset):
if not (self.data.get("scan") or self.data.get("scan__in")) and not (
self.data.get("updated_at")
or self.data.get("updated_at__date")
or self.data.get("updated_at__gte")
or self.data.get("updated_at__lte")
):
raise ValidationError(
[
{
"detail": "At least one date filter is required: filter[updated_at], filter[updated_at.gte], "
"or filter[updated_at.lte].",
"status": 400,
"source": {"pointer": "/data/attributes/updated_at"},
"code": "required",
}
]
)
gte_date = (
parse(self.data.get("updated_at__gte")).date()
if self.data.get("updated_at__gte")
else datetime.now(timezone.utc).date()
)
lte_date = (
parse(self.data.get("updated_at__lte")).date()
if self.data.get("updated_at__lte")
else datetime.now(timezone.utc).date()
)
if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
):
raise ValidationError(
[
{
"detail": f"The date range cannot exceed {settings.FINDINGS_MAX_DAYS_IN_RANGE} days.",
"status": 400,
"source": {"pointer": "/data/attributes/updated_at"},
"code": "invalid",
}
]
)
return super().filter_queryset(queryset)
def filter_tag_key(self, queryset, name, value):
return queryset.filter(Q(tags__key=value) | Q(tags__key__icontains=value))
def filter_tag_value(self, queryset, name, value):
return queryset.filter(Q(tags__value=value) | Q(tags__value__icontains=value))
def filter_tag(self, queryset, name, value):
# We won't know what the user wants to filter on just based on the value,
# and we don't want to build special filtering logic for every possible
# provider tag spec, so we'll just do a full text search
return queryset.filter(tags__text_search=value)
class LatestResourceFilter(ProviderRelationshipFilterSet):
tag_key = CharFilter(method="filter_tag_key")
tag_value = CharFilter(method="filter_tag_value")
tag = CharFilter(method="filter_tag")
tags = CharFilter(method="filter_tag")
class Meta:
model = Resource
fields = {
"provider": ["exact", "in"],
"uid": ["exact", "icontains"],
"name": ["exact", "icontains"],
"region": ["exact", "icontains", "in"],
"service": ["exact", "icontains", "in"],
"type": ["exact", "icontains", "in"],
}
def filter_tag_key(self, queryset, name, value):
return queryset.filter(Q(tags__key=value) | Q(tags__key__icontains=value))
@@ -751,72 +670,6 @@ class ScanSummaryFilter(FilterSet):
}
class ScanSummarySeverityFilter(ScanSummaryFilter):
"""Filter for findings_severity ScanSummary endpoint - includes status filters"""
# Custom status filters - only for severity grouping endpoint
status = ChoiceFilter(method="filter_status", choices=OverviewStatusChoices.choices)
status__in = CharInFilter(method="filter_status_in", lookup_expr="in")
def filter_status(self, queryset, name, value):
# Validate the status value
if value not in [choice[0] for choice in OverviewStatusChoices.choices]:
raise ValidationError(f"Invalid status value: {value}")
# Apply the filter by annotating the queryset with the status field
if value == OverviewStatusChoices.FAIL:
return queryset.annotate(status_count=F("fail"))
elif value == OverviewStatusChoices.PASS:
return queryset.annotate(status_count=F("_pass"))
else:
return queryset.annotate(status_count=F("total"))
def filter_status_in(self, queryset, name, value):
# Validate the status values
valid_statuses = [choice[0] for choice in OverviewStatusChoices.choices]
for status_val in value:
if status_val not in valid_statuses:
raise ValidationError(f"Invalid status value: {status_val}")
# If all statuses or no valid statuses, use total
if (
set(value)
>= {
OverviewStatusChoices.FAIL,
OverviewStatusChoices.PASS,
}
or not value
):
return queryset.annotate(status_count=F("total"))
# Build the sum expression based on status values
sum_expression = None
for status in value:
if status == OverviewStatusChoices.FAIL:
field_expr = F("fail")
elif status == OverviewStatusChoices.PASS:
field_expr = F("_pass")
else:
continue
if sum_expression is None:
sum_expression = field_expr
else:
sum_expression = sum_expression + field_expr
if sum_expression is None:
return queryset.annotate(status_count=F("total"))
return queryset.annotate(status_count=sum_expression)
class Meta:
model = ScanSummary
fields = {
"inserted_at": ["date", "gte", "lte"],
"region": ["exact", "icontains", "in"],
}
class ServiceOverviewFilter(ScanSummaryFilter):
def is_valid(self):
# Check if at least one of the inserted_at filters is present
@@ -851,32 +704,3 @@ class IntegrationFilter(FilterSet):
fields = {
"inserted_at": ["date", "gte", "lte"],
}
class ProcessorFilter(FilterSet):
processor_type = ChoiceFilter(choices=Processor.ProcessorChoices.choices)
processor_type__in = ChoiceInFilter(
choices=Processor.ProcessorChoices.choices,
field_name="processor_type",
lookup_expr="in",
)
class IntegrationJiraFindingsFilter(FilterSet):
# To be expanded as needed
finding_id = UUIDFilter(field_name="id", lookup_expr="exact")
finding_id__in = UUIDInFilter(field_name="id", lookup_expr="in")
class Meta:
model = Finding
fields = {}
def filter_queryset(self, queryset):
# Validate that there is at least one filter provided
if not self.data:
raise ValidationError(
{
"findings": "No finding filters provided. At least one filter is required."
}
)
return super().filter_queryset(queryset)
@@ -3,7 +3,7 @@
"model": "api.user",
"pk": "8b38e2eb-6689-4f1e-a4ba-95b275130200",
"fields": {
"password": "pbkdf2_sha256$870000$Z63pGJ7nre48hfcGbk5S0O$rQpKczAmijs96xa+gPVJifpT3Fetb8DOusl5Eq6gxac=",
"password": "pbkdf2_sha256$720000$vA62S78kog2c2ytycVQdke$Fp35GVLLMyy5fUq3krSL9I02A+ocQ+RVa4S22LIAO5s=",
"last_login": null,
"name": "Devie Prowlerson",
"email": "dev@prowler.com",
@@ -16,7 +16,7 @@
"model": "api.user",
"pk": "b6493a3a-c997-489b-8b99-278bf74de9f6",
"fields": {
"password": "pbkdf2_sha256$870000$Z63pGJ7nre48hfcGbk5S0O$rQpKczAmijs96xa+gPVJifpT3Fetb8DOusl5Eq6gxac=",
"password": "pbkdf2_sha256$720000$vA62S78kog2c2ytycVQdke$Fp35GVLLMyy5fUq3krSL9I02A+ocQ+RVa4S22LIAO5s=",
"last_login": null,
"name": "Devietoo Prowlerson",
"email": "dev2@prowler.com",
@@ -24,18 +24,5 @@
"is_active": true,
"date_joined": "2024-09-18T09:04:20.850Z"
}
},
{
"model": "api.user",
"pk": "6d4f8a91-3c2e-4b5a-8f7d-1e9c5b2a4d6f",
"fields": {
"password": "pbkdf2_sha256$870000$Z63pGJ7nre48hfcGbk5S0O$rQpKczAmijs96xa+gPVJifpT3Fetb8DOusl5Eq6gxac=",
"last_login": null,
"name": "E2E Test User",
"email": "e2e@prowler.com",
"company_name": "Prowler E2E Tests",
"is_active": true,
"date_joined": "2024-01-01T00:00:00.850Z"
}
}
]
@@ -46,24 +46,5 @@
"role": "member",
"date_joined": "2024-09-19T11:03:59.712Z"
}
},
{
"model": "api.tenant",
"pk": "7c8f94a3-e2d1-4b3a-9f87-2c4d5e6f1a2b",
"fields": {
"inserted_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"name": "E2E Test Tenant"
}
},
{
"model": "api.membership",
"pk": "9b1a2c3d-4e5f-6789-abc1-23456789def0",
"fields": {
"user": "6d4f8a91-3c2e-4b5a-8f7d-1e9c5b2a4d6f",
"tenant": "7c8f94a3-e2d1-4b3a-9f87-2c4d5e6f1a2b",
"role": "owner",
"date_joined": "2024-01-01T00:00:00.000Z"
}
}
]
@@ -149,32 +149,5 @@
"user": "8b38e2eb-6689-4f1e-a4ba-95b275130200",
"inserted_at": "2024-11-20T15:36:14.302Z"
}
},
{
"model": "api.role",
"pk": "a5b6c7d8-9e0f-1234-5678-90abcdef1234",
"fields": {
"tenant": "7c8f94a3-e2d1-4b3a-9f87-2c4d5e6f1a2b",
"name": "e2e_admin",
"manage_users": true,
"manage_account": true,
"manage_billing": true,
"manage_providers": true,
"manage_integrations": true,
"manage_scans": true,
"unlimited_visibility": true,
"inserted_at": "2024-01-01T00:00:00.000Z",
"updated_at": "2024-01-01T00:00:00.000Z"
}
},
{
"model": "api.userrolerelationship",
"pk": "f1e2d3c4-b5a6-9876-5432-10fedcba9876",
"fields": {
"tenant": "7c8f94a3-e2d1-4b3a-9f87-2c4d5e6f1a2b",
"role": "a5b6c7d8-9e0f-1234-5678-90abcdef1234",
"user": "6d4f8a91-3c2e-4b5a-8f7d-1e9c5b2a4d6f",
"inserted_at": "2024-01-01T00:00:00.000Z"
}
}
]
File diff suppressed because one or more lines are too long
@@ -1,61 +1,57 @@
# Generated by Django 5.1.10 on 2025-07-02 15:47
# Generated by Django 5.1.8 on 2025-05-15 09:54
import uuid
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
import api.db_utils
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0031_scan_disable_on_cascade_periodic_tasks"),
("api", "0029_findings_check_index_parent"),
]
operations = [
migrations.AlterField(
model_name="integration",
name="integration_type",
field=api.db_utils.IntegrationTypeEnumField(
choices=[
("amazon_s3", "Amazon S3"),
("aws_security_hub", "AWS Security Hub"),
("jira", "JIRA"),
("slack", "Slack"),
]
),
),
migrations.CreateModel(
name="SAMLToken",
name="SAMLDomainIndex",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("expires_at", models.DateTimeField(editable=False)),
("token", models.JSONField(unique=True)),
("email_domain", models.CharField(max_length=254, unique=True)),
(
"user",
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "saml_tokens",
"db_table": "saml_domain_index",
},
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=models.UniqueConstraint(
fields=("email_domain", "tenant"),
name="unique_resources_by_email_domain",
),
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=api.rls.BaseSecurityConstraint(
name="statements_on_samldomainindex",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.CreateModel(
name="SAMLConfiguration",
fields=[
@@ -109,42 +105,16 @@ class Migration(migrations.Migration):
fields=("tenant",), name="unique_samlconfig_per_tenant"
),
),
migrations.CreateModel(
name="SAMLDomainIndex",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("email_domain", models.CharField(max_length=254, unique=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "saml_domain_index",
},
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=models.UniqueConstraint(
fields=("email_domain", "tenant"),
name="unique_resources_by_email_domain",
),
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=api.rls.BaseSecurityConstraint(
name="statements_on_samldomainindex",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
migrations.AlterField(
model_name="integration",
name="integration_type",
field=api.db_utils.IntegrationTypeEnumField(
choices=[
("amazon_s3", "Amazon S3"),
("aws_security_hub", "AWS Security Hub"),
("jira", "JIRA"),
("slack", "Slack"),
]
),
),
]
@@ -11,7 +11,7 @@ import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0029_findings_check_index_parent"),
("api", "0030_samlconfigurations"),
]
operations = [
@@ -54,7 +54,6 @@ class Migration(migrations.Migration):
("gpt-4o-mini-2024-07-18", "GPT-4o Mini v2024-07-18"),
("gpt-4o-mini", "GPT-4o Mini Default"),
],
default="gpt-4o-2024-08-06",
help_text="Must be one of the supported model names",
max_length=50,
),
@@ -1,24 +0,0 @@
# Generated by Django 5.1.10 on 2025-06-23 10:04
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0030_lighthouseconfiguration"),
("django_celery_beat", "0019_alter_periodictasks_options"),
]
operations = [
migrations.AlterField(
model_name="scan",
name="scheduler_task",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="django_celery_beat.periodictask",
),
),
]
@@ -1,34 +0,0 @@
# Generated by Django 5.1.5 on 2025-03-03 15:46
from functools import partial
from django.db import migrations
from api.db_utils import PostgresEnumMigration, ProcessorTypeEnum, register_enum
from api.models import Processor
ProcessorTypeEnumMigration = PostgresEnumMigration(
enum_name="processor_type",
enum_values=tuple(
processor_type[0] for processor_type in Processor.ProcessorChoices.choices
),
)
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0032_saml"),
]
operations = [
migrations.RunPython(
ProcessorTypeEnumMigration.create_enum_type,
reverse_code=ProcessorTypeEnumMigration.drop_enum_type,
),
migrations.RunPython(
partial(register_enum, enum_class=ProcessorTypeEnum),
reverse_code=migrations.RunPython.noop,
),
]
@@ -1,88 +0,0 @@
# Generated by Django 5.1.5 on 2025-03-26 13:04
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
from api.rls import RowLevelSecurityConstraint
class Migration(migrations.Migration):
dependencies = [
("api", "0033_processors_enum"),
]
operations = [
migrations.CreateModel(
name="Processor",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"processor_type",
api.db_utils.ProcessorTypeEnumField(
choices=[("mutelist", "Mutelist")]
),
),
("configuration", models.JSONField(default=dict)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "processors",
"abstract": False,
"indexes": [
models.Index(
fields=["tenant_id", "id"], name="processor_tenant_id_idx"
),
models.Index(
fields=["tenant_id", "processor_type"],
name="processor_tenant_type_idx",
),
],
},
),
migrations.AddConstraint(
model_name="processor",
constraint=models.UniqueConstraint(
fields=("tenant_id", "processor_type"),
name="unique_processor_types_tenant",
),
),
migrations.AddConstraint(
model_name="processor",
constraint=RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_processor",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddField(
model_name="scan",
name="processor",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="scans",
related_query_name="scan",
to="api.processor",
),
),
]
@@ -1,22 +0,0 @@
import django.core.validators
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0034_processors"),
]
operations = [
migrations.AddField(
model_name="finding",
name="muted_reason",
field=models.TextField(
blank=True,
max_length=500,
null=True,
validators=[django.core.validators.MinLengthValidator(3)],
),
),
]
@@ -1,30 +0,0 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0035_finding_muted_reason"),
]
operations = [
migrations.RunPython(
partial(
create_index_on_partitions,
parent_table="resource_finding_mappings",
index_name="rfm_tenant_finding_idx",
columns="tenant_id, finding_id",
method="BTREE",
),
reverse_code=partial(
drop_index_on_partitions,
parent_table="resource_finding_mappings",
index_name="rfm_tenant_finding_idx",
),
),
]
@@ -1,17 +0,0 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0036_rfm_tenant_finding_index_partitions"),
]
operations = [
migrations.AddIndex(
model_name="resourcefindingmapping",
index=models.Index(
fields=["tenant_id", "finding_id"],
name="rfm_tenant_finding_idx",
),
),
]
@@ -1,15 +0,0 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0037_rfm_tenant_finding_index_parent"),
]
operations = [
migrations.AddField(
model_name="resource",
name="failed_findings_count",
field=models.IntegerField(default=0),
)
]
@@ -1,20 +0,0 @@
from django.contrib.postgres.operations import AddIndexConcurrently
from django.db import migrations, models
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0038_resource_failed_findings_count"),
]
operations = [
AddIndexConcurrently(
model_name="resource",
index=models.Index(
fields=["tenant_id", "-failed_findings_count", "id"],
name="resources_failed_findings_idx",
),
),
]
@@ -1,30 +0,0 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0039_resource_resources_failed_findings_idx"),
]
operations = [
migrations.RunPython(
partial(
create_index_on_partitions,
parent_table="resource_finding_mappings",
index_name="rfm_tenant_resource_idx",
columns="tenant_id, resource_id",
method="BTREE",
),
reverse_code=partial(
drop_index_on_partitions,
parent_table="resource_finding_mappings",
index_name="rfm_tenant_resource_idx",
),
),
]
@@ -1,17 +0,0 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0040_rfm_tenant_resource_index_partitions"),
]
operations = [
migrations.AddIndex(
model_name="resourcefindingmapping",
index=models.Index(
fields=["tenant_id", "resource_id"],
name="rfm_tenant_resource_idx",
),
),
]
@@ -1,23 +0,0 @@
from django.contrib.postgres.operations import AddIndexConcurrently
from django.db import migrations, models
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0041_rfm_tenant_resource_parent_partitions"),
("django_celery_beat", "0019_alter_periodictasks_options"),
]
operations = [
AddIndexConcurrently(
model_name="scan",
index=models.Index(
condition=models.Q(("state", "completed")),
fields=["tenant_id", "provider_id", "-inserted_at"],
include=("id",),
name="scans_prov_ins_desc_idx",
),
),
]
@@ -1,33 +0,0 @@
# Generated by Django 5.1.7 on 2025-07-09 14:44
from django.db import migrations
import api.db_utils
class Migration(migrations.Migration):
dependencies = [
("api", "0042_scan_scans_prov_ins_desc_idx"),
]
operations = [
migrations.AlterField(
model_name="provider",
name="provider",
field=api.db_utils.ProviderEnumField(
choices=[
("aws", "AWS"),
("azure", "Azure"),
("gcp", "GCP"),
("kubernetes", "Kubernetes"),
("m365", "M365"),
("github", "GitHub"),
],
default="aws",
),
),
migrations.RunSQL(
"ALTER TYPE provider ADD VALUE IF NOT EXISTS 'github';",
reverse_sql=migrations.RunSQL.noop,
),
]
@@ -1,19 +0,0 @@
# Generated by Django 5.1.10 on 2025-07-17 11:52
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0043_github_provider"),
]
operations = [
migrations.AddConstraint(
model_name="integration",
constraint=models.UniqueConstraint(
fields=("configuration", "tenant"),
name="unique_configuration_per_tenant",
),
),
]
@@ -1,17 +0,0 @@
# Generated by Django 5.1.10 on 2025-07-21 16:08
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0044_integration_unique_configuration_per_tenant"),
]
operations = [
migrations.AlterField(
model_name="scan",
name="output_location",
field=models.CharField(blank=True, max_length=4096, null=True),
),
]
@@ -1,33 +0,0 @@
# Generated by Django 5.1.10 on 2025-08-20 09:04
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0045_alter_scan_output_location"),
]
operations = [
migrations.AlterField(
model_name="lighthouseconfiguration",
name="model",
field=models.CharField(
choices=[
("gpt-4o-2024-11-20", "GPT-4o v2024-11-20"),
("gpt-4o-2024-08-06", "GPT-4o v2024-08-06"),
("gpt-4o-2024-05-13", "GPT-4o v2024-05-13"),
("gpt-4o", "GPT-4o Default"),
("gpt-4o-mini-2024-07-18", "GPT-4o Mini v2024-07-18"),
("gpt-4o-mini", "GPT-4o Mini Default"),
("gpt-5-2025-08-07", "GPT-5 v2025-08-07"),
("gpt-5", "GPT-5 Default"),
("gpt-5-mini-2025-08-07", "GPT-5 Mini v2025-08-07"),
("gpt-5-mini", "GPT-5 Mini Default"),
],
default="gpt-4o-2024-08-06",
help_text="Must be one of the supported model names",
max_length=50,
),
),
]
@@ -1,16 +0,0 @@
# Generated by Django 5.1.10 on 2025-08-20 08:24
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("api", "0046_lighthouse_gpt5"),
]
operations = [
migrations.RemoveConstraint(
model_name="integration",
name="unique_configuration_per_tenant",
),
]
@@ -1,34 +0,0 @@
# Generated by Django 5.1.10 on 2025-09-09 09:25
from django.db import migrations
import api.db_utils
class Migration(migrations.Migration):
dependencies = [
("api", "0047_remove_integration_unique_configuration_per_tenant"),
]
operations = [
migrations.AlterField(
model_name="provider",
name="provider",
field=api.db_utils.ProviderEnumField(
choices=[
("aws", "AWS"),
("azure", "Azure"),
("gcp", "GCP"),
("kubernetes", "Kubernetes"),
("m365", "M365"),
("github", "GitHub"),
("iac", "IaC"),
],
default="aws",
),
),
migrations.RunSQL(
"ALTER TYPE provider ADD VALUE IF NOT EXISTS 'iac';",
reverse_sql=migrations.RunSQL.noop,
),
]
+27 -195
View File
@@ -2,7 +2,6 @@ import json
import logging
import re
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
from allauth.socialaccount.models import SocialApp
@@ -34,7 +33,6 @@ from api.db_utils import (
IntegrationTypeEnumField,
InvitationStateEnumField,
MemberRoleEnumField,
ProcessorTypeEnumField,
ProviderEnumField,
ProviderSecretTypeEnumField,
ScanTriggerEnumField,
@@ -74,15 +72,6 @@ class StatusChoices(models.TextChoices):
MANUAL = "MANUAL", _("Manual")
class OverviewStatusChoices(models.TextChoices):
"""
Status filters allowed in overview/severity endpoints.
"""
FAIL = "FAIL", _("Fail")
PASS = "PASS", _("Pass")
class StateChoices(models.TextChoices):
AVAILABLE = "available", _("Available")
SCHEDULED = "scheduled", _("Scheduled")
@@ -214,8 +203,6 @@ class Provider(RowLevelSecurityProtectedModel):
GCP = "gcp", _("GCP")
KUBERNETES = "kubernetes", _("Kubernetes")
M365 = "m365", _("M365")
GITHUB = "github", _("GitHub")
IAC = "iac", _("IaC")
@staticmethod
def validate_aws_uid(value):
@@ -276,29 +263,6 @@ class Provider(RowLevelSecurityProtectedModel):
pointer="/data/attributes/uid",
)
@staticmethod
def validate_github_uid(value):
if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9-]{0,38}$", value):
raise ModelValidationError(
detail="GitHub provider ID must be a valid GitHub username or organization name (1-39 characters, "
"starting with alphanumeric, containing only alphanumeric characters and hyphens).",
code="github-uid",
pointer="/data/attributes/uid",
)
@staticmethod
def validate_iac_uid(value):
# Validate that it's a valid repository URL (git URL format)
if not re.match(
r"^(https?://|git@|ssh://)[^\s/]+[^\s]*\.git$|^(https?://)[^\s/]+[^\s]*$",
value,
):
raise ModelValidationError(
detail="IaC provider ID must be a valid repository URL (e.g., https://github.com/user/repo or https://github.com/user/repo.git).",
code="iac-uid",
pointer="/data/attributes/uid",
)
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
@@ -444,6 +408,20 @@ class Scan(RowLevelSecurityProtectedModel):
name = models.CharField(
blank=True, null=True, max_length=100, validators=[MinLengthValidator(3)]
)
provider = models.ForeignKey(
Provider,
on_delete=models.CASCADE,
related_name="scans",
related_query_name="scan",
)
task = models.ForeignKey(
Task,
on_delete=models.CASCADE,
related_name="scans",
related_query_name="scan",
null=True,
blank=True,
)
trigger = ScanTriggerEnumField(
choices=TriggerChoices.choices,
)
@@ -459,31 +437,11 @@ class Scan(RowLevelSecurityProtectedModel):
completed_at = models.DateTimeField(null=True, blank=True)
next_scan_at = models.DateTimeField(null=True, blank=True)
scheduler_task = models.ForeignKey(
PeriodicTask, on_delete=models.SET_NULL, null=True, blank=True
)
output_location = models.CharField(blank=True, null=True, max_length=4096)
provider = models.ForeignKey(
Provider,
on_delete=models.CASCADE,
related_name="scans",
related_query_name="scan",
)
task = models.ForeignKey(
Task,
on_delete=models.CASCADE,
related_name="scans",
related_query_name="scan",
null=True,
blank=True,
)
processor = models.ForeignKey(
"Processor",
on_delete=models.SET_NULL,
related_name="scans",
related_query_name="scan",
null=True,
blank=True,
PeriodicTask, on_delete=models.CASCADE, null=True, blank=True
)
output_location = models.CharField(blank=True, null=True, max_length=200)
# TODO: mutelist foreign key
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "scans"
@@ -510,13 +468,6 @@ class Scan(RowLevelSecurityProtectedModel):
condition=Q(state=StateChoices.COMPLETED),
name="scans_prov_state_ins_desc_idx",
),
# TODO This might replace `scans_prov_state_ins_desc_idx` completely. Review usage
models.Index(
fields=["tenant_id", "provider_id", "-inserted_at"],
condition=Q(state=StateChoices.COMPLETED),
include=["id"],
name="scans_prov_ins_desc_idx",
),
]
class JSONAPIMeta:
@@ -602,8 +553,6 @@ class Resource(RowLevelSecurityProtectedModel):
details = models.TextField(blank=True, null=True)
partition = models.TextField(blank=True, null=True)
failed_findings_count = models.IntegerField(default=0)
# Relationships
tags = models.ManyToManyField(
ResourceTag,
@@ -650,10 +599,6 @@ class Resource(RowLevelSecurityProtectedModel):
fields=["tenant_id", "provider_id"],
name="resources_tenant_provider_idx",
),
models.Index(
fields=["tenant_id", "-failed_findings_count", "id"],
name="resources_failed_findings_idx",
),
]
constraints = [
@@ -752,9 +697,6 @@ class Finding(PostgresPartitionedModel, RowLevelSecurityProtectedModel):
check_id = models.CharField(max_length=100, blank=False, null=False)
check_metadata = models.JSONField(default=dict, null=False)
muted = models.BooleanField(default=False, null=False)
muted_reason = models.TextField(
blank=True, null=True, validators=[MinLengthValidator(3)], max_length=500
)
compliance = models.JSONField(default=dict, null=True, blank=True)
# Denormalize resource data for performance
@@ -896,16 +838,6 @@ class ResourceFindingMapping(PostgresPartitionedModel, RowLevelSecurityProtected
# - tenant_id
# - id
indexes = [
models.Index(
fields=["tenant_id", "finding_id"],
name="rfm_tenant_finding_idx",
),
models.Index(
fields=["tenant_id", "resource_id"],
name="rfm_tenant_resource_idx",
),
]
constraints = [
models.UniqueConstraint(
fields=("tenant_id", "resource_id", "finding_id"),
@@ -1010,11 +942,6 @@ class Invitation(RowLevelSecurityProtectedModel):
null=True,
)
def save(self, *args, **kwargs):
if self.email:
self.email = self.email.strip().lower()
super().save(*args, **kwargs)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "invitations"
@@ -1369,7 +1296,7 @@ class ScanSummary(RowLevelSecurityProtectedModel):
class Integration(RowLevelSecurityProtectedModel):
class IntegrationChoices(models.TextChoices):
AMAZON_S3 = "amazon_s3", _("Amazon S3")
S3 = "amazon_s3", _("Amazon S3")
AWS_SECURITY_HUB = "aws_security_hub", _("AWS Security Hub")
JIRA = "jira", _("JIRA")
SLACK = "slack", _("Slack")
@@ -1443,26 +1370,6 @@ class IntegrationProviderRelationship(RowLevelSecurityProtectedModel):
]
class SAMLToken(models.Model):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
expires_at = models.DateTimeField(editable=False)
token = models.JSONField(unique=True)
user = models.ForeignKey(User, on_delete=models.CASCADE)
class Meta:
db_table = "saml_tokens"
def save(self, *args, **kwargs):
if not self.expires_at:
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=15)
super().save(*args, **kwargs)
def is_expired(self) -> bool:
return datetime.now(timezone.utc) >= self.expires_at
class SAMLDomainIndex(models.Model):
"""
Public index of SAML domains. No RLS. Used for fast lookup in SAML login flow.
@@ -1540,7 +1447,7 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
),
]
def clean(self, old_email_domain=None, is_create=False):
def clean(self, old_email_domain=None):
# Domain must not contain @
if "@" in self.email_domain:
raise ValidationError({"email_domain": "Domain must not contain @"})
@@ -1564,25 +1471,6 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
{"tenant": "There is a problem with your email domain."}
)
# The entityID must be unique in the system
idp_settings = self._parsed_metadata
entity_id = idp_settings.get("entity_id")
if entity_id:
# Find any SocialApp with this entityID
q = SocialApp.objects.filter(provider="saml", provider_id=entity_id)
# If updating, exclude our own SocialApp from the check
if not is_create:
q = q.exclude(client_id=old_email_domain)
else:
q = q.exclude(client_id=self.email_domain)
if q.exists():
raise ValidationError(
{"metadata_xml": "There is a problem with your metadata."}
)
def save(self, *args, **kwargs):
self.email_domain = self.email_domain.strip().lower()
is_create = not SAMLConfiguration.objects.filter(pk=self.pk).exists()
@@ -1595,8 +1483,7 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
old_email_domain = None
old_metadata_xml = None
self._parsed_metadata = self._parse_metadata()
self.clean(old_email_domain, is_create)
self.clean(old_email_domain)
super().save(*args, **kwargs)
if is_create or (
@@ -1614,12 +1501,6 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
email_domain=self.email_domain, defaults={"tenant": self.tenant}
)
def delete(self, *args, **kwargs):
super().delete(*args, **kwargs)
SocialApp.objects.filter(provider="saml", client_id=self.email_domain).delete()
SAMLDomainIndex.objects.filter(email_domain=self.email_domain).delete()
def _parse_metadata(self):
"""
Parse the raw IdP metadata XML and extract:
@@ -1639,8 +1520,6 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
# Entity ID
entity_id = root.attrib.get("entityID")
if not entity_id:
raise ValidationError({"metadata_xml": "Missing entityID in metadata."})
# SSO endpoint (must exist)
sso = root.find(".//md:IDPSSODescriptor/md:SingleSignOnService", ns)
@@ -1679,8 +1558,9 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
Create or update the corresponding SocialApp based on email_domain.
If the domain changed, update the matching SocialApp.
"""
idp_settings = self._parse_metadata()
settings_dict = SOCIALACCOUNT_PROVIDERS["saml"].copy()
settings_dict["idp"] = self._parsed_metadata
settings_dict["idp"] = idp_settings
current_site = Site.objects.get(id=settings.SITE_ID)
@@ -1688,24 +1568,19 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
provider="saml", client_id=previous_email_domain or self.email_domain
)
client_id = self.email_domain[:191]
name = f"SAML-{self.email_domain}"[:40]
if social_app_qs.exists():
social_app = social_app_qs.first()
social_app.client_id = client_id
social_app.name = name
social_app.client_id = self.email_domain
social_app.name = f"{self.tenant.name} SAML ({self.email_domain})"
social_app.settings = settings_dict
social_app.provider_id = self._parsed_metadata["entity_id"]
social_app.save()
social_app.sites.set([current_site])
else:
social_app = SocialApp.objects.create(
provider="saml",
client_id=client_id,
name=name,
client_id=self.email_domain,
name=f"{self.tenant.name} SAML ({self.email_domain})",
settings=settings_dict,
provider_id=self._parsed_metadata["entity_id"],
)
social_app.sites.set([current_site])
@@ -1771,10 +1646,6 @@ class LighthouseConfiguration(RowLevelSecurityProtectedModel):
GPT_4O = "gpt-4o", _("GPT-4o Default")
GPT_4O_MINI_2024_07_18 = "gpt-4o-mini-2024-07-18", _("GPT-4o Mini v2024-07-18")
GPT_4O_MINI = "gpt-4o-mini", _("GPT-4o Mini Default")
GPT_5_2025_08_07 = "gpt-5-2025-08-07", _("GPT-5 v2025-08-07")
GPT_5 = "gpt-5", _("GPT-5 Default")
GPT_5_MINI_2025_08_07 = "gpt-5-mini-2025-08-07", _("GPT-5 Mini v2025-08-07")
GPT_5_MINI = "gpt-5-mini", _("GPT-5 Mini Default")
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
@@ -1888,42 +1759,3 @@ class LighthouseConfiguration(RowLevelSecurityProtectedModel):
class JSONAPIMeta:
resource_name = "lighthouse-configurations"
class Processor(RowLevelSecurityProtectedModel):
class ProcessorChoices(models.TextChoices):
MUTELIST = "mutelist", _("Mutelist")
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
processor_type = ProcessorTypeEnumField(choices=ProcessorChoices.choices)
configuration = models.JSONField(default=dict)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "processors"
constraints = [
models.UniqueConstraint(
fields=("tenant_id", "processor_type"),
name="unique_processor_types_tenant",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
indexes = [
models.Index(
fields=["tenant_id", "id"],
name="processor_tenant_id_idx",
),
models.Index(
fields=["tenant_id", "processor_type"],
name="processor_tenant_type_idx",
),
]
class JSONAPIMeta:
resource_name = "processors"
-95
View File
@@ -1,95 +0,0 @@
def _pick_task_response_component(components):
schemas = components.get("schemas", {}) or {}
for candidate in ("TaskResponse",):
if candidate in schemas:
return candidate
return None
def _extract_task_example_from_components(components):
schemas = components.get("schemas", {}) or {}
candidate = "TaskResponse"
doc = schemas.get(candidate)
if isinstance(doc, dict) and "example" in doc:
return doc["example"]
res = schemas.get(candidate)
if isinstance(res, dict) and "example" in res:
example = res["example"]
return example if "data" in example else {"data": example}
# Fallback
return {
"data": {
"type": "tasks",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"attributes": {
"inserted_at": "2019-08-24T14:15:22Z",
"completed_at": "2019-08-24T14:15:22Z",
"name": "string",
"state": "available",
"result": None,
"task_args": None,
"metadata": None,
},
}
}
def attach_task_202_examples(result, generator, request, public): # noqa: F841
if not isinstance(result, dict):
return result
components = result.get("components", {}) or {}
task_resp_component = _pick_task_response_component(components)
task_example = _extract_task_example_from_components(components)
paths = result.get("paths", {}) or {}
for path_item in paths.values():
if not isinstance(path_item, dict):
continue
for method_obj in path_item.values():
if not isinstance(method_obj, dict):
continue
responses = method_obj.get("responses", {}) or {}
resp_202 = responses.get("202")
if not isinstance(resp_202, dict):
continue
content = resp_202.get("content", {}) or {}
jsonapi = content.get("application/vnd.api+json")
if not isinstance(jsonapi, dict):
continue
# Inject example if missing
if "examples" not in jsonapi and "example" not in jsonapi:
jsonapi["examples"] = {
"Task queued": {
"summary": "Task queued",
"value": task_example,
}
}
# Rewrite schema $ref if needed
if task_resp_component:
schema = jsonapi.get("schema")
must_replace = False
if not isinstance(schema, dict):
must_replace = True
else:
ref = schema.get("$ref")
if not ref:
must_replace = True
else:
current = ref.split("/")[-1]
if current != task_resp_component:
must_replace = True
if must_replace:
jsonapi["schema"] = {
"$ref": f"#/components/schemas/{task_resp_component}"
}
return result
File diff suppressed because it is too large Load Diff
@@ -11,7 +11,7 @@ def test_basic_authentication():
client = APIClient()
test_user = "test_email@prowler.com"
test_password = "Test_password@1"
test_password = "test_password"
# Check that a 401 is returned when no basic authentication is provided
no_auth_response = client.get(reverse("provider-list"))
@@ -108,7 +108,7 @@ def test_user_me_when_inviting_users(create_test_user, tenants_fixture, roles_fi
user1_email = "user1@testing.com"
user2_email = "user2@testing.com"
password = "Thisisapassword123@"
password = "thisisapassword123"
user1_response = client.post(
reverse("user-list"),
@@ -187,7 +187,7 @@ class TestTokenSwitchTenant:
client = APIClient()
test_user = "test_email@prowler.com"
test_password = "Test_password1@"
test_password = "test_password"
# Check that we can create a new user without any kind of authentication
user_creation_response = client.post(
@@ -17,7 +17,7 @@ def test_delete_provider_without_executing_task(
client = APIClient()
test_user = "test_email@prowler.com"
test_password = "Test_password1@"
test_password = "test_password"
prowler_task = tasks_fixture[0]
task_mock = Mock()
+26 -21
View File
@@ -1,10 +1,12 @@
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from allauth.socialaccount.models import SocialLogin
from django.contrib.auth import get_user_model
from api.adapters import ProwlerSocialAccountAdapter
from api.db_router import MainRouter
from api.models import Membership, SAMLConfiguration, Tenant
User = get_user_model()
@@ -25,8 +27,7 @@ class TestProwlerSocialAccountAdapter:
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.account = MagicMock()
sociallogin.provider = MagicMock()
sociallogin.provider.id = "saml"
sociallogin.account.provider = "saml"
sociallogin.account.extra_data = {}
sociallogin.user = create_test_user
sociallogin.connect = MagicMock()
@@ -45,9 +46,7 @@ class TestProwlerSocialAccountAdapter:
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.account = MagicMock()
sociallogin.provider = MagicMock()
sociallogin.user = MagicMock()
sociallogin.provider.id = "saml"
sociallogin.account.provider = "github"
sociallogin.account.extra_data = {}
sociallogin.connect = MagicMock()
@@ -55,23 +54,29 @@ class TestProwlerSocialAccountAdapter:
sociallogin.connect.assert_not_called()
def test_save_user_saml_sets_session_flag(self, rf):
def test_save_user_saml_flow(
self,
rf,
saml_setup,
saml_sociallogin,
):
adapter = ProwlerSocialAccountAdapter()
request = rf.get("/")
request.session = {}
saml_sociallogin.user.email = saml_setup["email"]
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.provider = MagicMock()
sociallogin.provider.id = "saml"
sociallogin.account = MagicMock()
sociallogin.account.extra_data = {}
tenant = Tenant.objects.using(MainRouter.admin_db).get(
id=saml_setup["tenant_id"]
)
saml_config = SAMLConfiguration.objects.using(MainRouter.admin_db).get(
tenant=tenant
)
assert saml_config.email_domain == saml_setup["domain"]
mock_user = MagicMock()
mock_user.id = 123
user = adapter.save_user(request, saml_sociallogin)
with patch("api.adapters.super") as mock_super:
with patch("api.adapters.transaction"):
with patch("api.adapters.MainRouter"):
mock_super.return_value.save_user.return_value = mock_user
adapter.save_user(request, sociallogin)
assert request.session["saml_user_created"] == "123"
assert user.email == saml_setup["email"]
assert (
Membership.objects.using(MainRouter.admin_db)
.filter(user=user, tenant=tenant)
.exists()
)
@@ -13,7 +13,6 @@ from api.db_utils import (
enum_to_choices,
generate_random_token,
one_week_from_now,
update_objects_in_batches,
)
from api.models import Provider
@@ -228,88 +227,3 @@ class TestCreateObjectsInBatches:
qs = Provider.objects.filter(tenant=tenant)
assert qs.count() == total
@pytest.mark.django_db
class TestUpdateObjectsInBatches:
@pytest.fixture
def tenant(self, tenants_fixture):
return tenants_fixture[0]
def make_provider_instances(self, tenant, count):
"""
Return a list of `count` unsaved Provider instances for the given tenant.
"""
base_uid = 2000
return [
Provider(
tenant=tenant,
uid=str(base_uid + i),
provider=Provider.ProviderChoices.AWS,
)
for i in range(count)
]
def test_exact_multiple_of_batch(self, tenant):
total = 6
batch_size = 3
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs, batch_size=batch_size)
# Fetch them back, mutate the `uid` field, then update in batches
providers = list(Provider.objects.filter(tenant=tenant))
for p in providers:
p.uid = f"{p.uid}_upd"
update_objects_in_batches(
tenant_id=str(tenant.id),
model=Provider,
objects=providers,
fields=["uid"],
batch_size=batch_size,
)
qs = Provider.objects.filter(tenant=tenant, uid__endswith="_upd")
assert qs.count() == total
def test_non_multiple_of_batch(self, tenant):
total = 7
batch_size = 3
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs, batch_size=batch_size)
providers = list(Provider.objects.filter(tenant=tenant))
for p in providers:
p.uid = f"{p.uid}_upd"
update_objects_in_batches(
tenant_id=str(tenant.id),
model=Provider,
objects=providers,
fields=["uid"],
batch_size=batch_size,
)
qs = Provider.objects.filter(tenant=tenant, uid__endswith="_upd")
assert qs.count() == total
def test_batch_size_default(self, tenant):
default_size = settings.DJANGO_DELETION_BATCH_SIZE
total = default_size + 2
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs)
providers = list(Provider.objects.filter(tenant=tenant))
for p in providers:
p.uid = f"{p.uid}_upd"
# Update without specifying batch_size (uses default)
update_objects_in_batches(
tenant_id=str(tenant.id),
model=Provider,
objects=providers,
fields=["uid"],
)
qs = Provider.objects.filter(tenant=tenant, uid__endswith="_upd")
assert qs.count() == total
+16 -71
View File
@@ -3,7 +3,7 @@ from allauth.socialaccount.models import SocialApp
from django.core.exceptions import ValidationError
from api.db_router import MainRouter
from api.models import Resource, ResourceTag, SAMLConfiguration, SAMLDomainIndex
from api.models import Resource, ResourceTag, SAMLConfiguration, Tenant
@pytest.mark.django_db
@@ -142,8 +142,8 @@ class TestSAMLConfigurationModel:
</md:EntityDescriptor>
"""
def test_creates_valid_configuration(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_creates_valid_configuration(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant A")
config = SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="ssoexample.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
@@ -153,8 +153,8 @@ class TestSAMLConfigurationModel:
assert config.email_domain == "ssoexample.com"
assert SocialApp.objects.filter(client_id="ssoexample.com").exists()
def test_email_domain_with_at_symbol_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_email_domain_with_at_symbol_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant B")
config = SAMLConfiguration(
email_domain="invalid@domain.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
@@ -168,8 +168,9 @@ class TestSAMLConfigurationModel:
assert "email_domain" in errors
assert "Domain must not contain @" in errors["email_domain"][0]
def test_duplicate_email_domain_fails(self, tenants_fixture):
tenant1, tenant2, *_ = tenants_fixture
def test_duplicate_email_domain_fails(self):
tenant1 = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant C1")
tenant2 = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant C2")
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="duplicate.com",
@@ -190,8 +191,8 @@ class TestSAMLConfigurationModel:
assert "tenant" in errors
assert "There is a problem with your email domain." in errors["tenant"][0]
def test_duplicate_tenant_config_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_duplicate_tenant_config_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant D")
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="unique1.com",
@@ -215,8 +216,8 @@ class TestSAMLConfigurationModel:
in errors["tenant"][0]
)
def test_invalid_metadata_xml_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_invalid_metadata_xml_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant E")
config = SAMLConfiguration(
email_domain="brokenxml.com",
metadata_xml="<bad<xml>",
@@ -231,8 +232,8 @@ class TestSAMLConfigurationModel:
assert "Invalid XML" in errors["metadata_xml"][0]
assert "not well-formed" in errors["metadata_xml"][0]
def test_metadata_missing_sso_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_metadata_missing_sso_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant F")
xml = """<md:EntityDescriptor entityID="x" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
<md:IDPSSODescriptor></md:IDPSSODescriptor>
</md:EntityDescriptor>"""
@@ -249,8 +250,8 @@ class TestSAMLConfigurationModel:
assert "metadata_xml" in errors
assert "Missing SingleSignOnService" in errors["metadata_xml"][0]
def test_metadata_missing_certificate_fails(self, tenants_fixture):
tenant = tenants_fixture[0]
def test_metadata_missing_certificate_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant G")
xml = """<md:EntityDescriptor entityID="x" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
<md:IDPSSODescriptor>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://example.com/sso"/>
@@ -268,59 +269,3 @@ class TestSAMLConfigurationModel:
errors = exc_info.value.message_dict
assert "metadata_xml" in errors
assert "X509Certificate" in errors["metadata_xml"][0]
def test_deletes_saml_configuration_and_related_objects(self, tenants_fixture):
tenant = tenants_fixture[0]
email_domain = "deleteme.com"
# Create the configuration
config = SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain=email_domain,
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
tenant=tenant,
)
# Verify that the SocialApp and SAMLDomainIndex exist
assert SocialApp.objects.filter(client_id=email_domain).exists()
assert (
SAMLDomainIndex.objects.using(MainRouter.admin_db)
.filter(email_domain=email_domain)
.exists()
)
# Delete the configuration
config.delete()
# Verify that the configuration and its related objects are deleted
assert (
not SAMLConfiguration.objects.using(MainRouter.admin_db)
.filter(pk=config.pk)
.exists()
)
assert not SocialApp.objects.filter(client_id=email_domain).exists()
assert (
not SAMLDomainIndex.objects.using(MainRouter.admin_db)
.filter(email_domain=email_domain)
.exists()
)
def test_duplicate_entity_id_fails_on_creation(self, tenants_fixture):
tenant1, tenant2, *_ = tenants_fixture
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="first.com",
metadata_xml=self.VALID_METADATA,
tenant=tenant1,
)
config = SAMLConfiguration(
email_domain="second.com",
metadata_xml=self.VALID_METADATA,
tenant=tenant2,
)
with pytest.raises(ValidationError) as exc_info:
config.save()
errors = exc_info.value.message_dict
assert "metadata_xml" in errors
assert "There is a problem with your metadata." in errors["metadata_xml"][0]
+3 -88
View File
@@ -1,7 +1,6 @@
from unittest.mock import ANY, Mock, patch
import pytest
from conftest import TODAY
from django.urls import reverse
from rest_framework import status
@@ -61,7 +60,7 @@ class TestUserViewSet:
def test_create_user_with_all_permissions(self, authenticated_client_rbac):
valid_user_payload = {
"name": "test",
"password": "Newpassword123@",
"password": "newpassword123",
"email": "new_user@test.com",
}
response = authenticated_client_rbac.post(
@@ -75,7 +74,7 @@ class TestUserViewSet:
):
valid_user_payload = {
"name": "test",
"password": "Newpassword123@",
"password": "newpassword123",
"email": "new_user@test.com",
}
response = authenticated_client_no_permissions_rbac.post(
@@ -322,7 +321,7 @@ class TestProviderViewSet:
@pytest.mark.django_db
class TestLimitedVisibility:
TEST_EMAIL = "rbac@rbac.com"
TEST_PASSWORD = "Thisisapassword123@"
TEST_PASSWORD = "thisisapassword123"
@pytest.fixture
def limited_admin_user(
@@ -410,87 +409,3 @@ class TestLimitedVisibility:
assert (
response.json()["data"]["relationships"]["providers"]["meta"]["count"] == 1
)
def test_overviews_providers(
self,
authenticated_client_rbac_limited,
scan_summaries_fixture,
providers_fixture,
):
# By default, the associated provider is the one which has the overview data
response = authenticated_client_rbac_limited.get(reverse("overview-providers"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) > 0
# Changing the provider visibility, no data should be returned
# Only the associated provider to that group is changed
new_provider = providers_fixture[1]
ProviderGroupMembership.objects.all().update(provider=new_provider)
response = authenticated_client_rbac_limited.get(reverse("overview-providers"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 0
@pytest.mark.parametrize(
"endpoint_name",
[
"findings",
"findings_severity",
],
)
def test_overviews_findings(
self,
endpoint_name,
authenticated_client_rbac_limited,
scan_summaries_fixture,
providers_fixture,
):
# By default, the associated provider is the one which has the overview data
response = authenticated_client_rbac_limited.get(
reverse(f"overview-{endpoint_name}")
)
assert response.status_code == status.HTTP_200_OK
values = response.json()["data"]["attributes"].values()
assert any(value > 0 for value in values)
# Changing the provider visibility, no data should be returned
# Only the associated provider to that group is changed
new_provider = providers_fixture[1]
ProviderGroupMembership.objects.all().update(provider=new_provider)
response = authenticated_client_rbac_limited.get(
reverse(f"overview-{endpoint_name}")
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]["attributes"].values()
assert all(value == 0 for value in data)
def test_overviews_services(
self,
authenticated_client_rbac_limited,
scan_summaries_fixture,
providers_fixture,
):
# By default, the associated provider is the one which has the overview data
response = authenticated_client_rbac_limited.get(
reverse("overview-services"), {"filter[inserted_at]": TODAY}
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) > 0
# Changing the provider visibility, no data should be returned
# Only the associated provider to that group is changed
new_provider = providers_fixture[1]
ProviderGroupMembership.objects.all().update(provider=new_provider)
response = authenticated_client_rbac_limited.get(
reverse("overview-services"), {"filter[inserted_at]": TODAY}
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 0
@@ -1,100 +0,0 @@
import pytest
from rest_framework.exceptions import ValidationError
from api.v1.serializer_utils.integrations import S3ConfigSerializer
class TestS3ConfigSerializer:
"""Test cases for S3ConfigSerializer validation."""
def test_validate_output_directory_valid_paths(self):
"""Test that valid output directory paths are accepted."""
serializer = S3ConfigSerializer()
# Test normal paths
assert serializer.validate_output_directory("test") == "test"
assert serializer.validate_output_directory("test/folder") == "test/folder"
assert serializer.validate_output_directory("my-folder_123") == "my-folder_123"
# Test paths with leading slashes (should be normalized)
assert serializer.validate_output_directory("/test") == "test"
assert serializer.validate_output_directory("/test/folder") == "test/folder"
# Test paths with excessive slashes (should be normalized)
assert serializer.validate_output_directory("///test") == "test"
assert serializer.validate_output_directory("///////test") == "test"
assert serializer.validate_output_directory("test//folder") == "test/folder"
assert serializer.validate_output_directory("test///folder") == "test/folder"
def test_validate_output_directory_empty_values(self):
"""Test that empty values raise validation errors."""
serializer = S3ConfigSerializer()
with pytest.raises(
ValidationError, match="Output directory cannot be empty or just"
):
serializer.validate_output_directory(".")
with pytest.raises(
ValidationError, match="Output directory cannot be empty or just"
):
serializer.validate_output_directory("/")
def test_validate_output_directory_invalid_characters(self):
"""Test that invalid characters are rejected."""
serializer = S3ConfigSerializer()
invalid_chars = ["<", ">", ":", '"', "|", "?", "*"]
for char in invalid_chars:
with pytest.raises(
ValidationError, match="Output directory contains invalid characters"
):
serializer.validate_output_directory(f"test{char}folder")
def test_validate_output_directory_too_long(self):
"""Test that paths that are too long are rejected."""
serializer = S3ConfigSerializer()
# Create a path longer than 900 characters
long_path = "a" * 901
with pytest.raises(ValidationError, match="Output directory path is too long"):
serializer.validate_output_directory(long_path)
def test_validate_output_directory_edge_cases(self):
"""Test edge cases for output directory validation."""
serializer = S3ConfigSerializer()
# Test path at the limit (900 characters)
path_at_limit = "a" * 900
assert serializer.validate_output_directory(path_at_limit) == path_at_limit
# Test complex normalization
assert serializer.validate_output_directory("//test/../folder//") == "folder"
assert serializer.validate_output_directory("/test/./folder/") == "test/folder"
def test_s3_config_serializer_full_validation(self):
"""Test the full S3ConfigSerializer with valid data."""
data = {
"bucket_name": "my-test-bucket",
"output_directory": "///////test", # This should be normalized
}
serializer = S3ConfigSerializer(data=data)
assert serializer.is_valid()
validated_data = serializer.validated_data
assert validated_data["bucket_name"] == "my-test-bucket"
assert validated_data["output_directory"] == "test" # Normalized
def test_s3_config_serializer_invalid_data(self):
"""Test the full S3ConfigSerializer with invalid data."""
data = {
"bucket_name": "my-test-bucket",
"output_directory": "test<invalid", # Contains invalid character
}
serializer = S3ConfigSerializer(data=data)
assert not serializer.is_valid()
assert "output_directory" in serializer.errors
+5 -322
View File
@@ -6,18 +6,16 @@ from rest_framework.exceptions import NotFound, ValidationError
from api.db_router import MainRouter
from api.exceptions import InvitationTokenExpiredException
from api.models import Integration, Invitation, Provider
from api.models import Invitation, Provider
from api.utils import (
get_prowler_provider_kwargs,
initialize_prowler_provider,
merge_dicts,
prowler_integration_connection_test,
prowler_provider_connection_test,
return_prowler_provider,
validate_invitation,
)
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHubConnection
from prowler.providers.azure.azure_provider import AzureProvider
from prowler.providers.gcp.gcp_provider import GcpProvider
from prowler.providers.kubernetes.kubernetes_provider import KubernetesProvider
@@ -133,21 +131,6 @@ class TestInitializeProwlerProvider:
initialize_prowler_provider(provider)
mock_return_prowler_provider.return_value.assert_called_once_with(key="value")
@patch("api.utils.return_prowler_provider")
def test_initialize_prowler_provider_with_mutelist(
self, mock_return_prowler_provider
):
provider = MagicMock()
provider.secret.secret = {"key": "value"}
mutelist_processor = MagicMock()
mutelist_processor.configuration = {"Mutelist": {"key": "value"}}
mock_return_prowler_provider.return_value = MagicMock()
initialize_prowler_provider(provider, mutelist_processor)
mock_return_prowler_provider.return_value.assert_called_once_with(
key="value", mutelist_content={"key": "value"}
)
class TestProwlerProviderConnectionTest:
@patch("api.utils.return_prowler_provider")
@@ -199,10 +182,6 @@ class TestGetProwlerProviderKwargs:
Provider.ProviderChoices.M365.value,
{},
),
(
Provider.ProviderChoices.GITHUB.value,
{"organizations": ["provider_uid"]},
),
],
)
def test_get_prowler_provider_kwargs(self, provider_type, expected_extra_kwargs):
@@ -221,25 +200,6 @@ class TestGetProwlerProviderKwargs:
expected_result = {**secret_dict, **expected_extra_kwargs}
assert result == expected_result
def test_get_prowler_provider_kwargs_with_mutelist(self):
provider_uid = "provider_uid"
secret_dict = {"key": "value"}
secret_mock = MagicMock()
secret_mock.secret = secret_dict
mutelist_processor = MagicMock()
mutelist_processor.configuration = {"Mutelist": {"key": "value"}}
provider = MagicMock()
provider.provider = Provider.ProviderChoices.AWS.value
provider.secret = secret_mock
provider.uid = provider_uid
result = get_prowler_provider_kwargs(provider, mutelist_processor)
expected_result = {**secret_dict, "mutelist_content": {"key": "value"}}
assert result == expected_result
def test_get_prowler_provider_kwargs_unsupported_provider(self):
# Setup
provider_uid = "provider_uid"
@@ -294,7 +254,7 @@ class TestValidateInvitation:
assert result == invitation
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="user@example.com"
token="VALID_TOKEN", email="user@example.com"
)
def test_invitation_not_found_raises_validation_error(self):
@@ -309,7 +269,7 @@ class TestValidateInvitation:
"invitation_token": "Invalid invitation code."
}
mock_db.get.assert_called_once_with(
token="INVALID_TOKEN", email__iexact="user@example.com"
token="INVALID_TOKEN", email="user@example.com"
)
def test_invitation_not_found_raises_not_found(self):
@@ -324,7 +284,7 @@ class TestValidateInvitation:
assert exc_info.value.detail == "Invitation is not valid."
mock_db.get.assert_called_once_with(
token="INVALID_TOKEN", email__iexact="user@example.com"
token="INVALID_TOKEN", email="user@example.com"
)
def test_invitation_expired(self, invitation):
@@ -372,282 +332,5 @@ class TestValidateInvitation:
"invitation_token": "Invalid invitation code."
}
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="different@example.com"
token="VALID_TOKEN", email="different@example.com"
)
def test_valid_invitation_uppercase_email(self):
"""Test that validate_invitation works with case-insensitive email lookup."""
uppercase_email = "USER@example.com"
invitation = MagicMock(spec=Invitation)
invitation.token = "VALID_TOKEN"
invitation.email = uppercase_email
invitation.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
invitation.state = Invitation.State.PENDING
invitation.tenant = MagicMock()
with patch("api.utils.Invitation.objects.using") as mock_using:
mock_db = mock_using.return_value
mock_db.get.return_value = invitation
result = validate_invitation("VALID_TOKEN", "user@example.com")
assert result == invitation
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="user@example.com"
)
class TestProwlerIntegrationConnectionTest:
"""Test prowler_integration_connection_test function for SecurityHub regions reset."""
@patch("api.utils.SecurityHub")
def test_security_hub_connection_failure_resets_regions(
self, mock_security_hub_class
):
"""Test that SecurityHub connection failure resets regions to empty dict."""
# Create integration with existing regions configuration
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AWS_SECURITY_HUB
integration.credentials = {
"aws_access_key_id": "test_key",
"aws_secret_access_key": "test_secret",
}
integration.configuration = {
"send_only_fails": True,
"regions": {
"us-east-1": True,
"us-west-2": True,
"eu-west-1": False,
"ap-south-1": False,
},
}
# Mock provider relationship
mock_provider = MagicMock()
mock_provider.uid = "123456789012"
mock_relationship = MagicMock()
mock_relationship.provider = mock_provider
integration.integrationproviderrelationship_set.first.return_value = (
mock_relationship
)
# Mock failed SecurityHub connection
mock_connection = SecurityHubConnection(
is_connected=False,
error=Exception("SecurityHub testing"),
enabled_regions=set(),
disabled_regions=set(),
)
mock_security_hub_class.test_connection.return_value = mock_connection
# Call the function
result = prowler_integration_connection_test(integration)
# Assertions
assert result.is_connected is False
assert str(result.error) == "SecurityHub testing"
# Verify regions were completely reset to empty dict
assert integration.configuration["regions"] == {}
# Verify save was called to persist the change
integration.save.assert_called_once()
# Verify test_connection was called with correct parameters
mock_security_hub_class.test_connection.assert_called_once_with(
aws_account_id="123456789012",
raise_on_exception=False,
aws_access_key_id="test_key",
aws_secret_access_key="test_secret",
)
@patch("api.utils.SecurityHub")
def test_security_hub_connection_success_saves_regions(
self, mock_security_hub_class
):
"""Test that successful SecurityHub connection saves regions correctly."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AWS_SECURITY_HUB
integration.credentials = {
"aws_access_key_id": "valid_key",
"aws_secret_access_key": "valid_secret",
}
integration.configuration = {"send_only_fails": False}
# Mock provider relationship
mock_provider = MagicMock()
mock_provider.uid = "123456789012"
mock_relationship = MagicMock()
mock_relationship.provider = mock_provider
integration.integrationproviderrelationship_set.first.return_value = (
mock_relationship
)
# Mock successful SecurityHub connection with regions
mock_connection = SecurityHubConnection(
is_connected=True,
error=None,
enabled_regions={"us-east-1", "eu-west-1"},
disabled_regions={"ap-south-1"},
)
mock_security_hub_class.test_connection.return_value = mock_connection
result = prowler_integration_connection_test(integration)
assert result.is_connected is True
# Verify regions were saved correctly
assert integration.configuration["regions"]["us-east-1"] is True
assert integration.configuration["regions"]["eu-west-1"] is True
assert integration.configuration["regions"]["ap-south-1"] is False
integration.save.assert_called_once()
@patch("api.utils.rls_transaction")
@patch("api.utils.Jira")
def test_jira_connection_success_basic_auth(
self, mock_jira_class, mock_rls_transaction
):
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.JIRA
integration.tenant_id = "test-tenant-id"
integration.credentials = {
"user_mail": "test@example.com",
"api_token": "test_api_token",
"domain": "example.atlassian.net",
}
integration.configuration = {}
# Mock successful JIRA connection with projects
mock_connection = MagicMock()
mock_connection.is_connected = True
mock_connection.error = None
mock_connection.projects = {"PROJ1": "Project 1", "PROJ2": "Project 2"}
mock_jira_class.test_connection.return_value = mock_connection
# Mock rls_transaction context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
result = prowler_integration_connection_test(integration)
assert result.is_connected is True
assert result.error is None
# Verify JIRA connection was called with correct parameters including domain from credentials
mock_jira_class.test_connection.assert_called_once_with(
user_mail="test@example.com",
api_token="test_api_token",
domain="example.atlassian.net",
raise_on_exception=False,
)
# Verify rls_transaction was called with correct tenant_id
mock_rls_transaction.assert_called_once_with("test-tenant-id")
# Verify projects were saved to integration configuration
assert integration.configuration["projects"] == {
"PROJ1": "Project 1",
"PROJ2": "Project 2",
}
# Verify integration.save() was called
integration.save.assert_called_once()
@patch("api.utils.rls_transaction")
@patch("api.utils.Jira")
def test_jira_connection_failure_invalid_credentials(
self, mock_jira_class, mock_rls_transaction
):
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.JIRA
integration.tenant_id = "test-tenant-id"
integration.credentials = {
"user_mail": "invalid@example.com",
"api_token": "invalid_token",
"domain": "invalid.atlassian.net",
}
integration.configuration = {}
# Mock failed JIRA connection
mock_connection = MagicMock()
mock_connection.is_connected = False
mock_connection.error = Exception("Authentication failed: Invalid credentials")
mock_connection.projects = {} # Empty projects when connection fails
mock_jira_class.test_connection.return_value = mock_connection
# Mock rls_transaction context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
result = prowler_integration_connection_test(integration)
assert result.is_connected is False
assert "Authentication failed: Invalid credentials" in str(result.error)
# Verify JIRA connection was called with correct parameters
mock_jira_class.test_connection.assert_called_once_with(
user_mail="invalid@example.com",
api_token="invalid_token",
domain="invalid.atlassian.net",
raise_on_exception=False,
)
# Verify rls_transaction was called even on failure
mock_rls_transaction.assert_called_once_with("test-tenant-id")
# Verify empty projects dict was saved to integration configuration
assert integration.configuration["projects"] == {}
# Verify integration.save() was called even on connection failure
integration.save.assert_called_once()
@patch("api.utils.rls_transaction")
@patch("api.utils.Jira")
def test_jira_connection_projects_update_with_existing_configuration(
self, mock_jira_class, mock_rls_transaction
):
"""Test that projects are properly updated when integration already has configuration data"""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.JIRA
integration.tenant_id = "test-tenant-id"
integration.credentials = {
"user_mail": "test@example.com",
"api_token": "test_api_token",
"domain": "example.atlassian.net",
}
integration.configuration = {
"issue_types": ["Task"], # Existing configuration
"projects": {"OLD_PROJ": "Old Project"}, # Will be overwritten
}
# Mock successful JIRA connection with new projects
mock_connection = MagicMock()
mock_connection.is_connected = True
mock_connection.error = None
mock_connection.projects = {
"NEW_PROJ1": "New Project 1",
"NEW_PROJ2": "New Project 2",
}
mock_jira_class.test_connection.return_value = mock_connection
# Mock rls_transaction context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
result = prowler_integration_connection_test(integration)
assert result.is_connected is True
assert result.error is None
# Verify projects were updated (old projects replaced with new ones)
assert integration.configuration["projects"] == {
"NEW_PROJ1": "New Project 1",
"NEW_PROJ2": "New Project 2",
}
# Verify other configuration fields were preserved
assert integration.configuration["issue_types"] == ["Task"]
# Verify integration.save() was called
integration.save.assert_called_once()
File diff suppressed because it is too large Load Diff
+12 -164
View File
@@ -6,19 +6,13 @@ from django.db.models import Subquery
from rest_framework.exceptions import NotFound, ValidationError
from api.db_router import MainRouter
from api.db_utils import rls_transaction
from api.exceptions import InvitationTokenExpiredException
from api.models import Integration, Invitation, Processor, Provider, Resource
from api.models import Invitation, Provider, Resource
from api.v1.serializers import FindingMetadataSerializer
from prowler.lib.outputs.jira.jira import Jira, JiraBasicAuthError
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.lib.s3.s3 import S3
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
from prowler.providers.azure.azure_provider import AzureProvider
from prowler.providers.common.models import Connection
from prowler.providers.gcp.gcp_provider import GcpProvider
from prowler.providers.github.github_provider import GithubProvider
from prowler.providers.iac.iac_provider import IacProvider
from prowler.providers.kubernetes.kubernetes_provider import KubernetesProvider
from prowler.providers.m365.m365_provider import M365Provider
@@ -61,22 +55,14 @@ def merge_dicts(default_dict: dict, replacement_dict: dict) -> dict:
def return_prowler_provider(
provider: Provider,
) -> [
AwsProvider
| AzureProvider
| GcpProvider
| GithubProvider
| IacProvider
| KubernetesProvider
| M365Provider
]:
) -> [AwsProvider | AzureProvider | GcpProvider | KubernetesProvider | M365Provider]:
"""Return the Prowler provider class based on the given provider type.
Args:
provider (Provider): The provider object containing the provider type and associated secrets.
Returns:
AwsProvider | AzureProvider | GcpProvider | GithubProvider | IacProvider | KubernetesProvider | M365Provider: The corresponding provider class.
AwsProvider | AzureProvider | GcpProvider | KubernetesProvider | M365Provider: The corresponding provider class.
Raises:
ValueError: If the provider type specified in `provider.provider` is not supported.
@@ -92,23 +78,16 @@ def return_prowler_provider(
prowler_provider = KubernetesProvider
case Provider.ProviderChoices.M365.value:
prowler_provider = M365Provider
case Provider.ProviderChoices.GITHUB.value:
prowler_provider = GithubProvider
case Provider.ProviderChoices.IAC.value:
prowler_provider = IacProvider
case _:
raise ValueError(f"Provider type {provider.provider} not supported")
return prowler_provider
def get_prowler_provider_kwargs(
provider: Provider, mutelist_processor: Processor | None = None
) -> dict:
def get_prowler_provider_kwargs(provider: Provider) -> dict:
"""Get the Prowler provider kwargs based on the given provider type.
Args:
provider (Provider): The provider object containing the provider type and associated secret.
mutelist_processor (Processor): The mutelist processor object containing the mutelist configuration.
Returns:
dict: The provider kwargs for the corresponding provider class.
@@ -126,56 +105,24 @@ def get_prowler_provider_kwargs(
}
elif provider.provider == Provider.ProviderChoices.KUBERNETES.value:
prowler_provider_kwargs = {**prowler_provider_kwargs, "context": provider.uid}
elif provider.provider == Provider.ProviderChoices.GITHUB.value:
if provider.uid:
prowler_provider_kwargs = {
**prowler_provider_kwargs,
"organizations": [provider.uid],
}
elif provider.provider == Provider.ProviderChoices.IAC.value:
# For IaC provider, uid contains the repository URL
# Extract the access token if present in the secret
prowler_provider_kwargs = {
"scan_repository_url": provider.uid,
}
if "access_token" in provider.secret.secret:
prowler_provider_kwargs["oauth_app_token"] = provider.secret.secret[
"access_token"
]
if mutelist_processor:
mutelist_content = mutelist_processor.configuration.get("Mutelist", {})
if mutelist_content:
prowler_provider_kwargs["mutelist_content"] = mutelist_content
return prowler_provider_kwargs
def initialize_prowler_provider(
provider: Provider,
mutelist_processor: Processor | None = None,
) -> (
AwsProvider
| AzureProvider
| GcpProvider
| GithubProvider
| IacProvider
| KubernetesProvider
| M365Provider
):
) -> AwsProvider | AzureProvider | GcpProvider | KubernetesProvider | M365Provider:
"""Initialize a Prowler provider instance based on the given provider type.
Args:
provider (Provider): The provider object containing the provider type and associated secrets.
mutelist_processor (Processor): The mutelist processor object containing the mutelist configuration.
Returns:
AwsProvider | AzureProvider | GcpProvider | GithubProvider | IacProvider | KubernetesProvider | M365Provider: An instance of the corresponding provider class
(`AwsProvider`, `AzureProvider`, `GcpProvider`, `GithubProvider`, `IacProvider`, `KubernetesProvider` or `M365Provider`) initialized with the
AwsProvider | AzureProvider | GcpProvider | KubernetesProvider | M365Provider: An instance of the corresponding provider class
(`AwsProvider`, `AzureProvider`, `GcpProvider`, `KubernetesProvider` or `M365Provider`) initialized with the
provider's secrets.
"""
prowler_provider = return_prowler_provider(provider)
prowler_provider_kwargs = get_prowler_provider_kwargs(provider, mutelist_processor)
prowler_provider_kwargs = get_prowler_provider_kwargs(provider)
return prowler_provider(**prowler_provider_kwargs)
@@ -195,94 +142,9 @@ def prowler_provider_connection_test(provider: Provider) -> Connection:
except Provider.secret.RelatedObjectDoesNotExist as secret_error:
return Connection(is_connected=False, error=secret_error)
# For IaC provider, construct the kwargs properly for test_connection
if provider.provider == Provider.ProviderChoices.IAC.value:
# Don't pass repository_url from secret, use scan_repository_url with the UID
iac_test_kwargs = {
"scan_repository_url": provider.uid,
"raise_on_exception": False,
}
# Add access_token if present in the secret
if "access_token" in prowler_provider_kwargs:
iac_test_kwargs["access_token"] = prowler_provider_kwargs["access_token"]
return prowler_provider.test_connection(**iac_test_kwargs)
else:
return prowler_provider.test_connection(
**prowler_provider_kwargs,
provider_id=provider.uid,
raise_on_exception=False,
)
def prowler_integration_connection_test(integration: Integration) -> Connection:
"""Test the connection to a Prowler integration based on the given integration type.
Args:
integration (Integration): The integration object containing the integration type and associated credentials.
Returns:
Connection: A connection object representing the result of the connection test for the specified integration.
"""
if integration.integration_type == Integration.IntegrationChoices.AMAZON_S3:
return S3.test_connection(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
raise_on_exception=False,
)
# TODO: It is possible that we can unify the connection test for all integrations, but need refactoring
# to avoid code duplication. Actually the AWS integrations are similar, so SecurityHub and S3 can be unified
# making some changes in the SDK.
elif (
integration.integration_type == Integration.IntegrationChoices.AWS_SECURITY_HUB
):
# Get the provider associated with this integration
provider_relationship = integration.integrationproviderrelationship_set.first()
if not provider_relationship:
return Connection(
is_connected=False, error="No provider associated with this integration"
)
credentials = (
integration.credentials
if integration.credentials
else provider_relationship.provider.secret.secret
)
connection = SecurityHub.test_connection(
aws_account_id=provider_relationship.provider.uid,
raise_on_exception=False,
**credentials,
)
# Only save regions if connection is successful
if connection.is_connected:
regions_status = {r: True for r in connection.enabled_regions}
regions_status.update({r: False for r in connection.disabled_regions})
# Save regions information in the integration configuration
integration.configuration["regions"] = regions_status
integration.save()
else:
# Reset regions information if connection fails
integration.configuration["regions"] = {}
integration.save()
return connection
elif integration.integration_type == Integration.IntegrationChoices.JIRA:
jira_connection = Jira.test_connection(
**integration.credentials,
raise_on_exception=False,
)
project_keys = jira_connection.projects if jira_connection.is_connected else {}
with rls_transaction(str(integration.tenant_id)):
integration.configuration["projects"] = project_keys
integration.save()
return jira_connection
elif integration.integration_type == Integration.IntegrationChoices.SLACK:
pass
else:
raise ValueError(
f"Integration type {integration.integration_type} not supported"
)
return prowler_provider.test_connection(
**prowler_provider_kwargs, provider_id=provider.uid, raise_on_exception=False
)
def validate_invitation(
@@ -325,7 +187,7 @@ def validate_invitation(
# Admin DB connector is used to bypass RLS protection since the invitation belongs to a tenant the user
# is not a member of yet
invitation = Invitation.objects.using(MainRouter.admin_db).get(
token=invitation_token, email__iexact=email
token=invitation_token, email=email
)
except Invitation.DoesNotExist:
if raise_not_found:
@@ -376,17 +238,3 @@ def get_findings_metadata_no_aggregations(tenant_id: str, filtered_queryset):
serializer.is_valid(raise_exception=True)
return serializer.data
def initialize_prowler_integration(integration: Integration) -> Jira:
# TODO Refactor other integrations to use this function
if integration.integration_type == Integration.IntegrationChoices.JIRA:
try:
return Jira(**integration.credentials)
except JiraBasicAuthError as jira_auth_error:
with rls_transaction(str(integration.tenant_id)):
integration.configuration["projects"] = {}
integration.connected = False
integration.connection_last_checked_at = datetime.now(tz=timezone.utc)
integration.save()
raise jira_auth_error
+2 -14
View File
@@ -24,32 +24,20 @@ class PaginateByPkMixin:
request, # noqa: F841
base_queryset,
manager,
select_related: list | None = None,
prefetch_related: list | None = None,
select_related: list[str] | None = None,
prefetch_related: list[str] | None = None,
) -> Response:
"""
Paginate a queryset by primary key.
This method is useful when you want to paginate a queryset that has been
filtered or annotated in a way that would be lost if you used the default
pagination method.
"""
pk_list = base_queryset.values_list("id", flat=True)
page = self.paginate_queryset(pk_list)
if page is None:
return Response(self.get_serializer(base_queryset, many=True).data)
queryset = manager.filter(id__in=page)
if select_related:
queryset = queryset.select_related(*select_related)
if prefetch_related:
queryset = queryset.prefetch_related(*prefetch_related)
# Optimize tags loading, if applicable
if hasattr(self, "_optimize_tags_loading"):
queryset = self._optimize_tags_loading(queryset)
queryset = sorted(queryset, key=lambda obj: page.index(obj.id))
serialized = self.get_serializer(queryset, many=True).data
@@ -1,23 +0,0 @@
import yaml
from rest_framework_json_api import serializers
from rest_framework_json_api.serializers import ValidationError
class BaseValidateSerializer(serializers.Serializer):
def validate(self, data):
if hasattr(self, "initial_data"):
initial_data = set(self.initial_data.keys()) - {"id", "type"}
unknown_keys = initial_data - set(self.fields.keys())
if unknown_keys:
raise ValidationError(f"Invalid fields: {unknown_keys}")
return data
class YamlOrJsonField(serializers.JSONField):
def to_internal_value(self, data):
if isinstance(data, str):
try:
data = yaml.safe_load(data)
except yaml.YAMLError as exc:
raise serializers.ValidationError("Invalid YAML format") from exc
return super().to_internal_value(data)
@@ -1,78 +1,24 @@
import os
import re
from drf_spectacular.utils import extend_schema_field
from rest_framework_json_api import serializers
from rest_framework_json_api.serializers import ValidationError
from api.v1.serializer_utils.base import BaseValidateSerializer
class BaseValidateSerializer(serializers.Serializer):
def validate(self, data):
if hasattr(self, "initial_data"):
initial_data = set(self.initial_data.keys()) - {"id", "type"}
unknown_keys = initial_data - set(self.fields.keys())
if unknown_keys:
raise ValidationError(f"Invalid fields: {unknown_keys}")
return data
# Integrations
class S3ConfigSerializer(BaseValidateSerializer):
bucket_name = serializers.CharField()
output_directory = serializers.CharField(allow_blank=True)
def validate_output_directory(self, value):
"""
Validate the output_directory field to ensure it's a properly formatted path.
Prevents paths with excessive slashes like "///////test".
If empty, sets a default value.
"""
# If empty or None, set default value
if not value:
return "output"
# Normalize the path to remove excessive slashes
normalized_path = os.path.normpath(value)
# Remove leading slashes for S3 paths
if normalized_path.startswith("/"):
normalized_path = normalized_path.lstrip("/")
# Check for invalid characters or patterns
if re.search(r'[<>:"|?*]', normalized_path):
raise serializers.ValidationError(
'Output directory contains invalid characters. Avoid: < > : " | ? *'
)
# Check for empty path after normalization
if not normalized_path or normalized_path == ".":
raise serializers.ValidationError(
"Output directory cannot be empty or just '.' or '/'."
)
# Check for paths that are too long (S3 key limit is 1024 characters, leave some room for filename)
if len(normalized_path) > 900:
raise serializers.ValidationError(
"Output directory path is too long (max 900 characters)."
)
return normalized_path
class Meta:
resource_name = "integrations"
class SecurityHubConfigSerializer(BaseValidateSerializer):
send_only_fails = serializers.BooleanField(default=False)
archive_previous_findings = serializers.BooleanField(default=False)
regions = serializers.DictField(default=dict, read_only=True)
def to_internal_value(self, data):
validated_data = super().to_internal_value(data)
# Always initialize regions as empty dict
validated_data["regions"] = {}
return validated_data
class Meta:
resource_name = "integrations"
class JiraConfigSerializer(BaseValidateSerializer):
domain = serializers.CharField(read_only=True)
issue_types = serializers.ListField(
read_only=True, child=serializers.CharField(), default=["Task"]
)
projects = serializers.DictField(read_only=True)
output_directory = serializers.CharField()
class Meta:
resource_name = "integrations"
@@ -93,15 +39,6 @@ class AWSCredentialSerializer(BaseValidateSerializer):
resource_name = "integrations"
class JiraCredentialSerializer(BaseValidateSerializer):
user_mail = serializers.EmailField(required=True)
api_token = serializers.CharField(required=True)
domain = serializers.CharField(required=True)
class Meta:
resource_name = "integrations"
@extend_schema_field(
{
"oneOf": [
@@ -153,27 +90,6 @@ class JiraCredentialSerializer(BaseValidateSerializer):
},
},
},
{
"type": "object",
"title": "JIRA Credentials",
"properties": {
"user_mail": {
"type": "string",
"format": "email",
"description": "The email address of the JIRA user account.",
},
"api_token": {
"type": "string",
"description": "The API token for authentication with JIRA. This can be generated from your "
"Atlassian account settings.",
},
"domain": {
"type": "string",
"description": "The JIRA domain/instance URL (e.g., 'your-domain.atlassian.net').",
},
},
"required": ["user_mail", "api_token", "domain"],
},
]
}
)
@@ -194,40 +110,10 @@ class IntegrationCredentialField(serializers.JSONField):
},
"output_directory": {
"type": "string",
"description": "The directory path within the bucket where files will be saved. Optional - "
'defaults to "output" if not provided. Path will be normalized to remove '
'excessive slashes and invalid characters are not allowed (< > : " | ? *). '
"Maximum length is 900 characters.",
"maxLength": 900,
"pattern": '^[^<>:"|?*]+$',
"default": "output",
"description": "The directory path within the bucket where files will be saved.",
},
},
"required": ["bucket_name"],
},
{
"type": "object",
"title": "AWS Security Hub",
"properties": {
"send_only_fails": {
"type": "boolean",
"default": False,
"description": "If true, only findings with status 'FAIL' will be sent to Security Hub.",
},
"archive_previous_findings": {
"type": "boolean",
"default": False,
"description": "If true, archives findings that are not present in the current execution.",
},
},
},
{
"type": "object",
"title": "JIRA",
"description": "JIRA integration does not accept any configuration in the payload. Leave it as an "
"empty JSON object (`{}`).",
"properties": {},
"additionalProperties": False,
"required": ["bucket_name", "output_directory"],
},
]
}
@@ -1,21 +0,0 @@
from drf_spectacular.utils import extend_schema_field
from api.v1.serializer_utils.base import YamlOrJsonField
from prowler.lib.mutelist.mutelist import mutelist_schema
@extend_schema_field(
{
"oneOf": [
{
"type": "object",
"title": "Mutelist",
"properties": {"Mutelist": mutelist_schema},
"additionalProperties": False,
},
]
}
)
class ProcessorConfigField(YamlOrJsonField):
pass
@@ -176,58 +176,6 @@ from rest_framework_json_api import serializers
},
"required": ["kubeconfig_content"],
},
{
"type": "object",
"title": "GitHub Personal Access Token",
"properties": {
"personal_access_token": {
"type": "string",
"description": "GitHub personal access token for authentication.",
}
},
"required": ["personal_access_token"],
},
{
"type": "object",
"title": "GitHub OAuth App Token",
"properties": {
"oauth_app_token": {
"type": "string",
"description": "GitHub OAuth App token for authentication.",
}
},
"required": ["oauth_app_token"],
},
{
"type": "object",
"title": "GitHub App Credentials",
"properties": {
"github_app_id": {
"type": "integer",
"description": "GitHub App ID for authentication.",
},
"github_app_key": {
"type": "string",
"description": "Path to the GitHub App private key file.",
},
},
"required": ["github_app_id", "github_app_key"],
},
{
"type": "object",
"title": "IaC Repository Credentials",
"properties": {
"repository_url": {
"type": "string",
"description": "Repository URL to scan for IaC files.",
},
"access_token": {
"type": "string",
"description": "Optional access token for private repositories.",
},
},
"required": ["repository_url"],
},
]
}
)
+22 -397
View File
@@ -7,15 +7,12 @@ from django.contrib.auth.models import update_last_login
from django.contrib.auth.password_validation import validate_password
from drf_spectacular.utils import extend_schema_field
from jwt.exceptions import InvalidKeyError
from rest_framework.validators import UniqueTogetherValidator
from rest_framework_json_api import serializers
from rest_framework_json_api.relations import SerializerMethodResourceRelatedField
from rest_framework_json_api.serializers import ValidationError
from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework_simplejwt.tokens import RefreshToken
from api.exceptions import ConflictException
from api.models import (
Finding,
Integration,
@@ -24,7 +21,6 @@ from api.models import (
InvitationRoleRelationship,
LighthouseConfiguration,
Membership,
Processor,
Provider,
ProviderGroup,
ProviderGroupMembership,
@@ -46,14 +42,9 @@ from api.v1.serializer_utils.integrations import (
AWSCredentialSerializer,
IntegrationConfigField,
IntegrationCredentialField,
JiraConfigSerializer,
JiraCredentialSerializer,
S3ConfigSerializer,
SecurityHubConfigSerializer,
)
from api.v1.serializer_utils.processors import ProcessorConfigField
from api.v1.serializer_utils.providers import ProviderSecretField
from prowler.lib.mutelist.mutelist import Mutelist
# Tokens
@@ -139,12 +130,6 @@ class TokenSerializer(BaseTokenSerializer):
class TokenSocialLoginSerializer(BaseTokenSerializer):
email = serializers.EmailField(write_only=True)
tenant_id = serializers.UUIDField(
write_only=True,
required=False,
help_text="If not provided, the tenant ID of the first membership that was added"
" to the user will be used.",
)
# Output tokens
refresh = serializers.CharField(read_only=True)
@@ -866,7 +851,6 @@ class ScanSerializer(RLSSerializer):
"completed_at",
"scheduled_at",
"next_scan_at",
"processor",
"url",
]
@@ -1004,12 +988,8 @@ class ResourceSerializer(RLSSerializer):
tags = serializers.SerializerMethodField()
type_ = serializers.CharField(read_only=True)
failed_findings_count = serializers.IntegerField(read_only=True)
findings = SerializerMethodResourceRelatedField(
many=True,
read_only=True,
)
findings = serializers.ResourceRelatedField(many=True, read_only=True)
class Meta:
model = Resource
@@ -1025,7 +1005,6 @@ class ResourceSerializer(RLSSerializer):
"tags",
"provider",
"findings",
"failed_findings_count",
"url",
]
extra_kwargs = {
@@ -1035,8 +1014,8 @@ class ResourceSerializer(RLSSerializer):
}
included_serializers = {
"findings": "api.v1.serializers.FindingIncludeSerializer",
"provider": "api.v1.serializers.ProviderIncludeSerializer",
"findings": "api.v1.serializers.FindingSerializer",
"provider": "api.v1.serializers.ProviderSerializer",
}
@extend_schema_field(
@@ -1047,10 +1026,6 @@ class ResourceSerializer(RLSSerializer):
}
)
def get_tags(self, obj):
# Use prefetched tags if available to avoid N+1 queries
if hasattr(obj, "prefetched_tags"):
return {tag.key: tag.value for tag in obj.prefetched_tags}
# Fallback to the original method if prefetch is not available
return obj.get_tags(self.context.get("tenant_id"))
def get_fields(self):
@@ -1060,17 +1035,10 @@ class ResourceSerializer(RLSSerializer):
fields["type"] = type_
return fields
def get_findings(self, obj):
return (
obj.latest_findings
if hasattr(obj, "latest_findings")
else obj.findings.all()
)
class ResourceIncludeSerializer(RLSSerializer):
"""
Serializer for the included Resource model.
Serializer for the Resource model.
"""
tags = serializers.SerializerMethodField()
@@ -1103,10 +1071,6 @@ class ResourceIncludeSerializer(RLSSerializer):
}
)
def get_tags(self, obj):
# Use prefetched tags if available to avoid N+1 queries
if hasattr(obj, "prefetched_tags"):
return {tag.key: tag.value for tag in obj.prefetched_tags}
# Fallback to the original method if prefetch is not available
return obj.get_tags(self.context.get("tenant_id"))
def get_fields(self):
@@ -1117,17 +1081,6 @@ class ResourceIncludeSerializer(RLSSerializer):
return fields
class ResourceMetadataSerializer(serializers.Serializer):
services = serializers.ListField(child=serializers.CharField(), allow_empty=True)
regions = serializers.ListField(child=serializers.CharField(), allow_empty=True)
types = serializers.ListField(child=serializers.CharField(), allow_empty=True)
# Temporarily disabled until we implement tag filtering in the UI
# tags = serializers.JSONField(help_text="Tags are described as key-value pairs.")
class Meta:
resource_name = "resources-metadata"
class FindingSerializer(RLSSerializer):
"""
Serializer for the Finding model.
@@ -1151,7 +1104,6 @@ class FindingSerializer(RLSSerializer):
"updated_at",
"first_seen_at",
"muted",
"muted_reason",
"url",
# Relationships
"scan",
@@ -1164,28 +1116,6 @@ class FindingSerializer(RLSSerializer):
}
class FindingIncludeSerializer(RLSSerializer):
"""
Serializer for the include Finding model.
"""
class Meta:
model = Finding
fields = [
"id",
"uid",
"status",
"severity",
"check_id",
"check_metadata",
"inserted_at",
"updated_at",
"first_seen_at",
"muted",
"muted_reason",
]
# To be removed when the related endpoint is removed as well
class FindingDynamicFilterSerializer(serializers.Serializer):
services = serializers.ListField(child=serializers.CharField(), allow_empty=True)
@@ -1221,10 +1151,6 @@ class BaseWriteProviderSecretSerializer(BaseWriteSerializer):
serializer = AzureProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.GCP.value:
serializer = GCPProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.GITHUB.value:
serializer = GithubProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.IAC.value:
serializer = IacProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.KUBERNETES.value:
serializer = KubernetesProviderSecret(data=secret)
elif provider_type == Provider.ProviderChoices.M365.value:
@@ -1274,8 +1200,8 @@ class M365ProviderSecret(serializers.Serializer):
client_id = serializers.CharField()
client_secret = serializers.CharField()
tenant_id = serializers.CharField()
user = serializers.EmailField(required=False)
password = serializers.CharField(required=False)
user = serializers.EmailField()
password = serializers.CharField()
class Meta:
resource_name = "provider-secrets"
@@ -1304,24 +1230,6 @@ class KubernetesProviderSecret(serializers.Serializer):
resource_name = "provider-secrets"
class GithubProviderSecret(serializers.Serializer):
personal_access_token = serializers.CharField(required=False)
oauth_app_token = serializers.CharField(required=False)
github_app_id = serializers.IntegerField(required=False)
github_app_key_content = serializers.CharField(required=False)
class Meta:
resource_name = "provider-secrets"
class IacProviderSecret(serializers.Serializer):
repository_url = serializers.CharField()
access_token = serializers.CharField(required=False)
class Meta:
resource_name = "provider-secrets"
class AWSRoleAssumptionProviderSecret(serializers.Serializer):
role_arn = serializers.CharField()
external_id = serializers.CharField()
@@ -1401,13 +1309,12 @@ class ProviderSecretUpdateSerializer(BaseWriteProviderSecretSerializer):
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
"provider": {"read_only": True},
"secret_type": {"required": False},
"secret_type": {"read_only": True},
}
def validate(self, attrs):
provider = self.instance.provider
# To allow updating a secret with the same type without making the `secret_type` mandatory
secret_type = attrs.get("secret_type") or self.instance.secret_type
secret_type = self.instance.secret_type
secret = attrs.get("secret")
validated_attrs = super().validate(attrs)
@@ -1964,62 +1871,6 @@ class ScheduleDailyCreateSerializer(serializers.Serializer):
class BaseWriteIntegrationSerializer(BaseWriteSerializer):
def validate(self, attrs):
integration_type = attrs.get("integration_type")
if (
integration_type == Integration.IntegrationChoices.AMAZON_S3
and Integration.objects.filter(
configuration=attrs.get("configuration")
).exists()
):
raise ConflictException(
detail="This integration already exists.",
pointer="/data/attributes/configuration",
)
if (
integration_type == Integration.IntegrationChoices.JIRA
and Integration.objects.filter(
configuration__contains={
"domain": attrs.get("configuration").get("domain")
}
).exists()
):
raise ConflictException(
detail="This integration already exists.",
pointer="/data/attributes/configuration",
)
# Check if any provider already has a SecurityHub integration
if hasattr(self, "instance") and self.instance and not integration_type:
integration_type = self.instance.integration_type
if (
integration_type == Integration.IntegrationChoices.AWS_SECURITY_HUB
and "providers" in attrs
):
providers = attrs.get("providers", [])
tenant_id = self.context.get("tenant_id")
for provider in providers:
# For updates, exclude the current instance from the check
query = IntegrationProviderRelationship.objects.filter(
provider=provider,
integration__integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
tenant_id=tenant_id,
)
if hasattr(self, "instance") and self.instance:
query = query.exclude(integration=self.instance)
if query.exists():
raise ConflictException(
detail=f"Provider {provider.id} already has a Security Hub integration. Only one "
"Security Hub integration is allowed per provider.",
pointer="/data/relationships/providers",
)
return super().validate(attrs)
@staticmethod
def validate_integration_data(
integration_type: str,
@@ -2027,49 +1878,17 @@ class BaseWriteIntegrationSerializer(BaseWriteSerializer):
configuration: dict,
credentials: dict,
):
if integration_type == Integration.IntegrationChoices.AMAZON_S3:
if integration_type == Integration.IntegrationChoices.S3:
config_serializer = S3ConfigSerializer
credentials_serializers = [AWSCredentialSerializer]
elif integration_type == Integration.IntegrationChoices.AWS_SECURITY_HUB:
if providers:
if len(providers) > 1:
raise serializers.ValidationError(
{
"providers": "Only one provider is supported for the Security Hub integration."
}
)
if providers[0].provider != Provider.ProviderChoices.AWS:
raise serializers.ValidationError(
{
"providers": "The provider must be AWS type for the Security Hub integration."
}
)
config_serializer = SecurityHubConfigSerializer
credentials_serializers = [AWSCredentialSerializer]
elif integration_type == Integration.IntegrationChoices.JIRA:
if providers:
raise serializers.ValidationError(
{
"providers": "Relationship field is not accepted. This integration applies to all providers."
}
)
if configuration:
raise serializers.ValidationError(
{
"configuration": "This integration does not support custom configuration."
}
)
config_serializer = JiraConfigSerializer
# Create non-editable configuration for JIRA integration
default_jira_issue_types = ["Task"]
configuration.update(
{
"projects": {},
"issue_types": default_jira_issue_types,
"domain": credentials.get("domain"),
}
)
credentials_serializers = [JiraCredentialSerializer]
# TODO: This will be required for AWS Security Hub
# if providers and not all(
# provider.provider == Provider.ProviderChoices.AWS
# for provider in providers
# ):
# raise serializers.ValidationError(
# {"providers": "All providers must be AWS for the S3 integration."}
# )
else:
raise serializers.ValidationError(
{
@@ -2077,11 +1896,7 @@ class BaseWriteIntegrationSerializer(BaseWriteSerializer):
}
)
serializer_instance = config_serializer(data=configuration)
serializer_instance.is_valid(raise_exception=True)
# Apply the validated (and potentially transformed) data back to configuration
configuration.update(serializer_instance.validated_data)
config_serializer(data=configuration).is_valid(raise_exception=True)
for cred_serializer in credentials_serializers:
try:
@@ -2133,10 +1948,6 @@ class IntegrationSerializer(RLSSerializer):
for provider in representation["providers"]
if provider["id"] in allowed_provider_ids
]
if instance.integration_type == Integration.IntegrationChoices.JIRA:
representation["configuration"].update(
{"domain": instance.credentials.get("domain")}
)
return representation
@@ -2164,6 +1975,7 @@ class IntegrationCreateSerializer(BaseWriteIntegrationSerializer):
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
"connected": {"read_only": True},
"enabled": {"read_only": True},
"connection_last_checked_at": {"read_only": True},
}
@@ -2173,18 +1985,10 @@ class IntegrationCreateSerializer(BaseWriteIntegrationSerializer):
configuration = attrs.get("configuration")
credentials = attrs.get("credentials")
if (
not providers
and integration_type == Integration.IntegrationChoices.AWS_SECURITY_HUB
):
raise serializers.ValidationError(
{"providers": "At least one provider is required for this integration."}
)
validated_attrs = super().validate(attrs)
self.validate_integration_data(
integration_type, providers, configuration, credentials
)
validated_attrs = super().validate(attrs)
return validated_attrs
def create(self, validated_data):
@@ -2237,16 +2041,13 @@ class IntegrationUpdateSerializer(BaseWriteIntegrationSerializer):
def validate(self, attrs):
integration_type = self.instance.integration_type
providers = attrs.get("providers")
if integration_type != Integration.IntegrationChoices.JIRA:
configuration = attrs.get("configuration") or self.instance.configuration
else:
configuration = attrs.get("configuration", {})
configuration = attrs.get("configuration") or self.instance.configuration
credentials = attrs.get("credentials") or self.instance.credentials
validated_attrs = super().validate(attrs)
self.validate_integration_data(
integration_type, providers, configuration, credentials
)
validated_attrs = super().validate(attrs)
return validated_attrs
def update(self, instance, validated_data):
@@ -2261,184 +2062,8 @@ class IntegrationUpdateSerializer(BaseWriteIntegrationSerializer):
]
IntegrationProviderRelationship.objects.bulk_create(new_relationships)
# Preserve regions field for Security Hub integrations
if instance.integration_type == Integration.IntegrationChoices.AWS_SECURITY_HUB:
if "configuration" in validated_data:
# Preserve the existing regions field if it exists
existing_regions = instance.configuration.get("regions", {})
validated_data["configuration"]["regions"] = existing_regions
return super().update(instance, validated_data)
def to_representation(self, instance):
representation = super().to_representation(instance)
# Ensure JIRA integrations show updated domain in configuration from credentials
if instance.integration_type == Integration.IntegrationChoices.JIRA:
representation["configuration"].update(
{"domain": instance.credentials.get("domain")}
)
return representation
class IntegrationJiraDispatchSerializer(serializers.Serializer):
"""
Serializer for dispatching findings to JIRA integration.
"""
project_key = serializers.CharField(required=True)
issue_type = serializers.ChoiceField(required=True, choices=["Task"])
class JSONAPIMeta:
resource_name = "integrations-jira-dispatches"
def validate(self, attrs):
validated_attrs = super().validate(attrs)
integration_instance = Integration.objects.get(
id=self.context.get("integration_id")
)
if integration_instance.integration_type != Integration.IntegrationChoices.JIRA:
raise ValidationError(
{"integration_type": "The given integration is not a JIRA integration"}
)
if not integration_instance.enabled:
raise ValidationError(
{"integration": "The given integration is not enabled"}
)
project_key = attrs.get("project_key")
if project_key not in integration_instance.configuration.get("projects", {}):
raise ValidationError(
{
"project_key": "The given project key is not available for this JIRA integration. Refresh the "
"connection if this is an error."
}
)
return validated_attrs
# Processors
class ProcessorSerializer(RLSSerializer):
"""
Serializer for the Processor model.
"""
configuration = ProcessorConfigField()
class Meta:
model = Processor
fields = [
"id",
"inserted_at",
"updated_at",
"processor_type",
"configuration",
"url",
]
class ProcessorCreateSerializer(RLSSerializer, BaseWriteSerializer):
configuration = ProcessorConfigField(required=True)
class Meta:
model = Processor
fields = [
"inserted_at",
"updated_at",
"processor_type",
"configuration",
]
extra_kwargs = {
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
}
validators = [
UniqueTogetherValidator(
queryset=Processor.objects.all(),
fields=["processor_type"],
message="A processor with the same type already exists.",
)
]
def validate(self, attrs):
validated_attrs = super().validate(attrs)
self.validate_processor_data(attrs)
return validated_attrs
def validate_processor_data(self, attrs):
processor_type = attrs.get("processor_type")
configuration = attrs.get("configuration")
if processor_type == "mutelist":
self.validate_mutelist_configuration(configuration)
def validate_mutelist_configuration(self, configuration):
if not isinstance(configuration, dict):
raise serializers.ValidationError("Invalid Mutelist configuration.")
mutelist_configuration = configuration.get("Mutelist", {})
if not mutelist_configuration:
raise serializers.ValidationError(
"Invalid Mutelist configuration: 'Mutelist' is a required property."
)
try:
Mutelist.validate_mutelist(mutelist_configuration, raise_on_exception=True)
return
except Exception as error:
raise serializers.ValidationError(
f"Invalid Mutelist configuration: {error}"
)
class ProcessorUpdateSerializer(BaseWriteSerializer):
configuration = ProcessorConfigField(required=True)
class Meta:
model = Processor
fields = [
"inserted_at",
"updated_at",
"configuration",
]
extra_kwargs = {
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
}
def validate(self, attrs):
validated_attrs = super().validate(attrs)
self.validate_processor_data(attrs)
return validated_attrs
def validate_processor_data(self, attrs):
processor_type = self.instance.processor_type
configuration = attrs.get("configuration")
if processor_type == "mutelist":
self.validate_mutelist_configuration(configuration)
def validate_mutelist_configuration(self, configuration):
if not isinstance(configuration, dict):
raise serializers.ValidationError("Invalid Mutelist configuration.")
mutelist_configuration = configuration.get("Mutelist", {})
if not mutelist_configuration:
raise serializers.ValidationError(
"Invalid Mutelist configuration: 'Mutelist' is a required property."
)
try:
Mutelist.validate_mutelist(mutelist_configuration, raise_on_exception=True)
return
except Exception as error:
raise serializers.ValidationError(
f"Invalid Mutelist configuration: {error}"
)
# SSO
+3 -36
View File
@@ -1,25 +1,21 @@
from allauth.socialaccount.providers.saml.views import ACSView, MetadataView, SLSView
from django.urls import include, path
from drf_spectacular.views import SpectacularRedocView
from rest_framework_nested import routers
from api.v1.views import (
ComplianceOverviewViewSet,
CustomSAMLLoginView,
CustomTokenObtainView,
CustomTokenRefreshView,
CustomTokenSwitchTenantView,
FindingViewSet,
GithubSocialLoginView,
GoogleSocialLoginView,
IntegrationJiraViewSet,
IntegrationViewSet,
InvitationAcceptViewSet,
InvitationViewSet,
LighthouseConfigViewSet,
MembershipViewSet,
OverviewViewSet,
ProcessorViewSet,
ProviderGroupProvidersRelationshipView,
ProviderGroupViewSet,
ProviderSecretViewSet,
@@ -29,7 +25,6 @@ from api.v1.views import (
RoleViewSet,
SAMLConfigurationViewSet,
SAMLInitiateAPIView,
SAMLTokenValidateView,
ScanViewSet,
ScheduleViewSet,
SchemaView,
@@ -58,7 +53,6 @@ router.register(
router.register(r"overviews", OverviewViewSet, basename="overview")
router.register(r"schedules", ScheduleViewSet, basename="schedule")
router.register(r"integrations", IntegrationViewSet, basename="integration")
router.register(r"processors", ProcessorViewSet, basename="processor")
router.register(r"saml-config", SAMLConfigurationViewSet, basename="saml-config")
router.register(
r"lighthouse-configurations",
@@ -74,13 +68,6 @@ tenants_router.register(
users_router = routers.NestedSimpleRouter(router, r"users", lookup="user")
users_router.register(r"memberships", MembershipViewSet, basename="user-membership")
integrations_router = routers.NestedSimpleRouter(
router, r"integrations", lookup="integration"
)
integrations_router.register(
r"jira", IntegrationJiraViewSet, basename="integration-jira"
)
urlpatterns = [
path("tokens", CustomTokenObtainView.as_view(), name="token-obtain"),
path("tokens/refresh", CustomTokenRefreshView.as_view(), name="token-refresh"),
@@ -139,38 +126,18 @@ urlpatterns = [
path(
"auth/saml/initiate/", SAMLInitiateAPIView.as_view(), name="api_saml_initiate"
),
# Allauth SAML endpoints for tenants
path("accounts/", include("allauth.urls")),
path(
"accounts/saml/<organization_slug>/login/",
CustomSAMLLoginView.as_view(),
name="saml_login",
),
path(
"accounts/saml/<organization_slug>/acs/",
ACSView.as_view(),
name="saml_acs",
),
path(
"accounts/saml/<organization_slug>/acs/finish/",
"api/v1/accounts/saml/<organization_slug>/acs/finish/",
TenantFinishACSView.as_view(),
name="saml_finish_acs",
),
path(
"accounts/saml/<organization_slug>/sls/",
SLSView.as_view(),
name="saml_sls",
),
path(
"accounts/saml/<organization_slug>/metadata/",
MetadataView.as_view(),
name="saml_metadata",
),
path("tokens/saml", SAMLTokenValidateView.as_view(), name="token-saml"),
path("tokens/google", GoogleSocialLoginView.as_view(), name="token-google"),
path("tokens/github", GithubSocialLoginView.as_view(), name="token-github"),
path("", include(router.urls)),
path("", include(tenants_router.urls)),
path("", include(users_router.urls)),
path("", include(integrations_router.urls)),
path("schema", SchemaView.as_view(), name="schema"),
path("docs", SpectacularRedocView.as_view(url_name="schema"), name="docs"),
]
File diff suppressed because it is too large Load Diff
-88
View File
@@ -1,5 +1,3 @@
import string
from django.core.exceptions import ValidationError
from django.utils.translation import gettext as _
@@ -22,89 +20,3 @@ class MaximumLengthValidator:
return _(
f"Your password must contain no more than {self.max_length} characters."
)
class SpecialCharactersValidator:
def __init__(self, special_characters=None, min_special_characters=1):
# Use string.punctuation if no custom characters provided
self.special_characters = special_characters or string.punctuation
self.min_special_characters = min_special_characters
def validate(self, password, user=None):
if (
sum(1 for char in password if char in self.special_characters)
< self.min_special_characters
):
raise ValidationError(
_("This password must contain at least one special character."),
code="password_no_special_characters",
params={
"special_characters": self.special_characters,
"min_special_characters": self.min_special_characters,
},
)
def get_help_text(self):
return _(
f"Your password must contain at least one special character from: {self.special_characters}"
)
class UppercaseValidator:
def __init__(self, min_uppercase=1):
self.min_uppercase = min_uppercase
def validate(self, password, user=None):
if sum(1 for char in password if char.isupper()) < self.min_uppercase:
raise ValidationError(
_(
"This password must contain at least %(min_uppercase)d uppercase letter."
),
code="password_no_uppercase_letters",
params={"min_uppercase": self.min_uppercase},
)
def get_help_text(self):
return _(
f"Your password must contain at least {self.min_uppercase} uppercase letter."
)
class LowercaseValidator:
def __init__(self, min_lowercase=1):
self.min_lowercase = min_lowercase
def validate(self, password, user=None):
if sum(1 for char in password if char.islower()) < self.min_lowercase:
raise ValidationError(
_(
"This password must contain at least %(min_lowercase)d lowercase letter."
),
code="password_no_lowercase_letters",
params={"min_lowercase": self.min_lowercase},
)
def get_help_text(self):
return _(
f"Your password must contain at least {self.min_lowercase} lowercase letter."
)
class NumericValidator:
def __init__(self, min_numeric=1):
self.min_numeric = min_numeric
def validate(self, password, user=None):
if sum(1 for char in password if char.isdigit()) < self.min_numeric:
raise ValidationError(
_(
"This password must contain at least %(min_numeric)d numeric character."
),
code="password_no_numeric_characters",
params={"min_numeric": self.min_numeric},
)
def get_help_text(self):
return _(
f"Your password must contain at least {self.min_numeric} numeric character."
)
-39
View File
@@ -11,7 +11,6 @@ SECRET_KEY = env("SECRET_KEY", default="secret")
DEBUG = env.bool("DJANGO_DEBUG", default=False)
ALLOWED_HOSTS = ["localhost", "127.0.0.1"]
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https")
USE_X_FORWARDED_HOST = True
# Application definition
@@ -108,13 +107,6 @@ REST_FRAMEWORK = {
),
"TEST_REQUEST_DEFAULT_FORMAT": "vnd.api+json",
"JSON_API_UNIFORM_EXCEPTIONS": True,
"DEFAULT_THROTTLE_CLASSES": [
"rest_framework.throttling.ScopedRateThrottle",
],
"DEFAULT_THROTTLE_RATES": {
"token-obtain": env("DJANGO_THROTTLE_TOKEN_OBTAIN", default=None),
"dj_rest_auth": None,
},
}
SPECTACULAR_SETTINGS = {
@@ -123,9 +115,6 @@ SPECTACULAR_SETTINGS = {
"PREPROCESSING_HOOKS": [
"drf_spectacular_jsonapi.hooks.fix_nested_path_parameters",
],
"POSTPROCESSING_HOOKS": [
"api.schema_hooks.attach_task_202_examples",
],
"TITLE": "API Reference - Prowler",
}
@@ -169,30 +158,6 @@ AUTH_PASSWORD_VALIDATORS = [
{
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
},
{
"NAME": "api.validators.SpecialCharactersValidator",
"OPTIONS": {
"min_special_characters": 1,
},
},
{
"NAME": "api.validators.UppercaseValidator",
"OPTIONS": {
"min_uppercase": 1,
},
},
{
"NAME": "api.validators.LowercaseValidator",
"OPTIONS": {
"min_lowercase": 1,
},
},
{
"NAME": "api.validators.NumericValidator",
"OPTIONS": {
"min_numeric": 1,
},
},
]
SIMPLE_JWT = {
@@ -283,7 +248,3 @@ X_FRAME_OPTIONS = "DENY"
SECURE_REFERRER_POLICY = "strict-origin-when-cross-origin"
DJANGO_DELETION_BATCH_SIZE = env.int("DJANGO_DELETION_BATCH_SIZE", 5000)
# SAML requirement
CSRF_COOKIE_SECURE = True
SESSION_COOKIE_SECURE = True
+3 -9
View File
@@ -4,7 +4,6 @@ from config.env import env
IGNORED_EXCEPTIONS = [
# Provider is not connected due to credentials errors
"is not connected",
"ProviderConnectionError",
# Authentication Errors from AWS
"InvalidToken",
"AccessDeniedException",
@@ -17,7 +16,7 @@ IGNORED_EXCEPTIONS = [
"InternalServerErrorException",
"AccessDenied",
"No Shodan API Key", # Shodan Check
"RequestLimitExceeded", # For now, we don't want to log the RequestLimitExceeded errors
"RequestLimitExceeded", # For now we don't want to log the RequestLimitExceeded errors
"ThrottlingException",
"Rate exceeded",
"SubscriptionRequiredException",
@@ -43,9 +42,7 @@ IGNORED_EXCEPTIONS = [
"AWSAccessKeyIDInvalidError",
"AWSSessionTokenExpiredError",
"EndpointConnectionError", # AWS Service is not available in a region
# The following comes from urllib3: eu-west-1 -- HTTPClientError[126]: An HTTP Client raised an
# unhandled exception: AWSHTTPSConnectionPool(host='hostname.s3.eu-west-1.amazonaws.com', port=443): Pool is closed.
"Pool is closed",
"Pool is closed", # The following comes from urllib3: eu-west-1 -- HTTPClientError[126]: An HTTP Client raised an unhandled exception: AWSHTTPSConnectionPool(host='hostname.s3.eu-west-1.amazonaws.com', port=443): Pool is closed.
# Authentication Errors from GCP
"ClientAuthenticationError",
"AuthorizationFailed",
@@ -69,15 +66,12 @@ IGNORED_EXCEPTIONS = [
"AzureClientIdAndClientSecretNotBelongingToTenantIdError",
"AzureHTTPResponseError",
"Error with credentials provided",
# PowerShell Errors in User Authentication
"Microsoft Teams User Auth connection failed: Please check your permissions and try again.",
"Exchange Online User Auth connection failed: Please check your permissions and try again.",
]
def before_send(event, hint):
"""
before_send handles the Sentry events in order to send them or not
before_send handles the Sentry events in order to sent them or not
"""
# Ignore logs with the ignored_exceptions
# https://docs.python.org/3/library/logging.html#logrecord-objects
@@ -25,18 +25,9 @@ SOCIALACCOUNT_EMAIL_AUTHENTICATION = True
SOCIALACCOUNT_EMAIL_AUTHENTICATION_AUTO_CONNECT = True
SOCIALACCOUNT_ADAPTER = "api.adapters.ProwlerSocialAccountAdapter"
# def inline(pem: str) -> str:
# return "".join(
# line.strip()
# for line in pem.splitlines()
# if "CERTIFICATE" not in line and "KEY" not in line
# )
# # SAML keys (TODO: Validate certificates)
# SAML_PUBLIC_CERT = inline(env("SAML_PUBLIC_CERT", default=""))
# SAML_PRIVATE_KEY = inline(env("SAML_PRIVATE_KEY", default=""))
# SAML keys
SAML_PUBLIC_CERT = env("SAML_PUBLIC_CERT", default="")
SAML_PRIVATE_KEY = env("SAML_PRIVATE_KEY", default="")
SOCIALACCOUNT_PROVIDERS = {
"google": {
@@ -69,14 +60,12 @@ SOCIALACCOUNT_PROVIDERS = {
"entity_id": "urn:prowler.com:sp",
},
"advanced": {
# TODO: Validate certificates
# "x509cert": SAML_PUBLIC_CERT,
# "private_key": SAML_PRIVATE_KEY,
# "authn_request_signed": True,
# "want_message_signed": True,
# "want_assertion_signed": True,
"reject_idp_initiated_sso": False,
"x509cert": SAML_PUBLIC_CERT,
"private_key": SAML_PRIVATE_KEY,
"name_id_format": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
"authn_request_signed": True,
"want_assertion_signed": True,
"want_message_signed": True,
},
},
}
+3 -96
View File
@@ -23,13 +23,11 @@ from api.models import (
Invitation,
LighthouseConfiguration,
Membership,
Processor,
Provider,
ProviderGroup,
ProviderSecret,
Resource,
ResourceTag,
ResourceTagMapping,
Role,
SAMLConfiguration,
SAMLDomainIndex,
@@ -46,19 +44,12 @@ from api.v1.serializers import TokenSerializer
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
TODAY = str(datetime.today().date())
API_JSON_CONTENT_TYPE = "application/vnd.api+json"
NO_TENANT_HTTP_STATUS = status.HTTP_401_UNAUTHORIZED
TEST_USER = "dev@prowler.com"
TEST_PASSWORD = "testing_psswd"
def today_after_n_days(n_days: int) -> str:
return datetime.strftime(
datetime.today().date() + timedelta(days=n_days), "%Y-%m-%d"
)
@pytest.fixture(scope="module")
def enforce_test_user_db_connection(django_db_setup, django_db_blocker):
"""Ensure tests use the test user for database connections."""
@@ -390,27 +381,8 @@ def providers_fixture(tenants_fixture):
tenant_id=tenant.id,
scanner_args={"key1": "value1", "key2": {"key21": "value21"}},
)
provider6 = Provider.objects.create(
provider="m365",
uid="m365.test.com",
alias="m365_testing",
tenant_id=tenant.id,
)
return provider1, provider2, provider3, provider4, provider5, provider6
@pytest.fixture
def processor_fixture(tenants_fixture):
tenant, *_ = tenants_fixture
processor = Processor.objects.create(
tenant_id=tenant.id,
processor_type="mutelist",
configuration="Mutelist:\n Accounts:\n *:\n Checks:\n iam_user_hardware_mfa_enabled:\n "
" Regions:\n - *\n Resources:\n - *",
)
return processor
return provider1, provider2, provider3, provider4, provider5
@pytest.fixture
@@ -662,7 +634,6 @@ def findings_fixture(scans_fixture, resources_fixture):
check_metadata={
"CheckId": "test_check_id",
"Description": "test description apple sauce",
"servicename": "ec2",
},
first_seen_at="2024-01-02T00:00:00Z",
)
@@ -689,7 +660,6 @@ def findings_fixture(scans_fixture, resources_fixture):
check_metadata={
"CheckId": "test_check_id",
"Description": "test description orange juice",
"servicename": "s3",
},
first_seen_at="2024-01-02T00:00:00Z",
muted=True,
@@ -1065,7 +1035,7 @@ def integrations_fixture(providers_fixture):
enabled=True,
connected=True,
integration_type="amazon_s3",
configuration={"key": "value1"},
configuration={"key": "value"},
credentials={"psswd": "1234"},
)
IntegrationProviderRelationship.objects.create(
@@ -1145,73 +1115,10 @@ def latest_scan_finding(authenticated_client, providers_fixture, resources_fixtu
return finding
@pytest.fixture(scope="function")
def latest_scan_resource(authenticated_client, providers_fixture):
provider = providers_fixture[0]
tenant_id = str(providers_fixture[0].tenant_id)
scan = Scan.objects.create(
name="latest completed scan for resource",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant_id,
)
resource = Resource.objects.create(
tenant_id=tenant_id,
provider=provider,
uid="latest_resource_uid",
name="Latest Resource",
region="us-east-1",
service="ec2",
type="instance",
metadata='{"test": "metadata"}',
details='{"test": "details"}',
)
resource_tag = ResourceTag.objects.create(
tenant_id=tenant_id,
key="environment",
value="test",
)
ResourceTagMapping.objects.create(
tenant_id=tenant_id,
resource=resource,
tag=resource_tag,
)
finding = Finding.objects.create(
tenant_id=tenant_id,
uid="test_finding_uid_latest",
scan=scan,
delta="new",
status=Status.FAIL,
status_extended="test status extended ",
impact=Severity.critical,
impact_extended="test impact extended",
severity=Severity.critical,
raw_result={
"status": Status.FAIL,
"impact": Severity.critical,
"severity": Severity.critical,
},
tags={"test": "latest"},
check_id="test_check_id_latest",
check_metadata={
"CheckId": "test_check_id_latest",
"Description": "test description latest",
},
first_seen_at="2024-01-02T00:00:00Z",
)
finding.add_resources([resource])
backfill_resource_scan_summaries(tenant_id, str(scan.id))
return resource
@pytest.fixture
def saml_setup(tenants_fixture):
tenant_id = tenants_fixture[0].id
domain = "prowler.com"
domain = "example.com"
SAMLDomainIndex.objects.create(email_domain=domain, tenant_id=tenant_id)
+10 -4
View File
@@ -2,10 +2,10 @@ import json
from datetime import datetime, timedelta, timezone
from django_celery_beat.models import IntervalSchedule, PeriodicTask
from rest_framework_json_api.serializers import ValidationError
from tasks.tasks import perform_scheduled_scan_task
from api.db_utils import rls_transaction
from api.exceptions import ConflictException
from api.models import Provider, Scan, StateChoices
@@ -24,9 +24,15 @@ def schedule_provider_scan(provider_instance: Provider):
if PeriodicTask.objects.filter(
interval=schedule, name=task_name, task="scan-perform-scheduled"
).exists():
raise ConflictException(
detail="There is already a scheduled scan for this provider.",
pointer="/data/attributes/provider_id",
raise ValidationError(
[
{
"detail": "There is already a scheduled scan for this provider.",
"status": 400,
"source": {"pointer": "/data/attributes/provider_id"},
"code": "invalid",
}
]
)
with rls_transaction(tenant_id):
+2 -37
View File
@@ -3,11 +3,8 @@ from datetime import datetime, timezone
import openai
from celery.utils.log import get_task_logger
from api.models import Integration, LighthouseConfiguration, Provider
from api.utils import (
prowler_integration_connection_test,
prowler_provider_connection_test,
)
from api.models import LighthouseConfiguration, Provider
from api.utils import prowler_provider_connection_test
logger = get_task_logger(__name__)
@@ -86,35 +83,3 @@ def check_lighthouse_connection(lighthouse_config_id: str):
lighthouse_config.is_active = False
lighthouse_config.save()
return {"connected": False, "error": str(e), "available_models": []}
def check_integration_connection(integration_id: str):
"""
Business logic to check the connection status of an integration.
Args:
integration_id (str): The primary key of the Integration instance to check.
"""
integration = Integration.objects.filter(pk=integration_id, enabled=True).first()
if not integration:
logger.info(f"Integration {integration_id} is not enabled")
return {"connected": False, "error": "Integration is not enabled"}
try:
result = prowler_integration_connection_test(integration)
except Exception as e:
logger.warning(
f"Unexpected exception checking {integration.integration_type} integration connection: {str(e)}"
)
raise e
# Update integration connection status
integration.connected = result.is_connected
integration.connection_last_checked_at = datetime.now(tz=timezone.utc)
integration.save()
return {
"connected": result.is_connected,
"error": str(result.error) if result.error else None,
}
+5 -23
View File
@@ -8,22 +8,18 @@ from botocore.exceptions import ClientError, NoCredentialsError, ParamValidation
from celery.utils.log import get_task_logger
from django.conf import settings
from api.db_utils import rls_transaction
from api.models import Scan
from prowler.config.config import (
csv_file_suffix,
html_file_suffix,
json_asff_file_suffix,
json_ocsf_file_suffix,
output_file_timestamp,
)
from prowler.lib.outputs.asff.asff import ASFF
from prowler.lib.outputs.compliance.aws_well_architected.aws_well_architected import (
AWSWellArchitected,
)
from prowler.lib.outputs.compliance.cis.cis_aws import AWSCIS
from prowler.lib.outputs.compliance.cis.cis_azure import AzureCIS
from prowler.lib.outputs.compliance.cis.cis_gcp import GCPCIS
from prowler.lib.outputs.compliance.cis.cis_github import GithubCIS
from prowler.lib.outputs.compliance.cis.cis_kubernetes import KubernetesCIS
from prowler.lib.outputs.compliance.cis.cis_m365 import M365CIS
from prowler.lib.outputs.compliance.ens.ens_aws import AWSENS
@@ -35,7 +31,6 @@ from prowler.lib.outputs.compliance.iso27001.iso27001_gcp import GCPISO27001
from prowler.lib.outputs.compliance.iso27001.iso27001_kubernetes import (
KubernetesISO27001,
)
from prowler.lib.outputs.compliance.iso27001.iso27001_m365 import M365ISO27001
from prowler.lib.outputs.compliance.kisa_ismsp.kisa_ismsp_aws import AWSKISAISMSP
from prowler.lib.outputs.compliance.mitre_attack.mitre_attack_aws import AWSMitreAttack
from prowler.lib.outputs.compliance.mitre_attack.mitre_attack_azure import (
@@ -95,14 +90,6 @@ COMPLIANCE_CLASS_MAP = {
"m365": [
(lambda name: name.startswith("cis_"), M365CIS),
(lambda name: name == "prowler_threatscore_m365", ProwlerThreatScoreM365),
(lambda name: name.startswith("iso27001_"), M365ISO27001),
],
"github": [
(lambda name: name.startswith("cis_"), GithubCIS),
],
"iac": [
# IaC provider doesn't have specific compliance frameworks yet
# Trivy handles its own compliance checks
],
}
@@ -115,7 +102,6 @@ OUTPUT_FORMATS_MAPPING = {
"kwargs": {},
},
"json-ocsf": {"class": OCSF, "suffix": json_ocsf_file_suffix, "kwargs": {}},
"json-asff": {"class": ASFF, "suffix": json_asff_file_suffix, "kwargs": {}},
"html": {"class": HTML, "suffix": html_file_suffix, "kwargs": {"stats": {}}},
}
@@ -179,7 +165,7 @@ def get_s3_client():
return s3_client
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str | None:
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
"""
Upload the specified ZIP file to an S3 bucket.
If the S3 bucket environment variables are not configured,
@@ -196,7 +182,7 @@ def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str | None:
"""
bucket = base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET
if not bucket:
return
return None
try:
s3 = get_s3_client()
@@ -256,19 +242,15 @@ def _generate_output_directory(
# Sanitize the prowler provider name to ensure it is a valid directory name
prowler_provider_sanitized = re.sub(r"[^\w\-]", "-", prowler_provider)
with rls_transaction(tenant_id):
started_at = Scan.objects.get(id=scan_id).started_at
timestamp = started_at.strftime("%Y%m%d%H%M%S")
path = (
f"{output_directory}/{tenant_id}/{scan_id}/prowler-output-"
f"{prowler_provider_sanitized}-{timestamp}"
f"{prowler_provider_sanitized}-{output_file_timestamp}"
)
os.makedirs("/".join(path.split("/")[:-1]), exist_ok=True)
compliance_path = (
f"{output_directory}/{tenant_id}/{scan_id}/compliance/prowler-output-"
f"{prowler_provider_sanitized}-{timestamp}"
f"{prowler_provider_sanitized}-{output_file_timestamp}"
)
os.makedirs("/".join(compliance_path.split("/")[:-1]), exist_ok=True)
-506
View File
@@ -1,506 +0,0 @@
import os
from glob import glob
from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
from tasks.utils import batched
from api.db_utils import rls_transaction
from api.models import Finding, Integration, Provider
from api.utils import initialize_prowler_integration, initialize_prowler_provider
from prowler.lib.outputs.asff.asff import ASFF
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.csv.csv import CSV
from prowler.lib.outputs.finding import Finding as FindingOutput
from prowler.lib.outputs.html.html import HTML
from prowler.lib.outputs.ocsf.ocsf import OCSF
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.lib.s3.s3 import S3
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
from prowler.providers.common.models import Connection
logger = get_task_logger(__name__)
def get_s3_client_from_integration(
integration: Integration,
) -> tuple[bool, S3 | Connection]:
"""
Create and return a boto3 S3 client using AWS credentials from an integration.
Args:
integration (Integration): The integration to get the S3 client from.
Returns:
tuple[bool, S3 | Connection]: A tuple containing a boolean indicating if the connection was successful and the S3 client or connection object.
"""
s3 = S3(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
output_directory=integration.configuration["output_directory"],
)
connection = s3.test_connection(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
)
if connection.is_connected:
return True, s3
return False, connection
def upload_s3_integration(
tenant_id: str, provider_id: str, output_directory: str
) -> bool:
"""
Upload the specified output files to an S3 bucket from an integration.
Reconstructs output objects from files in the output directory instead of using serialized data.
Args:
tenant_id (str): The tenant identifier, used as part of the S3 key prefix.
provider_id (str): The provider identifier, used as part of the S3 key prefix.
output_directory (str): Path to the directory containing output files.
Returns:
bool: True if all integrations were executed, False otherwise.
Raises:
botocore.exceptions.ClientError: If the upload attempt to S3 fails for any reason.
"""
logger.info(f"Processing S3 integrations for provider {provider_id}")
try:
with rls_transaction(tenant_id):
integrations = list(
Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
)
if not integrations:
logger.error(f"No S3 integrations found for provider {provider_id}")
return False
integration_executions = 0
for integration in integrations:
try:
connected, s3 = get_s3_client_from_integration(integration)
except Exception as e:
logger.info(
f"S3 connection failed for integration {integration.id}: {e}"
)
integration.connected = False
integration.save()
continue
if connected:
try:
# Reconstruct generated_outputs from files in output directory
# This approach scans the output directory for files and creates the appropriate
# output objects based on file extensions and naming patterns.
generated_outputs = {"regular": [], "compliance": []}
# Find and recreate regular outputs (CSV, HTML, OCSF)
output_file_patterns = {
".csv": CSV,
".html": HTML,
".ocsf.json": OCSF,
".asff.json": ASFF,
}
base_dir = os.path.dirname(output_directory)
for extension, output_class in output_file_patterns.items():
pattern = f"{output_directory}*{extension}"
for file_path in glob(pattern):
if os.path.exists(file_path):
output = output_class(findings=[], file_path=file_path)
output.create_file_descriptor(file_path)
generated_outputs["regular"].append(output)
# Find and recreate compliance outputs
compliance_pattern = os.path.join(base_dir, "compliance", "*.csv")
for file_path in glob(compliance_pattern):
if os.path.exists(file_path):
output = GenericCompliance(
findings=[],
compliance=None,
file_path=file_path,
file_extension=".csv",
)
output.create_file_descriptor(file_path)
generated_outputs["compliance"].append(output)
# Use send_to_bucket with recreated generated_outputs objects
s3.send_to_bucket(generated_outputs)
except Exception as e:
logger.error(
f"S3 upload failed for integration {integration.id}: {e}"
)
continue
integration_executions += 1
else:
integration.connected = False
integration.save()
logger.error(
f"S3 upload failed, connection failed for integration {integration.id}: {s3.error}"
)
result = integration_executions == len(integrations)
if result:
logger.info(
f"All the S3 integrations completed successfully for provider {provider_id}"
)
else:
logger.info(f"Some S3 integrations failed for provider {provider_id}")
return result
except Exception as e:
logger.error(f"S3 integrations failed for provider {provider_id}: {str(e)}")
return False
def get_security_hub_client_from_integration(
integration: Integration, tenant_id: str, findings: list
) -> tuple[bool, SecurityHub | Connection]:
"""
Create and return a SecurityHub client using AWS credentials from an integration.
Args:
integration (Integration): The integration to get the Security Hub client from.
tenant_id (str): The tenant identifier.
findings (list): List of findings in ASFF format to send to Security Hub.
Returns:
tuple[bool, SecurityHub | Connection]: A tuple containing a boolean indicating
if the connection was successful and the SecurityHub client or connection object.
"""
# Get the provider associated with this integration
with rls_transaction(tenant_id):
provider_relationship = integration.integrationproviderrelationship_set.first()
if not provider_relationship:
return Connection(
is_connected=False, error="No provider associated with this integration"
)
provider_uid = provider_relationship.provider.uid
provider_secret = provider_relationship.provider.secret.secret
credentials = (
integration.credentials if integration.credentials else provider_secret
)
connection = SecurityHub.test_connection(
aws_account_id=provider_uid,
raise_on_exception=False,
**credentials,
)
if connection.is_connected:
all_security_hub_regions = AwsProvider.get_available_aws_service_regions(
"securityhub", connection.partition
)
# Create regions status dictionary
regions_status = {}
for region in set(all_security_hub_regions):
regions_status[region] = region in connection.enabled_regions
# Save regions information in the integration configuration
with rls_transaction(tenant_id):
integration.configuration["regions"] = regions_status
integration.save()
# Create SecurityHub client with all necessary parameters
security_hub = SecurityHub(
aws_account_id=provider_uid,
findings=findings,
send_only_fails=integration.configuration.get("send_only_fails", False),
aws_security_hub_available_regions=list(connection.enabled_regions),
**credentials,
)
return True, security_hub
else:
# Reset regions information if connection fails
with rls_transaction(tenant_id):
integration.configuration["regions"] = {}
integration.save()
return False, connection
def upload_security_hub_integration(
tenant_id: str, provider_id: str, scan_id: str
) -> bool:
"""
Upload findings to AWS Security Hub using configured integrations.
This function retrieves findings from the database, transforms them to ASFF format,
and sends them to AWS Security Hub using the configured integration credentials.
Args:
tenant_id (str): The tenant identifier.
provider_id (str): The provider identifier.
scan_id (str): The scan identifier for which to send findings.
Returns:
bool: True if all integrations executed successfully, False otherwise.
"""
logger.info(f"Processing Security Hub integrations for provider {provider_id}")
try:
with rls_transaction(tenant_id):
# Get Security Hub integrations for this provider
integrations = list(
Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
enabled=True,
)
)
if not integrations:
logger.error(
f"No Security Hub integrations found for provider {provider_id}"
)
return False
# Get the provider object
provider = Provider.objects.get(id=provider_id)
# Initialize prowler provider for finding transformation
prowler_provider = initialize_prowler_provider(provider)
# Process each Security Hub integration
integration_executions = 0
total_findings_sent = {} # Track findings sent per integration
for integration in integrations:
try:
# Initialize Security Hub client for this integration
# We'll create the client once and reuse it for all batches
security_hub_client = None
send_only_fails = integration.configuration.get(
"send_only_fails", False
)
total_findings_sent[integration.id] = 0
# Process findings in batches to avoid memory issues
has_findings = False
batch_number = 0
with rls_transaction(tenant_id):
qs = (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.order_by("uid")
.iterator()
)
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
batch_number += 1
has_findings = True
# Transform findings for this batch
transformed_findings = [
FindingOutput.transform_api_finding(
finding, prowler_provider
)
for finding in batch
]
# Convert to ASFF format
asff_transformer = ASFF(
findings=transformed_findings,
file_path="",
file_extension="json",
)
asff_transformer.transform(transformed_findings)
# Get the batch of ASFF findings
batch_asff_findings = asff_transformer.data
if batch_asff_findings:
# Create Security Hub client for first batch or reuse existing
if not security_hub_client:
connected, security_hub = (
get_security_hub_client_from_integration(
integration, tenant_id, batch_asff_findings
)
)
if not connected:
logger.error(
f"Security Hub connection failed for integration {integration.id}: "
f"{security_hub.error}"
)
integration.connected = False
integration.save()
break # Skip this integration
security_hub_client = security_hub
logger.info(
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
f"integration {integration.id}"
)
else:
# Update findings in existing client for this batch
security_hub_client._findings_per_region = (
security_hub_client.filter(
batch_asff_findings, send_only_fails
)
)
# Send this batch to Security Hub
try:
findings_sent = (
security_hub_client.batch_send_to_security_hub()
)
total_findings_sent[integration.id] += findings_sent
if findings_sent > 0:
logger.debug(
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
)
except Exception as batch_error:
logger.error(
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
)
# Clear memory after processing each batch
asff_transformer._data.clear()
del batch_asff_findings
del transformed_findings
if not has_findings:
logger.info(
f"No findings to send to Security Hub for scan {scan_id}"
)
integration_executions += 1
elif security_hub_client:
if total_findings_sent[integration.id] > 0:
logger.info(
f"Successfully sent {total_findings_sent[integration.id]} total findings to Security Hub via integration {integration.id}"
)
integration_executions += 1
else:
logger.warning(
f"No findings were sent to Security Hub via integration {integration.id}"
)
# Archive previous findings if configured to do so
if integration.configuration.get(
"archive_previous_findings", False
):
logger.info(
f"Archiving previous findings in Security Hub via integration {integration.id}"
)
try:
findings_archived = (
security_hub_client.archive_previous_findings()
)
logger.info(
f"Successfully archived {findings_archived} previous findings in Security Hub"
)
except Exception as archive_error:
logger.warning(
f"Failed to archive previous findings: {str(archive_error)}"
)
except Exception as e:
logger.error(
f"Security Hub integration {integration.id} failed: {str(e)}"
)
continue
result = integration_executions == len(integrations)
if result:
logger.info(
f"All Security Hub integrations completed successfully for provider {provider_id}"
)
else:
logger.error(
f"Some Security Hub integrations failed for provider {provider_id}"
)
return result
except Exception as e:
logger.error(
f"Security Hub integrations failed for provider {provider_id}: {str(e)}"
)
return False
def send_findings_to_jira(
tenant_id: str,
integration_id: str,
project_key: str,
issue_type: str,
finding_ids: list[str],
):
with rls_transaction(tenant_id):
integration = Integration.objects.get(id=integration_id)
jira_integration = initialize_prowler_integration(integration)
num_tickets_created = 0
for finding_id in finding_ids:
with rls_transaction(tenant_id):
finding_instance = (
Finding.all_objects.select_related("scan__provider")
.prefetch_related("resources")
.get(id=finding_id)
)
# Extract resource information
resource = (
finding_instance.resources.first()
if finding_instance.resources.exists()
else None
)
resource_uid = resource.uid if resource else ""
resource_name = resource.name if resource else ""
resource_tags = {}
if resource and hasattr(resource, "tags"):
resource_tags = resource.get_tags(tenant_id)
# Get region
region = resource.region if resource and resource.region else ""
# Extract remediation information from check_metadata
check_metadata = finding_instance.check_metadata
remediation = check_metadata.get("remediation", {})
recommendation = remediation.get("recommendation", {})
remediation_code = remediation.get("code", {})
# Send the individual finding to Jira
result = jira_integration.send_finding(
check_id=finding_instance.check_id,
check_title=check_metadata.get("checktitle", ""),
severity=finding_instance.severity,
status=finding_instance.status,
status_extended=finding_instance.status_extended or "",
provider=finding_instance.scan.provider.provider,
region=region,
resource_uid=resource_uid,
resource_name=resource_name,
risk=check_metadata.get("risk", ""),
recommendation_text=recommendation.get("text", ""),
recommendation_url=recommendation.get("url", ""),
remediation_code_native_iac=remediation_code.get("nativeiac", ""),
remediation_code_terraform=remediation_code.get("terraform", ""),
remediation_code_cli=remediation_code.get("cli", ""),
remediation_code_other=remediation_code.get("other", ""),
resource_tags=resource_tags,
compliance=finding_instance.compliance or {},
project_key=project_key,
issue_type=issue_type,
)
if result:
num_tickets_created += 1
else:
logger.error(f"Failed to send finding {finding_id} to Jira")
return {
"created_count": num_tickets_created,
"failed_count": len(finding_ids) - num_tickets_created,
}
+16 -84
View File
@@ -1,29 +1,22 @@
import json
import time
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timezone
from celery.utils.log import get_task_logger
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
from django.db import IntegrityError, OperationalError
from django.db.models import Case, Count, IntegerField, Prefetch, Sum, When
from django.db.models import Case, Count, IntegerField, Sum, When
from tasks.utils import CustomEncoder
from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
generate_scan_compliance,
)
from api.db_utils import (
create_objects_in_batches,
rls_transaction,
update_objects_in_batches,
)
from api.exceptions import ProviderConnectionError
from api.db_utils import create_objects_in_batches, rls_transaction
from api.models import (
ComplianceRequirementOverview,
Finding,
Processor,
Provider,
Resource,
ResourceScanSummary,
@@ -33,7 +26,7 @@ from api.models import (
StateChoices,
)
from api.models import StatusChoices as FindingStatus
from api.utils import initialize_prowler_provider, return_prowler_provider
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.outputs.finding import Finding as ProwlerFinding
from prowler.lib.scan.scan import Scan as ProwlerScan
@@ -108,10 +101,7 @@ def _store_resources(
def perform_prowler_scan(
tenant_id: str,
scan_id: str,
provider_id: str,
checks_to_execute: list[str] | None = None,
tenant_id: str, scan_id: str, provider_id: str, checks_to_execute: list[str] = None
):
"""
Perform a scan using Prowler and store the findings and resources in the database.
@@ -142,28 +132,14 @@ def perform_prowler_scan(
scan_instance.started_at = datetime.now(tz=timezone.utc)
scan_instance.save()
# Find the mutelist processor if it exists
with rls_transaction(tenant_id):
try:
mutelist_processor = Processor.objects.get(
tenant_id=tenant_id, processor_type=Processor.ProcessorChoices.MUTELIST
)
except Processor.DoesNotExist:
mutelist_processor = None
except Exception as e:
logger.error(f"Error processing mutelist rules: {e}")
mutelist_processor = None
try:
with rls_transaction(tenant_id):
try:
prowler_provider = initialize_prowler_provider(
provider_instance, mutelist_processor
)
prowler_provider = initialize_prowler_provider(provider_instance)
provider_instance.connected = True
except Exception as e:
provider_instance.connected = False
exc = ProviderConnectionError(
exc = ValueError(
f"Provider {provider_instance.provider} is not connected: {e}"
)
finally:
@@ -173,8 +149,7 @@ def perform_prowler_scan(
provider_instance.save()
# If the provider is not connected, raise an exception outside the transaction.
# If raised within the transaction, the transaction will be rolled back and the provider will not be marked
# as not connected.
# If raised within the transaction, the transaction will be rolled back and the provider will not be marked as not connected.
if exc:
raise exc
@@ -183,7 +158,6 @@ def perform_prowler_scan(
resource_cache = {}
tag_cache = {}
last_status_cache = {}
resource_failed_findings_cache = defaultdict(int)
for progress, findings in prowler_scan.scan():
for finding in findings:
@@ -209,9 +183,6 @@ def perform_prowler_scan(
},
)
resource_cache[resource_uid] = resource_instance
# Initialize all processed resources in the cache
resource_failed_findings_cache[resource_uid] = 0
else:
resource_instance = resource_cache[resource_uid]
@@ -302,9 +273,6 @@ def perform_prowler_scan(
if not last_first_seen_at:
last_first_seen_at = datetime.now(tz=timezone.utc)
# If the finding is muted at this time the reason must be the configured Mutelist
muted_reason = "Muted by mutelist" if finding.muted else None
# Create the finding
finding_instance = Finding.objects.create(
tenant_id=tenant_id,
@@ -320,16 +288,10 @@ def perform_prowler_scan(
scan=scan_instance,
first_seen_at=last_first_seen_at,
muted=finding.muted,
muted_reason=muted_reason,
compliance=finding.compliance,
)
finding_instance.add_resources([resource_instance])
# Increment failed_findings_count cache if the finding status is FAIL and not muted
if status == FindingStatus.FAIL and not finding.muted:
resource_uid = finding.resource_uid
resource_failed_findings_cache[resource_uid] += 1
# Update scan resource summaries
scan_resource_cache.add(
(
@@ -347,24 +309,6 @@ def perform_prowler_scan(
scan_instance.state = StateChoices.COMPLETED
# Update failed_findings_count for all resources in batches if scan completed successfully
if resource_failed_findings_cache:
resources_to_update = []
for resource_uid, failed_count in resource_failed_findings_cache.items():
if resource_uid in resource_cache:
resource_instance = resource_cache[resource_uid]
resource_instance.failed_findings_count = failed_count
resources_to_update.append(resource_instance)
if resources_to_update:
update_objects_in_batches(
tenant_id=tenant_id,
model=Resource,
objects=resources_to_update,
fields=["failed_findings_count"],
batch_size=1000,
)
except Exception as e:
logger.error(f"Error performing scan {scan_id}: {e}")
exception = e
@@ -417,9 +361,6 @@ def aggregate_findings(tenant_id: str, scan_id: str):
changed, unchanged). The results are grouped by `check_id`, `service`, `severity`, and `region`.
These aggregated metrics are then stored in the `ScanSummary` table.
Additionally, it updates the failed_findings_count field for each resource based on the most
recent findings for each finding.uid.
Args:
tenant_id (str): The ID of the tenant to which the scan belongs.
scan_id (str): The ID of the scan for which findings need to be aggregated.
@@ -585,30 +526,21 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
with rls_transaction(tenant_id):
scan_instance = Scan.objects.get(pk=scan_id)
provider_instance = scan_instance.provider
prowler_provider = return_prowler_provider(provider_instance)
prowler_provider = initialize_prowler_provider(provider_instance)
# Get check status data by region from findings
findings = (
Finding.all_objects.filter(scan_id=scan_id, muted=False)
.only("id", "check_id", "status")
.prefetch_related(
Prefetch(
"resources",
queryset=Resource.objects.only("id", "region"),
to_attr="small_resources",
)
)
.iterator(chunk_size=1000)
)
check_status_by_region = {}
with rls_transaction(tenant_id):
findings = Finding.objects.filter(scan_id=scan_id, muted=False)
for finding in findings:
for resource in finding.small_resources:
# Get region from resources
for resource in finding.resources.all():
region = resource.region
current_status = check_status_by_region.setdefault(region, {})
if current_status.get(finding.check_id) != "FAIL":
current_status[finding.check_id] = finding.status
region_dict = check_status_by_region.setdefault(region, {})
current_status = region_dict.get(finding.check_id)
if current_status == "FAIL":
continue
region_dict[finding.check_id] = finding.status
try:
# Try to get regions from provider
+24 -223
View File
@@ -2,17 +2,13 @@ from datetime import datetime, timedelta, timezone
from pathlib import Path
from shutil import rmtree
from celery import chain, group, shared_task
from celery import chain, shared_task
from celery.utils.log import get_task_logger
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django_celery_beat.models import PeriodicTask
from tasks.jobs.backfill import backfill_resource_scan_summaries
from tasks.jobs.connection import (
check_integration_connection,
check_lighthouse_connection,
check_provider_connection,
)
from tasks.jobs.connection import check_lighthouse_connection, check_provider_connection
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.export import (
COMPLIANCE_CLASS_MAP,
@@ -21,11 +17,6 @@ from tasks.jobs.export import (
_generate_output_directory,
_upload_to_s3,
)
from tasks.jobs.integrations import (
send_findings_to_jira,
upload_s3_integration,
upload_security_hub_integration,
)
from tasks.jobs.scan import (
aggregate_findings,
create_compliance_requirements,
@@ -36,7 +27,7 @@ from tasks.utils import batched, get_next_execution_datetime
from api.compliance import get_compliance_frameworks
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.models import Finding, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.check.compliance_models import Compliance
@@ -46,31 +37,6 @@ from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str):
"""
Helper function to perform tasks after a scan is completed.
Args:
tenant_id (str): The tenant ID under which the scan was performed.
scan_id (str): The ID of the scan that was performed.
provider_id (str): The primary key of the Provider instance that was scanned.
"""
create_compliance_requirements_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
)
chain(
perform_scan_summary_task.si(tenant_id=tenant_id, scan_id=scan_id),
generate_outputs_task.si(
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
check_integrations_task.si(
tenant_id=tenant_id,
provider_id=provider_id,
scan_id=scan_id,
),
).apply_async()
@shared_task(base=RLSTask, name="provider-connection-check")
@set_tenant
def check_provider_connection_task(provider_id: str):
@@ -88,18 +54,6 @@ def check_provider_connection_task(provider_id: str):
return check_provider_connection(provider_id=provider_id)
@shared_task(base=RLSTask, name="integration-connection-check")
@set_tenant
def check_integration_connection_task(integration_id: str):
"""
Task to check the connection status of an integration.
Args:
integration_id (str): The primary key of the Integration instance to check.
"""
return check_integration_connection(integration_id=integration_id)
@shared_task(
base=RLSTask, name="provider-deletion", queue="deletion", autoretry_for=(Exception,)
)
@@ -149,7 +103,13 @@ def perform_scan_task(
checks_to_execute=checks_to_execute,
)
_perform_scan_complete_tasks(tenant_id, scan_id, provider_id)
chain(
perform_scan_summary_task.si(tenant_id, scan_id),
create_compliance_requirements_task.si(tenant_id=tenant_id, scan_id=scan_id),
generate_outputs.si(
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
).apply_async()
return result
@@ -254,12 +214,20 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scheduler_task_id=periodic_task_instance.id,
)
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)
chain(
perform_scan_summary_task.si(tenant_id, scan_instance.id),
create_compliance_requirements_task.si(
tenant_id=tenant_id, scan_id=str(scan_instance.id)
),
generate_outputs.si(
scan_id=str(scan_instance.id), provider_id=provider_id, tenant_id=tenant_id
),
).apply_async()
return result
@shared_task(name="scan-summary", queue="overview")
@shared_task(name="scan-summary")
def perform_scan_summary_task(tenant_id: str, scan_id: str):
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
@@ -275,7 +243,7 @@ def delete_tenant_task(tenant_id: str):
queue="scan-reports",
)
@set_tenant(keep_tenant=True)
def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
def generate_outputs(scan_id: str, provider_id: str, tenant_id: str):
"""
Process findings in batches and generate output files in multiple formats.
@@ -328,30 +296,12 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
ScanSummary.objects.filter(scan_id=scan_id)
)
# Check if we need to generate ASFF output for AWS providers with SecurityHub integration
generate_asff = False
if provider_type == "aws":
security_hub_integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
enabled=True,
)
generate_asff = security_hub_integrations.exists()
qs = (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.order_by("uid")
.iterator()
)
qs = Finding.all_objects.filter(scan_id=scan_id).order_by("uid").iterator()
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
fos = [FindingOutput.transform_api_finding(f, prowler_provider) for f in batch]
# Outputs
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
# Skip ASFF generation if not needed
if mode == "json-asff" and not generate_asff:
continue
cls = cfg["class"]
suffix = cfg["suffix"]
extra = cfg.get("kwargs", {}).copy()
@@ -405,34 +355,7 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
compressed = _compress_output_files(out_dir)
upload_uri = _upload_to_s3(tenant_id, compressed, scan_id)
# S3 integrations (need output_directory)
with rls_transaction(tenant_id):
s3_integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
if s3_integrations:
# Pass the output directory path to S3 integration task to reconstruct objects from files
s3_integration_task.apply_async(
kwargs={
"tenant_id": tenant_id,
"provider_id": provider_id,
"output_directory": out_dir,
}
).get(
disable_sync_subtasks=False
) # TODO: This synchronous execution is NOT recommended
# We're forced to do this because we need the files to exist before deletion occurs.
# Once we have the periodic file cleanup task implemented, we should:
# 1. Remove this .get() call and make it fully async
# 2. For Cloud deployments, develop a secondary approach where outputs are stored
# directly in S3 and read from there, eliminating local file dependencies
if upload_uri:
# TODO: We need to create a new periodic task to delete the output files
# This task shouldn't be responsible for deleting the output files
try:
rmtree(Path(compressed).parent, ignore_errors=True)
except Exception as e:
@@ -443,10 +366,7 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
Scan.all_objects.filter(id=scan_id).update(output_location=final_location)
logger.info(f"Scan outputs at {final_location}")
return {
"upload": did_upload,
}
return {"upload": did_upload}
@shared_task(name="backfill-scan-resource-summaries", queue="backfill")
@@ -461,7 +381,7 @@ def backfill_scan_resource_summaries_task(tenant_id: str, scan_id: str):
return backfill_resource_scan_summaries(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(base=RLSTask, name="scan-compliance-overviews", queue="overview")
@shared_task(base=RLSTask, name="scan-compliance-overviews")
def create_compliance_requirements_task(tenant_id: str, scan_id: str):
"""
Creates detailed compliance requirement records for a scan.
@@ -494,122 +414,3 @@ def check_lighthouse_connection_task(lighthouse_config_id: str, tenant_id: str =
- 'available_models' (list): List of available models if connection is successful.
"""
return check_lighthouse_connection(lighthouse_config_id=lighthouse_config_id)
@shared_task(name="integration-check")
def check_integrations_task(tenant_id: str, provider_id: str, scan_id: str = None):
"""
Check and execute all configured integrations for a provider.
Args:
tenant_id (str): The tenant identifier
provider_id (str): The provider identifier
scan_id (str, optional): The scan identifier for integrations that need scan data
"""
logger.info(f"Checking integrations for provider {provider_id}")
try:
integration_tasks = []
with rls_transaction(tenant_id):
integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
enabled=True,
)
if not integrations.exists():
logger.info(f"No integrations configured for provider {provider_id}")
return {"integrations_processed": 0}
# Security Hub integration
security_hub_integrations = integrations.filter(
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB
)
if security_hub_integrations.exists():
integration_tasks.append(
security_hub_integration_task.s(
tenant_id=tenant_id, provider_id=provider_id, scan_id=scan_id
)
)
# TODO: Add other integration types here
# slack_integrations = integrations.filter(
# integration_type=Integration.IntegrationChoices.SLACK
# )
# if slack_integrations.exists():
# integration_tasks.append(
# slack_integration_task.s(
# tenant_id=tenant_id,
# provider_id=provider_id,
# )
# )
except Exception as e:
logger.error(f"Integration check failed for provider {provider_id}: {str(e)}")
return {"integrations_processed": 0, "error": str(e)}
# Execute all integration tasks in parallel if any were found
if integration_tasks:
job = group(integration_tasks)
job.apply_async()
logger.info(f"Launched {len(integration_tasks)} integration task(s)")
return {"integrations_processed": len(integration_tasks)}
@shared_task(
base=RLSTask,
name="integration-s3",
queue="integrations",
)
def s3_integration_task(
tenant_id: str,
provider_id: str,
output_directory: str,
):
"""
Process S3 integrations for a provider.
Args:
tenant_id (str): The tenant identifier
provider_id (str): The provider identifier
output_directory (str): Path to the directory containing output files
"""
return upload_s3_integration(tenant_id, provider_id, output_directory)
@shared_task(
base=RLSTask,
name="integration-security-hub",
queue="integrations",
)
def security_hub_integration_task(
tenant_id: str,
provider_id: str,
scan_id: str,
):
"""
Process Security Hub integrations for a provider.
Args:
tenant_id (str): The tenant identifier
provider_id (str): The provider identifier
scan_id (str): The scan identifier
"""
return upload_security_hub_integration(tenant_id, provider_id, scan_id)
@shared_task(
base=RLSTask,
name="integration-jira",
queue="integrations",
)
def jira_integration_task(
tenant_id: str,
integration_id: str,
project_key: str,
issue_type: str,
finding_ids: list[str],
):
return send_findings_to_jira(
tenant_id, integration_id, project_key, issue_type, finding_ids
)
+3 -22
View File
@@ -3,9 +3,9 @@ from unittest.mock import patch
import pytest
from django_celery_beat.models import IntervalSchedule, PeriodicTask
from rest_framework_json_api.serializers import ValidationError
from tasks.beat import schedule_provider_scan
from api.exceptions import ConflictException
from api.models import Scan
@@ -48,29 +48,10 @@ class TestScheduleProviderScan:
with patch("tasks.tasks.perform_scheduled_scan_task.apply_async"):
schedule_provider_scan(provider_instance)
# Now, try scheduling again, should raise ConflictException
with pytest.raises(ConflictException) as exc_info:
# Now, try scheduling again, should raise ValidationError
with pytest.raises(ValidationError) as exc_info:
schedule_provider_scan(provider_instance)
assert "There is already a scheduled scan for this provider." in str(
exc_info.value
)
def test_remove_periodic_task(self, providers_fixture):
provider_instance = providers_fixture[0]
assert Scan.objects.count() == 0
with patch("tasks.tasks.perform_scheduled_scan_task.apply_async"):
schedule_provider_scan(provider_instance)
assert Scan.objects.count() == 1
scan = Scan.objects.first()
periodic_task = scan.scheduler_task
assert periodic_task is not None
periodic_task.delete()
scan.refresh_from_db()
# Assert the scan still exists but its scheduler_task is set to None
# Otherwise, Scan.DoesNotExist would be raised
assert Scan.objects.get(id=scan.id).scheduler_task is None
+2 -132
View File
@@ -1,15 +1,10 @@
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from tasks.jobs.connection import (
check_integration_connection,
check_lighthouse_connection,
check_provider_connection,
)
from tasks.jobs.connection import check_lighthouse_connection, check_provider_connection
from api.models import Integration, LighthouseConfiguration, Provider
from api.models import LighthouseConfiguration, Provider
@pytest.mark.parametrize(
@@ -132,128 +127,3 @@ def test_check_lighthouse_connection_missing_api_key(mock_lighthouse_get):
assert result["available_models"] == []
assert mock_lighthouse_instance.is_active is False
mock_lighthouse_instance.save.assert_called_once()
@pytest.mark.django_db
class TestCheckIntegrationConnection:
def setup_method(self):
self.integration_id = str(uuid.uuid4())
@patch("tasks.jobs.connection.Integration.objects.filter")
@patch("tasks.jobs.connection.prowler_integration_connection_test")
def test_check_integration_connection_success(
self, mock_prowler_test, mock_integration_filter
):
"""Test successful integration connection check with enabled=True filter."""
mock_integration = MagicMock()
mock_integration.id = self.integration_id
mock_integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
mock_queryset = MagicMock()
mock_queryset.first.return_value = mock_integration
mock_integration_filter.return_value = mock_queryset
mock_connection_result = MagicMock()
mock_connection_result.is_connected = True
mock_connection_result.error = None
mock_prowler_test.return_value = mock_connection_result
result = check_integration_connection(integration_id=self.integration_id)
# Verify that Integration.objects.filter was called with enabled=True filter
mock_integration_filter.assert_called_once_with(
pk=self.integration_id, enabled=True
)
mock_queryset.first.assert_called_once()
mock_prowler_test.assert_called_once_with(mock_integration)
# Verify the integration properties were updated
assert mock_integration.connected is True
assert mock_integration.connection_last_checked_at is not None
mock_integration.save.assert_called_once()
# Verify the return value
assert result["connected"] is True
assert result["error"] is None
@patch("tasks.jobs.connection.Integration.objects.filter")
@patch("tasks.jobs.connection.prowler_integration_connection_test")
def test_check_integration_connection_failure(
self, mock_prowler_test, mock_integration_filter
):
"""Test failed integration connection check."""
mock_integration = MagicMock()
mock_integration.id = self.integration_id
mock_queryset = MagicMock()
mock_queryset.first.return_value = mock_integration
mock_integration_filter.return_value = mock_queryset
test_error = Exception("Connection failed")
mock_connection_result = MagicMock()
mock_connection_result.is_connected = False
mock_connection_result.error = test_error
mock_prowler_test.return_value = mock_connection_result
result = check_integration_connection(integration_id=self.integration_id)
# Verify that Integration.objects.filter was called with enabled=True filter
mock_integration_filter.assert_called_once_with(
pk=self.integration_id, enabled=True
)
mock_queryset.first.assert_called_once()
# Verify the integration properties were updated
assert mock_integration.connected is False
assert mock_integration.connection_last_checked_at is not None
mock_integration.save.assert_called_once()
# Verify the return value
assert result["connected"] is False
assert result["error"] == str(test_error)
@patch("tasks.jobs.connection.Integration.objects.filter")
def test_check_integration_connection_not_enabled(self, mock_integration_filter):
"""Test that disabled integrations return proper error response."""
# Mock that no enabled integration is found
mock_queryset = MagicMock()
mock_queryset.first.return_value = None
mock_integration_filter.return_value = mock_queryset
result = check_integration_connection(integration_id=self.integration_id)
# Verify the filter was called with enabled=True
mock_integration_filter.assert_called_once_with(
pk=self.integration_id, enabled=True
)
mock_queryset.first.assert_called_once()
# Verify the return value matches the expected error response
assert result["connected"] is False
assert result["error"] == "Integration is not enabled"
@patch("tasks.jobs.connection.Integration.objects.filter")
@patch("tasks.jobs.connection.prowler_integration_connection_test")
def test_check_integration_connection_exception(
self, mock_prowler_test, mock_integration_filter
):
"""Test integration connection check when prowler test raises exception."""
mock_integration = MagicMock()
mock_integration.id = self.integration_id
mock_queryset = MagicMock()
mock_queryset.first.return_value = mock_integration
mock_integration_filter.return_value = mock_queryset
test_exception = Exception("Unexpected error during connection test")
mock_prowler_test.side_effect = test_exception
with pytest.raises(Exception, match="Unexpected error during connection test"):
check_integration_connection(integration_id=self.integration_id)
# Verify that Integration.objects.filter was called with enabled=True filter
mock_integration_filter.assert_called_once_with(
pk=self.integration_id, enabled=True
)
mock_queryset.first.assert_called_once()
mock_prowler_test.assert_called_once_with(mock_integration)
+12 -38
View File
@@ -1,7 +1,5 @@
import os
import uuid
import zipfile
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch
@@ -129,26 +127,14 @@ class TestOutputs:
_upload_to_s3("tenant", str(zip_path), "scan")
mock_logger.assert_called()
@patch("tasks.jobs.export.rls_transaction")
@patch("tasks.jobs.export.Scan")
def test_generate_output_directory_creates_paths(
self, mock_scan, mock_rls_transaction, tmpdir
):
# Mock the scan object with a started_at timestamp
mock_scan_instance = MagicMock()
mock_scan_instance.started_at = datetime(2023, 6, 15, 10, 30, 45)
mock_scan.objects.get.return_value = mock_scan_instance
# Mock rls_transaction as a context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock(return_value=False)
def test_generate_output_directory_creates_paths(self, tmpdir):
from prowler.config.config import output_file_timestamp
base_tmp = Path(str(tmpdir.mkdir("generate_output")))
base_dir = str(base_tmp)
tenant_id = str(uuid.uuid4())
scan_id = str(uuid.uuid4())
tenant_id = "t1"
scan_id = "s1"
provider = "aws"
expected_timestamp = "20230615103045"
path, compliance = _generate_output_directory(
base_dir, provider, tenant_id, scan_id
@@ -157,29 +143,17 @@ class TestOutputs:
assert os.path.isdir(os.path.dirname(path))
assert os.path.isdir(os.path.dirname(compliance))
assert path.endswith(f"{provider}-{expected_timestamp}")
assert compliance.endswith(f"{provider}-{expected_timestamp}")
assert path.endswith(f"{provider}-{output_file_timestamp}")
assert compliance.endswith(f"{provider}-{output_file_timestamp}")
@patch("tasks.jobs.export.rls_transaction")
@patch("tasks.jobs.export.Scan")
def test_generate_output_directory_invalid_character(
self, mock_scan, mock_rls_transaction, tmpdir
):
# Mock the scan object with a started_at timestamp
mock_scan_instance = MagicMock()
mock_scan_instance.started_at = datetime(2023, 6, 15, 10, 30, 45)
mock_scan.objects.get.return_value = mock_scan_instance
# Mock rls_transaction as a context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock(return_value=False)
def test_generate_output_directory_invalid_character(self, tmpdir):
from prowler.config.config import output_file_timestamp
base_tmp = Path(str(tmpdir.mkdir("generate_output")))
base_dir = str(base_tmp)
tenant_id = str(uuid.uuid4())
scan_id = str(uuid.uuid4())
tenant_id = "t1"
scan_id = "s1"
provider = "aws/test@check"
expected_timestamp = "20230615103045"
path, compliance = _generate_output_directory(
base_dir, provider, tenant_id, scan_id
@@ -188,5 +162,5 @@ class TestOutputs:
assert os.path.isdir(os.path.dirname(path))
assert os.path.isdir(os.path.dirname(compliance))
assert path.endswith(f"aws-test-check-{expected_timestamp}")
assert compliance.endswith(f"aws-test-check-{expected_timestamp}")
assert path.endswith(f"aws-test-check-{output_file_timestamp}")
assert compliance.endswith(f"aws-test-check-{output_file_timestamp}")
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+17 -633
View File
@@ -1,19 +1,11 @@
import uuid
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from tasks.tasks import (
_perform_scan_complete_tasks,
check_integrations_task,
generate_outputs_task,
s3_integration_task,
security_hub_integration_task,
)
from api.models import Integration
from tasks.tasks import generate_outputs
# TODO Move this to outputs/reports jobs
@pytest.mark.django_db
class TestGenerateOutputs:
def setup_method(self):
@@ -25,7 +17,7 @@ class TestGenerateOutputs:
with patch("tasks.tasks.ScanSummary.objects.filter") as mock_filter:
mock_filter.return_value.exists.return_value = False
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -34,6 +26,7 @@ class TestGenerateOutputs:
assert result == {"upload": False}
mock_filter.assert_called_once_with(scan_id=self.scan_id)
@patch("tasks.tasks.rmtree")
@patch("tasks.tasks._upload_to_s3")
@patch("tasks.tasks._compress_output_files")
@patch("tasks.tasks.get_compliance_frameworks")
@@ -52,6 +45,7 @@ class TestGenerateOutputs:
mock_get_available_frameworks,
mock_compress,
mock_upload,
mock_rmtree,
):
mock_scan_summary_filter.return_value.exists.return_value = True
@@ -101,12 +95,11 @@ class TestGenerateOutputs:
return_value=("out-dir", "comp-dir"),
),
patch("tasks.tasks.Scan.all_objects.filter") as mock_scan_update,
patch("tasks.tasks.rmtree"),
):
mock_compress.return_value = "/tmp/zipped.zip"
mock_upload.return_value = "s3://bucket/zipped.zip"
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -116,6 +109,9 @@ class TestGenerateOutputs:
mock_scan_update.return_value.update.assert_called_once_with(
output_location="s3://bucket/zipped.zip"
)
mock_rmtree.assert_called_once_with(
Path("/tmp/zipped.zip").parent, ignore_errors=True
)
def test_generate_outputs_fails_upload(self):
with (
@@ -147,7 +143,6 @@ class TestGenerateOutputs:
patch("tasks.tasks._compress_output_files", return_value="/tmp/compressed"),
patch("tasks.tasks._upload_to_s3", return_value=None),
patch("tasks.tasks.Scan.all_objects.filter") as mock_scan_update,
patch("tasks.tasks.rmtree"),
):
mock_filter.return_value.exists.return_value = True
mock_findings.return_value.order_by.return_value.iterator.return_value = [
@@ -155,9 +150,9 @@ class TestGenerateOutputs:
True,
]
result = generate_outputs_task(
result = generate_outputs(
scan_id="scan",
provider_id=self.provider_id,
provider_id="provider",
tenant_id=self.tenant_id,
)
@@ -189,7 +184,6 @@ class TestGenerateOutputs:
patch("tasks.tasks._compress_output_files", return_value="/tmp/compressed"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/f.zip"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch("tasks.tasks.rmtree"),
):
mock_filter.return_value.exists.return_value = True
mock_findings.return_value.order_by.return_value.iterator.return_value = [
@@ -214,7 +208,7 @@ class TestGenerateOutputs:
{"aws": [(lambda x: True, MagicMock())]},
),
):
generate_outputs_task(
generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -260,8 +254,8 @@ class TestGenerateOutputs:
),
patch("tasks.tasks._compress_output_files", return_value="outdir.zip"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/outdir.zip"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch("tasks.tasks.rmtree"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch(
"tasks.tasks.batched",
return_value=[
@@ -282,7 +276,7 @@ class TestGenerateOutputs:
}
},
):
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -338,13 +332,13 @@ class TestGenerateOutputs:
),
patch("tasks.tasks._compress_output_files", return_value="outdir.zip"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/outdir.zip"),
patch("tasks.tasks.rmtree"),
patch(
"tasks.tasks.Scan.all_objects.filter",
return_value=MagicMock(update=lambda **kw: None),
),
patch("tasks.tasks.batched", return_value=two_batches),
patch("tasks.tasks.OUTPUT_FORMATS_MAPPING", {}),
patch("tasks.tasks.rmtree"),
patch(
"tasks.tasks.COMPLIANCE_CLASS_MAP",
{"aws": [(lambda name: True, TrackingComplianceWriter)]},
@@ -352,7 +346,7 @@ class TestGenerateOutputs:
):
mock_summary.return_value.exists.return_value = True
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -363,7 +357,6 @@ class TestGenerateOutputs:
assert writer.transform_calls == [([raw2], compliance_obj, "cis")]
assert result == {"upload": True}
# TODO: We need to add a periodic task to delete old output files
def test_generate_outputs_logs_rmtree_exception(self, caplog):
mock_finding_output = MagicMock()
mock_finding_output.compliance = {"cis": ["requirement-1", "requirement-2"]}
@@ -414,618 +407,9 @@ class TestGenerateOutputs:
),
):
with caplog.at_level("ERROR"):
generate_outputs_task(
generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
assert "Error deleting output files" in caplog.text
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_generate_outputs_filters_enabled_s3_integrations(
self, mock_integration_filter, mock_rls
):
"""Test that generate_outputs_task only processes enabled S3 integrations."""
with (
patch("tasks.tasks.ScanSummary.objects.filter") as mock_summary,
patch("tasks.tasks.Provider.objects.get"),
patch("tasks.tasks.initialize_prowler_provider"),
patch("tasks.tasks.Compliance.get_bulk"),
patch("tasks.tasks.get_compliance_frameworks", return_value=[]),
patch("tasks.tasks.Finding.all_objects.filter") as mock_findings,
patch(
"tasks.tasks._generate_output_directory", return_value=("out", "comp")
),
patch("tasks.tasks.FindingOutput._transform_findings_stats"),
patch("tasks.tasks.FindingOutput.transform_api_finding"),
patch("tasks.tasks._compress_output_files", return_value="/tmp/compressed"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/file.zip"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch("tasks.tasks.rmtree"),
patch("tasks.tasks.s3_integration_task.apply_async") as mock_s3_task,
):
mock_summary.return_value.exists.return_value = True
mock_findings.return_value.order_by.return_value.iterator.return_value = [
[MagicMock()],
True,
]
mock_integration_filter.return_value = [MagicMock()]
mock_rls.return_value.__enter__.return_value = None
with (
patch("tasks.tasks.OUTPUT_FORMATS_MAPPING", {}),
patch("tasks.tasks.COMPLIANCE_CLASS_MAP", {"aws": []}),
):
generate_outputs_task(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
# Verify the S3 integrations filters
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
mock_s3_task.assert_called_once()
class TestScanCompleteTasks:
@patch("tasks.tasks.create_compliance_requirements_task.apply_async")
@patch("tasks.tasks.perform_scan_summary_task.si")
@patch("tasks.tasks.generate_outputs_task.si")
def test_scan_complete_tasks(
self, mock_outputs_task, mock_scan_summary_task, mock_compliance_tasks
):
_perform_scan_complete_tasks("tenant-id", "scan-id", "provider-id")
mock_compliance_tasks.assert_called_once_with(
kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"},
)
mock_scan_summary_task.assert_called_once_with(
scan_id="scan-id",
tenant_id="tenant-id",
)
mock_outputs_task.assert_called_once_with(
scan_id="scan-id",
provider_id="provider-id",
tenant_id="tenant-id",
)
@pytest.mark.django_db
class TestCheckIntegrationsTask:
def setup_method(self):
self.scan_id = str(uuid.uuid4())
self.provider_id = str(uuid.uuid4())
self.tenant_id = str(uuid.uuid4())
self.output_directory = "/tmp/some-output-dir"
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_check_integrations_no_integrations(
self, mock_integration_filter, mock_rls
):
mock_integration_filter.return_value.exists.return_value = False
# Ensure rls_transaction is mocked
mock_rls.return_value.__enter__.return_value = None
result = check_integrations_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
)
assert result == {"integrations_processed": 0}
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id,
enabled=True,
)
@patch("tasks.tasks.security_hub_integration_task")
@patch("tasks.tasks.group")
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_check_integrations_security_hub_success(
self, mock_integration_filter, mock_rls, mock_group, mock_security_hub_task
):
"""Test that SecurityHub integrations are processed correctly."""
# Mock that we have SecurityHub integrations
mock_integrations = MagicMock()
mock_integrations.exists.return_value = True
# Mock SecurityHub integrations to return existing integrations
mock_security_hub_integrations = MagicMock()
mock_security_hub_integrations.exists.return_value = True
# Set up the filter chain
mock_integration_filter.return_value = mock_integrations
mock_integrations.filter.return_value = mock_security_hub_integrations
# Mock the task signature
mock_task_signature = MagicMock()
mock_security_hub_task.s.return_value = mock_task_signature
# Mock group job
mock_job = MagicMock()
mock_group.return_value = mock_job
# Ensure rls_transaction is mocked
mock_rls.return_value.__enter__.return_value = None
# Execute the function
result = check_integrations_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
scan_id="test-scan-id",
)
# Should process 1 SecurityHub integration
assert result == {"integrations_processed": 1}
# Verify the integration filter was called
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id,
enabled=True,
)
# Verify SecurityHub integrations were filtered
mock_integrations.filter.assert_called_once_with(
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB
)
# Verify SecurityHub task was created with correct parameters
mock_security_hub_task.s.assert_called_once_with(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
scan_id="test-scan-id",
)
# Verify group was called and job was executed
mock_group.assert_called_once_with([mock_task_signature])
mock_job.apply_async.assert_called_once()
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_check_integrations_disabled_integrations_ignored(
self, mock_integration_filter, mock_rls
):
"""Test that disabled integrations are not processed."""
mock_integration_filter.return_value.exists.return_value = False
mock_rls.return_value.__enter__.return_value = None
result = check_integrations_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
)
assert result == {"integrations_processed": 0}
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id,
enabled=True,
)
@patch("tasks.tasks.s3_integration_task")
@patch("tasks.tasks.Integration.objects.filter")
@patch("tasks.tasks.ScanSummary.objects.filter")
@patch("tasks.tasks.Provider.objects.get")
@patch("tasks.tasks.initialize_prowler_provider")
@patch("tasks.tasks.Compliance.get_bulk")
@patch("tasks.tasks.get_compliance_frameworks")
@patch("tasks.tasks.Finding.all_objects.filter")
@patch("tasks.tasks._generate_output_directory")
@patch("tasks.tasks.FindingOutput._transform_findings_stats")
@patch("tasks.tasks.FindingOutput.transform_api_finding")
@patch("tasks.tasks._compress_output_files")
@patch("tasks.tasks._upload_to_s3")
@patch("tasks.tasks.Scan.all_objects.filter")
@patch("tasks.tasks.rmtree")
def test_generate_outputs_with_asff_for_aws_with_security_hub(
self,
mock_rmtree,
mock_scan_update,
mock_upload,
mock_compress,
mock_transform_finding,
mock_transform_stats,
mock_generate_dir,
mock_findings,
mock_get_frameworks,
mock_compliance_bulk,
mock_initialize_provider,
mock_provider_get,
mock_scan_summary,
mock_integration_filter,
mock_s3_task,
):
"""Test that ASFF output is generated for AWS providers with SecurityHub integration."""
# Setup
mock_scan_summary_qs = MagicMock()
mock_scan_summary_qs.exists.return_value = True
mock_scan_summary.return_value = mock_scan_summary_qs
# Mock AWS provider
mock_provider = MagicMock()
mock_provider.uid = "aws-account-123"
mock_provider.provider = "aws"
mock_provider_get.return_value = mock_provider
# Mock SecurityHub integration exists
mock_security_hub_integrations = MagicMock()
mock_security_hub_integrations.exists.return_value = True
mock_integration_filter.return_value = mock_security_hub_integrations
# Mock s3_integration_task
mock_s3_task.apply_async.return_value.get.return_value = True
# Mock other necessary components
mock_initialize_provider.return_value = MagicMock()
mock_compliance_bulk.return_value = {}
mock_get_frameworks.return_value = []
mock_generate_dir.return_value = ("out-dir", "comp-dir")
mock_transform_stats.return_value = {"stats": "data"}
# Mock findings
mock_finding = MagicMock()
mock_findings.return_value.order_by.return_value.iterator.return_value = [
[mock_finding],
True,
]
mock_transform_finding.return_value = MagicMock(compliance={})
# Track which output formats were created
created_writers = {}
def track_writer_creation(cls_type):
def factory(*args, **kwargs):
writer = MagicMock()
writer._data = []
writer.transform = MagicMock()
writer.batch_write_data_to_file = MagicMock()
created_writers[cls_type] = writer
return writer
return factory
# Mock OUTPUT_FORMATS_MAPPING with tracking
with patch(
"tasks.tasks.OUTPUT_FORMATS_MAPPING",
{
"csv": {
"class": track_writer_creation("csv"),
"suffix": ".csv",
"kwargs": {},
},
"json-asff": {
"class": track_writer_creation("asff"),
"suffix": ".asff.json",
"kwargs": {},
},
"json-ocsf": {
"class": track_writer_creation("ocsf"),
"suffix": ".ocsf.json",
"kwargs": {},
},
},
):
mock_compress.return_value = "/tmp/compressed.zip"
mock_upload.return_value = "s3://bucket/file.zip"
# Execute
result = generate_outputs_task(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
# Verify ASFF was created for AWS with SecurityHub
assert "asff" in created_writers, "ASFF writer should be created"
assert "csv" in created_writers, "CSV writer should be created"
assert "ocsf" in created_writers, "OCSF writer should be created"
# Verify SecurityHub integration was checked
assert mock_integration_filter.call_count == 2
mock_integration_filter.assert_any_call(
integrationproviderrelationship__provider_id=self.provider_id,
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
enabled=True,
)
assert result == {"upload": True}
@patch("tasks.tasks.s3_integration_task")
@patch("tasks.tasks.Integration.objects.filter")
@patch("tasks.tasks.ScanSummary.objects.filter")
@patch("tasks.tasks.Provider.objects.get")
@patch("tasks.tasks.initialize_prowler_provider")
@patch("tasks.tasks.Compliance.get_bulk")
@patch("tasks.tasks.get_compliance_frameworks")
@patch("tasks.tasks.Finding.all_objects.filter")
@patch("tasks.tasks._generate_output_directory")
@patch("tasks.tasks.FindingOutput._transform_findings_stats")
@patch("tasks.tasks.FindingOutput.transform_api_finding")
@patch("tasks.tasks._compress_output_files")
@patch("tasks.tasks._upload_to_s3")
@patch("tasks.tasks.Scan.all_objects.filter")
@patch("tasks.tasks.rmtree")
def test_generate_outputs_no_asff_for_aws_without_security_hub(
self,
mock_rmtree,
mock_scan_update,
mock_upload,
mock_compress,
mock_transform_finding,
mock_transform_stats,
mock_generate_dir,
mock_findings,
mock_get_frameworks,
mock_compliance_bulk,
mock_initialize_provider,
mock_provider_get,
mock_scan_summary,
mock_integration_filter,
mock_s3_task,
):
"""Test that ASFF output is NOT generated for AWS providers without SecurityHub integration."""
# Setup
mock_scan_summary_qs = MagicMock()
mock_scan_summary_qs.exists.return_value = True
mock_scan_summary.return_value = mock_scan_summary_qs
# Mock AWS provider
mock_provider = MagicMock()
mock_provider.uid = "aws-account-123"
mock_provider.provider = "aws"
mock_provider_get.return_value = mock_provider
# Mock NO SecurityHub integration
mock_security_hub_integrations = MagicMock()
mock_security_hub_integrations.exists.return_value = False
mock_integration_filter.return_value = mock_security_hub_integrations
# Mock other necessary components
mock_initialize_provider.return_value = MagicMock()
mock_compliance_bulk.return_value = {}
mock_get_frameworks.return_value = []
mock_generate_dir.return_value = ("out-dir", "comp-dir")
mock_transform_stats.return_value = {"stats": "data"}
# Mock findings
mock_finding = MagicMock()
mock_findings.return_value.order_by.return_value.iterator.return_value = [
[mock_finding],
True,
]
mock_transform_finding.return_value = MagicMock(compliance={})
# Track which output formats were created
created_writers = {}
def track_writer_creation(cls_type):
def factory(*args, **kwargs):
writer = MagicMock()
writer._data = []
writer.transform = MagicMock()
writer.batch_write_data_to_file = MagicMock()
created_writers[cls_type] = writer
return writer
return factory
# Mock OUTPUT_FORMATS_MAPPING with tracking
with patch(
"tasks.tasks.OUTPUT_FORMATS_MAPPING",
{
"csv": {
"class": track_writer_creation("csv"),
"suffix": ".csv",
"kwargs": {},
},
"json-asff": {
"class": track_writer_creation("asff"),
"suffix": ".asff.json",
"kwargs": {},
},
"json-ocsf": {
"class": track_writer_creation("ocsf"),
"suffix": ".ocsf.json",
"kwargs": {},
},
},
):
mock_compress.return_value = "/tmp/compressed.zip"
mock_upload.return_value = "s3://bucket/file.zip"
# Execute
result = generate_outputs_task(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
# Verify ASFF was NOT created when no SecurityHub integration
assert "asff" not in created_writers, "ASFF writer should NOT be created"
assert "csv" in created_writers, "CSV writer should be created"
assert "ocsf" in created_writers, "OCSF writer should be created"
# Verify SecurityHub integration was checked
assert mock_integration_filter.call_count == 2
mock_integration_filter.assert_any_call(
integrationproviderrelationship__provider_id=self.provider_id,
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
enabled=True,
)
assert result == {"upload": True}
@patch("tasks.tasks.ScanSummary.objects.filter")
@patch("tasks.tasks.Provider.objects.get")
@patch("tasks.tasks.initialize_prowler_provider")
@patch("tasks.tasks.Compliance.get_bulk")
@patch("tasks.tasks.get_compliance_frameworks")
@patch("tasks.tasks.Finding.all_objects.filter")
@patch("tasks.tasks._generate_output_directory")
@patch("tasks.tasks.FindingOutput._transform_findings_stats")
@patch("tasks.tasks.FindingOutput.transform_api_finding")
@patch("tasks.tasks._compress_output_files")
@patch("tasks.tasks._upload_to_s3")
@patch("tasks.tasks.Scan.all_objects.filter")
@patch("tasks.tasks.rmtree")
def test_generate_outputs_no_asff_for_non_aws_provider(
self,
mock_rmtree,
mock_scan_update,
mock_upload,
mock_compress,
mock_transform_finding,
mock_transform_stats,
mock_generate_dir,
mock_findings,
mock_get_frameworks,
mock_compliance_bulk,
mock_initialize_provider,
mock_provider_get,
mock_scan_summary,
):
"""Test that ASFF output is NOT generated for non-AWS providers (e.g., Azure, GCP)."""
# Setup
mock_scan_summary_qs = MagicMock()
mock_scan_summary_qs.exists.return_value = True
mock_scan_summary.return_value = mock_scan_summary_qs
# Mock Azure provider (non-AWS)
mock_provider = MagicMock()
mock_provider.uid = "azure-subscription-123"
mock_provider.provider = "azure" # Non-AWS provider
mock_provider_get.return_value = mock_provider
# Mock other necessary components
mock_initialize_provider.return_value = MagicMock()
mock_compliance_bulk.return_value = {}
mock_get_frameworks.return_value = []
mock_generate_dir.return_value = ("out-dir", "comp-dir")
mock_transform_stats.return_value = {"stats": "data"}
# Mock findings
mock_finding = MagicMock()
mock_findings.return_value.order_by.return_value.iterator.return_value = [
[mock_finding],
True,
]
mock_transform_finding.return_value = MagicMock(compliance={})
# Track which output formats were created
created_writers = {}
def track_writer_creation(cls_type):
def factory(*args, **kwargs):
writer = MagicMock()
writer._data = []
writer.transform = MagicMock()
writer.batch_write_data_to_file = MagicMock()
created_writers[cls_type] = writer
return writer
return factory
# Mock OUTPUT_FORMATS_MAPPING with tracking
with patch(
"tasks.tasks.OUTPUT_FORMATS_MAPPING",
{
"csv": {
"class": track_writer_creation("csv"),
"suffix": ".csv",
"kwargs": {},
},
"json-asff": {
"class": track_writer_creation("asff"),
"suffix": ".asff.json",
"kwargs": {},
},
"json-ocsf": {
"class": track_writer_creation("ocsf"),
"suffix": ".ocsf.json",
"kwargs": {},
},
},
):
mock_compress.return_value = "/tmp/compressed.zip"
mock_upload.return_value = "s3://bucket/file.zip"
# Execute
result = generate_outputs_task(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
# Verify ASFF was NOT created for non-AWS provider
assert (
"asff" not in created_writers
), "ASFF writer should NOT be created for non-AWS providers"
assert "csv" in created_writers, "CSV writer should be created"
assert "ocsf" in created_writers, "OCSF writer should be created"
assert result == {"upload": True}
@patch("tasks.tasks.upload_s3_integration")
def test_s3_integration_task_success(self, mock_upload):
mock_upload.return_value = True
output_directory = "/tmp/prowler_api_output/test"
result = s3_integration_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
output_directory=output_directory,
)
assert result is True
mock_upload.assert_called_once_with(
self.tenant_id, self.provider_id, output_directory
)
@patch("tasks.tasks.upload_s3_integration")
def test_s3_integration_task_failure(self, mock_upload):
mock_upload.return_value = False
output_directory = "/tmp/prowler_api_output/test"
result = s3_integration_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
output_directory=output_directory,
)
assert result is False
mock_upload.assert_called_once_with(
self.tenant_id, self.provider_id, output_directory
)
@patch("tasks.tasks.upload_security_hub_integration")
def test_security_hub_integration_task_success(self, mock_upload):
"""Test successful SecurityHub integration task execution."""
mock_upload.return_value = True
scan_id = "test-scan-123"
result = security_hub_integration_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
scan_id=scan_id,
)
assert result is True
mock_upload.assert_called_once_with(self.tenant_id, self.provider_id, scan_id)
@patch("tasks.tasks.upload_security_hub_integration")
def test_security_hub_integration_task_failure(self, mock_upload):
"""Test SecurityHub integration task handling failure."""
mock_upload.return_value = False
scan_id = "test-scan-123"
result = security_hub_integration_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
scan_id=scan_id,
)
assert result is False
mock_upload.assert_called_once_with(self.tenant_id, self.provider_id, scan_id)
@@ -1,128 +0,0 @@
import random
from collections import defaultdict
import requests
from locust import events, task
from utils.helpers import APIUserBase, get_api_token, get_auth_headers
GLOBAL = {
"token": None,
"available_scans_info": {},
}
SUPPORTED_COMPLIANCE_IDS = {
"aws": ["ens_rd2022", "cis_2.0", "prowler_threatscore", "soc2"],
"gcp": ["ens_rd2022", "cis_2.0", "prowler_threatscore", "soc2"],
"azure": ["ens_rd2022", "cis_2.0", "prowler_threatscore", "soc2"],
"m365": ["cis_4.0", "iso27001_2022", "prowler_threatscore"],
}
def _get_random_scan() -> tuple:
provider_type = random.choice(list(GLOBAL["available_scans_info"].keys()))
scan_info = random.choice(GLOBAL["available_scans_info"][provider_type])
return provider_type, scan_info
def _get_random_compliance_id(provider: str) -> str:
return f"{random.choice(SUPPORTED_COMPLIANCE_IDS[provider])}_{provider}"
def _get_compliance_available_scans_by_provider_type(host: str, token: str) -> dict:
excluded_providers = ["kubernetes"]
response_dict = defaultdict(list)
provider_response = requests.get(
f"{host}/providers?fields[providers]=id,provider&filter[connected]=true",
headers=get_auth_headers(token),
)
for provider in provider_response.json()["data"]:
provider_id = provider["id"]
provider_type = provider["attributes"]["provider"]
if provider_type in excluded_providers:
continue
scan_response = requests.get(
f"{host}/scans?fields[scans]=id&filter[provider]={provider_id}&filter[state]=completed",
headers=get_auth_headers(token),
)
scan_data = scan_response.json()["data"]
if not scan_data:
continue
scan_id = scan_data[0]["id"]
response_dict[provider_type].append(scan_id)
return response_dict
def _get_compliance_regions_from_scan(host: str, token: str, scan_id: str) -> list:
response = requests.get(
f"{host}/compliance-overviews/metadata?filter[scan_id]={scan_id}",
headers=get_auth_headers(token),
)
assert response.status_code == 200, f"Failed to get scan: {response.text}"
return response.json()["data"]["attributes"]["regions"]
@events.test_start.add_listener
def on_test_start(environment, **kwargs):
GLOBAL["token"] = get_api_token(environment.host)
scans_by_provider = _get_compliance_available_scans_by_provider_type(
environment.host, GLOBAL["token"]
)
scan_info = defaultdict(list)
for provider, scans in scans_by_provider.items():
for scan in scans:
scan_info[provider].append(
{
"scan_id": scan,
"regions": _get_compliance_regions_from_scan(
environment.host, GLOBAL["token"], scan
),
}
)
GLOBAL["available_scans_info"] = scan_info
class APIUser(APIUserBase):
def on_start(self):
self.token = GLOBAL["token"]
@task(3)
def compliance_overviews_default(self):
provider_type, scan_info = _get_random_scan()
name = f"/compliance-overviews ({provider_type})"
endpoint = f"/compliance-overviews?" f"filter[scan_id]={scan_info['scan_id']}"
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(2)
def compliance_overviews_region(self):
provider_type, scan_info = _get_random_scan()
name = f"/compliance-overviews?filter[region] ({provider_type})"
endpoint = (
f"/compliance-overviews"
f"?filter[scan_id]={scan_info['scan_id']}"
f"&filter[region]={random.choice(scan_info['regions'])}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task(2)
def compliance_overviews_requirements(self):
provider_type, scan_info = _get_random_scan()
compliance_id = _get_random_compliance_id(provider_type)
name = f"/compliance-overviews/requirements ({compliance_id})"
endpoint = (
f"/compliance-overviews/requirements"
f"?filter[scan_id]={scan_info['scan_id']}"
f"&filter[compliance_id]={compliance_id}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)
@task
def compliance_overviews_attributes(self):
provider_type, _ = _get_random_scan()
compliance_id = _get_random_compliance_id(provider_type)
name = f"/compliance-overviews/attributes ({compliance_id})"
endpoint = (
f"/compliance-overviews/attributes"
f"?filter[compliance_id]={compliance_id}"
)
self.client.get(endpoint, headers=get_auth_headers(self.token), name=name)

Some files were not shown because too many files have changed in this diff Show More