feat(attack-surfaces): add new endpoints to retrieve overview data (#9309)

This commit is contained in:
Víctor Fernández Poyatos
2025-12-02 12:12:47 +01:00
committed by GitHub
parent 4661e01c26
commit 07e82bde56
16 changed files with 1388 additions and 92 deletions
+3
View File
@@ -4,6 +4,9 @@ All notable changes to the **Prowler API** are documented in this file.
## [1.16.0] (Unreleased)
### Added
- New endpoint to retrieve an overview of the attack surfaces [(#9309)](https://github.com/prowler-cloud/prowler/pull/9309)
### Changed
- Restore the compliance overview endpoint's mandatory filters [(#9330)](https://github.com/prowler-cloud/prowler/pull/9330)
+1 -1
View File
@@ -44,7 +44,7 @@ name = "prowler-api"
package-mode = false
# Needed for the SDK compatibility
requires-python = ">=3.11,<3.13"
version = "1.15.0"
version = "1.16.0"
[project.scripts]
celery = "src.backend.config.settings.celery"
+11
View File
@@ -40,6 +40,7 @@ class ApiConfig(AppConfig):
self._ensure_crypto_keys()
load_prowler_compliance()
self._initialize_attack_surface_mapping()
def _ensure_crypto_keys(self):
"""
@@ -167,3 +168,13 @@ class ApiConfig(AppConfig):
f"Error generating JWT keys: {e}. Please set '{SIGNING_KEY_ENV}' and '{VERIFYING_KEY_ENV}' manually."
)
raise e
def _initialize_attack_surface_mapping(self):
from tasks.jobs.scan import ( # noqa: F401
_get_attack_surface_mapping_from_provider,
)
from api.models import Provider # noqa: F401
for provider_type, _label in Provider.ProviderChoices.choices:
_get_attack_surface_mapping_from_provider(provider_type)
+20
View File
@@ -23,6 +23,7 @@ from api.db_utils import (
StatusEnumField,
)
from api.models import (
AttackSurfaceOverview,
ComplianceRequirementOverview,
Finding,
Integration,
@@ -1013,3 +1014,22 @@ class ThreatScoreSnapshotFilter(FilterSet):
"inserted_at": ["date", "gte", "lte"],
"overall_score": ["exact", "gte", "lte"],
}
class AttackSurfaceOverviewFilter(FilterSet):
"""Filter for attack surface overview aggregations by provider."""
provider_id = UUIDFilter(field_name="scan__provider__id", lookup_expr="exact")
provider_id__in = UUIDInFilter(field_name="scan__provider__id", lookup_expr="in")
provider_type = ChoiceFilter(
field_name="scan__provider__provider", choices=Provider.ProviderChoices.choices
)
provider_type__in = ChoiceInFilter(
field_name="scan__provider__provider",
choices=Provider.ProviderChoices.choices,
lookup_expr="in",
)
class Meta:
model = AttackSurfaceOverview
fields = {}
@@ -0,0 +1,89 @@
# Generated by Django 5.1.14 on 2025-11-19 13:03
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0059_compliance_overview_summary"),
]
operations = [
migrations.CreateModel(
name="AttackSurfaceOverview",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
(
"attack_surface_type",
models.CharField(
choices=[
("internet-exposed", "Internet Exposed"),
("secrets", "Exposed Secrets"),
("privilege-escalation", "Privilege Escalation"),
("ec2-imdsv1", "EC2 IMDSv1 Enabled"),
],
max_length=50,
),
),
("total_findings", models.IntegerField(default=0)),
("failed_findings", models.IntegerField(default=0)),
("muted_failed_findings", models.IntegerField(default=0)),
],
options={
"db_table": "attack_surface_overviews",
"abstract": False,
},
),
migrations.AddField(
model_name="attacksurfaceoverview",
name="scan",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="attack_surface_overviews",
related_query_name="attack_surface_overview",
to="api.scan",
),
),
migrations.AddField(
model_name="attacksurfaceoverview",
name="tenant",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
migrations.AddIndex(
model_name="attacksurfaceoverview",
index=models.Index(
fields=["tenant_id", "scan_id"], name="attack_surf_tenant_scan_idx"
),
),
migrations.AddConstraint(
model_name="attacksurfaceoverview",
constraint=models.UniqueConstraint(
fields=("tenant_id", "scan_id", "attack_surface_type"),
name="unique_attack_surface_per_scan",
),
),
migrations.AddConstraint(
model_name="attacksurfaceoverview",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_attacksurfaceoverview",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
]
+60
View File
@@ -2405,3 +2405,63 @@ class ThreatScoreSnapshot(RowLevelSecurityProtectedModel):
class JSONAPIMeta:
resource_name = "threatscore-snapshots"
class AttackSurfaceOverview(RowLevelSecurityProtectedModel):
"""
Pre-aggregated attack surface metrics per scan.
Stores counts for each attack surface type (internet-exposed, secrets,
privilege-escalation, ec2-imdsv1) to enable fast overview queries.
"""
class AttackSurfaceTypeChoices(models.TextChoices):
INTERNET_EXPOSED = "internet-exposed", _("Internet Exposed")
SECRETS = "secrets", _("Exposed Secrets")
PRIVILEGE_ESCALATION = "privilege-escalation", _("Privilege Escalation")
EC2_IMDSV1 = "ec2-imdsv1", _("EC2 IMDSv1 Enabled")
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
scan = models.ForeignKey(
Scan,
on_delete=models.CASCADE,
related_name="attack_surface_overviews",
related_query_name="attack_surface_overview",
)
attack_surface_type = models.CharField(
max_length=50,
choices=AttackSurfaceTypeChoices.choices,
)
# Finding counts
total_findings = models.IntegerField(default=0) # All findings (PASS + FAIL)
failed_findings = models.IntegerField(default=0) # Non-muted failed findings
muted_failed_findings = models.IntegerField(default=0) # Muted failed findings
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "attack_surface_overviews"
constraints = [
models.UniqueConstraint(
fields=("tenant_id", "scan_id", "attack_surface_type"),
name="unique_attack_surface_per_scan",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
indexes = [
models.Index(
fields=["tenant_id", "scan_id"],
name="attack_surf_tenant_scan_idx",
),
]
class JSONAPIMeta:
resource_name = "attack-surface-overviews"
+2 -2
View File
@@ -65,11 +65,11 @@ def get_providers(role: Role) -> QuerySet[Provider]:
A QuerySet of Provider objects filtered by the role's provider groups.
If the role has no provider groups, returns an empty queryset.
"""
tenant = role.tenant
tenant_id = role.tenant_id
provider_groups = role.provider_groups.all()
if not provider_groups.exists():
return Provider.objects.none()
return Provider.objects.filter(
tenant=tenant, provider_groups__in=provider_groups
tenant_id=tenant_id, provider_groups__in=provider_groups
).distinct()
+98 -1
View File
@@ -1,7 +1,7 @@
openapi: 3.0.3
info:
title: Prowler API
version: 1.15.0
version: 1.16.0
description: |-
Prowler API specification.
@@ -4497,6 +4497,60 @@ paths:
responses:
'204':
description: No response body
/api/v1/overviews/attack-surfaces:
get:
operationId: overviews_attack_surfaces_retrieve
description: Retrieve aggregated attack surface metrics from latest completed
scans per provider.
summary: Get attack surface overview
parameters:
- in: query
name: fields[attack-surface-overviews]
schema:
type: array
items:
type: string
enum:
- id
- total_findings
- failed_findings
- muted_failed_findings
- check_ids
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: query
name: filter[provider_id.in]
schema:
type: string
description: Filter by multiple provider IDs (comma-separated UUIDs)
- in: query
name: filter[provider_id]
schema:
type: string
format: uuid
description: Filter by specific provider ID
- in: query
name: filter[provider_type.in]
schema:
type: string
description: Filter by multiple provider types (comma-separated)
- in: query
name: filter[provider_type]
schema:
type: string
description: Filter by provider type (aws, azure, gcp, etc.)
tags:
- Overview
security:
- JWT or API Key: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/AttackSurfaceOverviewResponse'
description: ''
/api/v1/overviews/findings:
get:
operationId: overviews_findings_retrieve
@@ -10618,6 +10672,49 @@ paths:
description: ''
components:
schemas:
AttackSurfaceOverview:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-surface-overviews
id: {}
attributes:
type: object
properties:
id:
type: string
total_findings:
type: integer
failed_findings:
type: integer
muted_failed_findings:
type: integer
check_ids:
type: array
items:
type: string
readOnly: true
required:
- id
- total_findings
- failed_findings
- muted_failed_findings
AttackSurfaceOverviewResponse:
type: object
properties:
data:
$ref: '#/components/schemas/AttackSurfaceOverview'
required:
- data
ComplianceOverview:
type: object
required:
+385
View File
@@ -35,6 +35,7 @@ from rest_framework.response import Response
from api.compliance import get_compliance_frameworks
from api.db_router import MainRouter
from api.models import (
AttackSurfaceOverview,
Finding,
Integration,
Invitation,
@@ -6856,6 +6857,390 @@ class TestOverviewViewSet:
assert combined_attributes["medium"] == 4
assert combined_attributes["critical"] == 3
def test_overview_attack_surface_no_data(self, authenticated_client):
response = authenticated_client.get(reverse("overview-attack-surface"))
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 4
for item in data:
assert item["attributes"]["total_findings"] == 0
assert item["attributes"]["failed_findings"] == 0
assert item["attributes"]["muted_failed_findings"] == 0
assert item["attributes"]["check_ids"] == []
def test_overview_attack_surface_with_data(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
create_attack_surface_overview,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
mapping = {
"internet-exposed": {"aws-check-1", "aws-check-2"},
"secrets": {"aws-secret-check"},
"privilege-escalation": {"aws-priv-check"},
"ec2-imdsv1": {"aws-imdsv1-check"},
}
scan = Scan.objects.create(
name="attack-surface-scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
create_attack_surface_overview(
tenant,
scan,
AttackSurfaceOverview.AttackSurfaceTypeChoices.INTERNET_EXPOSED,
total=20,
failed=10,
muted_failed=3,
)
create_attack_surface_overview(
tenant,
scan,
AttackSurfaceOverview.AttackSurfaceTypeChoices.SECRETS,
total=15,
failed=8,
muted_failed=2,
)
with patch(
"api.v1.views._get_attack_surface_mapping_from_provider",
return_value=mapping,
):
response = authenticated_client.get(reverse("overview-attack-surface"))
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 4
results_by_type = {item["id"]: item["attributes"] for item in data}
assert results_by_type["internet-exposed"]["total_findings"] == 20
assert results_by_type["internet-exposed"]["failed_findings"] == 10
assert set(results_by_type["internet-exposed"]["check_ids"]) == {
"aws-check-1",
"aws-check-2",
}
assert results_by_type["secrets"]["total_findings"] == 15
assert results_by_type["secrets"]["failed_findings"] == 8
assert set(results_by_type["secrets"]["check_ids"]) == {"aws-secret-check"}
assert results_by_type["privilege-escalation"]["total_findings"] == 0
assert set(results_by_type["privilege-escalation"]["check_ids"]) == {
"aws-priv-check"
}
assert results_by_type["ec2-imdsv1"]["total_findings"] == 0
assert set(results_by_type["ec2-imdsv1"]["check_ids"]) == {"aws-imdsv1-check"}
def test_overview_attack_surface_provider_filter(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
create_attack_surface_overview,
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
scan1 = Scan.objects.create(
name="attack-surface-scan-1",
provider=provider1,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
scan2 = Scan.objects.create(
name="attack-surface-scan-2",
provider=provider2,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
mapping = {
"internet-exposed": {"shared-check", "shared-check"},
"secrets": set(),
"privilege-escalation": {"priv-check"},
"ec2-imdsv1": {"imdsv1-check"},
}
create_attack_surface_overview(
tenant,
scan1,
AttackSurfaceOverview.AttackSurfaceTypeChoices.INTERNET_EXPOSED,
total=10,
failed=5,
muted_failed=1,
)
create_attack_surface_overview(
tenant,
scan2,
AttackSurfaceOverview.AttackSurfaceTypeChoices.INTERNET_EXPOSED,
total=20,
failed=15,
muted_failed=3,
)
with patch(
"api.v1.views._get_attack_surface_mapping_from_provider",
return_value=mapping,
):
response = authenticated_client.get(
reverse("overview-attack-surface"),
{"filter[provider_id]": str(provider1.id)},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
results_by_type = {item["id"]: item["attributes"] for item in data}
assert results_by_type["internet-exposed"]["total_findings"] == 10
assert results_by_type["internet-exposed"]["failed_findings"] == 5
assert results_by_type["internet-exposed"]["check_ids"] == ["shared-check"]
def test_overview_services_region_filter(
self, authenticated_client, scan_summaries_fixture
):
response = authenticated_client.get(
reverse("overview-services"),
{"filter[region]": "region1"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 2
service_ids = {item["id"] for item in data}
assert service_ids == {"service1", "service2"}
def test_overview_services_provider_type_filter(
self, authenticated_client, tenants_fixture, providers_fixture
):
tenant = tenants_fixture[0]
aws_provider, _, gcp_provider, *_ = providers_fixture
aws_scan = Scan.objects.create(
name="aws-scan",
provider=aws_provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
gcp_scan = Scan.objects.create(
name="gcp-scan",
provider=gcp_provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
ScanSummary.objects.create(
tenant=tenant,
scan=aws_scan,
check_id="aws-check",
service="aws-service",
severity="high",
region="us-east-1",
_pass=5,
fail=2,
muted=1,
total=8,
)
ScanSummary.objects.create(
tenant=tenant,
scan=gcp_scan,
check_id="gcp-check",
service="gcp-service",
severity="medium",
region="us-central1",
_pass=3,
fail=1,
muted=0,
total=4,
)
response = authenticated_client.get(
reverse("overview-services"),
{"filter[provider_type]": "aws"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
service_ids = [item["id"] for item in data]
assert "aws-service" in service_ids
assert "gcp-service" not in service_ids
@pytest.mark.parametrize(
"status_filter,field_to_check",
[
("FAIL", "fail"),
("PASS", "_pass"),
],
)
def test_overview_findings_severity_status_filter(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
status_filter,
field_to_check,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
scan = Scan.objects.create(
name="status-filter-scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
ScanSummary.objects.create(
tenant=tenant,
scan=scan,
check_id="status-check-high",
service="service-a",
severity="high",
region="us-east-1",
_pass=10,
fail=5,
muted=3,
total=18,
)
ScanSummary.objects.create(
tenant=tenant,
scan=scan,
check_id="status-check-medium",
service="service-a",
severity="medium",
region="us-east-1",
_pass=8,
fail=2,
muted=1,
total=11,
)
response = authenticated_client.get(
reverse("overview-findings_severity"),
{
"filter[provider_id]": str(provider.id),
"filter[status]": status_filter,
},
)
assert response.status_code == status.HTTP_200_OK
attrs = response.json()["data"]["attributes"]
if status_filter == "FAIL":
assert attrs["high"] == 5
assert attrs["medium"] == 2
else:
assert attrs["high"] == 10
assert attrs["medium"] == 8
def test_overview_threatscore_compliance_id_filter(
self, authenticated_client, tenants_fixture, providers_fixture
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
scan = self._create_scan(tenant, provider, "compliance-filter-scan")
self._create_threatscore_snapshot(
tenant,
scan,
provider,
compliance_id="prowler_threatscore_aws",
overall_score="75.00",
score_delta="2.00",
section_scores={"1. IAM": "70.00"},
critical_requirements=[],
total_requirements=50,
passed_requirements=35,
failed_requirements=15,
manual_requirements=0,
total_findings=30,
passed_findings=20,
failed_findings=10,
)
self._create_threatscore_snapshot(
tenant,
scan,
provider,
compliance_id="cis_1.4_aws",
overall_score="65.00",
score_delta="1.00",
section_scores={"1. IAM": "60.00"},
critical_requirements=[],
total_requirements=40,
passed_requirements=25,
failed_requirements=15,
manual_requirements=0,
total_findings=25,
passed_findings=15,
failed_findings=10,
)
response = authenticated_client.get(
reverse("overview-threatscore"),
{"filter[compliance_id]": "prowler_threatscore_aws"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
assert data[0]["attributes"]["overall_score"] == "75.00"
assert data[0]["attributes"]["compliance_id"] == "prowler_threatscore_aws"
def test_overview_threatscore_provider_type_filter(
self, authenticated_client, tenants_fixture, providers_fixture
):
tenant = tenants_fixture[0]
aws_provider, _, gcp_provider, *_ = providers_fixture
aws_scan = self._create_scan(tenant, aws_provider, "aws-threatscore-scan")
gcp_scan = self._create_scan(tenant, gcp_provider, "gcp-threatscore-scan")
self._create_threatscore_snapshot(
tenant,
aws_scan,
aws_provider,
compliance_id="prowler_threatscore_aws",
overall_score="80.00",
score_delta="3.00",
section_scores={"1. IAM": "75.00"},
critical_requirements=[],
total_requirements=60,
passed_requirements=45,
failed_requirements=15,
manual_requirements=0,
total_findings=40,
passed_findings=30,
failed_findings=10,
)
self._create_threatscore_snapshot(
tenant,
gcp_scan,
gcp_provider,
compliance_id="prowler_threatscore_gcp",
overall_score="70.00",
score_delta="2.00",
section_scores={"1. IAM": "65.00"},
critical_requirements=[],
total_requirements=50,
passed_requirements=35,
failed_requirements=15,
manual_requirements=0,
total_findings=35,
passed_findings=25,
failed_findings=10,
)
response = authenticated_client.get(
reverse("overview-threatscore"),
{"filter[provider_type]": "aws"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
assert data[0]["attributes"]["overall_score"] == "80.00"
@pytest.mark.django_db
class TestScheduleViewSet:
+75 -70
View File
@@ -72,6 +72,42 @@ from api.v1.serializer_utils.processors import ProcessorConfigField
from api.v1.serializer_utils.providers import ProviderSecretField
from prowler.lib.mutelist.mutelist import Mutelist
# Base
class BaseModelSerializerV1(serializers.ModelSerializer):
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
class BaseSerializerV1(serializers.Serializer):
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
class BaseWriteSerializer(BaseModelSerializerV1):
def validate(self, data):
if hasattr(self, "initial_data"):
initial_data = set(self.initial_data.keys()) - {"id", "type"}
unknown_keys = initial_data - set(self.fields.keys())
if unknown_keys:
raise ValidationError(f"Invalid fields: {unknown_keys}")
return data
class RLSSerializer(BaseModelSerializerV1):
def create(self, validated_data):
tenant_id = self.context.get("tenant_id")
validated_data["tenant_id"] = tenant_id
return super().create(validated_data)
class StateEnumSerializerField(serializers.ChoiceField):
def __init__(self, **kwargs):
kwargs["choices"] = StateChoices.choices
super().__init__(**kwargs)
# Tokens
@@ -179,7 +215,7 @@ class TokenSocialLoginSerializer(BaseTokenSerializer):
# TODO: Check if we can change the parent class to TokenRefreshSerializer from rest_framework_simplejwt.serializers
class TokenRefreshSerializer(serializers.Serializer):
class TokenRefreshSerializer(BaseSerializerV1):
refresh = serializers.CharField()
# Output token
@@ -213,7 +249,7 @@ class TokenRefreshSerializer(serializers.Serializer):
raise ValidationError({"refresh": "Invalid or expired token"})
class TokenSwitchTenantSerializer(serializers.Serializer):
class TokenSwitchTenantSerializer(BaseSerializerV1):
tenant_id = serializers.UUIDField(
write_only=True, help_text="The tenant ID for which to request a new token."
)
@@ -237,41 +273,10 @@ class TokenSwitchTenantSerializer(serializers.Serializer):
return generate_tokens(user, tenant_id)
# Base
class BaseSerializerV1(serializers.ModelSerializer):
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
class BaseWriteSerializer(BaseSerializerV1):
def validate(self, data):
if hasattr(self, "initial_data"):
initial_data = set(self.initial_data.keys()) - {"id", "type"}
unknown_keys = initial_data - set(self.fields.keys())
if unknown_keys:
raise ValidationError(f"Invalid fields: {unknown_keys}")
return data
class RLSSerializer(BaseSerializerV1):
def create(self, validated_data):
tenant_id = self.context.get("tenant_id")
validated_data["tenant_id"] = tenant_id
return super().create(validated_data)
class StateEnumSerializerField(serializers.ChoiceField):
def __init__(self, **kwargs):
kwargs["choices"] = StateChoices.choices
super().__init__(**kwargs)
# Users
class UserSerializer(BaseSerializerV1):
class UserSerializer(BaseModelSerializerV1):
"""
Serializer for the User model.
"""
@@ -402,7 +407,7 @@ class UserUpdateSerializer(BaseWriteSerializer):
return super().update(instance, validated_data)
class RoleResourceIdentifierSerializer(serializers.Serializer):
class RoleResourceIdentifierSerializer(BaseSerializerV1):
resource_type = serializers.CharField(source="type")
id = serializers.UUIDField()
@@ -585,7 +590,7 @@ class TaskSerializer(RLSSerializer, TaskBase):
# Tenants
class TenantSerializer(BaseSerializerV1):
class TenantSerializer(BaseModelSerializerV1):
"""
Serializer for the Tenant model.
"""
@@ -597,7 +602,7 @@ class TenantSerializer(BaseSerializerV1):
fields = ["id", "name", "memberships"]
class TenantIncludeSerializer(BaseSerializerV1):
class TenantIncludeSerializer(BaseModelSerializerV1):
class Meta:
model = Tenant
fields = ["id", "name"]
@@ -773,7 +778,7 @@ class ProviderGroupUpdateSerializer(ProviderGroupSerializer):
return super().update(instance, validated_data)
class ProviderResourceIdentifierSerializer(serializers.Serializer):
class ProviderResourceIdentifierSerializer(BaseSerializerV1):
resource_type = serializers.CharField(source="type")
id = serializers.UUIDField()
@@ -1110,7 +1115,7 @@ class ScanTaskSerializer(RLSSerializer):
]
class ScanReportSerializer(serializers.Serializer):
class ScanReportSerializer(BaseSerializerV1):
id = serializers.CharField(source="scan")
class Meta:
@@ -1118,7 +1123,7 @@ class ScanReportSerializer(serializers.Serializer):
fields = ["id"]
class ScanComplianceReportSerializer(serializers.Serializer):
class ScanComplianceReportSerializer(BaseSerializerV1):
id = serializers.CharField(source="scan")
name = serializers.CharField()
@@ -1267,7 +1272,7 @@ class ResourceIncludeSerializer(RLSSerializer):
return fields
class ResourceMetadataSerializer(serializers.Serializer):
class ResourceMetadataSerializer(BaseSerializerV1):
services = serializers.ListField(child=serializers.CharField(), allow_empty=True)
regions = serializers.ListField(child=serializers.CharField(), allow_empty=True)
types = serializers.ListField(child=serializers.CharField(), allow_empty=True)
@@ -1337,7 +1342,7 @@ class FindingIncludeSerializer(RLSSerializer):
# To be removed when the related endpoint is removed as well
class FindingDynamicFilterSerializer(serializers.Serializer):
class FindingDynamicFilterSerializer(BaseSerializerV1):
services = serializers.ListField(child=serializers.CharField(), allow_empty=True)
regions = serializers.ListField(child=serializers.CharField(), allow_empty=True)
@@ -1345,7 +1350,7 @@ class FindingDynamicFilterSerializer(serializers.Serializer):
resource_name = "finding-dynamic-filters"
class FindingMetadataSerializer(serializers.Serializer):
class FindingMetadataSerializer(BaseSerializerV1):
services = serializers.ListField(child=serializers.CharField(), allow_empty=True)
regions = serializers.ListField(child=serializers.CharField(), allow_empty=True)
resource_types = serializers.ListField(
@@ -2039,7 +2044,7 @@ class RoleProviderGroupRelationshipSerializer(RLSSerializer, BaseWriteSerializer
# Compliance overview
class ComplianceOverviewSerializer(serializers.Serializer):
class ComplianceOverviewSerializer(BaseSerializerV1):
"""
Serializer for compliance requirement status aggregated by compliance framework.
@@ -2061,7 +2066,7 @@ class ComplianceOverviewSerializer(serializers.Serializer):
resource_name = "compliance-overviews"
class ComplianceOverviewDetailSerializer(serializers.Serializer):
class ComplianceOverviewDetailSerializer(BaseSerializerV1):
"""
Serializer for detailed compliance requirement information.
@@ -2090,7 +2095,7 @@ class ComplianceOverviewDetailThreatscoreSerializer(ComplianceOverviewDetailSeri
total_findings = serializers.IntegerField()
class ComplianceOverviewAttributesSerializer(serializers.Serializer):
class ComplianceOverviewAttributesSerializer(BaseSerializerV1):
id = serializers.CharField()
compliance_name = serializers.CharField()
framework_description = serializers.CharField()
@@ -2104,7 +2109,7 @@ class ComplianceOverviewAttributesSerializer(serializers.Serializer):
resource_name = "compliance-requirements-attributes"
class ComplianceOverviewMetadataSerializer(serializers.Serializer):
class ComplianceOverviewMetadataSerializer(BaseSerializerV1):
regions = serializers.ListField(child=serializers.CharField(), allow_empty=True)
class JSONAPIMeta:
@@ -2114,7 +2119,7 @@ class ComplianceOverviewMetadataSerializer(serializers.Serializer):
# Overviews
class OverviewProviderSerializer(serializers.Serializer):
class OverviewProviderSerializer(BaseSerializerV1):
id = serializers.CharField(source="provider")
findings = serializers.SerializerMethodField(read_only=True)
resources = serializers.SerializerMethodField(read_only=True)
@@ -2122,9 +2127,6 @@ class OverviewProviderSerializer(serializers.Serializer):
class JSONAPIMeta:
resource_name = "providers-overview"
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
@extend_schema_field(
{
"type": "object",
@@ -2158,18 +2160,15 @@ class OverviewProviderSerializer(serializers.Serializer):
}
class OverviewProviderCountSerializer(serializers.Serializer):
class OverviewProviderCountSerializer(BaseSerializerV1):
id = serializers.CharField(source="provider")
count = serializers.IntegerField()
class JSONAPIMeta:
resource_name = "providers-count-overview"
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
class OverviewFindingSerializer(serializers.Serializer):
class OverviewFindingSerializer(BaseSerializerV1):
id = serializers.CharField(default="n/a")
new = serializers.IntegerField()
changed = serializers.IntegerField()
@@ -2188,15 +2187,12 @@ class OverviewFindingSerializer(serializers.Serializer):
class JSONAPIMeta:
resource_name = "findings-overview"
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields["pass"] = self.fields.pop("_pass")
class OverviewSeveritySerializer(serializers.Serializer):
class OverviewSeveritySerializer(BaseSerializerV1):
id = serializers.CharField(default="n/a")
critical = serializers.IntegerField()
high = serializers.IntegerField()
@@ -2207,11 +2203,8 @@ class OverviewSeveritySerializer(serializers.Serializer):
class JSONAPIMeta:
resource_name = "findings-severity-overview"
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
class OverviewServiceSerializer(serializers.Serializer):
class OverviewServiceSerializer(BaseSerializerV1):
id = serializers.CharField(source="service")
total = serializers.IntegerField()
_pass = serializers.IntegerField()
@@ -2225,8 +2218,20 @@ class OverviewServiceSerializer(serializers.Serializer):
super().__init__(*args, **kwargs)
self.fields["pass"] = self.fields.pop("_pass")
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
class AttackSurfaceOverviewSerializer(BaseSerializerV1):
"""Serializer for attack surface overview aggregations."""
id = serializers.CharField(source="attack_surface_type")
total_findings = serializers.IntegerField()
failed_findings = serializers.IntegerField()
muted_failed_findings = serializers.IntegerField()
check_ids = serializers.ListField(
child=serializers.CharField(), allow_empty=True, default=list, read_only=True
)
class JSONAPIMeta:
resource_name = "attack-surface-overviews"
class OverviewRegionSerializer(serializers.Serializer):
@@ -2256,7 +2261,7 @@ class OverviewRegionSerializer(serializers.Serializer):
# Schedules
class ScheduleDailyCreateSerializer(serializers.Serializer):
class ScheduleDailyCreateSerializer(BaseSerializerV1):
provider_id = serializers.UUIDField(required=True)
class JSONAPIMeta:
@@ -2592,7 +2597,7 @@ class IntegrationUpdateSerializer(BaseWriteIntegrationSerializer):
return representation
class IntegrationJiraDispatchSerializer(serializers.Serializer):
class IntegrationJiraDispatchSerializer(BaseSerializerV1):
"""
Serializer for dispatching findings to JIRA integration.
"""
@@ -2755,14 +2760,14 @@ class ProcessorUpdateSerializer(BaseWriteSerializer):
# SSO
class SamlInitiateSerializer(serializers.Serializer):
class SamlInitiateSerializer(BaseSerializerV1):
email_domain = serializers.CharField()
class JSONAPIMeta:
resource_name = "saml-initiate"
class SamlMetadataSerializer(serializers.Serializer):
class SamlMetadataSerializer(BaseSerializerV1):
class JSONAPIMeta:
resource_name = "saml-meta"
+167 -17
View File
@@ -74,6 +74,7 @@ from rest_framework_json_api.views import RelationshipView, Response
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.jobs.scan import _get_attack_surface_mapping_from_provider
from tasks.tasks import (
backfill_scan_resource_summaries_task,
check_integration_connection_task,
@@ -97,6 +98,7 @@ from api.db_router import MainRouter
from api.db_utils import rls_transaction
from api.exceptions import TaskFailedException
from api.filters import (
AttackSurfaceOverviewFilter,
ComplianceOverviewFilter,
CustomDjangoFilterBackend,
FindingFilter,
@@ -125,6 +127,7 @@ from api.filters import (
UserFilter,
)
from api.models import (
AttackSurfaceOverview,
ComplianceRequirementOverview,
Finding,
Integration,
@@ -170,6 +173,7 @@ from api.utils import (
from api.uuid_utils import datetime_to_uuid7, uuid7_start
from api.v1.mixins import DisablePaginationMixin, PaginateByPkMixin, TaskManagementMixin
from api.v1.serializers import (
AttackSurfaceOverviewSerializer,
ComplianceOverviewAttributesSerializer,
ComplianceOverviewDetailSerializer,
ComplianceOverviewDetailThreatscoreSerializer,
@@ -348,7 +352,7 @@ class SchemaView(SpectacularAPIView):
def get(self, request, *args, **kwargs):
spectacular_settings.TITLE = "Prowler API"
spectacular_settings.VERSION = "1.15.0"
spectacular_settings.VERSION = "1.16.0"
spectacular_settings.DESCRIPTION = (
"Prowler API specification.\n\nThis file is auto-generated."
)
@@ -3889,6 +3893,37 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
),
filters=True,
),
attack_surface=extend_schema(
summary="Get attack surface overview",
description="Retrieve aggregated attack surface metrics from latest completed scans per provider.",
tags=["Overview"],
parameters=[
OpenApiParameter(
name="filter[provider_id]",
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Filter by specific provider ID",
),
OpenApiParameter(
name="filter[provider_id.in]",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by multiple provider IDs (comma-separated UUIDs)",
),
OpenApiParameter(
name="filter[provider_type]",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by provider type (aws, azure, gcp, etc.)",
),
OpenApiParameter(
name="filter[provider_type.in]",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by multiple provider types (comma-separated)",
),
],
),
)
@method_decorator(CACHE_DECORATOR, name="list")
class OverviewViewSet(BaseRLSViewSet):
@@ -3923,6 +3958,8 @@ class OverviewViewSet(BaseRLSViewSet):
return OverviewRegionSerializer
elif self.action == "threatscore":
return ThreatScoreSnapshotSerializer
elif self.action == "attack_surface":
return AttackSurfaceOverviewSerializer
return super().get_serializer_class()
def get_filterset_class(self):
@@ -3971,6 +4008,68 @@ class OverviewViewSet(BaseRLSViewSet):
tenant_id=tenant_id, scan_id__in=latest_scan_ids
)
def _normalize_jsonapi_params(self, query_params, exclude_keys=None):
"""Convert JSON:API filter params (filter[X]) to flat params (X)."""
exclude_keys = exclude_keys or set()
normalized = QueryDict(mutable=True)
for key, values in query_params.lists():
normalized_key = (
key[7:-1] if key.startswith("filter[") and key.endswith("]") else key
)
if normalized_key not in exclude_keys:
normalized.setlist(normalized_key, values)
return normalized
def _ensure_allowed_providers(self):
"""Populate allowed providers for RBAC-aware queries once per request."""
if getattr(self, "_providers_initialized", False):
return
self.get_queryset()
self._providers_initialized = True
def _get_provider_filter(self, provider_field="provider"):
self._ensure_allowed_providers()
if hasattr(self, "allowed_providers"):
return {f"{provider_field}__in": self.allowed_providers}
return {}
def _apply_provider_filter(self, queryset, provider_field="provider"):
provider_filter = self._get_provider_filter(provider_field)
if provider_filter:
return queryset.filter(**provider_filter)
return queryset
def _apply_filterset(self, queryset, filterset_class, exclude_keys=None):
normalized_params = self._normalize_jsonapi_params(
self.request.query_params, exclude_keys=set(exclude_keys or [])
)
filterset = filterset_class(normalized_params, queryset=queryset)
return filterset.qs
def _latest_scan_ids_for_allowed_providers(self, tenant_id):
provider_filter = self._get_provider_filter()
return (
Scan.all_objects.filter(
tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter
)
.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
.values_list("id", flat=True)
)
def _attack_surface_check_ids_by_provider_types(self, provider_types):
check_ids_by_type = {
attack_surface_type: set()
for attack_surface_type in AttackSurfaceOverview.AttackSurfaceTypeChoices.values
}
for provider_type in provider_types:
attack_surface_mapping = _get_attack_surface_mapping_from_provider(
provider_type=provider_type
)
for attack_surface_type, check_ids in attack_surface_mapping.items():
check_ids_by_type[attack_surface_type].update(check_ids)
return check_ids_by_type
@action(detail=False, methods=["get"], url_name="providers")
def providers(self, request):
tenant_id = self.request.tenant_id
@@ -4200,11 +4299,9 @@ class OverviewViewSet(BaseRLSViewSet):
snapshot_id = request.query_params.get("snapshot_id")
# Base queryset with RLS
base_queryset = ThreatScoreSnapshot.objects.filter(tenant_id=tenant_id)
# Apply RBAC filtering
if hasattr(self, "allowed_providers"):
base_queryset = base_queryset.filter(provider__in=self.allowed_providers)
base_queryset = self._apply_provider_filter(
ThreatScoreSnapshot.objects.filter(tenant_id=tenant_id)
)
# Case 1: Specific snapshot requested
if snapshot_id:
@@ -4220,17 +4317,9 @@ class OverviewViewSet(BaseRLSViewSet):
# Case 2: Latest snapshot per provider (default)
# Apply filters manually: this @action is outside the standard list endpoint flow,
# so DRF's filter backends don't execute and we must flatten JSON:API params ourselves.
normalized_params = QueryDict(mutable=True)
for param_key, values in request.query_params.lists():
normalized_key = param_key
if param_key.startswith("filter[") and param_key.endswith("]"):
normalized_key = param_key[7:-1]
if normalized_key == "snapshot_id":
continue
normalized_params.setlist(normalized_key, values)
filterset = ThreatScoreSnapshotFilter(normalized_params, queryset=base_queryset)
filtered_queryset = filterset.qs
filtered_queryset = self._apply_filterset(
base_queryset, ThreatScoreSnapshotFilter, exclude_keys={"snapshot_id"}
)
# Get distinct provider IDs from filtered queryset
# Pick the latest snapshot per provider using Postgres DISTINCT ON pattern.
@@ -4474,6 +4563,67 @@ class OverviewViewSet(BaseRLSViewSet):
return aggregated_snapshot
@action(
detail=False,
methods=["get"],
url_name="attack-surface",
url_path="attack-surfaces",
)
def attack_surface(self, request):
tenant_id = request.tenant_id
latest_scan_ids = self._latest_scan_ids_for_allowed_providers(tenant_id)
# Build base queryset and apply user filters via FilterSet
base_queryset = AttackSurfaceOverview.objects.filter(
tenant_id=tenant_id, scan_id__in=latest_scan_ids
)
filtered_queryset = self._apply_filterset(
base_queryset, AttackSurfaceOverviewFilter
)
provider_types = list(
filtered_queryset.values_list(
"scan__provider__provider", flat=True
).distinct()
)
attack_surface_check_ids = self._attack_surface_check_ids_by_provider_types(
provider_types
)
# Aggregate attack surface data
aggregation = filtered_queryset.values("attack_surface_type").annotate(
total_findings=Coalesce(Sum("total_findings"), 0),
failed_findings=Coalesce(Sum("failed_findings"), 0),
muted_failed_findings=Coalesce(Sum("muted_failed_findings"), 0),
)
results = {
attack_surface_type: {
"total_findings": 0,
"failed_findings": 0,
"muted_failed_findings": 0,
}
for attack_surface_type in AttackSurfaceOverview.AttackSurfaceTypeChoices.values
}
for item in aggregation:
results[item["attack_surface_type"]] = {
"total_findings": item["total_findings"],
"failed_findings": item["failed_findings"],
"muted_failed_findings": item["muted_failed_findings"],
}
response_data = [
{
"attack_surface_type": key,
**value,
"check_ids": attack_surface_check_ids.get(key, []),
}
for key, value in results.items()
]
return Response(
self.get_serializer(response_data, many=True).data,
status=status.HTTP_200_OK,
)
@extend_schema(tags=["Schedule"])
@extend_schema_view(
+16
View File
@@ -15,6 +15,7 @@ from tasks.jobs.backfill import backfill_resource_scan_summaries
from api.db_utils import rls_transaction
from api.models import (
AttackSurfaceOverview,
ComplianceOverview,
ComplianceRequirementOverview,
Finding,
@@ -1469,6 +1470,21 @@ def mute_rules_fixture(tenants_fixture, create_test_user, findings_fixture):
return mute_rule1, mute_rule2
@pytest.fixture
def create_attack_surface_overview():
def _create(tenant, scan, attack_surface_type, total=10, failed=5, muted_failed=2):
return AttackSurfaceOverview.objects.create(
tenant=tenant,
scan=scan,
attack_surface_type=attack_surface_type,
total_findings=total,
failed_findings=failed,
muted_failed_findings=muted_failed,
)
return _create
def get_authorization_header(access_token: str) -> dict:
return {"Authorization": f"Bearer {access_token}"}
+153 -1
View File
@@ -12,7 +12,7 @@ from celery.utils.log import get_task_logger
from config.env import env
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
from django.db import IntegrityError, OperationalError
from django.db.models import Case, Count, IntegerField, Prefetch, Sum, When
from django.db.models import Case, Count, IntegerField, Prefetch, Q, Sum, When
from tasks.utils import CustomEncoder
from api.compliance import PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE
@@ -26,6 +26,7 @@ from api.db_utils import (
)
from api.exceptions import ProviderConnectionError
from api.models import (
AttackSurfaceOverview,
ComplianceOverviewSummary,
ComplianceRequirementOverview,
Finding,
@@ -43,6 +44,7 @@ from api.models import (
from api.models import StatusChoices as FindingStatus
from api.utils import initialize_prowler_provider, return_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.check.models import CheckMetadata
from prowler.lib.outputs.finding import Finding as ProwlerFinding
from prowler.lib.scan.scan import Scan as ProwlerScan
@@ -75,6 +77,44 @@ FINDINGS_MICRO_BATCH_SIZE = env.int("DJANGO_FINDINGS_MICRO_BATCH_SIZE", default=
SCAN_DB_BATCH_SIZE = env.int("DJANGO_SCAN_DB_BATCH_SIZE", default=500)
ATTACK_SURFACE_PROVIDER_COMPATIBILITY = {
"internet-exposed": None, # Compatible with all providers
"secrets": None, # Compatible with all providers
"privilege-escalation": ["aws", "kubernetes"],
"ec2-imdsv1": ["aws"],
}
_ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {}
def _get_attack_surface_mapping_from_provider(provider_type: str) -> dict:
global _ATTACK_SURFACE_MAPPING_CACHE
if provider_type in _ATTACK_SURFACE_MAPPING_CACHE:
return _ATTACK_SURFACE_MAPPING_CACHE[provider_type]
attack_surface_check_mappings = {
"internet-exposed": None,
"secrets": None,
"privilege-escalation": {
"iam_policy_allows_privilege_escalation",
"iam_inline_policy_allows_privilege_escalation",
},
"ec2-imdsv1": {
"ec2_instance_imdsv2_enabled"
}, # AWS only - IMDSv1 enabled findings
}
for category_name, check_ids in attack_surface_check_mappings.items():
if check_ids is None:
sdk_check_ids = CheckMetadata.list(
provider=provider_type, category=category_name
)
attack_surface_check_mappings[category_name] = sdk_check_ids
_ATTACK_SURFACE_MAPPING_CACHE[provider_type] = attack_surface_check_mappings
return attack_surface_check_mappings
def _create_finding_delta(
last_status: FindingStatus | None | str, new_status: FindingStatus | None
) -> Finding.DeltaChoices:
@@ -1196,3 +1236,115 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
except Exception as e:
logger.error(f"Error creating compliance requirements for scan {scan_id}: {e}")
raise e
def aggregate_attack_surface(tenant_id: str, scan_id: str):
"""
Aggregate findings into attack surface overview records.
Creates one AttackSurfaceOverview record per attack surface type
for the given scan, based on check_id mappings.
Args:
tenant_id: Tenant that owns the scan.
scan_id: Scan UUID whose findings should be aggregated.
"""
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
scan_instance = Scan.all_objects.select_related("provider").get(pk=scan_id)
provider_type = scan_instance.provider.provider
provider_attack_surface_mapping = _get_attack_surface_mapping_from_provider(
provider_type=provider_type
)
# Filter out attack surfaces that are not compatible or have no resolved check IDs
supported_mappings: dict[str, list[str]] = {}
for attack_surface_type, check_ids in provider_attack_surface_mapping.items():
compatible_providers = ATTACK_SURFACE_PROVIDER_COMPATIBILITY.get(
attack_surface_type
)
if (
compatible_providers is not None
and provider_type not in compatible_providers
):
logger.info(
f"Skipping {attack_surface_type} - not supported for {provider_type}"
)
continue
if not check_ids:
logger.info(
f"Skipping {attack_surface_type} - no check IDs resolved for {provider_type}"
)
continue
supported_mappings[attack_surface_type] = list(check_ids)
if not supported_mappings:
logger.info(
f"No attack surface mappings available for scan {scan_id} and provider {provider_type}"
)
logger.info(f"No attack surface overview records created for scan {scan_id}")
return
# Map every check_id to its attack surface, so we can aggregate with a single query
check_id_to_surface: dict[str, str] = {}
for attack_surface_type, check_ids in supported_mappings.items():
for check_id in check_ids:
check_id_to_surface[check_id] = attack_surface_type
aggregated_counts = {
attack_surface_type: {"total": 0, "failed": 0, "muted": 0}
for attack_surface_type in supported_mappings.keys()
}
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
finding_stats = (
Finding.all_objects.filter(
tenant_id=tenant_id,
scan_id=scan_id,
check_id__in=list(check_id_to_surface.keys()),
)
.values("check_id")
.annotate(
total=Count("id"),
failed=Count("id", filter=Q(status="FAIL", muted=False)),
muted=Count("id", filter=Q(status="FAIL", muted=True)),
)
)
for stats in finding_stats:
attack_surface_type = check_id_to_surface.get(stats["check_id"])
if not attack_surface_type:
continue
aggregated_counts[attack_surface_type]["total"] += stats["total"] or 0
aggregated_counts[attack_surface_type]["failed"] += stats["failed"] or 0
aggregated_counts[attack_surface_type]["muted"] += stats["muted"] or 0
overview_objects = []
for attack_surface_type, counts in aggregated_counts.items():
total = counts["total"]
if not total:
continue
overview_objects.append(
AttackSurfaceOverview(
tenant_id=tenant_id,
scan_id=scan_id,
attack_surface_type=attack_surface_type,
total_findings=total,
failed_findings=counts["failed"],
muted_failed_findings=counts["muted"],
)
)
# Bulk create overview records
if overview_objects:
with rls_transaction(tenant_id):
AttackSurfaceOverview.objects.bulk_create(overview_objects, batch_size=500)
logger.info(
f"Created {len(overview_objects)} attack surface overview records for scan {scan_id}"
)
else:
logger.info(f"No attack surface overview records created for scan {scan_id}")
+19
View File
@@ -37,6 +37,7 @@ from tasks.jobs.lighthouse_providers import (
from tasks.jobs.muting import mute_historical_findings
from tasks.jobs.report import generate_compliance_reports_job
from tasks.jobs.scan import (
aggregate_attack_surface,
aggregate_findings,
create_compliance_requirements,
perform_prowler_scan,
@@ -69,6 +70,9 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str)
create_compliance_requirements_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
)
aggregate_attack_surface_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
)
chain(
perform_scan_summary_task.si(tenant_id=tenant_id, scan_id=scan_id),
generate_outputs_task.si(
@@ -529,6 +533,21 @@ def create_compliance_requirements_task(tenant_id: str, scan_id: str):
return create_compliance_requirements(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(name="scan-attack-surface-overviews", queue="overview")
def aggregate_attack_surface_task(tenant_id: str, scan_id: str):
"""
Creates attack surface overview records for a scan.
This task processes findings and aggregates them into attack surface categories
(internet-exposed, secrets, privilege-escalation, ec2-imdsv1) for quick overview queries.
Args:
tenant_id (str): The tenant ID for which to create records.
scan_id (str): The ID of the scan for which to create records.
"""
return aggregate_attack_surface(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(base=RLSTask, name="lighthouse-connection-check")
@set_tenant
def check_lighthouse_connection_task(lighthouse_config_id: str, tenant_id: str = None):
+282
View File
@@ -9,14 +9,17 @@ from unittest.mock import MagicMock, patch
import pytest
from tasks.jobs.scan import (
_ATTACK_SURFACE_MAPPING_CACHE,
_aggregate_findings_by_region,
_copy_compliance_requirement_rows,
_create_compliance_summaries,
_create_finding_delta,
_get_attack_surface_mapping_from_provider,
_normalized_compliance_key,
_persist_compliance_requirement_rows,
_process_finding_micro_batch,
_store_resources,
aggregate_attack_surface,
aggregate_findings,
create_compliance_requirements,
perform_prowler_scan,
@@ -3474,3 +3477,282 @@ class TestAggregateFindingsByRegion:
assert check_status_by_region == {}
assert findings_count_by_compliance == {}
@pytest.mark.django_db
class TestAggregateAttackSurface:
"""Test aggregate_attack_surface function and related caching."""
def setup_method(self):
"""Clear cache before each test."""
_ATTACK_SURFACE_MAPPING_CACHE.clear()
def teardown_method(self):
"""Clear cache after each test."""
_ATTACK_SURFACE_MAPPING_CACHE.clear()
@patch("tasks.jobs.scan.CheckMetadata.list")
def test_get_attack_surface_mapping_caches_result(self, mock_check_metadata_list):
"""Test that _get_attack_surface_mapping_from_provider caches results."""
mock_check_metadata_list.return_value = {"check_internet_exposed_1"}
# First call should hit CheckMetadata.list
result1 = _get_attack_surface_mapping_from_provider("aws")
assert mock_check_metadata_list.call_count == 2 # internet-exposed, secrets
# Second call should use cache
result2 = _get_attack_surface_mapping_from_provider("aws")
assert mock_check_metadata_list.call_count == 2 # No additional calls
assert result1 is result2
assert "aws" in _ATTACK_SURFACE_MAPPING_CACHE
@patch("tasks.jobs.scan.CheckMetadata.list")
def test_get_attack_surface_mapping_different_providers(
self, mock_check_metadata_list
):
"""Test caching works independently for different providers."""
mock_check_metadata_list.return_value = {"check_1"}
_get_attack_surface_mapping_from_provider("aws")
aws_call_count = mock_check_metadata_list.call_count
_get_attack_surface_mapping_from_provider("gcp")
gcp_call_count = mock_check_metadata_list.call_count
# Both providers should have made calls
assert gcp_call_count > aws_call_count
assert "aws" in _ATTACK_SURFACE_MAPPING_CACHE
assert "gcp" in _ATTACK_SURFACE_MAPPING_CACHE
@patch("tasks.jobs.scan.CheckMetadata.list")
def test_get_attack_surface_mapping_returns_hardcoded_checks(
self, mock_check_metadata_list
):
"""Test that hardcoded check IDs are returned for privilege-escalation and ec2-imdsv1."""
mock_check_metadata_list.return_value = set()
result = _get_attack_surface_mapping_from_provider("aws")
# Hardcoded checks should be present
assert (
"iam_policy_allows_privilege_escalation" in result["privilege-escalation"]
)
assert (
"iam_inline_policy_allows_privilege_escalation"
in result["privilege-escalation"]
)
assert "ec2_instance_imdsv2_enabled" in result["ec2-imdsv1"]
@patch("tasks.jobs.scan.AttackSurfaceOverview.objects.bulk_create")
@patch("tasks.jobs.scan.Finding.all_objects.filter")
@patch("tasks.jobs.scan._get_attack_surface_mapping_from_provider")
@patch("tasks.jobs.scan.rls_transaction")
def test_aggregate_attack_surface_creates_overview_records(
self,
mock_rls_transaction,
mock_get_mapping,
mock_findings_filter,
mock_bulk_create,
tenants_fixture,
scans_fixture,
):
"""Test that aggregate_attack_surface creates AttackSurfaceOverview records."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
scan.provider.provider = "aws"
scan.provider.save()
mock_get_mapping.return_value = {
"internet-exposed": {"check_internet_1", "check_internet_2"},
"secrets": {"check_secrets_1"},
"privilege-escalation": {"check_privesc_1"},
"ec2-imdsv1": {"check_imdsv1_1"},
}
# Mock findings aggregation
mock_queryset = MagicMock()
mock_queryset.values.return_value = mock_queryset
mock_queryset.annotate.return_value = [
{"check_id": "check_internet_1", "total": 10, "failed": 3, "muted": 1},
{"check_id": "check_secrets_1", "total": 5, "failed": 2, "muted": 0},
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_findings_filter.return_value = mock_queryset
aggregate_attack_surface(str(tenant.id), str(scan.id))
mock_bulk_create.assert_called_once()
args, kwargs = mock_bulk_create.call_args
objects = args[0]
# Should create records for internet-exposed and secrets (the ones with findings)
assert len(objects) == 2
assert kwargs["batch_size"] == 500
@patch("tasks.jobs.scan.AttackSurfaceOverview.objects.bulk_create")
@patch("tasks.jobs.scan.Finding.all_objects.filter")
@patch("tasks.jobs.scan._get_attack_surface_mapping_from_provider")
@patch("tasks.jobs.scan.rls_transaction")
def test_aggregate_attack_surface_skips_unsupported_provider(
self,
mock_rls_transaction,
mock_get_mapping,
mock_findings_filter,
mock_bulk_create,
tenants_fixture,
scans_fixture,
):
"""Test that ec2-imdsv1 is skipped for non-AWS providers."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
scan.provider.provider = "gcp"
scan.provider.uid = "gcp-test-project-id"
scan.provider.save()
mock_get_mapping.return_value = {
"internet-exposed": {"check_internet_1"},
"secrets": {"check_secrets_1"},
"privilege-escalation": set(), # Not supported for GCP
"ec2-imdsv1": {"check_imdsv1_1"}, # Should be skipped for GCP
}
mock_queryset = MagicMock()
mock_queryset.values.return_value = mock_queryset
mock_queryset.annotate.return_value = [
{"check_id": "check_internet_1", "total": 5, "failed": 1, "muted": 0},
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_findings_filter.return_value = mock_queryset
aggregate_attack_surface(str(tenant.id), str(scan.id))
# ec2-imdsv1 check_ids should not be in the filter
filter_call = mock_findings_filter.call_args
check_ids_in_filter = filter_call[1]["check_id__in"]
assert "check_imdsv1_1" not in check_ids_in_filter
@patch("tasks.jobs.scan.AttackSurfaceOverview.objects.bulk_create")
@patch("tasks.jobs.scan.Finding.all_objects.filter")
@patch("tasks.jobs.scan._get_attack_surface_mapping_from_provider")
@patch("tasks.jobs.scan.rls_transaction")
def test_aggregate_attack_surface_no_findings(
self,
mock_rls_transaction,
mock_get_mapping,
mock_findings_filter,
mock_bulk_create,
tenants_fixture,
scans_fixture,
):
"""Test that no records are created when there are no findings."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
mock_get_mapping.return_value = {
"internet-exposed": {"check_1"},
"secrets": {"check_2"},
"privilege-escalation": set(),
"ec2-imdsv1": set(),
}
mock_queryset = MagicMock()
mock_queryset.values.return_value = mock_queryset
mock_queryset.annotate.return_value = [] # No findings
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_findings_filter.return_value = mock_queryset
aggregate_attack_surface(str(tenant.id), str(scan.id))
mock_bulk_create.assert_not_called()
@patch("tasks.jobs.scan.AttackSurfaceOverview.objects.bulk_create")
@patch("tasks.jobs.scan.Finding.all_objects.filter")
@patch("tasks.jobs.scan._get_attack_surface_mapping_from_provider")
@patch("tasks.jobs.scan.rls_transaction")
def test_aggregate_attack_surface_aggregates_counts_correctly(
self,
mock_rls_transaction,
mock_get_mapping,
mock_findings_filter,
mock_bulk_create,
tenants_fixture,
scans_fixture,
):
"""Test that counts from multiple check_ids are aggregated per attack surface type."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
scan.provider.provider = "aws"
scan.provider.save()
mock_get_mapping.return_value = {
"internet-exposed": {"check_internet_1", "check_internet_2"},
"secrets": set(),
"privilege-escalation": set(),
"ec2-imdsv1": set(),
}
mock_queryset = MagicMock()
mock_queryset.values.return_value = mock_queryset
mock_queryset.annotate.return_value = [
{"check_id": "check_internet_1", "total": 10, "failed": 3, "muted": 1},
{"check_id": "check_internet_2", "total": 5, "failed": 2, "muted": 0},
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_findings_filter.return_value = mock_queryset
aggregate_attack_surface(str(tenant.id), str(scan.id))
args, kwargs = mock_bulk_create.call_args
objects = args[0]
assert len(objects) == 1
overview = objects[0]
assert overview.attack_surface_type == "internet-exposed"
assert overview.total_findings == 15 # 10 + 5
assert overview.failed_findings == 5 # 3 + 2
assert overview.muted_failed_findings == 1 # 1 + 0
@patch("tasks.jobs.scan.Scan.all_objects.select_related")
@patch("tasks.jobs.scan.rls_transaction")
def test_aggregate_attack_surface_uses_select_related(
self, mock_rls_transaction, mock_select_related, tenants_fixture, scans_fixture
):
"""Test that select_related is used to avoid N+1 query."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
mock_scan = MagicMock()
mock_scan.provider.provider = "aws"
mock_select_related.return_value.get.return_value = mock_scan
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
with patch(
"tasks.jobs.scan._get_attack_surface_mapping_from_provider"
) as mock_map:
mock_map.return_value = {}
aggregate_attack_surface(str(tenant.id), str(scan.id))
mock_select_related.assert_called_once_with("provider")
@@ -529,6 +529,7 @@ class TestGenerateOutputs:
class TestScanCompleteTasks:
@patch("tasks.tasks.aggregate_attack_surface_task.apply_async")
@patch("tasks.tasks.create_compliance_requirements_task.apply_async")
@patch("tasks.tasks.perform_scan_summary_task.si")
@patch("tasks.tasks.generate_outputs_task.si")
@@ -541,6 +542,7 @@ class TestScanCompleteTasks:
mock_outputs_task,
mock_scan_summary_task,
mock_compliance_requirements_task,
mock_attack_surface_task,
):
"""Test that scan complete tasks are properly orchestrated with optimized reports."""
_perform_scan_complete_tasks("tenant-id", "scan-id", "provider-id")
@@ -550,6 +552,11 @@ class TestScanCompleteTasks:
kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"},
)
# Verify attack surface task is called
mock_attack_surface_task.assert_called_once_with(
kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"},
)
# Verify scan summary task is called
mock_scan_summary_task.assert_called_once_with(
scan_id="scan-id",