mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-04-01 13:47:21 +00:00
Compare commits
5 Commits
chore/GHA-
...
PROWLER-69
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1202b35ef2 | ||
|
|
757b4cca8d | ||
|
|
8c61befdfa | ||
|
|
5548614024 | ||
|
|
4b0801faf9 |
@@ -8,6 +8,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- `/api/v1/overviews/compliance-watchlist` to retrieve the compliance watchlist [(#9596)](https://github.com/prowler-cloud/prowler/pull/9596)
|
||||
- Support AlibabaCloud provider [(#9485)](https://github.com/prowler-cloud/prowler/pull/9485)
|
||||
- `provider_id` and `provider_id__in` filter aliases for findings endpoints to enable consistent frontend parameter naming [(#9701)](https://github.com/prowler-cloud/prowler/pull/9701)
|
||||
- `status` field to Provider model to track connection state (`pending`, `checking`, `connected`, `error`) [(#9804)](https://github.com/prowler-cloud/prowler/pull/9804)
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -70,6 +70,25 @@ class ProviderDeletedException(Exception):
|
||||
"""Raised when a provider has been deleted during scan/task execution."""
|
||||
|
||||
|
||||
class ProviderNotAvailableError(APIException):
|
||||
"""Raised when attempting to perform actions on an unavailable provider."""
|
||||
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
default_detail = (
|
||||
"Cannot perform this action on an unavailable provider. "
|
||||
"The provider no longer exists in the cloud environment."
|
||||
)
|
||||
default_code = "provider_unavailable"
|
||||
|
||||
def __init__(self, detail=None):
|
||||
error_detail = {
|
||||
"detail": detail or self.default_detail,
|
||||
"status": str(self.status_code),
|
||||
"code": self.default_code,
|
||||
}
|
||||
super().__init__(detail=[error_detail])
|
||||
|
||||
|
||||
def custom_exception_handler(exc, context):
|
||||
if isinstance(exc, django_validation_error):
|
||||
if hasattr(exc, "error_dict"):
|
||||
|
||||
@@ -319,6 +319,11 @@ class ProviderFilter(FilterSet):
|
||||
choices=Provider.ProviderChoices.choices,
|
||||
lookup_expr="in",
|
||||
)
|
||||
available = BooleanFilter(
|
||||
help_text="""Filter by provider availability. Set to True to return only
|
||||
available providers, or False to return only unavailable providers
|
||||
(ephemeral accounts that no longer exist)."""
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = Provider
|
||||
@@ -329,6 +334,7 @@ class ProviderFilter(FilterSet):
|
||||
"alias": ["exact", "icontains", "in"],
|
||||
"inserted_at": ["gte", "lte"],
|
||||
"updated_at": ["gte", "lte"],
|
||||
"available": ["exact"],
|
||||
}
|
||||
filter_overrides = {
|
||||
ProviderEnumField: {
|
||||
|
||||
19
api/src/backend/api/migrations/0068_provider_available.py
Normal file
19
api/src/backend/api/migrations/0068_provider_available.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Migration to add available field to Provider model
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0067_tenant_compliance_summary"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="provider",
|
||||
name="available",
|
||||
field=models.BooleanField(
|
||||
default=True,
|
||||
help_text="Whether the provider account still exists. If False, connection checks are skipped.",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -48,7 +48,7 @@ from api.db_utils import (
|
||||
generate_random_token,
|
||||
one_week_from_now,
|
||||
)
|
||||
from api.exceptions import ModelValidationError
|
||||
from api.exceptions import ModelValidationError, ProviderNotAvailableError
|
||||
from api.rls import (
|
||||
BaseSecurityConstraint,
|
||||
RowLevelSecurityConstraint,
|
||||
@@ -419,6 +419,10 @@ class Provider(RowLevelSecurityProtectedModel):
|
||||
)
|
||||
connected = models.BooleanField(null=True, blank=True)
|
||||
connection_last_checked_at = models.DateTimeField(null=True, blank=True)
|
||||
available = models.BooleanField(
|
||||
default=True,
|
||||
help_text="Whether the provider account still exists. If False, connection checks are skipped.",
|
||||
)
|
||||
metadata = models.JSONField(default=dict, blank=True)
|
||||
scanner_args = models.JSONField(default=dict, blank=True)
|
||||
|
||||
@@ -430,6 +434,16 @@ class Provider(RowLevelSecurityProtectedModel):
|
||||
self.full_clean()
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
def check_available(self):
|
||||
"""
|
||||
Check if the provider is available.
|
||||
|
||||
Raises:
|
||||
ProviderNotAvailableError: If the provider is not available.
|
||||
"""
|
||||
if not self.available:
|
||||
raise ProviderNotAvailableError()
|
||||
|
||||
class Meta(RowLevelSecurityProtectedModel.Meta):
|
||||
db_table = "providers"
|
||||
|
||||
|
||||
@@ -6645,6 +6645,14 @@ paths:
|
||||
connections. If not specified, both connected and failed providers are
|
||||
included. Providers with no connection attempt (status is null) are
|
||||
excluded from this filter.
|
||||
- in: query
|
||||
name: filter[available]
|
||||
schema:
|
||||
type: boolean
|
||||
description: |-
|
||||
Filter by provider availability. Set to True to return only
|
||||
available providers, or False to return only unavailable providers
|
||||
(ephemeral accounts that no longer exist).
|
||||
- in: query
|
||||
name: filter[id]
|
||||
schema:
|
||||
@@ -16859,6 +16867,10 @@ components:
|
||||
nullable: true
|
||||
maxLength: 100
|
||||
minLength: 3
|
||||
available:
|
||||
type: boolean
|
||||
description: Whether the provider account still exists. If False, connection
|
||||
checks are skipped.
|
||||
connection:
|
||||
type: object
|
||||
properties:
|
||||
|
||||
@@ -1712,6 +1712,20 @@ class TestProviderViewSet:
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_providers_connection_unavailable_provider(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test that connection check returns 400 for unavailable provider."""
|
||||
provider1, *_ = providers_fixture
|
||||
provider1.available = False
|
||||
provider1.save()
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("provider-connection", kwargs={"pk": provider1.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert response.json()["errors"][0]["code"] == "provider_unavailable"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_name, filter_value, expected_count",
|
||||
(
|
||||
@@ -2887,6 +2901,34 @@ class TestScanViewSet:
|
||||
response.json()["errors"][0]["source"]["pointer"] == "/data/attributes/name"
|
||||
)
|
||||
|
||||
def test_scans_create_unavailable_provider(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test that creating a scan for an unavailable provider returns 400."""
|
||||
provider1, *_ = providers_fixture
|
||||
provider1.available = False
|
||||
provider1.save()
|
||||
|
||||
scan_json_payload = {
|
||||
"data": {
|
||||
"type": "scans",
|
||||
"attributes": {
|
||||
"name": "Test Scan",
|
||||
},
|
||||
"relationships": {
|
||||
"provider": {"data": {"type": "providers", "id": str(provider1.id)}}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("scan-list"),
|
||||
data=scan_json_payload,
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert response.json()["errors"][0]["code"] == "provider_unavailable"
|
||||
|
||||
def test_scans_partial_update(self, authenticated_client, scans_fixture):
|
||||
scan1, *_ = scans_fixture
|
||||
new_name = "Updated Scan Name"
|
||||
@@ -8166,6 +8208,23 @@ class TestScheduleViewSet:
|
||||
)
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
def test_schedule_daily_unavailable_provider(
|
||||
self, authenticated_client, providers_fixture
|
||||
):
|
||||
"""Test that creating a schedule for an unavailable provider returns 400."""
|
||||
provider, *_ = providers_fixture
|
||||
provider.available = False
|
||||
provider.save()
|
||||
|
||||
json_payload = {
|
||||
"provider_id": str(provider.id),
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse("schedule-daily"), data=json_payload, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert response.json()["errors"][0]["code"] == "provider_unavailable"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestIntegrationViewSet:
|
||||
|
||||
@@ -877,6 +877,7 @@ class ProviderSerializer(RLSSerializer):
|
||||
"provider",
|
||||
"uid",
|
||||
"alias",
|
||||
"available",
|
||||
"connection",
|
||||
# "scanner_args",
|
||||
"secret",
|
||||
@@ -917,6 +918,7 @@ class ProviderIncludeSerializer(RLSSerializer):
|
||||
"provider",
|
||||
"uid",
|
||||
"alias",
|
||||
"available",
|
||||
"connection",
|
||||
# "scanner_args",
|
||||
]
|
||||
@@ -2350,6 +2352,11 @@ class ScheduleDailyCreateSerializer(BaseSerializerV1):
|
||||
class JSONAPIMeta:
|
||||
resource_name = "daily-schedules"
|
||||
|
||||
def validate_provider_id(self, provider_id):
|
||||
if not Provider.objects.filter(pk=provider_id).exists():
|
||||
raise serializers.ValidationError("Provider not found.")
|
||||
return provider_id
|
||||
|
||||
# TODO: DRY this when we have more time
|
||||
def validate(self, data):
|
||||
if hasattr(self, "initial_data"):
|
||||
|
||||
@@ -1551,7 +1551,9 @@ class ProviderViewSet(DisablePaginationMixin, BaseRLSViewSet):
|
||||
)
|
||||
@action(detail=True, methods=["post"], url_name="connection")
|
||||
def connection(self, request, pk=None):
|
||||
get_object_or_404(Provider, pk=pk)
|
||||
provider = get_object_or_404(Provider, pk=pk)
|
||||
provider.check_available()
|
||||
|
||||
with transaction.atomic():
|
||||
task = check_provider_connection_task.delay(
|
||||
provider_id=pk, tenant_id=self.request.tenant_id
|
||||
@@ -2142,6 +2144,12 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
def create(self, request, *args, **kwargs):
|
||||
input_serializer = self.get_serializer(data=request.data)
|
||||
input_serializer.is_valid(raise_exception=True)
|
||||
|
||||
# Check provider availability before creating scan
|
||||
provider = input_serializer.validated_data.get("provider")
|
||||
if provider:
|
||||
provider.check_available()
|
||||
|
||||
with transaction.atomic():
|
||||
scan = input_serializer.save()
|
||||
with transaction.atomic():
|
||||
@@ -5130,6 +5138,8 @@ class ScheduleViewSet(BaseRLSViewSet):
|
||||
provider_id = serializer.validated_data["provider_id"]
|
||||
|
||||
provider_instance = get_object_or_404(Provider, pk=provider_id)
|
||||
provider_instance.check_available()
|
||||
|
||||
with transaction.atomic():
|
||||
task = schedule_provider_scan(provider_instance)
|
||||
|
||||
|
||||
@@ -29,16 +29,33 @@ def check_provider_connection(provider_id: str):
|
||||
Model.DoesNotExist: If the provider does not exist.
|
||||
"""
|
||||
provider_instance = Provider.objects.get(pk=provider_id)
|
||||
|
||||
# Skip connection check if provider is marked as unavailable
|
||||
if not provider_instance.available:
|
||||
logger.info(
|
||||
f"Skipping connection check for provider {provider_id}: marked as unavailable"
|
||||
)
|
||||
return {
|
||||
"connected": False,
|
||||
"error": "Provider is marked as unavailable",
|
||||
"skipped": True,
|
||||
}
|
||||
|
||||
try:
|
||||
connection_result = prowler_provider_connection_test(provider_instance)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unexpected exception checking {provider_instance.provider} provider connection: {str(e)}"
|
||||
)
|
||||
provider_instance.connected = False
|
||||
provider_instance.connection_last_checked_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
provider_instance.save()
|
||||
raise e
|
||||
|
||||
provider_instance.connected = connection_result.is_connected
|
||||
provider_instance.connection_last_checked_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
provider_instance.save()
|
||||
|
||||
connection_error = f"{connection_result.error}" if connection_result.error else None
|
||||
|
||||
@@ -763,6 +763,20 @@ def perform_prowler_scan(
|
||||
with rls_transaction(tenant_id):
|
||||
provider_instance = Provider.objects.get(pk=provider_id)
|
||||
scan_instance = Scan.objects.get(pk=scan_id)
|
||||
|
||||
# Skip scan if provider is marked as unavailable
|
||||
if not provider_instance.available:
|
||||
logger.info(
|
||||
f"Skipping scan for provider {provider_id}: marked as unavailable"
|
||||
)
|
||||
scan_instance.state = StateChoices.FAILED
|
||||
scan_instance.started_at = datetime.now(tz=timezone.utc)
|
||||
scan_instance.completed_at = datetime.now(tz=timezone.utc)
|
||||
scan_instance.save()
|
||||
raise ProviderConnectionError(
|
||||
f"Provider {provider_instance.provider} is marked as unavailable"
|
||||
)
|
||||
|
||||
scan_instance.state = StateChoices.EXECUTING
|
||||
scan_instance.started_at = datetime.now(tz=timezone.utc)
|
||||
scan_instance.save()
|
||||
|
||||
@@ -12,6 +12,21 @@ from tasks.jobs.connection import (
|
||||
from api.models import Integration, LighthouseConfiguration, Provider
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_provider_created_with_default_available(tenants_fixture):
|
||||
"""Test that a newly created provider has available=True by default."""
|
||||
provider = Provider.objects.create(
|
||||
provider="aws",
|
||||
uid="123456789012",
|
||||
alias="aws-test",
|
||||
tenant_id=tenants_fixture[0].id,
|
||||
)
|
||||
|
||||
assert provider.available is True
|
||||
assert provider.connected is None
|
||||
assert provider.connection_last_checked_at is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_data",
|
||||
[
|
||||
@@ -62,6 +77,7 @@ def test_check_provider_connection_exception(
|
||||
):
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_provider_instance.provider = Provider.ProviderChoices.AWS.value
|
||||
mock_provider_instance.available = True
|
||||
mock_provider_get.return_value = mock_provider_instance
|
||||
|
||||
mock_provider_connection_test.return_value = MagicMock()
|
||||
@@ -72,11 +88,33 @@ def test_check_provider_connection_exception(
|
||||
|
||||
assert result["connected"] is False
|
||||
assert result["error"] is not None
|
||||
|
||||
mock_provider_instance.save.assert_called_once()
|
||||
assert mock_provider_instance.save.call_count == 1
|
||||
assert mock_provider_instance.connected is False
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_check_provider_connection_skips_unavailable_provider(tenants_fixture):
|
||||
"""Test that connection check is skipped when provider is marked as unavailable."""
|
||||
provider = Provider.objects.create(
|
||||
provider="aws",
|
||||
uid="123456789012",
|
||||
alias="aws-test",
|
||||
available=False,
|
||||
tenant_id=tenants_fixture[0].id,
|
||||
)
|
||||
|
||||
result = check_provider_connection(provider_id=str(provider.id))
|
||||
|
||||
assert result["connected"] is False
|
||||
assert result["error"] == "Provider is marked as unavailable"
|
||||
assert result["skipped"] is True
|
||||
|
||||
provider.refresh_from_db()
|
||||
# Provider should not have been modified
|
||||
assert provider.connected is None
|
||||
assert provider.connection_last_checked_at is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lighthouse_data",
|
||||
[
|
||||
|
||||
@@ -260,6 +260,34 @@ class TestPerformScan:
|
||||
assert provider.connected is False
|
||||
assert isinstance(provider.connection_last_checked_at, datetime)
|
||||
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
def test_perform_prowler_scan_unavailable_provider(
|
||||
self,
|
||||
mock_rls_transaction,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""Test that scan fails immediately for unavailable provider."""
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
# Mark provider as unavailable
|
||||
provider.available = False
|
||||
provider.save()
|
||||
|
||||
tenant_id = str(tenant.id)
|
||||
scan_id = str(scan.id)
|
||||
provider_id = str(provider.id)
|
||||
checks_to_execute = ["check1", "check2"]
|
||||
|
||||
with pytest.raises(ProviderConnectionError, match="marked as unavailable"):
|
||||
perform_prowler_scan(tenant_id, scan_id, provider_id, checks_to_execute)
|
||||
|
||||
scan.refresh_from_db()
|
||||
assert scan.state == StateChoices.FAILED
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"last_status, new_status, expected_delta",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user