Compare commits

...

5 Commits

Author SHA1 Message Date
pedrooot
1202b35ef2 feat(api): add custom check available method under model 2026-01-19 14:22:38 +01:00
pedrooot
757b4cca8d feat(api): update available logic 2026-01-19 13:50:17 +01:00
pedrooot
8c61befdfa feat(api): add provider available field 2026-01-19 12:19:06 +01:00
pedrooot
5548614024 chore(changelog): update with latest changes 2026-01-15 17:51:40 +01:00
pedrooot
4b0801faf9 feat(api): add provider status 2026-01-15 17:43:18 +01:00
13 changed files with 248 additions and 4 deletions

View File

@@ -8,6 +8,7 @@ All notable changes to the **Prowler API** are documented in this file.
- `/api/v1/overviews/compliance-watchlist` to retrieve the compliance watchlist [(#9596)](https://github.com/prowler-cloud/prowler/pull/9596)
- Support AlibabaCloud provider [(#9485)](https://github.com/prowler-cloud/prowler/pull/9485)
- `provider_id` and `provider_id__in` filter aliases for findings endpoints to enable consistent frontend parameter naming [(#9701)](https://github.com/prowler-cloud/prowler/pull/9701)
- `status` field to Provider model to track connection state (`pending`, `checking`, `connected`, `error`) [(#9804)](https://github.com/prowler-cloud/prowler/pull/9804)
---

View File

@@ -70,6 +70,25 @@ class ProviderDeletedException(Exception):
"""Raised when a provider has been deleted during scan/task execution."""
class ProviderNotAvailableError(APIException):
"""Raised when attempting to perform actions on an unavailable provider."""
status_code = status.HTTP_400_BAD_REQUEST
default_detail = (
"Cannot perform this action on an unavailable provider. "
"The provider no longer exists in the cloud environment."
)
default_code = "provider_unavailable"
def __init__(self, detail=None):
error_detail = {
"detail": detail or self.default_detail,
"status": str(self.status_code),
"code": self.default_code,
}
super().__init__(detail=[error_detail])
def custom_exception_handler(exc, context):
if isinstance(exc, django_validation_error):
if hasattr(exc, "error_dict"):

View File

@@ -319,6 +319,11 @@ class ProviderFilter(FilterSet):
choices=Provider.ProviderChoices.choices,
lookup_expr="in",
)
available = BooleanFilter(
help_text="""Filter by provider availability. Set to True to return only
available providers, or False to return only unavailable providers
(ephemeral accounts that no longer exist)."""
)
class Meta:
model = Provider
@@ -329,6 +334,7 @@ class ProviderFilter(FilterSet):
"alias": ["exact", "icontains", "in"],
"inserted_at": ["gte", "lte"],
"updated_at": ["gte", "lte"],
"available": ["exact"],
}
filter_overrides = {
ProviderEnumField: {

View File

@@ -0,0 +1,19 @@
# Migration to add available field to Provider model
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0067_tenant_compliance_summary"),
]
operations = [
migrations.AddField(
model_name="provider",
name="available",
field=models.BooleanField(
default=True,
help_text="Whether the provider account still exists. If False, connection checks are skipped.",
),
),
]

View File

@@ -48,7 +48,7 @@ from api.db_utils import (
generate_random_token,
one_week_from_now,
)
from api.exceptions import ModelValidationError
from api.exceptions import ModelValidationError, ProviderNotAvailableError
from api.rls import (
BaseSecurityConstraint,
RowLevelSecurityConstraint,
@@ -419,6 +419,10 @@ class Provider(RowLevelSecurityProtectedModel):
)
connected = models.BooleanField(null=True, blank=True)
connection_last_checked_at = models.DateTimeField(null=True, blank=True)
available = models.BooleanField(
default=True,
help_text="Whether the provider account still exists. If False, connection checks are skipped.",
)
metadata = models.JSONField(default=dict, blank=True)
scanner_args = models.JSONField(default=dict, blank=True)
@@ -430,6 +434,16 @@ class Provider(RowLevelSecurityProtectedModel):
self.full_clean()
super().save(*args, **kwargs)
def check_available(self):
"""
Check if the provider is available.
Raises:
ProviderNotAvailableError: If the provider is not available.
"""
if not self.available:
raise ProviderNotAvailableError()
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "providers"

View File

@@ -6645,6 +6645,14 @@ paths:
connections. If not specified, both connected and failed providers are
included. Providers with no connection attempt (status is null) are
excluded from this filter.
- in: query
name: filter[available]
schema:
type: boolean
description: |-
Filter by provider availability. Set to True to return only
available providers, or False to return only unavailable providers
(ephemeral accounts that no longer exist).
- in: query
name: filter[id]
schema:
@@ -16859,6 +16867,10 @@ components:
nullable: true
maxLength: 100
minLength: 3
available:
type: boolean
description: Whether the provider account still exists. If False, connection
checks are skipped.
connection:
type: object
properties:

View File

@@ -1712,6 +1712,20 @@ class TestProviderViewSet:
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_providers_connection_unavailable_provider(
self, authenticated_client, providers_fixture
):
"""Test that connection check returns 400 for unavailable provider."""
provider1, *_ = providers_fixture
provider1.available = False
provider1.save()
response = authenticated_client.post(
reverse("provider-connection", kwargs={"pk": provider1.id})
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"][0]["code"] == "provider_unavailable"
@pytest.mark.parametrize(
"filter_name, filter_value, expected_count",
(
@@ -2887,6 +2901,34 @@ class TestScanViewSet:
response.json()["errors"][0]["source"]["pointer"] == "/data/attributes/name"
)
def test_scans_create_unavailable_provider(
self, authenticated_client, providers_fixture
):
"""Test that creating a scan for an unavailable provider returns 400."""
provider1, *_ = providers_fixture
provider1.available = False
provider1.save()
scan_json_payload = {
"data": {
"type": "scans",
"attributes": {
"name": "Test Scan",
},
"relationships": {
"provider": {"data": {"type": "providers", "id": str(provider1.id)}}
},
}
}
response = authenticated_client.post(
reverse("scan-list"),
data=scan_json_payload,
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"][0]["code"] == "provider_unavailable"
def test_scans_partial_update(self, authenticated_client, scans_fixture):
scan1, *_ = scans_fixture
new_name = "Updated Scan Name"
@@ -8166,6 +8208,23 @@ class TestScheduleViewSet:
)
assert response.status_code == status.HTTP_409_CONFLICT
def test_schedule_daily_unavailable_provider(
self, authenticated_client, providers_fixture
):
"""Test that creating a schedule for an unavailable provider returns 400."""
provider, *_ = providers_fixture
provider.available = False
provider.save()
json_payload = {
"provider_id": str(provider.id),
}
response = authenticated_client.post(
reverse("schedule-daily"), data=json_payload, format="json"
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"][0]["code"] == "provider_unavailable"
@pytest.mark.django_db
class TestIntegrationViewSet:

View File

@@ -877,6 +877,7 @@ class ProviderSerializer(RLSSerializer):
"provider",
"uid",
"alias",
"available",
"connection",
# "scanner_args",
"secret",
@@ -917,6 +918,7 @@ class ProviderIncludeSerializer(RLSSerializer):
"provider",
"uid",
"alias",
"available",
"connection",
# "scanner_args",
]
@@ -2350,6 +2352,11 @@ class ScheduleDailyCreateSerializer(BaseSerializerV1):
class JSONAPIMeta:
resource_name = "daily-schedules"
def validate_provider_id(self, provider_id):
if not Provider.objects.filter(pk=provider_id).exists():
raise serializers.ValidationError("Provider not found.")
return provider_id
# TODO: DRY this when we have more time
def validate(self, data):
if hasattr(self, "initial_data"):

View File

@@ -1551,7 +1551,9 @@ class ProviderViewSet(DisablePaginationMixin, BaseRLSViewSet):
)
@action(detail=True, methods=["post"], url_name="connection")
def connection(self, request, pk=None):
get_object_or_404(Provider, pk=pk)
provider = get_object_or_404(Provider, pk=pk)
provider.check_available()
with transaction.atomic():
task = check_provider_connection_task.delay(
provider_id=pk, tenant_id=self.request.tenant_id
@@ -2142,6 +2144,12 @@ class ScanViewSet(BaseRLSViewSet):
def create(self, request, *args, **kwargs):
input_serializer = self.get_serializer(data=request.data)
input_serializer.is_valid(raise_exception=True)
# Check provider availability before creating scan
provider = input_serializer.validated_data.get("provider")
if provider:
provider.check_available()
with transaction.atomic():
scan = input_serializer.save()
with transaction.atomic():
@@ -5130,6 +5138,8 @@ class ScheduleViewSet(BaseRLSViewSet):
provider_id = serializer.validated_data["provider_id"]
provider_instance = get_object_or_404(Provider, pk=provider_id)
provider_instance.check_available()
with transaction.atomic():
task = schedule_provider_scan(provider_instance)

View File

@@ -29,16 +29,33 @@ def check_provider_connection(provider_id: str):
Model.DoesNotExist: If the provider does not exist.
"""
provider_instance = Provider.objects.get(pk=provider_id)
# Skip connection check if provider is marked as unavailable
if not provider_instance.available:
logger.info(
f"Skipping connection check for provider {provider_id}: marked as unavailable"
)
return {
"connected": False,
"error": "Provider is marked as unavailable",
"skipped": True,
}
try:
connection_result = prowler_provider_connection_test(provider_instance)
except Exception as e:
logger.warning(
f"Unexpected exception checking {provider_instance.provider} provider connection: {str(e)}"
)
provider_instance.connected = False
provider_instance.connection_last_checked_at = datetime.now(tz=timezone.utc)
provider_instance.save()
raise e
provider_instance.connected = connection_result.is_connected
provider_instance.connection_last_checked_at = datetime.now(tz=timezone.utc)
provider_instance.save()
connection_error = f"{connection_result.error}" if connection_result.error else None

View File

@@ -763,6 +763,20 @@ def perform_prowler_scan(
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
scan_instance = Scan.objects.get(pk=scan_id)
# Skip scan if provider is marked as unavailable
if not provider_instance.available:
logger.info(
f"Skipping scan for provider {provider_id}: marked as unavailable"
)
scan_instance.state = StateChoices.FAILED
scan_instance.started_at = datetime.now(tz=timezone.utc)
scan_instance.completed_at = datetime.now(tz=timezone.utc)
scan_instance.save()
raise ProviderConnectionError(
f"Provider {provider_instance.provider} is marked as unavailable"
)
scan_instance.state = StateChoices.EXECUTING
scan_instance.started_at = datetime.now(tz=timezone.utc)
scan_instance.save()

View File

@@ -12,6 +12,21 @@ from tasks.jobs.connection import (
from api.models import Integration, LighthouseConfiguration, Provider
@pytest.mark.django_db
def test_provider_created_with_default_available(tenants_fixture):
"""Test that a newly created provider has available=True by default."""
provider = Provider.objects.create(
provider="aws",
uid="123456789012",
alias="aws-test",
tenant_id=tenants_fixture[0].id,
)
assert provider.available is True
assert provider.connected is None
assert provider.connection_last_checked_at is None
@pytest.mark.parametrize(
"provider_data",
[
@@ -62,6 +77,7 @@ def test_check_provider_connection_exception(
):
mock_provider_instance = MagicMock()
mock_provider_instance.provider = Provider.ProviderChoices.AWS.value
mock_provider_instance.available = True
mock_provider_get.return_value = mock_provider_instance
mock_provider_connection_test.return_value = MagicMock()
@@ -72,11 +88,33 @@ def test_check_provider_connection_exception(
assert result["connected"] is False
assert result["error"] is not None
mock_provider_instance.save.assert_called_once()
assert mock_provider_instance.save.call_count == 1
assert mock_provider_instance.connected is False
@pytest.mark.django_db
def test_check_provider_connection_skips_unavailable_provider(tenants_fixture):
"""Test that connection check is skipped when provider is marked as unavailable."""
provider = Provider.objects.create(
provider="aws",
uid="123456789012",
alias="aws-test",
available=False,
tenant_id=tenants_fixture[0].id,
)
result = check_provider_connection(provider_id=str(provider.id))
assert result["connected"] is False
assert result["error"] == "Provider is marked as unavailable"
assert result["skipped"] is True
provider.refresh_from_db()
# Provider should not have been modified
assert provider.connected is None
assert provider.connection_last_checked_at is None
@pytest.mark.parametrize(
"lighthouse_data",
[

View File

@@ -260,6 +260,34 @@ class TestPerformScan:
assert provider.connected is False
assert isinstance(provider.connection_last_checked_at, datetime)
@patch("tasks.jobs.scan.rls_transaction")
def test_perform_prowler_scan_unavailable_provider(
self,
mock_rls_transaction,
tenants_fixture,
scans_fixture,
providers_fixture,
):
"""Test that scan fails immediately for unavailable provider."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
# Mark provider as unavailable
provider.available = False
provider.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
provider_id = str(provider.id)
checks_to_execute = ["check1", "check2"]
with pytest.raises(ProviderConnectionError, match="marked as unavailable"):
perform_prowler_scan(tenant_id, scan_id, provider_id, checks_to_execute)
scan.refresh_from_db()
assert scan.state == StateChoices.FAILED
@pytest.mark.parametrize(
"last_status, new_status, expected_delta",
[