Compare commits

...

4 Commits

Author SHA1 Message Date
Andoni A. ad3b57d405 feat(image): add image provider API support with registry scan mode
- Add IMAGE provider type to API models, serializers, views, and utils
- Add registry credential handling (username/password, token, filters)
- Add UID validation for registry URLs with port validation
- Add connection testing via OCI registry adapter
- Add registry scan mode with OCI, Docker Hub, and ECR adapter layer
- Add per-image progress tracking and Trivy native auth
- Skip compliance/reports for IMAGE provider scans
- Add migration, OpenAPI spec updates, and comprehensive tests
2026-02-18 18:21:20 +01:00
Andoni A. f7adbc2dcf feat(image): add docker login and pull for private registry authentication
Trivy's remote source cannot authenticate against Docker Hub (and some
other registries) even after docker login. This adds a docker login +
docker pull flow before scanning so Trivy can access private images
from the local Docker daemon.

- Add _docker_login, _docker_pull, _docker_logout, cleanup methods
- Add _extract_registry to determine registry from image reference
- Wrap run() in try/finally to ensure cleanup on success or error
- Wire registry credentials from CLI args to ImageProvider
- Add ImageDockerLoginError and ImageDockerNotFoundError exceptions
2026-02-06 15:08:42 +01:00
Andoni A. 41e972e088 chore(image): remove POC mention from CHANGELOG and drop provider README 2026-02-06 14:04:18 +01:00
Andoni A. 31d9a23225 feat(image): add container image provider for CLI scanning
Add a new Image provider that uses Trivy for container image vulnerability
and secret scanning, integrated into the Prowler CLI.

- ImageProvider class with Trivy integration for vuln/secret/misconfig scanning
- CLI support via `prowler image -I <image>` with severity filters, timeout,
  ignore-unfixed, and image-list-file options
- CheckReportImage model for image-specific findings
- Custom exceptions (9000-9005) with clear remediation messages
- Error handling for Trivy failures (non-zero exit, binary not found)
- Batch processing of findings with progress bar
- test_connection() for registry accessibility checks
- Comprehensive test coverage
2026-02-06 13:59:48 +01:00
48 changed files with 5739 additions and 358 deletions
+8
View File
@@ -2,6 +2,14 @@
All notable changes to the **Prowler API** are documented in this file.
## [Unreleased]
### 🚀 Added
- Image provider support with registry credential handling, UID validation, and connection testing
---
## [1.19.0] (Prowler v5.18.0)
### 🚀 Added
+513 -193
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -24,7 +24,7 @@ dependencies = [
"drf-spectacular-jsonapi==0.5.1",
"gunicorn==23.0.0",
"lxml==5.3.2",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@master",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@feat/PROWLER-940-stage-2-a-image-provider-api-support",
"psycopg2-binary==2.9.9",
"pytest-celery[redis] (>=1.0.1,<2.0.0)",
"sentry-sdk[django] (>=2.20.0,<3.0.0)",
@@ -49,7 +49,7 @@ name = "prowler-api"
package-mode = false
# Needed for the SDK compatibility
requires-python = ">=3.11,<3.13"
version = "1.19.0"
version = "1.20.0"
[project.scripts]
celery = "src.backend.config.settings.celery"
@@ -0,0 +1,39 @@
# Generated by Django migration for Image provider support
from django.db import migrations
import api.db_utils
class Migration(migrations.Migration):
dependencies = [
("api", "0075_cloudflare_provider"),
]
operations = [
migrations.AlterField(
model_name="provider",
name="provider",
field=api.db_utils.ProviderEnumField(
choices=[
("aws", "AWS"),
("azure", "Azure"),
("gcp", "GCP"),
("kubernetes", "Kubernetes"),
("m365", "M365"),
("github", "GitHub"),
("mongodbatlas", "MongoDB Atlas"),
("iac", "IaC"),
("oraclecloud", "Oracle Cloud Infrastructure"),
("alibabacloud", "Alibaba Cloud"),
("cloudflare", "Cloudflare"),
("image", "Image"),
],
default="aws",
),
),
migrations.RunSQL(
"ALTER TYPE provider ADD VALUE IF NOT EXISTS 'image';",
reverse_sql=migrations.RunSQL.noop,
),
]
+23 -17
View File
@@ -288,6 +288,7 @@ class Provider(RowLevelSecurityProtectedModel):
ORACLECLOUD = "oraclecloud", _("Oracle Cloud Infrastructure")
ALIBABACLOUD = "alibabacloud", _("Alibaba Cloud")
CLOUDFLARE = "cloudflare", _("Cloudflare")
IMAGE = "image", _("Image")
@staticmethod
def validate_aws_uid(value):
@@ -410,6 +411,26 @@ class Provider(RowLevelSecurityProtectedModel):
pointer="/data/attributes/uid",
)
@staticmethod
def validate_image_uid(value):
pattern = r"^[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?(:\d{1,5})?(/[a-zA-Z0-9._-]+)*/?$"
if not re.match(pattern, value):
raise ModelValidationError(
detail="Image provider ID must be a valid registry URL "
"(e.g., docker.io, ghcr.io, registry.example.com:5000).",
code="image-uid",
pointer="/data/attributes/uid",
)
port_match = re.search(r":(\d{1,5})(?=/|$)", value)
if port_match:
port = int(port_match.group(1))
if not 1 <= port <= 65535:
raise ModelValidationError(
detail="Port number must be between 1 and 65535.",
code="image-uid",
pointer="/data/attributes/uid",
)
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
@@ -433,6 +454,8 @@ class Provider(RowLevelSecurityProtectedModel):
def clean(self):
super().clean()
if self.provider == self.ProviderChoices.IMAGE.value and self.uid:
self.uid = re.sub(r"^https?://", "", self.uid)
getattr(self, f"validate_{self.provider}_uid")(self.uid)
def save(self, *args, **kwargs):
@@ -681,8 +704,6 @@ class AttackPathsScan(RowLevelSecurityProtectedModel):
update_tag = models.BigIntegerField(
null=True, blank=True, help_text="Cartography update tag (epoch)"
)
graph_database = models.CharField(max_length=63, null=True, blank=True)
is_graph_database_deleted = models.BooleanField(default=False)
ingestion_exceptions = models.JSONField(default=dict, null=True, blank=True)
class Meta(RowLevelSecurityProtectedModel.Meta):
@@ -709,21 +730,6 @@ class AttackPathsScan(RowLevelSecurityProtectedModel):
fields=["tenant_id", "scan_id"],
name="aps_scan_lookup_idx",
),
models.Index(
fields=["tenant_id", "provider_id"],
name="aps_active_graph_idx",
include=["graph_database", "id"],
condition=Q(is_graph_database_deleted=False),
),
models.Index(
fields=["tenant_id", "provider_id", "-completed_at"],
name="aps_completed_graph_idx",
include=["graph_database", "id"],
condition=Q(
state=StateChoices.COMPLETED,
is_graph_database_deleted=False,
),
),
]
class JSONAPIMeta:
+43 -7
View File
@@ -1,7 +1,7 @@
openapi: 3.0.3
info:
title: Prowler API
version: 1.19.0
version: 1.20.0
description: |-
Prowler API specification.
@@ -616,7 +616,7 @@ paths:
operationId: attack_paths_scans_queries_retrieve
description: Retrieve the catalog of Attack Paths queries available for this
Attack Paths scan.
summary: List attack paths queries
summary: List Attack Paths queries
parameters:
- in: query
name: fields[attack-paths-scans]
@@ -714,7 +714,7 @@ paths:
description: Bad request (e.g., Unknown Attack Paths query for the selected
provider)
'404':
description: No attack paths found for the given query and parameters
description: No Attack Paths found for the given query and parameters
'500':
description: Attack Paths query execution failed due to a database error
/api/v1/compliance-overviews:
@@ -12438,6 +12438,8 @@ components:
type: string
name:
type: string
short_description:
type: string
description:
type: string
provider:
@@ -12446,12 +12448,42 @@ components:
type: array
items:
$ref: '#/components/schemas/AttackPathsQueryParameter'
attribution:
allOf:
- $ref: '#/components/schemas/AttackPathsQueryAttribution'
nullable: true
required:
- id
- name
- short_description
- description
- provider
- parameters
AttackPathsQueryAttribution:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-paths-query-attributions
id: {}
attributes:
type: object
properties:
text:
type: string
link:
type: string
required:
- text
- link
AttackPathsQueryParameter:
type: object
required:
@@ -17316,7 +17348,8 @@ components:
required:
- api_key
- api_email
writeOnly: true
- type: object
- type: object
required:
- secret
relationships:
@@ -19346,7 +19379,8 @@ components:
required:
- api_key
- api_email
writeOnly: true
- type: object
- type: object
required:
- secret_type
- secret
@@ -19732,7 +19766,8 @@ components:
required:
- api_key
- api_email
writeOnly: true
- type: object
- type: object
required:
- secret_type
- secret
@@ -20130,7 +20165,8 @@ components:
required:
- api_key
- api_email
writeOnly: true
- type: object
- type: object
required:
- secret
relationships:
+147
View File
@@ -24,6 +24,7 @@ from prowler.providers.cloudflare.cloudflare_provider import CloudflareProvider
from prowler.providers.gcp.gcp_provider import GcpProvider
from prowler.providers.github.github_provider import GithubProvider
from prowler.providers.iac.iac_provider import IacProvider
from prowler.providers.image.image_provider import ImageProvider
from prowler.providers.kubernetes.kubernetes_provider import KubernetesProvider
from prowler.providers.m365.m365_provider import M365Provider
from prowler.providers.mongodbatlas.mongodbatlas_provider import MongodbatlasProvider
@@ -120,6 +121,7 @@ class TestReturnProwlerProvider:
(Provider.ProviderChoices.IAC.value, IacProvider),
(Provider.ProviderChoices.ALIBABACLOUD.value, AlibabacloudProvider),
(Provider.ProviderChoices.CLOUDFLARE.value, CloudflareProvider),
(Provider.ProviderChoices.IMAGE.value, ImageProvider),
],
)
def test_return_prowler_provider(self, provider_type, expected_provider):
@@ -186,6 +188,54 @@ class TestProwlerProviderConnectionTest:
assert isinstance(connection.error, Provider.secret.RelatedObjectDoesNotExist)
assert str(connection.error) == "Provider has no secret."
@patch(
"prowler.providers.image.lib.registry.factory.create_registry_adapter",
)
@patch("api.utils.return_prowler_provider")
def test_prowler_provider_connection_test_image_success(
self, mock_return_prowler_provider, mock_create_adapter
):
provider = MagicMock()
provider.provider = Provider.ProviderChoices.IMAGE.value
provider.uid = "ghcr.io"
provider.secret.secret = {"registry_token": "tok"}
mock_adapter = MagicMock()
mock_create_adapter.return_value = mock_adapter
connection = prowler_provider_connection_test(provider)
assert connection.is_connected is True
assert connection.error is None
mock_create_adapter.assert_called_once_with(
registry_url="ghcr.io",
username=None,
password=None,
token="tok",
)
mock_adapter.list_repositories.assert_called_once()
@patch(
"prowler.providers.image.lib.registry.factory.create_registry_adapter",
)
@patch("api.utils.return_prowler_provider")
def test_prowler_provider_connection_test_image_failure(
self, mock_return_prowler_provider, mock_create_adapter
):
provider = MagicMock()
provider.provider = Provider.ProviderChoices.IMAGE.value
provider.uid = "ghcr.io"
provider.secret.secret = {"registry_token": "bad-token"}
mock_adapter = MagicMock()
mock_adapter.list_repositories.side_effect = Exception("401 Unauthorized")
mock_create_adapter.return_value = mock_adapter
connection = prowler_provider_connection_test(provider)
assert connection.is_connected is False
assert connection.error == "401 Unauthorized"
class TestGetProwlerProviderKwargs:
@pytest.mark.parametrize(
@@ -330,6 +380,103 @@ class TestGetProwlerProviderKwargs:
}
assert result == expected_result
def test_get_prowler_provider_kwargs_image_provider(self):
"""Test that Image provider gets correct kwargs with registry URL and auth."""
provider_uid = "ghcr.io"
secret_dict = {
"registry_username": "user",
"registry_password": "pass",
}
secret_mock = MagicMock()
secret_mock.secret = secret_dict
provider = MagicMock()
provider.provider = Provider.ProviderChoices.IMAGE.value
provider.secret = secret_mock
provider.uid = provider_uid
result = get_prowler_provider_kwargs(provider)
expected_result = {
"registry": provider_uid,
"registry_username": "user",
"registry_password": "pass",
}
assert result == expected_result
def test_get_prowler_provider_kwargs_image_provider_with_filters(self):
"""Test that Image provider includes scan filters."""
provider_uid = "docker.io"
secret_dict = {
"registry_token": "ghp_abc123",
"image_filter": "my-app.*",
"tag_filter": "v[0-9]+",
"max_images": 50,
}
secret_mock = MagicMock()
secret_mock.secret = secret_dict
provider = MagicMock()
provider.provider = Provider.ProviderChoices.IMAGE.value
provider.secret = secret_mock
provider.uid = provider_uid
result = get_prowler_provider_kwargs(provider)
expected_result = {
"registry": provider_uid,
"registry_token": "ghp_abc123",
"image_filter": "my-app.*",
"tag_filter": "v[0-9]+",
"max_images": 50,
}
assert result == expected_result
def test_get_prowler_provider_kwargs_image_provider_no_auth(self):
"""Test that Image provider works with empty secret for public registries."""
provider_uid = "docker.io"
secret_dict = {}
secret_mock = MagicMock()
secret_mock.secret = secret_dict
provider = MagicMock()
provider.provider = Provider.ProviderChoices.IMAGE.value
provider.secret = secret_mock
provider.uid = provider_uid
result = get_prowler_provider_kwargs(provider)
expected_result = {"registry": provider_uid}
assert result == expected_result
def test_get_prowler_provider_kwargs_image_provider_ignores_mutelist(self):
"""Test that Image provider does NOT receive mutelist_content.
Image provider uses Trivy's built-in logic, so it should not
receive mutelist_content even when a mutelist processor is configured.
"""
provider_uid = "ghcr.io"
secret_dict = {"registry_token": "test_token"}
secret_mock = MagicMock()
secret_mock.secret = secret_dict
mutelist_processor = MagicMock()
mutelist_processor.configuration = {"Mutelist": {"key": "value"}}
provider = MagicMock()
provider.provider = Provider.ProviderChoices.IMAGE.value
provider.secret = secret_mock
provider.uid = provider_uid
result = get_prowler_provider_kwargs(provider, mutelist_processor)
assert "mutelist_content" not in result
expected_result = {
"registry": provider_uid,
"registry_token": "test_token",
}
assert result == expected_result
def test_get_prowler_provider_kwargs_unsupported_provider(self):
# Setup
provider_uid = "provider_uid"
+305 -69
View File
@@ -1179,6 +1179,26 @@ class TestProviderViewSet:
"uid": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4",
"alias": "Cloudflare Account",
},
{
"provider": "image",
"uid": "ghcr.io",
"alias": "GitHub Registry",
},
{
"provider": "image",
"uid": "docker.io",
"alias": "Docker Hub",
},
{
"provider": "image",
"uid": "123456789012.dkr.ecr.us-east-1.amazonaws.com",
"alias": "ECR",
},
{
"provider": "image",
"uid": "https://registry.example.com:5000",
"alias": "Private",
},
]
),
)
@@ -1598,6 +1618,26 @@ class TestProviderViewSet:
"cloudflare-uid",
"uid",
),
# Image UID validation - too short (below min_length)
(
{
"provider": "image",
"uid": "ab",
"alias": "test",
},
"min_length",
"uid",
),
# Image UID validation - invalid characters (space)
(
{
"provider": "image",
"uid": "not valid!",
"alias": "test",
},
"image-uid",
"uid",
),
]
),
)
@@ -1753,6 +1793,38 @@ class TestProviderViewSet:
assert "Content-Location" in response.headers
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
@patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.check_provider_connection_task.delay")
def test_providers_connection_image(
self,
mock_provider_connection,
mock_task_get,
authenticated_client,
providers_fixture,
tasks_fixture,
):
prowler_task = tasks_fixture[0]
task_mock = Mock()
task_mock.id = prowler_task.id
task_mock.status = "PENDING"
mock_provider_connection.return_value = task_mock
mock_task_get.return_value = prowler_task
image_provider = providers_fixture[10]
assert image_provider.provider == "image"
assert image_provider.connected is None
assert image_provider.connection_last_checked_at is None
response = authenticated_client.post(
reverse("provider-connection", kwargs={"pk": image_provider.id})
)
assert response.status_code == status.HTTP_202_ACCEPTED
mock_provider_connection.assert_called_once_with(
provider_id=str(image_provider.id), tenant_id=ANY
)
assert "Content-Location" in response.headers
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
def test_providers_connection_invalid_provider(
self, authenticated_client, providers_fixture
):
@@ -1775,17 +1847,17 @@ class TestProviderViewSet:
),
("alias", "aws_testing_1", 1),
("alias.icontains", "aws", 2),
("inserted_at", TODAY, 10),
("inserted_at", TODAY, 11),
(
"inserted_at.gte",
"2024-01-01",
10,
11,
),
("inserted_at.lte", "2024-01-01", 0),
(
"updated_at.gte",
"2024-01-01",
10,
11,
),
("updated_at.lte", "2024-01-01", 0),
]
@@ -2392,6 +2464,39 @@ class TestProviderSecretViewSet:
"api_email": "user@example.com",
},
),
# Image with Docker login
(
Provider.ProviderChoices.IMAGE.value,
ProviderSecret.TypeChoices.STATIC,
{
"registry_username": "user",
"registry_password": "pass",
},
),
# Image with token
(
Provider.ProviderChoices.IMAGE.value,
ProviderSecret.TypeChoices.STATIC,
{
"registry_token": "ghp_abc123",
},
),
# Image with no auth + filters
(
Provider.ProviderChoices.IMAGE.value,
ProviderSecret.TypeChoices.STATIC,
{
"image_filter": "my-app.*",
"tag_filter": "v[0-9]+",
"max_images": 50,
},
),
# Image with empty secret (public registry)
(
Provider.ProviderChoices.IMAGE.value,
ProviderSecret.TypeChoices.STATIC,
{},
),
],
)
def test_provider_secrets_create_valid(
@@ -3830,6 +3935,7 @@ class TestAttackPathsScanViewSet:
AttackPathsQueryDefinition(
id="aws-rds",
name="RDS inventory",
short_description="List account RDS assets.",
description="List account RDS assets",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
@@ -3887,11 +3993,11 @@ class TestAttackPathsScanViewSet:
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_database="tenant-db",
)
query_definition = AttackPathsQueryDefinition(
id="aws-rds",
name="RDS inventory",
short_description="List account RDS assets.",
description="List account RDS assets",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
@@ -3917,10 +4023,16 @@ class TestAttackPathsScanViewSet:
],
}
expected_db_name = f"db-tenant-{attack_paths_scan.provider.tenant_id}"
with (
patch(
"api.v1.views.get_query_by_id", return_value=query_definition
) as mock_get_query,
patch(
"api.v1.views.graph_database.get_database_name",
return_value=expected_db_name,
) as mock_get_db_name,
patch(
"api.v1.views.attack_paths_views_helpers.prepare_query_parameters",
return_value=prepared_parameters,
@@ -3942,17 +4054,18 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_200_OK
mock_get_query.assert_called_once_with("aws-rds")
mock_get_db_name.assert_called_once_with(attack_paths_scan.provider.tenant_id)
mock_prepare.assert_called_once_with(
query_definition,
{},
attack_paths_scan.provider.uid,
)
mock_execute.assert_called_once_with(
attack_paths_scan,
expected_db_name,
query_definition,
prepared_parameters,
)
mock_clear_cache.assert_called_once_with(attack_paths_scan.graph_database)
mock_clear_cache.assert_called_once_with(expected_db_name)
result = response.json()["data"]
attributes = result["attributes"]
assert attributes["nodes"] == graph_payload["nodes"]
@@ -3983,31 +4096,6 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "must be completed" in response.json()["errors"][0]["detail"]
def test_run_attack_paths_query_requires_graph_database(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_database=None,
)
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run", kwargs={"pk": attack_paths_scan.id}
),
data=self._run_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "does not reference a graph database" in str(response.json())
def test_run_attack_paths_query_unknown_query(
self,
authenticated_client,
@@ -4049,6 +4137,7 @@ class TestAttackPathsScanViewSet:
query_definition = AttackPathsQueryDefinition(
id="aws-empty",
name="empty",
short_description="",
description="",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
@@ -10841,25 +10930,20 @@ class TestTenantFinishACSView:
assert "sso_saml_failed=true" in response.url
def test_dispatch_skips_role_mapping_when_single_manage_account_user(
self, create_test_user, tenants_fixture, saml_setup, settings, monkeypatch
self,
create_test_user,
tenants_fixture,
admin_role_fixture,
saml_setup,
settings,
monkeypatch,
):
"""Test that role mapping is skipped when tenant has only one user with MANAGE_ACCOUNT role"""
monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete")
user = create_test_user
tenant = tenants_fixture[0]
# Create a single role with manage_account=True for the user
admin_role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant=tenant,
manage_account=True,
manage_users=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
admin_role = admin_role_fixture
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user, role=admin_role, tenant_id=tenant.id
)
@@ -10930,35 +11014,26 @@ class TestTenantFinishACSView:
.exists()
)
def test_dispatch_applies_role_mapping_when_multiple_manage_account_users(
self, create_test_user, tenants_fixture, saml_setup, settings, monkeypatch
def test_dispatch_skips_role_mapping_when_last_manage_account_user_maps_to_existing_role(
self,
create_test_user,
tenants_fixture,
admin_role_fixture,
roles_fixture,
saml_setup,
settings,
monkeypatch,
):
"""Test that role mapping is applied when tenant has multiple users with MANAGE_ACCOUNT role"""
"""Test that role mapping is skipped when it would remove the last MANAGE_ACCOUNT user"""
monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete")
user = create_test_user
tenant = tenants_fixture[0]
# Create a second user with manage_account=True
second_admin = User.objects.using(MainRouter.admin_db).create(
email="admin2@prowler.com", name="Second Admin"
)
admin_role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant=tenant,
manage_account=True,
manage_users=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
admin_role = admin_role_fixture
viewer_role = roles_fixture[3]
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user, role=admin_role, tenant_id=tenant.id
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=second_admin, role=admin_role, tenant_id=tenant.id
)
social_account = SocialAccount(
user=user,
@@ -10967,7 +11042,7 @@ class TestTenantFinishACSView:
"firstName": ["John"],
"lastName": ["Doe"],
"organization": ["testing_company"],
"userType": ["viewer"], # This SHOULD be applied
"userType": [viewer_role.name],
},
)
@@ -11005,10 +11080,91 @@ class TestTenantFinishACSView:
assert response.status_code == 302
# Verify the viewer role was created and assigned (role mapping was applied)
viewer_role = Role.objects.using(MainRouter.admin_db).get(
name="viewer", tenant=tenant
assert (
UserRoleRelationship.objects.using(MainRouter.admin_db)
.filter(user=user, role=admin_role, tenant_id=tenant.id)
.exists()
)
assert not (
UserRoleRelationship.objects.using(MainRouter.admin_db)
.filter(user=user, role=viewer_role, tenant_id=tenant.id)
.exists()
)
def test_dispatch_applies_role_mapping_when_multiple_manage_account_users(
self,
create_test_user,
tenants_fixture,
admin_role_fixture,
roles_fixture,
saml_setup,
settings,
monkeypatch,
):
"""Test that role mapping is applied when tenant has multiple users with MANAGE_ACCOUNT role"""
monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete")
user = create_test_user
tenant = tenants_fixture[0]
# Create a second user with manage_account=True
second_admin = User.objects.using(MainRouter.admin_db).create(
email="admin2@prowler.com", name="Second Admin"
)
admin_role = admin_role_fixture
viewer_role = roles_fixture[3]
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user, role=admin_role, tenant_id=tenant.id
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=second_admin, role=admin_role, tenant_id=tenant.id
)
social_account = SocialAccount(
user=user,
provider="saml",
extra_data={
"firstName": ["John"],
"lastName": ["Doe"],
"organization": ["testing_company"],
"userType": [viewer_role.name], # This SHOULD be applied
},
)
request = RequestFactory().get(
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
)
request.user = user
request.session = {}
with (
patch(
"allauth.socialaccount.providers.saml.views.get_app_or_404"
) as mock_get_app_or_404,
patch(
"allauth.socialaccount.models.SocialApp.objects.get"
) as mock_socialapp_get,
patch(
"allauth.socialaccount.models.SocialAccount.objects.get"
) as mock_sa_get,
patch("api.models.SAMLDomainIndex.objects.get") as mock_saml_domain_get,
patch("api.models.SAMLConfiguration.objects.get") as mock_saml_config_get,
patch("api.models.User.objects.get") as mock_user_get,
):
mock_get_app_or_404.return_value = MagicMock(
provider="saml", client_id="testtenant", name="Test App", settings={}
)
mock_sa_get.return_value = social_account
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id)
mock_saml_config_get.return_value = MagicMock()
mock_user_get.return_value = user
view = TenantFinishACSView.as_view()
response = view(request, organization_slug="testtenant")
assert response.status_code == 302
# Verify the viewer role was assigned (role mapping was applied)
assert (
UserRoleRelationship.objects.using(MainRouter.admin_db)
.filter(user=user, role=viewer_role, tenant_id=tenant.id)
@@ -11022,6 +11178,86 @@ class TestTenantFinishACSView:
.exists()
)
def test_dispatch_applies_role_mapping_for_non_admin_user_with_single_admin(
self,
create_test_user,
tenants_fixture,
admin_role_fixture,
roles_fixture,
saml_setup,
settings,
monkeypatch,
):
"""Test that role mapping is applied for a non-admin user when a single admin exists"""
monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete")
admin_user = create_test_user
tenant = tenants_fixture[0]
non_admin_user = User.objects.using(MainRouter.admin_db).create(
email="viewer@prowler.com", name="Viewer"
)
admin_role = admin_role_fixture
viewer_role = roles_fixture[3]
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=admin_user, role=admin_role, tenant_id=tenant.id
)
social_account = SocialAccount(
user=non_admin_user,
provider="saml",
extra_data={
"firstName": ["Jane"],
"lastName": ["Doe"],
"organization": ["testing_company"],
"userType": [viewer_role.name],
},
)
request = RequestFactory().get(
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
)
request.user = non_admin_user
request.session = {}
with (
patch(
"allauth.socialaccount.providers.saml.views.get_app_or_404"
) as mock_get_app_or_404,
patch(
"allauth.socialaccount.models.SocialApp.objects.get"
) as mock_socialapp_get,
patch(
"allauth.socialaccount.models.SocialAccount.objects.get"
) as mock_sa_get,
patch("api.models.SAMLDomainIndex.objects.get") as mock_saml_domain_get,
patch("api.models.SAMLConfiguration.objects.get") as mock_saml_config_get,
patch("api.models.User.objects.get") as mock_user_get,
):
mock_get_app_or_404.return_value = MagicMock(
provider="saml", client_id="testtenant", name="Test App", settings={}
)
mock_sa_get.return_value = social_account
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id)
mock_saml_config_get.return_value = MagicMock()
mock_user_get.return_value = non_admin_user
view = TenantFinishACSView.as_view()
response = view(request, organization_slug="testtenant")
assert response.status_code == 302
assert (
UserRoleRelationship.objects.using(MainRouter.admin_db)
.filter(user=non_admin_user, role=viewer_role, tenant_id=tenant.id)
.exists()
)
assert (
UserRoleRelationship.objects.using(MainRouter.admin_db)
.filter(user=admin_user, role=admin_role, tenant_id=tenant.id)
.exists()
)
@pytest.mark.django_db
class TestLighthouseConfigViewSet:
+45 -4
View File
@@ -28,6 +28,7 @@ if TYPE_CHECKING:
from prowler.providers.gcp.gcp_provider import GcpProvider
from prowler.providers.github.github_provider import GithubProvider
from prowler.providers.iac.iac_provider import IacProvider
from prowler.providers.image.image_provider import ImageProvider
from prowler.providers.kubernetes.kubernetes_provider import KubernetesProvider
from prowler.providers.m365.m365_provider import M365Provider
from prowler.providers.mongodbatlas.mongodbatlas_provider import (
@@ -78,9 +79,11 @@ def return_prowler_provider(
AlibabacloudProvider
| AwsProvider
| AzureProvider
| CloudflareProvider
| GcpProvider
| GithubProvider
| IacProvider
| ImageProvider
| KubernetesProvider
| M365Provider
| MongodbatlasProvider
@@ -92,7 +95,7 @@ def return_prowler_provider(
provider (Provider): The provider object containing the provider type and associated secrets.
Returns:
AlibabacloudProvider | AwsProvider | AzureProvider | CloudflareProvider | GcpProvider | GithubProvider | IacProvider | KubernetesProvider | M365Provider | MongodbatlasProvider | OraclecloudProvider: The corresponding provider class.
AlibabacloudProvider | AwsProvider | AzureProvider | CloudflareProvider | GcpProvider | GithubProvider | IacProvider | ImageProvider | KubernetesProvider | M365Provider | MongodbatlasProvider | OraclecloudProvider: The corresponding provider class.
Raises:
ValueError: If the provider type specified in `provider.provider` is not supported.
@@ -152,6 +155,10 @@ def return_prowler_provider(
)
prowler_provider = CloudflareProvider
case Provider.ProviderChoices.IMAGE.value:
from prowler.providers.image.image_provider import ImageProvider
prowler_provider = ImageProvider
case _:
raise ValueError(f"Provider type {provider.provider} not supported")
return prowler_provider
@@ -208,11 +215,30 @@ def get_prowler_provider_kwargs(
**prowler_provider_kwargs,
"filter_accounts": [provider.uid],
}
elif provider.provider == Provider.ProviderChoices.IMAGE.value:
prowler_provider_kwargs = {
"registry": provider.uid,
}
secret = provider.secret.secret
for key in (
"registry_username",
"registry_password",
"registry_token",
"image_filter",
"tag_filter",
):
if key in secret:
prowler_provider_kwargs[key] = secret[key]
if "max_images" in secret:
prowler_provider_kwargs["max_images"] = int(secret["max_images"])
if mutelist_processor:
mutelist_content = mutelist_processor.configuration.get("Mutelist", {})
# IaC provider doesn't support mutelist (uses Trivy's built-in logic)
if mutelist_content and provider.provider != Provider.ProviderChoices.IAC.value:
# IaC and Image providers don't support mutelist (Trivy handles its own logic)
if mutelist_content and provider.provider not in (
Provider.ProviderChoices.IAC.value,
Provider.ProviderChoices.IMAGE.value,
):
prowler_provider_kwargs["mutelist_content"] = mutelist_content
return prowler_provider_kwargs
@@ -229,6 +255,7 @@ def initialize_prowler_provider(
| GcpProvider
| GithubProvider
| IacProvider
| ImageProvider
| KubernetesProvider
| M365Provider
| MongodbatlasProvider
@@ -241,7 +268,7 @@ def initialize_prowler_provider(
mutelist_processor (Processor): The mutelist processor object containing the mutelist configuration.
Returns:
AlibabacloudProvider | AwsProvider | AzureProvider | CloudflareProvider | GcpProvider | GithubProvider | IacProvider | KubernetesProvider | M365Provider | MongodbatlasProvider | OraclecloudProvider: An instance of the corresponding provider class
AlibabacloudProvider | AwsProvider | AzureProvider | CloudflareProvider | GcpProvider | GithubProvider | IacProvider | ImageProvider | KubernetesProvider | M365Provider | MongodbatlasProvider | OraclecloudProvider: An instance of the corresponding provider class
initialized with the provider's secrets.
"""
prowler_provider = return_prowler_provider(provider)
@@ -276,6 +303,20 @@ def prowler_provider_connection_test(provider: Provider) -> Connection:
if "access_token" in prowler_provider_kwargs:
iac_test_kwargs["access_token"] = prowler_provider_kwargs["access_token"]
return prowler_provider.test_connection(**iac_test_kwargs)
elif provider.provider == Provider.ProviderChoices.IMAGE.value:
from prowler.providers.image.lib.registry.factory import create_registry_adapter
try:
adapter = create_registry_adapter(
registry_url=provider.uid,
username=prowler_provider_kwargs.get("registry_username"),
password=prowler_provider_kwargs.get("registry_password"),
token=prowler_provider_kwargs.get("registry_token"),
)
adapter.list_repositories()
return Connection(is_connected=True, error=None)
except Exception as e:
return Connection(is_connected=False, error=str(e))
else:
return prowler_provider.test_connection(
**prowler_provider_kwargs,
@@ -373,6 +373,37 @@ from rest_framework_json_api import serializers
},
"required": ["api_key", "api_email"],
},
{
"type": "object",
"title": "Image Registry Credentials",
"properties": {
"registry_username": {
"type": "string",
"description": "Username for Docker login authentication.",
},
"registry_password": {
"type": "string",
"description": "Password for Docker login authentication.",
},
"registry_token": {
"type": "string",
"description": "Bearer token for registry authentication.",
},
"image_filter": {
"type": "string",
"description": "Regex pattern to filter repository names during registry enumeration.",
},
"tag_filter": {
"type": "string",
"description": "Regex pattern to filter image tags during registry enumeration.",
},
"max_images": {
"type": "integer",
"minimum": 0,
"description": "Maximum number of images to scan (0 = unlimited).",
},
},
},
]
}
)
+43
View File
@@ -1176,6 +1176,14 @@ class AttackPathsScanSerializer(RLSSerializer):
return provider.uid if provider else None
class AttackPathsQueryAttributionSerializer(BaseSerializerV1):
text = serializers.CharField()
link = serializers.CharField()
class JSONAPIMeta:
resource_name = "attack-paths-query-attributions"
class AttackPathsQueryParameterSerializer(BaseSerializerV1):
name = serializers.CharField()
label = serializers.CharField()
@@ -1190,7 +1198,9 @@ class AttackPathsQueryParameterSerializer(BaseSerializerV1):
class AttackPathsQuerySerializer(BaseSerializerV1):
id = serializers.CharField()
name = serializers.CharField()
short_description = serializers.CharField()
description = serializers.CharField()
attribution = AttackPathsQueryAttributionSerializer(allow_null=True, required=False)
provider = serializers.CharField()
parameters = AttackPathsQueryParameterSerializer(many=True)
@@ -1515,6 +1525,8 @@ class BaseWriteProviderSecretSerializer(BaseWriteSerializer):
"or both 'api_key' and 'api_email'."
}
)
elif provider_type == Provider.ProviderChoices.IMAGE.value:
serializer = ImageProviderSecret(data=secret)
else:
raise serializers.ValidationError(
{"provider": f"Provider type not supported {provider_type}"}
@@ -1653,6 +1665,37 @@ class IacProviderSecret(serializers.Serializer):
resource_name = "provider-secrets"
class ImageProviderSecret(serializers.Serializer):
registry_username = serializers.CharField(required=False)
registry_password = serializers.CharField(required=False)
registry_token = serializers.CharField(required=False)
image_filter = serializers.CharField(required=False)
tag_filter = serializers.CharField(required=False)
max_images = serializers.IntegerField(required=False, min_value=0)
def validate(self, attrs):
has_username = attrs.get("registry_username")
has_password = attrs.get("registry_password")
has_token = attrs.get("registry_token")
if (has_username or has_password) and has_token:
raise serializers.ValidationError(
"You cannot provide both registry_username/registry_password and registry_token."
)
if has_username and not has_password:
raise serializers.ValidationError(
"registry_password is required when registry_username is provided."
)
if has_password and not has_username:
raise serializers.ValidationError(
"registry_username is required when registry_password is provided."
)
return super().validate(attrs)
class Meta:
resource_name = "provider-secrets"
class OracleCloudProviderSecret(serializers.Serializer):
user = serializers.CharField()
fingerprint = serializers.CharField()
+33 -26
View File
@@ -392,7 +392,7 @@ class SchemaView(SpectacularAPIView):
def get(self, request, *args, **kwargs):
spectacular_settings.TITLE = "Prowler API"
spectacular_settings.VERSION = "1.19.0"
spectacular_settings.VERSION = "1.20.0"
spectacular_settings.DESCRIPTION = (
"Prowler API specification.\n\nThis file is auto-generated."
)
@@ -763,27 +763,40 @@ class TenantFinishACSView(FinishACSView):
.tenant
)
# Check if tenant has only one user with MANAGE_ACCOUNT role
users_with_manage_account = (
role_name = (
extra.get("userType", ["no_permissions"])[0].strip()
if extra.get("userType")
else "no_permissions"
)
role = (
Role.objects.using(MainRouter.admin_db)
.filter(name=role_name, tenant=tenant)
.first()
)
# Only skip mapping if it would remove the last MANAGE_ACCOUNT user
remaining_manage_account_users = (
UserRoleRelationship.objects.using(MainRouter.admin_db)
.filter(role__manage_account=True, tenant_id=tenant.id)
.exclude(user_id=user_id)
.values("user")
.distinct()
.count()
)
user_has_manage_account = (
UserRoleRelationship.objects.using(MainRouter.admin_db)
.filter(role__manage_account=True, tenant_id=tenant.id, user_id=user_id)
.exists()
)
role_manage_account = role.manage_account if role else False
would_remove_last_manage_account = (
user_has_manage_account
and remaining_manage_account_users == 0
and not role_manage_account
)
# Only apply role mapping from userType if tenant does NOT have exactly one user with MANAGE_ACCOUNT
if users_with_manage_account != 1:
role_name = (
extra.get("userType", ["no_permissions"])[0].strip()
if extra.get("userType")
else "no_permissions"
)
try:
role = Role.objects.using(MainRouter.admin_db).get(
name=role_name, tenant=tenant
)
except Role.DoesNotExist:
if not would_remove_last_manage_account:
if role is None:
role = Role.objects.using(MainRouter.admin_db).create(
name=role_name,
tenant=tenant,
@@ -2415,15 +2428,6 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
}
)
if not attack_paths_scan.graph_database:
logger.error(
f"The Attack Paths Scan {attack_paths_scan.id} does not reference a graph database"
)
return Response(
{"detail": "The Attack Paths scan does not reference a graph database"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
payload = attack_paths_views_helpers.normalize_run_payload(request.data)
serializer = AttackPathsQueryRunRequestSerializer(data=payload)
serializer.is_valid(raise_exception=True)
@@ -2437,6 +2441,9 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
{"id": "Unknown Attack Paths query for the selected provider"}
)
database_name = graph_database.get_database_name(
attack_paths_scan.provider.tenant_id
)
parameters = attack_paths_views_helpers.prepare_query_parameters(
query_definition,
serializer.validated_data.get("parameters", {}),
@@ -2444,9 +2451,9 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
)
graph = attack_paths_views_helpers.execute_attack_paths_query(
attack_paths_scan, query_definition, parameters
database_name, query_definition, parameters
)
graph_database.clear_cache(attack_paths_scan.graph_database)
graph_database.clear_cache(database_name)
status_code = status.HTTP_200_OK
if not graph.get("nodes"):
+8 -2
View File
@@ -537,6 +537,12 @@ def providers_fixture(tenants_fixture):
alias="cloudflare_testing",
tenant_id=tenant.id,
)
provider11 = Provider.objects.create(
provider="image",
uid="ghcr.io",
alias="image_testing",
tenant_id=tenant.id,
)
return (
provider1,
@@ -549,6 +555,7 @@ def providers_fixture(tenants_fixture):
provider8,
provider9,
provider10,
provider11,
)
@@ -1618,7 +1625,6 @@ def create_attack_paths_scan():
scan=None,
state=StateChoices.COMPLETED,
progress=0,
graph_database="tenant-db",
**extra_fields,
):
scan_instance = scan or Scan.objects.create(
@@ -1635,7 +1641,6 @@ def create_attack_paths_scan():
"scan": scan_instance,
"state": state,
"progress": progress,
"graph_database": graph_database,
}
payload.update(extra_fields)
@@ -1663,6 +1668,7 @@ def attack_paths_query_definition_factory():
definition_payload = {
"id": "aws-test",
"name": "Attack Paths Test Query",
"short_description": "Synthetic short description for tests.",
"description": "Synthetic Attack Paths definition for tests.",
"provider": "aws",
"cypher": "RETURN 1",
+14
View File
@@ -35,6 +35,11 @@ from prowler.lib.outputs.compliance.cis.cis_github import GithubCIS
from prowler.lib.outputs.compliance.cis.cis_kubernetes import KubernetesCIS
from prowler.lib.outputs.compliance.cis.cis_m365 import M365CIS
from prowler.lib.outputs.compliance.cis.cis_oraclecloud import OracleCloudCIS
from prowler.lib.outputs.compliance.csa.csa_alibabacloud import AlibabaCloudCSA
from prowler.lib.outputs.compliance.csa.csa_aws import AWSCSA
from prowler.lib.outputs.compliance.csa.csa_azure import AzureCSA
from prowler.lib.outputs.compliance.csa.csa_gcp import GCPCSA
from prowler.lib.outputs.compliance.csa.csa_oraclecloud import OracleCloudCSA
from prowler.lib.outputs.compliance.ens.ens_aws import AWSENS
from prowler.lib.outputs.compliance.ens.ens_azure import AzureENS
from prowler.lib.outputs.compliance.ens.ens_gcp import GCPENS
@@ -90,6 +95,7 @@ COMPLIANCE_CLASS_MAP = {
(lambda name: name == "prowler_threatscore_aws", ProwlerThreatScoreAWS),
(lambda name: name == "ccc_aws", CCC_AWS),
(lambda name: name.startswith("c5_"), AWSC5),
(lambda name: name.startswith("csa_"), AWSCSA),
],
"azure": [
(lambda name: name.startswith("cis_"), AzureCIS),
@@ -99,6 +105,7 @@ COMPLIANCE_CLASS_MAP = {
(lambda name: name == "ccc_azure", CCC_Azure),
(lambda name: name == "prowler_threatscore_azure", ProwlerThreatScoreAzure),
(lambda name: name == "c5_azure", AzureC5),
(lambda name: name.startswith("csa_"), AzureCSA),
],
"gcp": [
(lambda name: name.startswith("cis_"), GCPCIS),
@@ -108,6 +115,7 @@ COMPLIANCE_CLASS_MAP = {
(lambda name: name == "prowler_threatscore_gcp", ProwlerThreatScoreGCP),
(lambda name: name == "ccc_gcp", CCC_GCP),
(lambda name: name == "c5_gcp", GCPC5),
(lambda name: name.startswith("csa_"), GCPCSA),
],
"kubernetes": [
(lambda name: name.startswith("cis_"), KubernetesCIS),
@@ -129,11 +137,17 @@ COMPLIANCE_CLASS_MAP = {
# IaC provider doesn't have specific compliance frameworks yet
# Trivy handles its own compliance checks
],
"image": [
# Image provider doesn't have specific compliance frameworks yet
# Trivy handles its own compliance checks
],
"oraclecloud": [
(lambda name: name.startswith("cis_"), OracleCloudCIS),
(lambda name: name.startswith("csa_"), OracleCloudCSA),
],
"alibabacloud": [
(lambda name: name.startswith("cis_"), AlibabaCloudCIS),
(lambda name: name.startswith("csa_"), AlibabaCloudCSA),
(
lambda name: name == "prowler_threatscore_alibabacloud",
ProwlerThreatScoreAlibaba,
+35 -17
View File
@@ -133,13 +133,41 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str)
scan_id (str): The ID of the scan that was performed.
provider_id (str): The primary key of the Provider instance that was scanned.
"""
chain(
create_compliance_requirements_task.si(tenant_id=tenant_id, scan_id=scan_id),
update_provider_compliance_scores_task.si(tenant_id=tenant_id, scan_id=scan_id),
).apply_async()
aggregate_attack_surface_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
with rls_transaction(tenant_id):
provider_type = Provider.objects.get(id=provider_id).provider
has_compliance = provider_type not in (
Provider.ProviderChoices.IAC.value,
Provider.ProviderChoices.IMAGE.value,
)
if has_compliance:
chain(
create_compliance_requirements_task.si(
tenant_id=tenant_id, scan_id=scan_id
),
update_provider_compliance_scores_task.si(
tenant_id=tenant_id, scan_id=scan_id
),
).apply_async()
aggregate_attack_surface_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
)
final_group_tasks = [
check_integrations_task.si(
tenant_id=tenant_id,
provider_id=provider_id,
scan_id=scan_id,
),
]
if has_compliance:
final_group_tasks.append(
generate_compliance_reports_task.si(
tenant_id=tenant_id, scan_id=scan_id, provider_id=provider_id
),
)
chain(
perform_scan_summary_task.si(tenant_id=tenant_id, scan_id=scan_id),
group(
@@ -148,17 +176,7 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str)
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
),
group(
# Use optimized task that generates both reports with shared queries
generate_compliance_reports_task.si(
tenant_id=tenant_id, scan_id=scan_id, provider_id=provider_id
),
check_integrations_task.si(
tenant_id=tenant_id,
provider_id=provider_id,
scan_id=scan_id,
),
),
group(*final_group_tasks),
).apply_async()
if can_provider_run_attack_paths_scan(tenant_id, provider_id):
Generated
+6 -1
View File
@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.3.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
[[package]]
name = "about-time"
@@ -5874,6 +5874,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"},
@@ -5882,6 +5883,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"},
@@ -5890,6 +5892,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"},
@@ -5898,6 +5901,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f6f3eac23941b32afccc23081e1f50612bdbe4e982012ef4f5797986828cd01"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"},
@@ -5906,6 +5910,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"},
{file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"},
+1
View File
@@ -70,6 +70,7 @@ All notable changes to the **Prowler SDK** are documented in this file.
- CIS 5.0 compliance framework for the Azure provider [(#9777)](https://github.com/prowler-cloud/prowler/pull/9777)
- `Cloudflare` Bot protection, WAF, Privacy, Anti-Scraping and Zone configuration checks [(#9425)](https://github.com/prowler-cloud/prowler/pull/9425)
- `Cloudflare` `waf` and `dns record` checks [(#9426)](https://github.com/prowler-cloud/prowler/pull/9426)
- Container Image provider using Trivy for vulnerability and secret scanning
### Changed
+15 -7
View File
@@ -119,6 +119,8 @@ from prowler.providers.common.quick_inventory import run_provider_quick_inventor
from prowler.providers.gcp.models import GCPOutputOptions
from prowler.providers.github.models import GithubOutputOptions
from prowler.providers.iac.models import IACOutputOptions
from prowler.providers.image.exceptions.exceptions import ImageBaseException
from prowler.providers.image.models import ImageOutputOptions
from prowler.providers.kubernetes.models import KubernetesOutputOptions
from prowler.providers.llm.models import LLMOutputOptions
from prowler.providers.m365.models import M365OutputOptions
@@ -206,8 +208,8 @@ def prowler():
# Load compliance frameworks
logger.debug("Loading compliance frameworks from .json files")
# Skip compliance frameworks for IAC and LLM providers
if provider != "iac" and provider != "llm":
# Skip compliance frameworks for IAC, LLM, and Image providers
if provider not in ("iac", "llm", "image"):
bulk_compliance_frameworks = Compliance.get_bulk(provider)
# Complete checks metadata with the compliance framework specification
bulk_checks_metadata = update_checks_metadata_with_compliance(
@@ -264,8 +266,8 @@ def prowler():
if not args.only_logs:
global_provider.print_credentials()
# Skip service and check loading for IAC and LLM providers
if provider != "iac" and provider != "llm":
# Skip service and check loading for IAC, LLM, and Image providers
if provider not in ("iac", "llm", "image"):
# Import custom checks from folder
if checks_folder:
custom_checks = parse_checks_from_folder(global_provider, checks_folder)
@@ -352,6 +354,8 @@ def prowler():
)
elif provider == "iac":
output_options = IACOutputOptions(args, bulk_checks_metadata)
elif provider == "image":
output_options = ImageOutputOptions(args, bulk_checks_metadata)
elif provider == "llm":
output_options = LLMOutputOptions(args, bulk_checks_metadata)
elif provider == "oraclecloud":
@@ -375,8 +379,8 @@ def prowler():
# Execute checks
findings = []
if provider == "iac" or provider == "llm":
# For IAC and LLM providers, run the scan directly
if provider in ("iac", "llm", "image"):
# For IAC, LLM, and Image providers, run the scan directly
if provider == "llm":
def streaming_callback(findings_batch):
@@ -386,7 +390,11 @@ def prowler():
findings = global_provider.run_scan(streaming_callback=streaming_callback)
else:
# Original behavior for IAC or non-verbose LLM
findings = global_provider.run()
try:
findings = global_provider.run()
except ImageBaseException as error:
logger.critical(f"{error}")
sys.exit(1)
# Note: IaC doesn't support granular progress tracking since Trivy runs as a black box
# and returns all findings at once. Progress tracking would just be 0% → 100%.
+4 -1
View File
@@ -75,7 +75,10 @@ def get_available_compliance_frameworks(provider=None):
if provider:
providers = [provider]
for provider in providers:
with os.scandir(f"{actual_directory}/../compliance/{provider}") as files:
compliance_dir = f"{actual_directory}/../compliance/{provider}"
if not os.path.isdir(compliance_dir):
continue
with os.scandir(compliance_dir) as files:
for file in files:
if file.is_file() and file.name.endswith(".json"):
available_compliance_frameworks.append(
+40
View File
@@ -163,6 +163,7 @@ class CheckMetadata(BaseModel):
check_id
and values.get("Provider") != "iac"
and values.get("Provider") != "llm"
and values.get("Provider") != "image"
):
service_from_check_id = check_id.split("_")[0]
if service_name != service_from_check_id:
@@ -183,6 +184,7 @@ class CheckMetadata(BaseModel):
check_id
and values.get("Provider") != "iac"
and values.get("Provider") != "llm"
and values.get("Provider") != "image"
):
if "-" in check_id:
raise ValueError(
@@ -865,6 +867,44 @@ class CheckReportIAC(Check_Report):
)
@dataclass
class CheckReportImage(Check_Report):
"""Contains the Container Image Check's finding information using Trivy."""
resource_name: str
image_digest: str
package_name: str
installed_version: str
fixed_version: str
def __init__(
self,
metadata: Optional[dict] = None,
finding: Optional[dict] = None,
image_name: str = "",
) -> None:
"""
Initialize the Container Image Check's finding information from a Trivy vulnerability/secret dict.
Args:
metadata (Dict): Check metadata.
finding (dict): A single vulnerability/secret result from Trivy's JSON output.
image_name (str): The container image name being scanned.
"""
if metadata is None:
metadata = {}
if finding is None:
finding = {}
super().__init__(metadata, finding)
self.resource = finding
self.resource_name = image_name
self.image_digest = finding.get("PkgID", "")
self.package_name = finding.get("PkgName", "")
self.installed_version = finding.get("InstalledVersion", "")
self.fixed_version = finding.get("FixedVersion", "")
@dataclass
class CheckReportLLM(Check_Report):
"""Contains the LLM Check's finding information."""
+2 -2
View File
@@ -14,8 +14,8 @@ def recover_checks_from_provider(
Returns a list of tuples with the following format (check_name, check_path)
"""
try:
# Bypass check loading for IAC provider since it uses Trivy directly
if provider == "iac" or provider == "llm":
# Bypass check loading for IAC, LLM, and Image providers since they use external tools directly
if provider in ("iac", "llm", "image"):
return []
checks = []
+3 -2
View File
@@ -27,10 +27,10 @@ class ProwlerArgumentParser:
self.parser = argparse.ArgumentParser(
prog="prowler",
formatter_class=RawTextHelpFormatter,
usage="prowler [-h] [--version] {aws,azure,gcp,kubernetes,m365,github,nhn,mongodbatlas,oraclecloud,alibabacloud,cloudflare,openstack,dashboard,iac} ...",
usage="prowler [-h] [--version] {aws,azure,gcp,kubernetes,m365,github,nhn,mongodbatlas,oraclecloud,alibabacloud,cloudflare,openstack,dashboard,iac,image} ...",
epilog="""
Available Cloud Providers:
{aws,azure,gcp,kubernetes,m365,github,iac,llm,nhn,mongodbatlas,oraclecloud,alibabacloud,cloudflare,openstack}
{aws,azure,gcp,kubernetes,m365,github,iac,llm,image,nhn,mongodbatlas,oraclecloud,alibabacloud,cloudflare,openstack}
aws AWS Provider
azure Azure Provider
gcp GCP Provider
@@ -43,6 +43,7 @@ Available Cloud Providers:
alibabacloud Alibaba Cloud Provider
iac IaC Provider (Beta)
llm LLM Provider (Beta)
image Container Image Provider (PoC)
nhn NHN Provider (Unofficial)
mongodbatlas MongoDB Atlas Provider (Beta)
+17
View File
@@ -380,6 +380,23 @@ class Finding(BaseModel):
output_data["resource_uid"] = check_output.resource_id
output_data["region"] = check_output.region
elif provider.type == "image":
output_data["auth_method"] = provider.auth_method
output_data["account_uid"] = "image"
output_data["account_name"] = "image"
output_data["resource_name"] = getattr(
check_output, "resource_name", ""
)
output_data["resource_uid"] = getattr(check_output, "resource_name", "")
output_data["region"] = getattr(check_output, "region", "container")
output_data["package_name"] = getattr(check_output, "package_name", "")
output_data["installed_version"] = getattr(
check_output, "installed_version", ""
)
output_data["fixed_version"] = getattr(
check_output, "fixed_version", ""
)
# check_output Unique ID
# TODO: move this to a function
# TODO: in Azure, GCP and K8s there are findings without resource_name
+6
View File
@@ -93,6 +93,12 @@ def display_summary_table(
if provider.identity.project_name
else provider.identity.project_id
)
elif provider.type == "image":
entity_type = "Image"
if len(provider.images) == 1:
audited_entities = provider.images[0]
else:
audited_entities = f"{len(provider.images)} images"
# Check if there are findings and that they are not all MANUAL
if findings and not all(finding.status == "MANUAL" for finding in findings):
+71 -8
View File
@@ -27,6 +27,7 @@ from prowler.lib.scan.exceptions.exceptions import (
from prowler.providers.common.models import Audit_Metadata, ProviderOutputOptions
from prowler.providers.common.provider import Provider
from prowler.providers.iac.iac_provider import IacProvider
from prowler.providers.image.image_provider import ImageProvider
class Scan:
@@ -92,10 +93,10 @@ class Scan:
except ValueError:
raise ScanInvalidStatusError(f"Invalid status provided: {s}.")
# Special setup for IaC provider - override inputs to work with traditional flow
if provider.type == "iac":
# IaC doesn't use traditional Prowler checks, so clear all input parameters
# to avoid validation errors and let it flow through the normal logic
# Special setup for IaC/Image providers - override inputs to work with traditional flow
if provider.type in ("iac", "image"):
# These providers don't use traditional Prowler checks, so clear all input
# parameters to avoid validation errors and let them flow through the normal logic
checks = None
services = None
excluded_checks = None
@@ -160,8 +161,8 @@ class Scan:
)
# Load checks to execute
if provider.type == "iac":
self._checks_to_execute = ["iac_scan"] # Dummy check name for IaC
if provider.type in ("iac", "image"):
self._checks_to_execute = [f"{provider.type}_scan"] # Dummy check name
else:
self._checks_to_execute = sorted(
load_checks_to_execute(
@@ -200,8 +201,8 @@ class Scan:
self._number_of_checks_to_execute = len(self._checks_to_execute)
# Set up service-based checks tracking
if provider.type == "iac":
service_checks_to_execute = {"iac": set(["iac_scan"])}
if provider.type in ("iac", "image"):
service_checks_to_execute = {provider.type: set([f"{provider.type}_scan"])}
else:
service_checks_to_execute = get_service_checks_to_execute(
self._checks_to_execute
@@ -346,6 +347,68 @@ class Scan:
self._duration = int((end_time - start_time).total_seconds())
return
# Special handling for Image provider
elif self._provider.type == "image":
if isinstance(self._provider, ImageProvider):
logger.info("Running Image scan with Trivy...")
total_images = len(self._provider.images)
self._number_of_checks_to_execute = max(total_images, 1)
for i, (image_name, image_reports) in enumerate(
self._provider.scan_per_image()
):
# Build resource UID from image name + SHA (all reports share the same SHA)
image_sha = image_reports[0].image_sha if image_reports else ""
resource_uid = (
f"{image_name}:{image_sha}" if image_sha else image_name
)
findings = []
for report in image_reports:
finding_uid = (
f"{report.check_metadata.CheckID}"
f"-{image_name}"
f"-{report.resource_id}"
)
status_enum = (
Status.FAIL if report.status == "FAIL" else Status.PASS
)
if report.muted:
status_enum = Status.MUTED
finding = Finding(
auth_method="Registry",
timestamp=datetime.datetime.now(timezone.utc),
account_uid=getattr(self._provider, "registry", None)
or "image",
account_name="Container Registry",
metadata=report.check_metadata,
uid=finding_uid,
status=status_enum,
status_extended=report.status_extended,
muted=report.muted,
resource_uid=resource_uid,
resource_metadata=report.resource,
resource_name=image_name,
resource_details=report.resource_details,
resource_tags={},
region=report.region,
compliance={},
raw=report.resource,
)
findings.append(finding)
if self._status:
findings = [f for f in findings if f.status in self._status]
self._number_of_checks_completed = i + 1
yield (self.progress, findings)
end_time = datetime.datetime.now()
self._duration = int((end_time - start_time).total_seconds())
return
for check_name in checks_to_execute:
try:
# Recover service from check name
+14
View File
@@ -274,6 +274,20 @@ class Provider(ABC):
config_path=arguments.config_file,
fixer_config=fixer_config,
)
elif "image" in provider_class_name.lower():
provider_class(
images=arguments.images,
image_list_file=arguments.image_list_file,
scanners=arguments.scanners,
trivy_severity=arguments.trivy_severity,
ignore_unfixed=arguments.ignore_unfixed,
timeout=arguments.timeout,
config_path=arguments.config_file,
fixer_config=fixer_config,
registry_username=arguments.registry_username,
registry_password=arguments.registry_password,
registry_token=arguments.registry_token,
)
elif "mongodbatlas" in provider_class_name.lower():
provider_class(
atlas_public_key=arguments.atlas_public_key,
View File
@@ -0,0 +1,229 @@
from prowler.exceptions.exceptions import ProwlerException
# Exceptions codes from 11000 to 11999 are reserved for Image exceptions
class ImageBaseException(ProwlerException):
"""Base class for Image provider errors."""
IMAGE_ERROR_CODES = {
(11000, "ImageNoImagesProvidedError"): {
"message": "No container images provided for scanning.",
"remediation": "Provide at least one image using --image or --image-list-file.",
},
(11001, "ImageListFileNotFoundError"): {
"message": "Image list file not found.",
"remediation": "Ensure the image list file exists at the specified path.",
},
(11002, "ImageListFileReadError"): {
"message": "Error reading image list file.",
"remediation": "Check file permissions and format. The file should contain one image per line.",
},
(11003, "ImageFindingProcessingError"): {
"message": "Error processing image scan finding.",
"remediation": "Check the Trivy output format and ensure the finding structure is valid.",
},
(11004, "ImageTrivyBinaryNotFoundError"): {
"message": "Trivy binary not found.",
"remediation": "Install Trivy from https://trivy.dev/latest/getting-started/installation/",
},
(11005, "ImageScanError"): {
"message": "Error scanning container image.",
"remediation": "Check the image name and ensure it is accessible.",
},
(11006, "ImageInvalidTimeoutError"): {
"message": "Invalid timeout format.",
"remediation": "Use a valid timeout like '5m', '300s', or '1h'.",
},
(11007, "ImageInvalidScannerError"): {
"message": "Invalid scanner type.",
"remediation": "Use valid scanners: vuln, secret, misconfig, license.",
},
(11008, "ImageInvalidSeverityError"): {
"message": "Invalid severity level.",
"remediation": "Use valid severities: CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN.",
},
(11009, "ImageInvalidNameError"): {
"message": "Invalid container image name.",
"remediation": "Use a valid image reference (e.g., 'alpine:3.18', 'registry.example.com/repo/image:tag').",
},
(11010, "ImageInvalidConfigScannerError"): {
"message": "Invalid image config scanner type.",
"remediation": "Use valid image config scanners: misconfig, secret.",
},
(11013, "ImageRegistryAuthError"): {
"message": "Registry authentication failed.",
"remediation": "Check REGISTRY_USERNAME/REGISTRY_PASSWORD or REGISTRY_TOKEN environment variables.",
},
(11014, "ImageRegistryCatalogError"): {
"message": "Registry does not support catalog listing.",
"remediation": "Use --image or --image-list instead of --registry.",
},
(11015, "ImageRegistryNetworkError"): {
"message": "Network error communicating with registry.",
"remediation": "Check registry URL and network connectivity.",
},
(11016, "ImageMaxImagesExceededError"): {
"message": "Discovered images exceed --max-images limit.",
"remediation": "Use --image-filter or --tag-filter to narrow results, or increase --max-images.",
},
(11017, "ImageInvalidFilterError"): {
"message": "Invalid regex filter pattern.",
"remediation": "Check the regex syntax for --image-filter or --tag-filter.",
},
}
def __init__(self, code, file=None, original_exception=None, message=None):
error_info = self.IMAGE_ERROR_CODES.get((code, self.__class__.__name__))
if error_info and message:
error_info = {**error_info, "message": message}
super().__init__(
code,
source="Image",
file=file,
original_exception=original_exception,
error_info=error_info,
)
class ImageNoImagesProvidedError(ImageBaseException):
"""Exception raised when no container images are provided for scanning."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11000, file=file, original_exception=original_exception, message=message
)
class ImageListFileNotFoundError(ImageBaseException):
"""Exception raised when the image list file is not found."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11001, file=file, original_exception=original_exception, message=message
)
class ImageListFileReadError(ImageBaseException):
"""Exception raised when the image list file cannot be read."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11002, file=file, original_exception=original_exception, message=message
)
class ImageFindingProcessingError(ImageBaseException):
"""Exception raised when a finding cannot be processed."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11003, file=file, original_exception=original_exception, message=message
)
class ImageTrivyBinaryNotFoundError(ImageBaseException):
"""Exception raised when the Trivy binary is not found."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11004, file=file, original_exception=original_exception, message=message
)
class ImageScanError(ImageBaseException):
"""Exception raised when a general scan error occurs."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11005, file=file, original_exception=original_exception, message=message
)
class ImageInvalidTimeoutError(ImageBaseException):
"""Exception raised when an invalid timeout format is provided."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11006, file=file, original_exception=original_exception, message=message
)
class ImageInvalidScannerError(ImageBaseException):
"""Exception raised when an invalid scanner type is provided."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11007, file=file, original_exception=original_exception, message=message
)
class ImageInvalidSeverityError(ImageBaseException):
"""Exception raised when an invalid severity level is provided."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11008, file=file, original_exception=original_exception, message=message
)
class ImageInvalidNameError(ImageBaseException):
"""Exception raised when an invalid container image name is provided."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11009, file=file, original_exception=original_exception, message=message
)
class ImageInvalidConfigScannerError(ImageBaseException):
"""Exception raised when an invalid image config scanner type is provided."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11010, file=file, original_exception=original_exception, message=message
)
class ImageRegistryAuthError(ImageBaseException):
"""Exception raised when registry authentication fails."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11013, file=file, original_exception=original_exception, message=message
)
class ImageRegistryCatalogError(ImageBaseException):
"""Exception raised when registry does not support catalog listing."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11014, file=file, original_exception=original_exception, message=message
)
class ImageRegistryNetworkError(ImageBaseException):
"""Exception raised when a network error occurs communicating with a registry."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11015, file=file, original_exception=original_exception, message=message
)
class ImageMaxImagesExceededError(ImageBaseException):
"""Exception raised when discovered images exceed --max-images limit."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11016, file=file, original_exception=original_exception, message=message
)
class ImageInvalidFilterError(ImageBaseException):
"""Exception raised when an invalid regex filter pattern is provided."""
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
11017, file=file, original_exception=original_exception, message=message
)
+983
View File
@@ -0,0 +1,983 @@
from __future__ import annotations
import json
import os
import re
import subprocess
import sys
from typing import Generator
from alive_progress import alive_bar
from colorama import Fore, Style
from prowler.config.config import (
default_config_file_path,
load_and_validate_config_file,
)
from prowler.lib.check.models import CheckReportImage
from prowler.lib.logger import logger
from prowler.lib.utils.utils import print_boxes
from prowler.providers.common.models import Audit_Metadata, Connection
from prowler.providers.common.provider import Provider
from prowler.providers.image.exceptions.exceptions import (
ImageFindingProcessingError,
ImageInvalidConfigScannerError,
ImageInvalidFilterError,
ImageInvalidNameError,
ImageInvalidScannerError,
ImageInvalidSeverityError,
ImageInvalidTimeoutError,
ImageListFileNotFoundError,
ImageListFileReadError,
ImageMaxImagesExceededError,
ImageNoImagesProvidedError,
ImageScanError,
ImageTrivyBinaryNotFoundError,
)
from prowler.providers.image.lib.arguments.arguments import (
IMAGE_CONFIG_SCANNERS_CHOICES,
SCANNERS_CHOICES,
SEVERITY_CHOICES,
)
from prowler.providers.image.lib.registry.dockerhub_adapter import DockerHubAdapter
from prowler.providers.image.lib.registry.factory import create_registry_adapter
class ImageProvider(Provider):
"""
Container Image Provider using Trivy for vulnerability and secret scanning.
This is a Tool/Wrapper provider that delegates all scanning logic to Trivy's
`trivy image` command and converts the output to Prowler's finding format.
"""
_type: str = "image"
FINDING_BATCH_SIZE: int = 100
MAX_IMAGE_LIST_LINES: int = 10_000
MAX_IMAGE_NAME_LENGTH: int = 500
_IMAGE_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9.\-_/:@]+$")
_SHELL_METACHARACTERS = frozenset(";|&$`\n\r")
audit_metadata: Audit_Metadata
def __init__(
self,
images: list[str] | None = None,
image_list_file: str | None = None,
scanners: list[str] | None = None,
image_config_scanners: list[str] | None = None,
trivy_severity: list[str] | None = None,
ignore_unfixed: bool = False,
timeout: str = "5m",
config_path: str | None = None,
config_content: dict | None = None,
fixer_config: dict | None = None,
registry_username: str | None = None,
registry_password: str | None = None,
registry_token: str | None = None,
registry: str | None = None,
image_filter: str | None = None,
tag_filter: str | None = None,
max_images: int = 0,
registry_insecure: bool = False,
registry_list_images: bool = False,
):
logger.info("Instantiating Image Provider...")
self.images = images if images is not None else []
self.image_list_file = image_list_file
self.scanners = scanners if scanners is not None else ["vuln", "secret"]
self.image_config_scanners = (
image_config_scanners if image_config_scanners is not None else []
)
self.trivy_severity = trivy_severity if trivy_severity is not None else []
self.ignore_unfixed = ignore_unfixed
self.timeout = timeout
self.region = "container"
self.audited_account = "image-scan"
self._session = None
self._identity = "prowler"
self._listing_only = False
# Registry authentication (follows IaC pattern: explicit params, env vars internal)
self.registry_username = registry_username or os.environ.get(
"REGISTRY_USERNAME"
)
self.registry_password = registry_password or os.environ.get(
"REGISTRY_PASSWORD"
)
self.registry_token = registry_token or os.environ.get("REGISTRY_TOKEN")
if self.registry_username and self.registry_password:
self._auth_method = "Docker login"
logger.info("Using docker login for registry authentication")
elif self.registry_token:
self._auth_method = "Registry token"
logger.info("Using registry token for authentication")
else:
self._auth_method = "No auth"
# Registry scan mode
self.registry = registry
self.image_filter = image_filter
self.tag_filter = tag_filter
self.max_images = max_images
self.registry_insecure = registry_insecure
self.registry_list_images = registry_list_images
# Compile regex filters
self._image_filter_re = None
self._tag_filter_re = None
if self.image_filter:
try:
self._image_filter_re = re.compile(self.image_filter)
except re.error as exc:
raise ImageInvalidFilterError(
file=__file__,
message=f"Invalid --image-filter regex '{self.image_filter}': {exc}",
)
if self.tag_filter:
try:
self._tag_filter_re = re.compile(self.tag_filter)
except re.error as exc:
raise ImageInvalidFilterError(
file=__file__,
message=f"Invalid --tag-filter regex '{self.tag_filter}': {exc}",
)
self._validate_inputs()
# Load images from file if provided
if image_list_file:
self._load_images_from_file(image_list_file)
# Registry scan mode: enumerate images from registry
if self.registry:
self._enumerate_registry()
if self._listing_only:
return
for image in self.images:
self._validate_image_name(image)
if not self.images:
raise ImageNoImagesProvidedError(
file=__file__,
message="No images provided for scanning.",
)
# Audit Config
if config_content:
self._audit_config = config_content
else:
if not config_path:
config_path = default_config_file_path
self._audit_config = load_and_validate_config_file(self._type, config_path)
# Fixer Config
self._fixer_config = fixer_config if fixer_config is not None else {}
# Mutelist (not needed for Image provider since Trivy has its own logic)
self._mutelist = None
self.audit_metadata = Audit_Metadata(
provider=self._type,
account_id=self.audited_account,
account_name="image",
region=self.region,
services_scanned=0,
expected_checks=[],
completed_checks=0,
audit_progress=0,
)
Provider.set_global_provider(self)
def _load_images_from_file(self, file_path: str) -> None:
"""Load image names from a file (one per line)."""
try:
line_count = 0
with open(file_path, "r") as f:
for line in f:
line_count += 1
if line_count > self.MAX_IMAGE_LIST_LINES:
raise ImageListFileReadError(
file=file_path,
message=f"Image list file exceeds maximum of {self.MAX_IMAGE_LIST_LINES} lines.",
)
line = line.strip()
if not line or line.startswith("#"):
continue
if len(line) > self.MAX_IMAGE_NAME_LENGTH:
logger.warning(
f"Skipping image name exceeding {self.MAX_IMAGE_NAME_LENGTH} chars at line {line_count} in {file_path}"
)
continue
self.images.append(line)
logger.info(f"Loaded {len(self.images)} images from {file_path}")
except FileNotFoundError:
raise ImageListFileNotFoundError(
file=file_path,
message=f"Image list file not found: {file_path}",
)
except (ImageListFileReadError, ImageListFileNotFoundError):
raise
except Exception as error:
raise ImageListFileReadError(
file=file_path,
original_exception=error,
message=f"Error reading image list file: {error}",
)
def _validate_inputs(self) -> None:
"""Validate timeout, scanners, and severity inputs."""
if not re.fullmatch(r"\d+[smh]", self.timeout):
raise ImageInvalidTimeoutError(
file=__file__,
message=f"Invalid timeout format: '{self.timeout}'. Expected pattern like '5m', '300s', or '1h'.",
)
for scanner in self.scanners:
if scanner not in SCANNERS_CHOICES:
raise ImageInvalidScannerError(
file=__file__,
message=f"Invalid scanner: '{scanner}'. Valid options: {', '.join(SCANNERS_CHOICES)}.",
)
for config_scanner in self.image_config_scanners:
if config_scanner not in IMAGE_CONFIG_SCANNERS_CHOICES:
raise ImageInvalidConfigScannerError(
file=__file__,
message=f"Invalid image config scanner: '{config_scanner}'. Valid options: {', '.join(IMAGE_CONFIG_SCANNERS_CHOICES)}.",
)
for severity in self.trivy_severity:
if severity not in SEVERITY_CHOICES:
raise ImageInvalidSeverityError(
file=__file__,
message=f"Invalid severity: '{severity}'. Valid options: {', '.join(SEVERITY_CHOICES)}.",
)
def _validate_image_name(self, name: str) -> None:
"""Validate a container image name for safety and correctness."""
if not name:
raise ImageInvalidNameError(
file=__file__,
message="Image name must not be empty.",
)
if len(name) > self.MAX_IMAGE_NAME_LENGTH:
raise ImageInvalidNameError(
file=__file__,
message=f"Image name exceeds maximum length of {self.MAX_IMAGE_NAME_LENGTH} characters: '{name[:50]}...'",
)
if any(c in self._SHELL_METACHARACTERS for c in name):
raise ImageInvalidNameError(
file=__file__,
message=f"Image name contains invalid characters: '{name}'",
)
if not self._IMAGE_NAME_PATTERN.fullmatch(name):
raise ImageInvalidNameError(
file=__file__,
message=f"Image name does not match valid OCI reference format: '{name}'",
)
@property
def auth_method(self) -> str:
return self._auth_method
@property
def type(self) -> str:
return self._type
@property
def identity(self) -> str:
return self._identity
@property
def session(self) -> None:
return self._session
@property
def audit_config(self) -> dict:
return self._audit_config
@property
def fixer_config(self) -> dict:
return self._fixer_config
def setup_session(self) -> None:
"""Image provider doesn't need a session since it uses Trivy directly"""
return None
@staticmethod
def _extract_registry(image: str) -> str | None:
"""Extract registry hostname from an image reference.
Returns None for Docker Hub images (no registry prefix).
"""
parts = image.split("/")
if len(parts) >= 2 and ("." in parts[0] or ":" in parts[0]):
return parts[0]
return None
def cleanup(self) -> None:
"""Clean up any resources after scanning."""
def _process_finding(
self,
finding: dict,
image: str,
trivy_target: str,
image_sha: str = "",
) -> CheckReportImage:
"""
Process a single finding and create a CheckReportImage object.
Args:
finding: The finding object from Trivy output
image: The clean container image name (e.g., "alpine:3.18")
trivy_target: The Trivy target string (e.g., "alpine:3.18 (alpine 3.18.0)")
image_sha: Short SHA from Trivy Metadata.ImageID for resource uniqueness
Returns:
CheckReportImage: The processed check report
"""
try:
# Determine finding ID and category based on type
if "VulnerabilityID" in finding:
finding_id = finding["VulnerabilityID"]
finding_description = finding.get(
"Description", finding.get("Title", "")
)
finding_status = "FAIL"
finding_categories = ["vulnerability"]
elif "RuleID" in finding:
# Secret finding
finding_id = finding["RuleID"]
finding_description = finding.get("Title", "Secret detected")
finding_status = "FAIL"
finding_categories = ["secrets"]
else:
finding_id = finding.get("ID", "UNKNOWN")
finding_description = finding.get("Description", "")
finding_status = finding.get("Status", "FAIL")
finding_categories = []
# Build remediation text for vulnerabilities
remediation_text = ""
if finding.get("FixedVersion"):
remediation_text = f"Upgrade {finding.get('PkgName', 'package')} to version {finding['FixedVersion']}"
elif finding.get("Resolution"):
remediation_text = finding["Resolution"]
# Convert Trivy severity to Prowler severity (lowercase, map UNKNOWN to informational)
trivy_severity = finding.get("Severity", "UNKNOWN").lower()
if trivy_severity == "unknown":
trivy_severity = "informational"
metadata_dict = {
"Provider": "image",
"CheckID": finding_id,
"CheckTitle": finding.get("Title", finding_id),
"CheckType": ["Container Image Security"],
"ServiceName": "container-image",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": trivy_severity,
"ResourceType": "container-image",
"ResourceGroup": "container",
"Description": finding_description,
"Risk": finding.get(
"Description", "Vulnerability detected in container image"
),
"RelatedUrl": "",
"Remediation": {
"Code": {
"NativeIaC": "",
"Terraform": "",
"CLI": "",
"Other": "",
},
"Recommendation": {
"Text": remediation_text,
"Url": finding.get("PrimaryURL", ""),
},
},
"Categories": finding_categories,
"DependsOn": [],
"RelatedTo": [],
"Notes": "",
}
# Convert metadata dict to JSON string
metadata = json.dumps(metadata_dict)
report = CheckReportImage(
metadata=metadata, finding=finding, image_name=image
)
report.status = finding_status
report.status_extended = self._build_status_extended(finding)
report.region = self.region
report.image_sha = image_sha
report.resource_details = trivy_target
return report
except Exception as error:
raise ImageFindingProcessingError(
file=__file__,
original_exception=error,
message=f"Error processing finding: {error}",
)
def _build_status_extended(self, finding: dict) -> str:
"""Build a detailed status message for the finding."""
parts = []
if finding.get("VulnerabilityID"):
parts.append(f"{finding['VulnerabilityID']}")
if finding.get("PkgName"):
pkg_info = finding["PkgName"]
if finding.get("InstalledVersion"):
pkg_info += f"@{finding['InstalledVersion']}"
parts.append(f"in package {pkg_info}")
if finding.get("FixedVersion"):
parts.append(f"(fix available: {finding['FixedVersion']})")
elif finding.get("Status") == "will_not_fix":
parts.append("(no fix available)")
if finding.get("Title"):
parts.append(f"- {finding['Title']}")
return (
" ".join(parts) if parts else finding.get("Description", "Finding detected")
)
def run(self) -> list[CheckReportImage]:
"""Execute the container image scan."""
try:
reports = []
for batch in self.run_scan():
reports.extend(batch)
return reports
finally:
self.cleanup()
def scan_per_image(
self,
) -> Generator[tuple[str, list[CheckReportImage]], None, None]:
"""Scan images one by one, yielding (image_name, findings) per image.
Unlike run() which returns all findings at once, this method yields
after each image completes, enabling progress tracking.
"""
try:
for image in self.images:
try:
image_findings = []
for batch in self._scan_single_image(image):
image_findings.extend(batch)
yield (image, image_findings)
except (ImageScanError, ImageTrivyBinaryNotFoundError):
raise
except Exception as error:
logger.error(f"Error scanning image {image}: {error}")
continue
finally:
self.cleanup()
def run_scan(self) -> Generator[list[CheckReportImage], None, None]:
"""
Run Trivy scan on all configured images.
Yields:
list[CheckReportImage]: Batches of findings
"""
for image in self.images:
try:
yield from self._scan_single_image(image)
except (ImageScanError, ImageTrivyBinaryNotFoundError):
raise
except Exception as error:
logger.error(f"Error scanning image {image}: {error}")
continue
def _scan_single_image(
self, image: str
) -> Generator[list[CheckReportImage], None, None]:
"""
Scan a single container image with Trivy.
Args:
image: The container image name/tag to scan
Yields:
list[CheckReportImage]: Batches of findings
"""
try:
logger.info(f"Scanning container image: {image}")
# Build Trivy command
trivy_command = [
"trivy",
"image",
"--format",
"json",
"--scanners",
",".join(self.scanners),
"--timeout",
self.timeout,
]
if self.image_config_scanners:
trivy_command.extend(
["--image-config-scanners", ",".join(self.image_config_scanners)]
)
if self.trivy_severity:
trivy_command.extend(["--severity", ",".join(self.trivy_severity)])
if self.ignore_unfixed:
trivy_command.append("--ignore-unfixed")
trivy_command.append(image)
# Execute Trivy
process = self._execute_trivy(trivy_command, image)
# Log stderr output
if process.stderr:
self._log_trivy_stderr(process.stderr)
# Check for Trivy failure
if process.returncode != 0:
error_msg = self._extract_trivy_errors(process.stderr)
categorized_msg = self._categorize_trivy_error(error_msg)
raise ImageScanError(
file=__file__,
message=f"Trivy scan failed for {image}: {categorized_msg}",
)
# Parse JSON output
try:
output = json.loads(process.stdout)
results = output.get("Results", [])
if not results:
logger.info(f"No findings for image: {image}")
return
# Extract image digest for resource uniqueness
trivy_metadata = output.get("Metadata", {})
image_id = trivy_metadata.get("ImageID", "")
if not image_id:
repo_digests = trivy_metadata.get("RepoDigests", [])
if repo_digests:
image_id = (
repo_digests[0].split("@")[-1]
if "@" in repo_digests[0]
else ""
)
short_sha = image_id.replace("sha256:", "")[:12] if image_id else ""
except json.JSONDecodeError as error:
logger.error(f"Failed to parse Trivy output for {image}: {error}")
logger.debug(f"Trivy stdout: {process.stdout[:500]}")
return
# Process findings in batches
batch = []
for result in results:
target = result.get("Target", image)
# Process Vulnerabilities
for vuln in result.get("Vulnerabilities", []):
report = self._process_finding(
vuln, image, target, image_sha=short_sha
)
batch.append(report)
if len(batch) >= self.FINDING_BATCH_SIZE:
yield batch
batch = []
# Process Secrets
for secret in result.get("Secrets", []):
report = self._process_finding(
secret, image, target, image_sha=short_sha
)
batch.append(report)
if len(batch) >= self.FINDING_BATCH_SIZE:
yield batch
batch = []
# Process Misconfigurations (from Dockerfile)
for misconfig in result.get("Misconfigurations", []):
report = self._process_finding(
misconfig, image, target, image_sha=short_sha
)
batch.append(report)
if len(batch) >= self.FINDING_BATCH_SIZE:
yield batch
batch = []
# Yield remaining findings
if batch:
yield batch
except (ImageScanError, ImageTrivyBinaryNotFoundError):
raise
except Exception as error:
if "No such file or directory: 'trivy'" in str(error):
raise ImageTrivyBinaryNotFoundError(
file=__file__,
original_exception=error,
message="Trivy binary not found. Please install Trivy from https://trivy.dev/latest/getting-started/installation/",
)
logger.error(f"Error scanning image {image}: {error}")
def _build_trivy_env(self) -> dict:
"""Build environment variables for Trivy, injecting registry credentials."""
env = dict(os.environ)
if self.registry_username and self.registry_password:
env["TRIVY_USERNAME"] = self.registry_username
env["TRIVY_PASSWORD"] = self.registry_password
elif self.registry_token:
env["TRIVY_REGISTRY_TOKEN"] = self.registry_token
return env
def _execute_trivy(self, command: list, image: str) -> subprocess.CompletedProcess:
"""Execute Trivy command with optional progress bar."""
env = self._build_trivy_env()
try:
if sys.stdout.isatty():
with alive_bar(
ctrl_c=False,
bar="blocks",
spinner="classic",
stats=False,
enrich_print=False,
) as bar:
bar.title = f"-> Scanning {image}..."
process = subprocess.run(
command,
capture_output=True,
text=True,
env=env,
)
bar.title = f"-> Scan completed for {image}"
return process
else:
logger.info(f"Scanning {image}...")
process = subprocess.run(
command,
capture_output=True,
text=True,
env=env,
)
logger.info(f"Scan completed for {image}")
return process
except (AttributeError, OSError):
logger.info(f"Scanning {image}...")
return subprocess.run(command, capture_output=True, text=True, env=env)
def _log_trivy_stderr(self, stderr: str) -> None:
"""Parse and log Trivy's stderr output."""
for line in stderr.strip().split("\n"):
if line.strip():
parts = line.split()
if len(parts) >= 3:
level = parts[1]
message = " ".join(parts[2:])
if level == "ERROR":
logger.error(message)
elif level == "WARN":
logger.warning(message)
elif level == "INFO":
logger.info(message)
elif level == "DEBUG":
logger.debug(message)
else:
logger.info(message)
else:
logger.info(line)
@staticmethod
def _extract_trivy_errors(stderr: str) -> str:
"""Extract only ERROR-level messages from Trivy stderr output."""
if not stderr:
return "Unknown error"
error_lines = []
for line in stderr.strip().split("\n"):
parts = line.split()
if len(parts) >= 3 and parts[1] == "ERROR":
error_lines.append(" ".join(parts[2:]))
elif len(parts) >= 3 and parts[1] == "FATAL":
error_lines.append(" ".join(parts[2:]))
if error_lines:
return "; ".join(error_lines)[:500]
# Fallback: no ERROR lines found, return last non-empty line
for line in reversed(stderr.strip().split("\n")):
if line.strip():
return line.strip()[:500]
return "Unknown error"
@staticmethod
def _categorize_trivy_error(error_msg: str) -> str:
"""Categorize a Trivy error message to provide actionable guidance."""
lower = error_msg.lower()
if any(kw in lower for kw in ("401", "403", "unauthorized", "denied")):
return f"Auth failure — check `docker login`: {error_msg}"
if any(kw in lower for kw in ("404", "manifest unknown", "not found")):
return f"Image not found — check name/tag/registry: {error_msg}"
if any(kw in lower for kw in ("429", "rate limit", "too many requests")):
return f"Rate limited — wait or authenticate: {error_msg}"
if any(kw in lower for kw in ("timeout", "connection refused", "no such host")):
return f"Network issue — check connectivity: {error_msg}"
return error_msg
def _enumerate_registry(self) -> None:
"""Enumerate images from a registry using the appropriate adapter."""
verify_ssl = not self.registry_insecure
adapter = create_registry_adapter(
registry_url=self.registry,
username=self.registry_username,
password=self.registry_password,
token=self.registry_token,
verify_ssl=verify_ssl,
)
repositories = adapter.list_repositories()
logger.info(
f"Discovered {len(repositories)} repositories from registry {self.registry}"
)
# Apply image filter
if self._image_filter_re:
repositories = [r for r in repositories if self._image_filter_re.search(r)]
logger.info(
f"{len(repositories)} repositories match --image-filter '{self.image_filter}'"
)
if not repositories:
logger.warning(
f"No repositories found in registry {self.registry} (after filtering)"
)
return
# Determine if this is a Docker Hub adapter (for image reference format)
is_dockerhub = isinstance(adapter, DockerHubAdapter)
discovered_images = []
repos_tags: dict[str, list[str]] = {}
for repo in repositories:
tags = adapter.list_tags(repo)
# Apply tag filter
if self._tag_filter_re:
tags = [t for t in tags if self._tag_filter_re.search(t)]
if tags:
repos_tags[repo] = tags
for tag in tags:
if is_dockerhub:
# Docker Hub images don't need a host prefix
image_ref = f"{repo}:{tag}"
else:
# OCI registries need the full host/repo:tag reference
registry_host = self.registry.rstrip("/")
for prefix in ("https://", "http://"):
if registry_host.startswith(prefix):
registry_host = registry_host[len(prefix) :]
break
image_ref = f"{registry_host}/{repo}:{tag}"
discovered_images.append(image_ref)
# Registry list mode: print listing and return early
if self.registry_list_images:
self._print_registry_listing(repos_tags, len(discovered_images))
self._listing_only = True
return
# Check max-images limit
if self.max_images and len(discovered_images) > self.max_images:
raise ImageMaxImagesExceededError(
file=__file__,
message=f"Discovered {len(discovered_images)} images, exceeding --max-images {self.max_images}. Use --image-filter or --tag-filter to narrow results.",
)
# Deduplicate with explicit images
existing = set(self.images)
for img in discovered_images:
if img not in existing:
self.images.append(img)
existing.add(img)
logger.info(
f"Discovered {len(discovered_images)} images from registry {self.registry} "
f"({len(repositories)} repositories). Total images to scan: {len(self.images)}"
)
def _print_registry_listing(
self, repos_tags: dict[str, list[str]], total_images: int
) -> None:
"""Print a structured listing of registry repositories and tags."""
num_repos = len(repos_tags)
print(
f"\n{Style.BRIGHT}Registry:{Style.RESET_ALL} "
f"{Fore.CYAN}{self.registry}{Style.RESET_ALL} "
f"({num_repos} {'repository' if num_repos == 1 else 'repositories'}, "
f"{total_images} {'image' if total_images == 1 else 'images'})\n"
)
for repo, tags in repos_tags.items():
print(f" {Fore.YELLOW}{repo}{Style.RESET_ALL} " f"({len(tags)} tags)")
print(f" {', '.join(tags)}")
print()
def print_credentials(self) -> None:
"""Print scan configuration."""
report_title = f"{Style.BRIGHT}Scanning container images:{Style.RESET_ALL}"
report_lines = []
if len(self.images) <= 3:
for img in self.images:
report_lines.append(f"Image: {Fore.YELLOW}{img}{Style.RESET_ALL}")
else:
report_lines.append(
f"Images: {Fore.YELLOW}{len(self.images)} images{Style.RESET_ALL}"
)
report_lines.append(
f"Scanners: {Fore.YELLOW}{', '.join(self.scanners)}{Style.RESET_ALL}"
)
if self.image_config_scanners:
report_lines.append(
f"Image config scanners: {Fore.YELLOW}{', '.join(self.image_config_scanners)}{Style.RESET_ALL}"
)
if self.trivy_severity:
report_lines.append(
f"Severity filter: {Fore.YELLOW}{', '.join(self.trivy_severity)}{Style.RESET_ALL}"
)
if self.ignore_unfixed:
report_lines.append(f"Ignore unfixed: {Fore.YELLOW}Yes{Style.RESET_ALL}")
report_lines.append(f"Timeout: {Fore.YELLOW}{self.timeout}{Style.RESET_ALL}")
report_lines.append(
f"Authentication method: {Fore.YELLOW}{self.auth_method}{Style.RESET_ALL}"
)
if self.registry:
report_lines.append(
f"Registry: {Fore.YELLOW}{self.registry}{Style.RESET_ALL}"
)
if self.image_filter:
report_lines.append(
f"Image filter: {Fore.YELLOW}{self.image_filter}{Style.RESET_ALL}"
)
if self.tag_filter:
report_lines.append(
f"Tag filter: {Fore.YELLOW}{self.tag_filter}{Style.RESET_ALL}"
)
print_boxes(report_lines, report_title)
@staticmethod
def test_connection(
image: str | None = None,
raise_on_exception: bool = True,
provider_id: str | None = None,
registry_username: str | None = None,
registry_password: str | None = None,
registry_token: str | None = None,
) -> "Connection":
"""
Test connection to container registry by attempting to inspect an image.
Args:
image: Container image to test
raise_on_exception: Whether to raise exceptions
provider_id: Fallback for image name
registry_username: Registry username for basic auth
registry_password: Registry password for basic auth
registry_token: Registry token for token-based auth
Returns:
Connection: Connection object with success status
"""
try:
if provider_id and not image:
image = provider_id
if not image:
return Connection(is_connected=False, error="Image name is required")
# Build env with registry credentials
env = dict(os.environ)
if registry_username and registry_password:
env["TRIVY_USERNAME"] = registry_username
env["TRIVY_PASSWORD"] = registry_password
elif registry_token:
env["TRIVY_REGISTRY_TOKEN"] = registry_token
# Test by running trivy with --skip-update to just test image access
process = subprocess.run(
[
"trivy",
"image",
"--skip-db-update",
"--download-db-only=false",
image,
],
capture_output=True,
text=True,
timeout=60,
env=env,
)
if process.returncode == 0:
return Connection(is_connected=True)
else:
error_msg = process.stderr or "Unknown error"
if "401" in error_msg or "unauthorized" in error_msg.lower():
return Connection(
is_connected=False,
error="Authentication failed. Check registry credentials.",
)
elif "not found" in error_msg.lower() or "404" in error_msg:
return Connection(
is_connected=False,
error="Image not found in registry.",
)
else:
return Connection(
is_connected=False,
error=f"Failed to access image: {error_msg[:200]}",
)
except subprocess.TimeoutExpired:
return Connection(
is_connected=False,
error="Connection timed out",
)
except FileNotFoundError:
return Connection(
is_connected=False,
error="Trivy binary not found. Please install Trivy.",
)
except Exception as error:
if raise_on_exception:
raise
return Connection(
is_connected=False,
error=f"Unexpected error: {str(error)}",
)
@@ -0,0 +1,183 @@
SCANNERS_CHOICES = [
"vuln",
"secret",
"misconfig",
"license",
]
IMAGE_CONFIG_SCANNERS_CHOICES = [
"misconfig",
"secret",
]
SEVERITY_CHOICES = [
"CRITICAL",
"HIGH",
"MEDIUM",
"LOW",
"UNKNOWN",
]
def init_parser(self):
"""Init the Image Provider CLI parser"""
image_parser = self.subparsers.add_parser(
"image", parents=[self.common_providers_parser], help="Container Image Provider"
)
# Image Selection
image_selection_group = image_parser.add_argument_group("Image Selection")
image_selection_group.add_argument(
"--image",
"-I",
dest="images",
action="append",
default=[],
help="Container image to scan. Can be specified multiple times. Examples: nginx:latest, alpine:3.18, myregistry.io/myapp:v1.0",
)
image_selection_group.add_argument(
"--image-list",
dest="image_list_file",
default=None,
help="Path to a file containing list of images to scan (one per line). Lines starting with # are treated as comments.",
)
# Scan Configuration
scan_config_group = image_parser.add_argument_group("Scan Configuration")
scan_config_group.add_argument(
"--scanners",
"--scanner",
dest="scanners",
nargs="+",
default=["vuln", "secret"],
choices=SCANNERS_CHOICES,
help="Trivy scanners to use. Default: vuln, secret. Available: vuln, secret, misconfig, license",
)
scan_config_group.add_argument(
"--image-config-scanners",
dest="image_config_scanners",
nargs="+",
default=[],
choices=IMAGE_CONFIG_SCANNERS_CHOICES,
help="Trivy image config scanners (scans Dockerfile-level metadata). Available: misconfig, secret",
)
scan_config_group.add_argument(
"--trivy-severity",
dest="trivy_severity",
nargs="+",
default=[],
choices=SEVERITY_CHOICES,
help="Filter Trivy findings by severity. Default: all severities. Available: CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN",
)
scan_config_group.add_argument(
"--ignore-unfixed",
dest="ignore_unfixed",
action="store_true",
default=False,
help="Ignore vulnerabilities without available fixes.",
)
scan_config_group.add_argument(
"--timeout",
dest="timeout",
default="5m",
help="Trivy scan timeout. Default: 5m. Examples: 10m, 1h",
)
# Registry Scan Mode
registry_group = image_parser.add_argument_group("Registry Scan Mode")
registry_group.add_argument(
"--registry",
dest="registry",
default=None,
help="Registry URL to enumerate and scan all images. Examples: myregistry.io, docker.io/myorg, 123456789.dkr.ecr.us-east-1.amazonaws.com",
)
registry_group.add_argument(
"--image-filter",
dest="image_filter",
default=None,
help="Regex to filter repository names during registry enumeration (re.search). Example: '^prod/.*'",
)
registry_group.add_argument(
"--tag-filter",
dest="tag_filter",
default=None,
help=r"Regex to filter tags during registry enumeration (re.search). Example: '^(latest|v\d+\.\d+\.\d+)$'",
)
registry_group.add_argument(
"--max-images",
dest="max_images",
type=int,
default=0,
help="Maximum number of images to scan from registry. 0 = unlimited. Aborts if exceeded.",
)
registry_group.add_argument(
"--registry-insecure",
dest="registry_insecure",
action="store_true",
default=False,
help="Skip TLS verification for registry connections (for self-signed certificates).",
)
registry_group.add_argument(
"--registry-list",
dest="registry_list_images",
action="store_true",
default=False,
help="List all repositories and tags from the registry, then exit without scanning. Useful for discovering available images before building --image-filter or --tag-filter.",
)
def validate_arguments(arguments):
"""Validate Image provider arguments."""
images = getattr(arguments, "images", [])
image_list_file = getattr(arguments, "image_list_file", None)
registry = getattr(arguments, "registry", None)
image_filter = getattr(arguments, "image_filter", None)
tag_filter = getattr(arguments, "tag_filter", None)
max_images = getattr(arguments, "max_images", 0)
registry_insecure = getattr(arguments, "registry_insecure", False)
registry_list_images = getattr(arguments, "registry_list_images", False)
if registry_list_images and not registry:
return (False, "--registry-list requires --registry.")
if not images and not image_list_file and not registry:
return (
False,
"At least one image source must be specified using --image (-I), --image-list, or --registry.",
)
# Registry-only flags require --registry
if not registry:
if image_filter:
return (False, "--image-filter requires --registry.")
if tag_filter:
return (False, "--tag-filter requires --registry.")
if max_images:
return (False, "--max-images requires --registry.")
if registry_insecure:
return (False, "--registry-insecure requires --registry.")
# Docker Hub namespace validation
if registry:
url = registry.rstrip("/")
for prefix in ("https://", "http://"):
if url.startswith(prefix):
url = url[len(prefix) :]
break
stripped = url
for prefix in ("registry-1.docker.io", "docker.io"):
if stripped.startswith(prefix):
stripped = stripped[len(prefix) :].lstrip("/")
if not stripped:
return (
False,
"Docker Hub requires a namespace. Use --registry docker.io/{org_or_user}.",
)
break
return (True, "")
@@ -0,0 +1,141 @@
"""Registry adapter abstract base class."""
from __future__ import annotations
import re
import time
from abc import ABC, abstractmethod
import requests
from prowler.config.config import prowler_version
from prowler.lib.logger import logger
from prowler.providers.image.exceptions.exceptions import ImageRegistryNetworkError
_MAX_RETRIES = 3
_BACKOFF_BASE = 1
_USER_AGENT = f"Prowler/{prowler_version} (registry-adapter)"
class RegistryAdapter(ABC):
"""Abstract base class for registry adapters."""
def __init__(
self,
registry_url: str,
username: str | None = None,
password: str | None = None,
token: str | None = None,
verify_ssl: bool = True,
) -> None:
self.registry_url = registry_url
self.username = username
self._password = password
self._token = token
self.verify_ssl = verify_ssl
@property
def password(self) -> str | None:
return self._password
@property
def token(self) -> str | None:
return self._token
def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["_password"] = "***" if state.get("_password") else None
state["_token"] = "***" if state.get("_token") else None
return state
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"registry_url={self.registry_url!r}, "
f"username={self.username!r}, "
f"password={'<redacted>' if self._password else None}, "
f"token={'<redacted>' if self._token else None})"
)
@abstractmethod
def list_repositories(self) -> list[str]:
"""Enumerate all repository names in the registry."""
...
@abstractmethod
def list_tags(self, repository: str) -> list[str]:
"""Enumerate all tags for a repository."""
...
def _request_with_retry(self, method: str, url: str, **kwargs) -> requests.Response:
context_label = kwargs.pop("context_label", None) or self.registry_url
kwargs.setdefault("timeout", 30)
kwargs.setdefault("verify", self.verify_ssl)
headers = kwargs.get("headers", {})
headers.setdefault("User-Agent", _USER_AGENT)
kwargs["headers"] = headers
last_exception = None
last_status = None
last_body = None
for attempt in range(1, _MAX_RETRIES + 1):
try:
resp = requests.request(method, url, **kwargs)
if resp.status_code == 429:
last_status = 429
wait = _BACKOFF_BASE * (2 ** (attempt - 1))
logger.warning(
f"Rate limited by {context_label}, retrying in {wait}s (attempt {attempt}/{_MAX_RETRIES})"
)
time.sleep(wait)
continue
if resp.status_code >= 500:
last_status = resp.status_code
last_body = (resp.text or "")[:500]
wait = _BACKOFF_BASE * (2 ** (attempt - 1))
logger.warning(
f"Server error from {context_label} (HTTP {resp.status_code}), "
f"retrying in {wait}s (attempt {attempt}/{_MAX_RETRIES}): {last_body}"
)
time.sleep(wait)
continue
return resp
except requests.exceptions.ConnectionError as exc:
last_exception = exc
if attempt < _MAX_RETRIES:
wait = _BACKOFF_BASE * (2 ** (attempt - 1))
logger.warning(
f"Connection error to {context_label}, retrying in {wait}s (attempt {attempt}/{_MAX_RETRIES})"
)
time.sleep(wait)
continue
except requests.exceptions.Timeout as exc:
raise ImageRegistryNetworkError(
file=__file__,
message=f"Connection timed out to {context_label}.",
original_exception=exc,
)
if last_status == 429:
raise ImageRegistryNetworkError(
file=__file__,
message=f"Rate limited by {context_label} after {_MAX_RETRIES} attempts.",
)
if last_status is not None and last_status >= 500:
raise ImageRegistryNetworkError(
file=__file__,
message=f"Server error from {context_label} (HTTP {last_status}) after {_MAX_RETRIES} attempts: {last_body}",
)
raise ImageRegistryNetworkError(
file=__file__,
message=f"Failed to connect to {context_label} after {_MAX_RETRIES} attempts.",
original_exception=last_exception,
)
@staticmethod
def _next_page_url(resp: requests.Response) -> str | None:
link_header = resp.headers.get("Link", "")
if not link_header:
return None
match = re.search(r'<([^>]+)>;\s*rel="next"', link_header)
if match:
return match.group(1)
return None
@@ -0,0 +1,221 @@
"""Docker Hub registry adapter."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from prowler.lib.logger import logger
from prowler.providers.image.exceptions.exceptions import (
ImageRegistryAuthError,
ImageRegistryCatalogError,
ImageRegistryNetworkError,
)
from prowler.providers.image.lib.registry.base import RegistryAdapter
if TYPE_CHECKING:
import requests
_HUB_API = "https://hub.docker.com"
_REGISTRY_HOST = "https://registry-1.docker.io"
_AUTH_URL = "https://auth.docker.io/token"
class DockerHubAdapter(RegistryAdapter):
"""Adapter for Docker Hub using the Hub REST API + OCI tag listing."""
def __init__(
self,
registry_url: str,
username: str | None = None,
password: str | None = None,
token: str | None = None,
verify_ssl: bool = True,
) -> None:
if not verify_ssl:
logger.warning(
"Docker Hub always uses TLS verification; --registry-insecure is ignored for Docker Hub registries."
)
super().__init__(registry_url, username, password, token, verify_ssl=True)
self.namespace = self._extract_namespace(registry_url)
self._hub_jwt: str | None = None
self._registry_tokens: dict[str, str] = {}
@staticmethod
def _extract_namespace(registry_url: str) -> str:
url = registry_url.rstrip("/")
for prefix in (
"https://registry-1.docker.io",
"http://registry-1.docker.io",
"https://docker.io",
"http://docker.io",
"registry-1.docker.io",
"docker.io",
"https://",
"http://",
):
if url.startswith(prefix):
url = url[len(prefix) :]
break
url = url.lstrip("/")
parts = url.split("/")
namespace = parts[0] if parts and parts[0] else ""
return namespace
def list_repositories(self) -> list[str]:
if not self.namespace:
raise ImageRegistryCatalogError(
file=__file__,
message="Docker Hub requires a namespace. Use --registry docker.io/{org_or_user}.",
)
self._hub_login()
repositories: list[str] = []
if self._hub_jwt:
url = f"{_HUB_API}/v2/namespaces/{self.namespace}/repositories"
else:
url = f"{_HUB_API}/v2/repositories/{self.namespace}/"
params: dict = {"page_size": 100}
while url:
resp = self._hub_request("GET", url, params=params)
self._check_hub_response(resp, "repository listing")
data = resp.json()
for repo in data.get("results", []):
name = repo.get("name", "")
if name:
repositories.append(f"{self.namespace}/{name}")
url = data.get("next")
params = {}
return repositories
def list_tags(self, repository: str) -> list[str]:
token = self._get_registry_token(repository)
tags: list[str] = []
url = f"{_REGISTRY_HOST}/v2/{repository}/tags/list"
params: dict = {"n": 100}
while url:
resp = self._registry_request("GET", url, token, params=params)
if resp.status_code in (401, 403):
raise ImageRegistryAuthError(
file=__file__,
message=f"Authentication failed for tag listing of {repository} on Docker Hub. Check REGISTRY_USERNAME and REGISTRY_PASSWORD.",
)
if resp.status_code != 200:
logger.warning(
f"Failed to list tags for {repository} (HTTP {resp.status_code}): {resp.text[:200]}"
)
break
data = resp.json()
tags.extend(data.get("tags", []) or [])
url = self._next_tag_page_url(resp)
params = {}
return tags
def _hub_login(self) -> None:
if self._hub_jwt:
return
if not self.username or not self.password:
return
logger.debug(f"Docker Hub login attempt for username: {self.username!r}")
resp = self._request_with_retry(
"POST",
f"{_HUB_API}/v2/users/login",
json={"username": self.username, "password": self.password},
context_label="Docker Hub",
)
if resp.status_code != 200:
body_preview = resp.text[:200] if resp.text else "(empty body)"
raise ImageRegistryAuthError(
file=__file__,
message=(
f"Docker Hub login failed (HTTP {resp.status_code}). "
f"Check REGISTRY_USERNAME and REGISTRY_PASSWORD. "
f"Response: {body_preview}"
),
)
self._hub_jwt = resp.json().get("token")
if not self._hub_jwt:
raise ImageRegistryAuthError(
file=__file__,
message="Docker Hub login returned an empty JWT token. Check REGISTRY_USERNAME and REGISTRY_PASSWORD.",
)
def _get_registry_token(self, repository: str) -> str:
if repository in self._registry_tokens:
return self._registry_tokens[repository]
params = {
"service": "registry.docker.io",
"scope": f"repository:{repository}:pull",
}
auth = None
if self.username and self.password:
auth = (self.username, self.password)
resp = self._request_with_retry(
"GET",
_AUTH_URL,
params=params,
auth=auth,
context_label="Docker Hub",
)
if resp.status_code != 200:
raise ImageRegistryAuthError(
file=__file__,
message=f"Failed to obtain Docker Hub registry token for {repository} (HTTP {resp.status_code}). Check REGISTRY_USERNAME and REGISTRY_PASSWORD.",
)
token = resp.json().get("token", "")
if not token:
raise ImageRegistryAuthError(
file=__file__,
message=f"Docker Hub registry token endpoint returned an empty token for {repository}. Check REGISTRY_USERNAME and REGISTRY_PASSWORD.",
)
self._registry_tokens[repository] = token
return token
def _hub_request(self, method: str, url: str, **kwargs) -> requests.Response:
headers = kwargs.pop("headers", {})
if self._hub_jwt:
headers["Authorization"] = f"Bearer {self._hub_jwt}"
kwargs["headers"] = headers
return self._request_with_retry(
method, url, context_label="Docker Hub", **kwargs
)
def _registry_request(
self, method: str, url: str, token: str, **kwargs
) -> requests.Response:
headers = kwargs.pop("headers", {})
headers["Authorization"] = f"Bearer {token}"
kwargs["headers"] = headers
return self._request_with_retry(
method, url, context_label="Docker Hub", **kwargs
)
def _check_hub_response(self, resp: requests.Response, context: str) -> None:
if resp.status_code == 200:
return
if resp.status_code in (401, 403):
raise ImageRegistryAuthError(
file=__file__,
message=f"Authentication failed for {context} on Docker Hub (HTTP {resp.status_code}). Check REGISTRY_USERNAME and REGISTRY_PASSWORD environment variables.",
)
if resp.status_code == 404:
raise ImageRegistryCatalogError(
file=__file__,
message=f"Namespace '{self.namespace}' not found on Docker Hub. Check the namespace in --registry docker.io/{{namespace}}.",
)
raise ImageRegistryNetworkError(
file=__file__,
message=f"Unexpected error during {context} on Docker Hub (HTTP {resp.status_code}): {resp.text[:200]}",
)
@staticmethod
def _next_tag_page_url(resp: requests.Response) -> str | None:
link_header = resp.headers.get("Link", "")
if not link_header:
return None
match = re.search(r'<([^>]+)>;\s*rel="next"', link_header)
if match:
next_url = match.group(1)
if next_url.startswith("/"):
return f"{_REGISTRY_HOST}{next_url}"
return next_url
return None
@@ -0,0 +1,40 @@
"""Factory for auto-detecting registry type and returning the appropriate adapter."""
from __future__ import annotations
import re
from prowler.providers.image.lib.registry.base import RegistryAdapter
from prowler.providers.image.lib.registry.dockerhub_adapter import DockerHubAdapter
from prowler.providers.image.lib.registry.oci_adapter import OciRegistryAdapter
_DOCKER_HUB_PATTERN = re.compile(
r"^(https?://)?(docker\.io|registry-1\.docker\.io)(/|$)", re.IGNORECASE
)
def create_registry_adapter(
registry_url: str,
username: str | None = None,
password: str | None = None,
token: str | None = None,
verify_ssl: bool = True,
) -> RegistryAdapter:
"""Auto-detect registry type from URL and return the appropriate adapter."""
if _DOCKER_HUB_PATTERN.search(registry_url):
return DockerHubAdapter(
registry_url=registry_url,
username=username,
password=password,
token=token,
verify_ssl=verify_ssl,
)
# ECR and other non-Docker-Hub registries implement the OCI Distribution Spec,
# so they are handled by the generic OCI adapter.
return OciRegistryAdapter(
registry_url=registry_url,
username=username,
password=password,
token=token,
verify_ssl=verify_ssl,
)
@@ -0,0 +1,228 @@
"""Generic OCI Distribution Spec registry adapter."""
from __future__ import annotations
import base64
import ipaddress
import re
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from prowler.lib.logger import logger
from prowler.providers.image.exceptions.exceptions import (
ImageRegistryAuthError,
ImageRegistryCatalogError,
ImageRegistryNetworkError,
)
from prowler.providers.image.lib.registry.base import RegistryAdapter
if TYPE_CHECKING:
import requests
class OciRegistryAdapter(RegistryAdapter):
"""Adapter for registries implementing OCI Distribution Spec."""
def __init__(
self,
registry_url: str,
username: str | None = None,
password: str | None = None,
token: str | None = None,
verify_ssl: bool = True,
) -> None:
super().__init__(registry_url, username, password, token, verify_ssl)
self._base_url = self._normalise_url(registry_url)
self._bearer_token: str | None = None
self._basic_auth_verified = False
@staticmethod
def _normalise_url(url: str) -> str:
url = url.rstrip("/")
if not url.startswith(("http://", "https://")):
url = f"https://{url}"
return url
def list_repositories(self) -> list[str]:
self._ensure_auth()
repositories: list[str] = []
url = f"{self._base_url}/v2/_catalog"
params: dict = {"n": 200}
while url:
resp = self._authed_request("GET", url, params=params)
if resp.status_code == 404:
raise ImageRegistryCatalogError(
file=__file__,
message=f"Registry at {self.registry_url} does not support catalog listing (/_catalog returned 404). Use --image or --image-list instead.",
)
self._check_response(resp, "catalog listing")
data = resp.json()
repositories.extend(data.get("repositories", []))
url = self._next_page_url(resp)
params = {}
return repositories
def list_tags(self, repository: str) -> list[str]:
self._ensure_auth(repository=repository)
tags: list[str] = []
url = f"{self._base_url}/v2/{repository}/tags/list"
params: dict = {"n": 200}
while url:
resp = self._authed_request("GET", url, params=params)
self._check_response(resp, f"tag listing for {repository}")
data = resp.json()
tags.extend(data.get("tags", []) or [])
url = self._next_page_url(resp)
params = {}
return tags
def _ensure_auth(self, repository: str | None = None) -> None:
if self._bearer_token:
return
if self._basic_auth_verified:
return
if self.token:
self._bearer_token = self.token
return
ping_url = f"{self._base_url}/v2/"
resp = self._request_with_retry("GET", ping_url)
if resp.status_code == 200:
return
if resp.status_code == 401:
www_auth = resp.headers.get("Www-Authenticate", "")
if not www_auth.lower().startswith("bearer"):
# Basic auth challenge (e.g., AWS ECR)
if self.username and self.password:
self._basic_auth_verified = True
return
raise ImageRegistryAuthError(
file=__file__,
message=(
f"Registry {self.registry_url} requires authentication "
f"but no credentials provided. "
f"Set REGISTRY_USERNAME and REGISTRY_PASSWORD."
),
)
# Bearer token exchange (standard OCI flow)
self._bearer_token = self._obtain_bearer_token(www_auth, repository)
return
if resp.status_code == 403:
raise ImageRegistryAuthError(
file=__file__,
message=f"Access denied to registry {self.registry_url} (HTTP 403). Check REGISTRY_USERNAME and REGISTRY_PASSWORD.",
)
raise ImageRegistryNetworkError(
file=__file__,
message=f"Unexpected HTTP {resp.status_code} from registry {self.registry_url} during auth check.",
)
def _obtain_bearer_token(
self, www_authenticate: str, repository: str | None = None
) -> str:
match = re.search(r'realm="([^"]+)"', www_authenticate)
if not match:
raise ImageRegistryAuthError(
file=__file__,
message=f"Cannot parse token endpoint from registry {self.registry_url}. Www-Authenticate: {www_authenticate[:200]}",
)
realm = match.group(1)
self._validate_realm_url(realm)
params: dict = {}
service_match = re.search(r'service="([^"]+)"', www_authenticate)
if service_match:
params["service"] = service_match.group(1)
scope_match = re.search(r'scope="([^"]+)"', www_authenticate)
if scope_match:
params["scope"] = scope_match.group(1)
elif repository:
params["scope"] = f"repository:{repository}:pull"
auth = None
if self.username and self.password:
auth = (self.username, self.password)
resp = self._request_with_retry("GET", realm, params=params, auth=auth)
if resp.status_code != 200:
raise ImageRegistryAuthError(
file=__file__,
message=f"Failed to obtain bearer token from {realm} (HTTP {resp.status_code}). Check REGISTRY_USERNAME and REGISTRY_PASSWORD.",
)
data = resp.json()
token = data.get("token") or data.get("access_token", "")
if not token:
raise ImageRegistryAuthError(
file=__file__,
message=f"Token endpoint {realm} returned an empty token. Check REGISTRY_USERNAME and REGISTRY_PASSWORD.",
)
return token
@staticmethod
def _validate_realm_url(realm: str) -> None:
parsed = urlparse(realm)
if parsed.scheme not in ("http", "https"):
raise ImageRegistryAuthError(
file=__file__,
message=f"Bearer token realm has disallowed scheme: {parsed.scheme}. Only http/https are allowed.",
)
if parsed.scheme == "http":
logger.warning(f"Bearer token realm uses HTTP (not HTTPS): {realm}")
hostname = parsed.hostname or ""
try:
addr = ipaddress.ip_address(hostname)
if addr.is_private or addr.is_loopback or addr.is_link_local:
raise ImageRegistryAuthError(
file=__file__,
message=f"Bearer token realm points to a private/loopback address: {hostname}. This may indicate an SSRF attempt.",
)
except ValueError:
pass
def _resolve_basic_credentials(self) -> tuple[str | None, str | None]:
"""Decode pre-encoded base64 auth tokens (e.g., from aws ecr get-authorization-token).
Returns (username, password) decoded if the password is a base64 token
containing 'username:real_password', otherwise returned as-is.
"""
if not self.password:
return self.username, self.password
try:
decoded = base64.b64decode(self.password).decode("utf-8")
if decoded.startswith(f"{self.username}:"):
return self.username, decoded[len(self.username) + 1 :]
except (ValueError, UnicodeDecodeError):
logger.debug("Password is not a base64-encoded auth token, using as-is")
return self.username, self.password
def _authed_request(self, method: str, url: str, **kwargs) -> requests.Response:
resp = self._do_authed_request(method, url, **kwargs)
if resp.status_code == 401 and self._bearer_token:
logger.debug(
f"Bearer token rejected (HTTP 401), re-authenticating to {self.registry_url}"
)
self._bearer_token = None
self._ensure_auth()
resp = self._do_authed_request(method, url, **kwargs)
return resp
def _do_authed_request(self, method: str, url: str, **kwargs) -> requests.Response:
headers = kwargs.pop("headers", {})
if self._bearer_token:
headers["Authorization"] = f"Bearer {self._bearer_token}"
elif self.username and self.password:
user, pwd = self._resolve_basic_credentials()
kwargs.setdefault("auth", (user, pwd))
kwargs["headers"] = headers
return self._request_with_retry(method, url, **kwargs)
def _check_response(self, resp: requests.Response, context: str) -> None:
if resp.status_code == 200:
return
if resp.status_code in (401, 403):
raise ImageRegistryAuthError(
file=__file__,
message=f"Authentication failed for {context} on {self.registry_url} (HTTP {resp.status_code}). Check REGISTRY_USERNAME and REGISTRY_PASSWORD.",
)
raise ImageRegistryNetworkError(
file=__file__,
message=f"Unexpected error during {context} on {self.registry_url} (HTTP {resp.status_code}): {resp.text[:200]}",
)
+21
View File
@@ -0,0 +1,21 @@
from prowler.config.config import output_file_timestamp
from prowler.providers.common.models import ProviderOutputOptions
class ImageOutputOptions(ProviderOutputOptions):
"""
ImageOutputOptions customizes output filename logic for container image scanning.
Attributes inherited from ProviderOutputOptions:
- output_filename (str): The base filename used for generated reports.
- output_directory (str): The directory to store the output files.
"""
def __init__(self, arguments, bulk_checks_metadata):
super().__init__(arguments, bulk_checks_metadata)
# If --output-filename is not specified, build a default name
if not getattr(arguments, "output_filename", None):
self.output_filename = f"prowler-output-image-{output_file_timestamp}"
else:
self.output_filename = arguments.output_filename
+143
View File
@@ -0,0 +1,143 @@
import json
# Sample vulnerability finding from Trivy
SAMPLE_VULNERABILITY_FINDING = {
"VulnerabilityID": "CVE-2024-1234",
"PkgID": "openssl@1.1.1k-r0",
"PkgName": "openssl",
"InstalledVersion": "1.1.1k-r0",
"FixedVersion": "1.1.1l-r0",
"Severity": "HIGH",
"Title": "OpenSSL Buffer Overflow",
"Description": "A buffer overflow vulnerability in OpenSSL allows remote attackers to execute arbitrary code.",
"PrimaryURL": "https://avd.aquasec.com/nvd/cve-2024-1234",
}
# Sample secret finding from Trivy
SAMPLE_SECRET_FINDING = {
"RuleID": "aws-access-key-id",
"Category": "AWS",
"Severity": "CRITICAL",
"Title": "AWS Access Key ID",
"StartLine": 10,
"EndLine": 10,
"Match": "AKIA...",
}
# Sample misconfiguration finding from Trivy
SAMPLE_MISCONFIGURATION_FINDING = {
"ID": "DS001",
"Title": "Dockerfile should not use latest tag",
"Description": "Using latest tag can cause unpredictable builds.",
"Severity": "MEDIUM",
"Resolution": "Use a specific version tag instead of latest",
"PrimaryURL": "https://avd.aquasec.com/misconfig/ds001",
}
# Sample finding with UNKNOWN severity
SAMPLE_UNKNOWN_SEVERITY_FINDING = {
"VulnerabilityID": "CVE-2024-9999",
"PkgID": "test-pkg@0.0.1",
"PkgName": "test-pkg",
"InstalledVersion": "0.0.1",
"Severity": "UNKNOWN",
"Title": "Unknown severity issue",
"Description": "An issue with unknown severity.",
}
# Sample image SHA for testing (first 12 chars of a sha256 digest)
SAMPLE_IMAGE_SHA = "c1aabb73d233"
SAMPLE_IMAGE_ID = f"sha256:{SAMPLE_IMAGE_SHA}abcdef1234567890"
# Full Trivy JSON output structure with a single vulnerability
SAMPLE_TRIVY_IMAGE_OUTPUT = {
"Metadata": {
"ImageID": SAMPLE_IMAGE_ID,
"RepoDigests": [f"alpine@sha256:{SAMPLE_IMAGE_SHA}abcdef1234567890"],
},
"Results": [
{
"Target": "alpine:3.18 (alpine 3.18.0)",
"Type": "alpine",
"Vulnerabilities": [SAMPLE_VULNERABILITY_FINDING],
"Secrets": [],
"Misconfigurations": [],
}
],
}
# Full Trivy JSON output with mixed finding types
SAMPLE_TRIVY_MULTI_TYPE_OUTPUT = {
"Metadata": {
"ImageID": SAMPLE_IMAGE_ID,
"RepoDigests": [f"myimage@sha256:{SAMPLE_IMAGE_SHA}abcdef1234567890"],
},
"Results": [
{
"Target": "myimage:latest (debian 12)",
"Type": "debian",
"Vulnerabilities": [SAMPLE_VULNERABILITY_FINDING],
"Secrets": [SAMPLE_SECRET_FINDING],
"Misconfigurations": [SAMPLE_MISCONFIGURATION_FINDING],
}
],
}
# Trivy output with only RepoDigests (no ImageID) for fallback testing
SAMPLE_TRIVY_REPO_DIGEST_ONLY_OUTPUT = {
"Metadata": {
"RepoDigests": ["alpine@sha256:e5f6g7h8i9j0abcdef1234567890"],
},
"Results": [
{
"Target": "alpine:3.18 (alpine 3.18.0)",
"Type": "alpine",
"Vulnerabilities": [SAMPLE_VULNERABILITY_FINDING],
"Secrets": [],
"Misconfigurations": [],
}
],
}
# Trivy output with no Metadata at all
SAMPLE_TRIVY_NO_METADATA_OUTPUT = {
"Results": [
{
"Target": "alpine:3.18 (alpine 3.18.0)",
"Type": "alpine",
"Vulnerabilities": [SAMPLE_VULNERABILITY_FINDING],
"Secrets": [],
"Misconfigurations": [],
}
],
}
def get_sample_trivy_json_output():
"""Return sample Trivy JSON output as string."""
return json.dumps(SAMPLE_TRIVY_IMAGE_OUTPUT)
def get_empty_trivy_output():
"""Return empty Trivy output as string."""
return json.dumps({"Results": []})
def get_invalid_trivy_output():
"""Return invalid JSON output as string."""
return "invalid json output"
def get_multi_type_trivy_output():
"""Return Trivy output with multiple finding types as string."""
return json.dumps(SAMPLE_TRIVY_MULTI_TYPE_OUTPUT)
def get_repo_digest_only_trivy_output():
"""Return Trivy output with only RepoDigests (no ImageID) as string."""
return json.dumps(SAMPLE_TRIVY_REPO_DIGEST_ONLY_OUTPUT)
def get_no_metadata_trivy_output():
"""Return Trivy output with no Metadata as string."""
return json.dumps(SAMPLE_TRIVY_NO_METADATA_OUTPUT)
@@ -0,0 +1,930 @@
import os
import tempfile
from unittest import mock
from unittest.mock import MagicMock, patch
import pytest
from prowler.lib.check.models import CheckReportImage
from prowler.providers.image.exceptions.exceptions import (
ImageInvalidConfigScannerError,
ImageInvalidNameError,
ImageInvalidScannerError,
ImageInvalidSeverityError,
ImageInvalidTimeoutError,
ImageListFileNotFoundError,
ImageListFileReadError,
ImageNoImagesProvidedError,
ImageScanError,
ImageTrivyBinaryNotFoundError,
)
from prowler.providers.image.image_provider import ImageProvider
from tests.providers.image.image_fixtures import (
SAMPLE_IMAGE_SHA,
SAMPLE_MISCONFIGURATION_FINDING,
SAMPLE_SECRET_FINDING,
SAMPLE_UNKNOWN_SEVERITY_FINDING,
SAMPLE_VULNERABILITY_FINDING,
get_empty_trivy_output,
get_invalid_trivy_output,
get_multi_type_trivy_output,
get_no_metadata_trivy_output,
get_repo_digest_only_trivy_output,
get_sample_trivy_json_output,
)
def _make_provider(**kwargs):
"""Helper to create an ImageProvider with test defaults."""
defaults = {
"images": ["alpine:3.18"],
"config_content": {},
}
defaults.update(kwargs)
return ImageProvider(**defaults)
class TestImageProvider:
def test_image_provider(self):
"""Test default initialization."""
provider = _make_provider()
assert provider._type == "image"
assert provider.type == "image"
assert provider.images == ["alpine:3.18"]
assert provider.scanners == ["vuln", "secret"]
assert provider.image_config_scanners == []
assert provider.trivy_severity == []
assert provider.ignore_unfixed is False
assert provider.timeout == "5m"
assert provider.region == "container"
assert provider.audited_account == "image-scan"
assert provider.identity == "prowler"
assert provider.auth_method == "No auth"
assert provider.session is None
assert provider.audit_config == {}
assert provider.fixer_config == {}
assert provider._mutelist is None
def test_image_provider_custom_params(self):
"""Test initialization with custom parameters."""
provider = _make_provider(
images=["nginx:1.25", "redis:7"],
scanners=["vuln", "secret", "misconfig"],
trivy_severity=["HIGH", "CRITICAL"],
ignore_unfixed=True,
timeout="10m",
fixer_config={"key": "value"},
)
assert provider.images == ["nginx:1.25", "redis:7"]
assert provider.scanners == ["vuln", "secret", "misconfig"]
assert provider.trivy_severity == ["HIGH", "CRITICAL"]
assert provider.ignore_unfixed is True
assert provider.timeout == "10m"
assert provider.fixer_config == {"key": "value"}
def test_image_provider_with_image_list_file(self):
"""Test loading images from a file, skipping comments and blank lines."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("# Comment line\n")
f.write("alpine:3.18\n")
f.write("\n")
f.write(" nginx:latest \n")
f.write("# Another comment\n")
f.write("redis:7\n")
f.name
provider = _make_provider(
images=None,
image_list_file=f.name,
)
assert "alpine:3.18" in provider.images
assert "nginx:latest" in provider.images
assert "redis:7" in provider.images
assert len(provider.images) == 3
def test_image_provider_no_images(self):
"""Test that ImageNoImagesProvidedError is raised when no images are given."""
with pytest.raises(ImageNoImagesProvidedError):
_make_provider(images=[])
def test_image_provider_image_list_file_not_found(self):
"""Test that ImageListFileNotFoundError is raised for missing file."""
with pytest.raises(ImageListFileNotFoundError):
_make_provider(
images=None,
image_list_file="/nonexistent/path/images.txt",
)
def test_process_finding_vulnerability(self):
"""Test processing a vulnerability finding."""
provider = _make_provider()
report = provider._process_finding(
SAMPLE_VULNERABILITY_FINDING,
"alpine:3.18",
"alpine:3.18 (alpine 3.18.0)",
image_sha="c1aabb73d233",
)
assert isinstance(report, CheckReportImage)
assert report.status == "FAIL"
assert report.check_metadata.CheckID == "CVE-2024-1234"
assert report.check_metadata.Severity == "high"
assert report.check_metadata.ServiceName == "container-image"
assert report.check_metadata.ResourceType == "container-image"
assert report.check_metadata.ResourceGroup == "container"
assert report.package_name == "openssl"
assert report.installed_version == "1.1.1k-r0"
assert report.fixed_version == "1.1.1l-r0"
assert report.resource_name == "alpine:3.18"
assert report.image_sha == "c1aabb73d233"
assert report.resource_details == "alpine:3.18 (alpine 3.18.0)"
assert report.region == "container"
assert report.check_metadata.Categories == ["vulnerability"]
assert report.check_metadata.RelatedUrl == ""
def test_process_finding_secret(self):
"""Test processing a secret finding (identified by RuleID)."""
provider = _make_provider()
report = provider._process_finding(
SAMPLE_SECRET_FINDING,
"myimage:latest",
"myimage:latest (debian 12)",
)
assert isinstance(report, CheckReportImage)
assert report.status == "FAIL"
assert report.check_metadata.CheckID == "aws-access-key-id"
assert report.check_metadata.Severity == "critical"
assert report.check_metadata.ServiceName == "container-image"
assert report.check_metadata.Categories == ["secrets"]
def test_process_finding_misconfiguration(self):
"""Test processing a misconfiguration finding (identified by ID)."""
provider = _make_provider()
report = provider._process_finding(
SAMPLE_MISCONFIGURATION_FINDING,
"myimage:latest",
"myimage:latest (debian 12)",
)
assert isinstance(report, CheckReportImage)
assert report.check_metadata.CheckID == "DS001"
assert report.check_metadata.Severity == "medium"
assert report.check_metadata.ServiceName == "container-image"
assert report.check_metadata.Categories == []
def test_process_finding_unknown_severity(self):
"""Test that UNKNOWN severity is mapped to informational."""
provider = _make_provider()
report = provider._process_finding(
SAMPLE_UNKNOWN_SEVERITY_FINDING,
"myimage:latest",
"myimage:latest (alpine 3.18.0)",
)
assert report.check_metadata.Severity == "informational"
@patch("subprocess.run")
def test_run_scan_success(self, mock_subprocess):
"""Test successful scan with mocked subprocess."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_sample_trivy_json_output(), stderr=""
)
reports = []
for batch in provider.run_scan():
reports.extend(batch)
assert len(reports) == 1
assert reports[0].check_metadata.CheckID == "CVE-2024-1234"
assert reports[0].image_sha == SAMPLE_IMAGE_SHA
assert reports[0].resource_name == "alpine:3.18"
assert reports[0].check_metadata.ServiceName == "container-image"
@patch("subprocess.run")
def test_run_scan_empty_output(self, mock_subprocess):
"""Test scan with empty Trivy output produces no findings."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_empty_trivy_output(), stderr=""
)
reports = []
for batch in provider.run_scan():
reports.extend(batch)
assert len(reports) == 0
@patch("subprocess.run")
def test_run_scan_invalid_json(self, mock_subprocess):
"""Test scan with malformed output doesn't crash."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_invalid_trivy_output(), stderr=""
)
reports = []
for batch in provider.run_scan():
reports.extend(batch)
assert len(reports) == 0
@patch("subprocess.run")
def test_run_scan_trivy_not_found(self, mock_subprocess):
"""Test that ImageTrivyBinaryNotFoundError is raised when trivy is missing."""
provider = _make_provider()
mock_subprocess.side_effect = FileNotFoundError(
"[Errno 2] No such file or directory: 'trivy'"
)
with pytest.raises(ImageTrivyBinaryNotFoundError):
for _ in provider._scan_single_image("alpine:3.18"):
pass
@patch("subprocess.run")
def test_run_scan_multiple_images(self, mock_subprocess):
"""Test scanning multiple images makes separate subprocess calls."""
provider = _make_provider(images=["alpine:3.18", "nginx:latest"])
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_sample_trivy_json_output(), stderr=""
)
reports = []
for batch in provider.run_scan():
reports.extend(batch)
assert mock_subprocess.call_count == 2
@patch("subprocess.run")
def test_run_scan_multi_type_output(self, mock_subprocess):
"""Test scan with vulnerabilities, secrets, and misconfigurations."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_multi_type_trivy_output(), stderr=""
)
reports = []
for batch in provider.run_scan():
reports.extend(batch)
assert len(reports) == 3
check_ids = [r.check_metadata.CheckID for r in reports]
assert "CVE-2024-1234" in check_ids
assert "aws-access-key-id" in check_ids
assert "DS001" in check_ids
def test_print_credentials(self):
"""Test that print_credentials outputs image names."""
provider = _make_provider()
with mock.patch("builtins.print") as mock_print:
provider.print_credentials()
output = " ".join(
str(call.args[0]) for call in mock_print.call_args_list if call.args
)
assert "alpine:3.18" in output
@patch("subprocess.run")
def test_test_connection_success(self, mock_subprocess):
"""Test successful connection returns is_connected=True."""
mock_subprocess.return_value = MagicMock(returncode=0, stderr="")
result = ImageProvider.test_connection(image="alpine:3.18")
assert result.is_connected is True
@patch("subprocess.run")
def test_test_connection_auth_failure(self, mock_subprocess):
"""Test 401 error returns auth failure."""
mock_subprocess.return_value = MagicMock(
returncode=1, stderr="401 unauthorized"
)
result = ImageProvider.test_connection(image="private/image:latest")
assert result.is_connected is False
assert "Authentication failed" in result.error
@patch("subprocess.run")
def test_test_connection_not_found(self, mock_subprocess):
"""Test 404 error returns not found."""
mock_subprocess.return_value = MagicMock(returncode=1, stderr="404 not found")
result = ImageProvider.test_connection(image="nonexistent/image:latest")
assert result.is_connected is False
assert "not found" in result.error
def test_build_status_extended(self):
"""Test status message content for different finding types."""
provider = _make_provider()
# Vulnerability with fix
status = provider._build_status_extended(SAMPLE_VULNERABILITY_FINDING)
assert "CVE-2024-1234" in status
assert "openssl" in status
assert "fix available" in status
# Finding with no special fields
status = provider._build_status_extended({"Description": "Simple finding"})
assert status == "Simple finding"
# Finding with will_not_fix status
finding_no_fix = {
"VulnerabilityID": "CVE-2024-0000",
"PkgName": "libc",
"Status": "will_not_fix",
"Title": "Some vuln",
}
status = provider._build_status_extended(finding_no_fix)
assert "no fix available" in status
def test_validate_arguments(self):
"""Test valid and invalid argument combinations."""
# Valid: images provided
provider = _make_provider(images=["alpine:3.18"])
assert provider.images == ["alpine:3.18"]
# Invalid: empty images and no file
with pytest.raises(ImageNoImagesProvidedError):
_make_provider(images=[])
# Valid: custom scanners
provider = _make_provider(scanners=["vuln"])
assert provider.scanners == ["vuln"]
def test_setup_session(self):
"""Test that setup_session returns None."""
provider = _make_provider()
assert provider.setup_session() is None
@patch("subprocess.run")
def test_run_method(self, mock_subprocess):
"""Test that run() collects all batches into a list."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_sample_trivy_json_output(), stderr=""
)
reports = provider.run()
assert isinstance(reports, list)
assert len(reports) == 1
@patch("subprocess.run")
def test_scan_single_image_trivy_nonzero_exit(self, mock_subprocess):
"""Test that a non-zero Trivy exit code raises ImageScanError."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=1,
stdout="",
stderr="fatal error: unable to pull image",
)
with pytest.raises(ImageScanError):
for _ in provider._scan_single_image("alpine:3.18"):
pass
@patch("subprocess.run")
def test_scan_single_image_auth_failure(self, mock_subprocess):
"""Test that a 401 unauthorized stderr raises ImageScanError with message."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=1,
stdout="",
stderr="ERROR 401 unauthorized: authentication required",
)
with pytest.raises(ImageScanError, match="401 unauthorized"):
for _ in provider._scan_single_image("private/image:latest"):
pass
@patch("subprocess.run")
def test_sha_extraction_from_image_id(self, mock_subprocess):
"""Test that image_sha is extracted from Trivy Metadata.ImageID."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_sample_trivy_json_output(), stderr=""
)
reports = []
for batch in provider._scan_single_image("alpine:3.18"):
reports.extend(batch)
assert len(reports) == 1
assert reports[0].image_sha == SAMPLE_IMAGE_SHA
@patch("subprocess.run")
def test_sha_extraction_fallback_to_repo_digests(self, mock_subprocess):
"""Test that image_sha falls back to RepoDigests when ImageID is absent."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_repo_digest_only_trivy_output(), stderr=""
)
reports = []
for batch in provider._scan_single_image("alpine:3.18"):
reports.extend(batch)
assert len(reports) == 1
assert reports[0].image_sha == "e5f6g7h8i9j0"
@patch("subprocess.run")
def test_sha_extraction_no_metadata(self, mock_subprocess):
"""Test that image_sha is empty when no Metadata is present."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_no_metadata_trivy_output(), stderr=""
)
reports = []
for batch in provider._scan_single_image("alpine:3.18"):
reports.extend(batch)
assert len(reports) == 1
assert reports[0].image_sha == ""
@patch("subprocess.run")
def test_run_scan_propagates_scan_error(self, mock_subprocess):
"""Test that run_scan() re-raises ImageScanError instead of swallowing it."""
provider = _make_provider()
mock_subprocess.return_value = MagicMock(
returncode=1,
stdout="",
stderr="image not found",
)
with pytest.raises(ImageScanError):
for _ in provider.run_scan():
pass
class TestImageProviderRegistryAuth:
def test_no_auth_by_default(self):
"""Test that no auth is set when no credentials are provided."""
provider = _make_provider()
assert provider.registry_username is None
assert provider.registry_password is None
assert provider.registry_token is None
assert provider.auth_method == "No auth"
def test_basic_auth_with_explicit_params(self):
"""Test basic auth via explicit constructor params."""
provider = _make_provider(
registry_username="myuser",
registry_password="mypass",
)
assert provider.registry_username == "myuser"
assert provider.registry_password == "mypass"
assert provider.auth_method == "Docker login"
def test_token_auth_with_explicit_param(self):
"""Test token auth via explicit constructor param."""
provider = _make_provider(registry_token="my-token-123")
assert provider.registry_token == "my-token-123"
assert provider.auth_method == "Registry token"
def test_basic_auth_takes_precedence_over_token(self):
"""Test that username/password takes precedence over token."""
provider = _make_provider(
registry_username="myuser",
registry_password="mypass",
registry_token="my-token",
)
assert provider.auth_method == "Docker login"
@patch.dict(
os.environ, {"REGISTRY_USERNAME": "envuser", "REGISTRY_PASSWORD": "envpass"}
)
def test_basic_auth_from_env_vars(self):
"""Test that env vars are used as fallback for basic auth."""
provider = _make_provider()
assert provider.registry_username == "envuser"
assert provider.registry_password == "envpass"
assert provider.auth_method == "Docker login"
@patch.dict(os.environ, {"REGISTRY_TOKEN": "env-token"})
def test_token_auth_from_env_var(self):
"""Test that env var is used as fallback for token auth."""
provider = _make_provider()
assert provider.registry_token == "env-token"
assert provider.auth_method == "Registry token"
@patch.dict(
os.environ, {"REGISTRY_USERNAME": "envuser", "REGISTRY_PASSWORD": "envpass"}
)
def test_explicit_params_override_env_vars(self):
"""Test that explicit params take precedence over env vars."""
provider = _make_provider(
registry_username="explicit",
registry_password="explicit-pass",
)
assert provider.registry_username == "explicit"
assert provider.registry_password == "explicit-pass"
def test_build_trivy_env_no_auth(self):
"""Test that _build_trivy_env returns base env when no auth."""
provider = _make_provider()
env = provider._build_trivy_env()
assert "TRIVY_USERNAME" not in env
assert "TRIVY_PASSWORD" not in env
assert "TRIVY_REGISTRY_TOKEN" not in env
def test_build_trivy_env_basic_auth_sets_env_vars(self):
"""Test that _build_trivy_env injects TRIVY_USERNAME/PASSWORD for native Trivy auth."""
provider = _make_provider(
registry_username="myuser",
registry_password="mypass",
)
env = provider._build_trivy_env()
assert env["TRIVY_USERNAME"] == "myuser"
assert env["TRIVY_PASSWORD"] == "mypass"
def test_build_trivy_env_token_auth(self):
"""Test that _build_trivy_env injects registry token."""
provider = _make_provider(registry_token="my-token")
env = provider._build_trivy_env()
assert env["TRIVY_REGISTRY_TOKEN"] == "my-token"
@patch("subprocess.run")
def test_execute_trivy_sets_trivy_env_with_basic_auth(self, mock_subprocess):
"""Test that _execute_trivy sets TRIVY_USERNAME/PASSWORD for native Trivy auth."""
provider = _make_provider(
registry_username="myuser",
registry_password="mypass",
)
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_sample_trivy_json_output(), stderr=""
)
provider._execute_trivy(["trivy", "image", "alpine:3.18"], "alpine:3.18")
call_kwargs = mock_subprocess.call_args
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env")
assert env["TRIVY_USERNAME"] == "myuser"
assert env["TRIVY_PASSWORD"] == "mypass"
@patch("subprocess.run")
def test_test_connection_with_basic_auth(self, mock_subprocess):
"""Test test_connection uses Trivy native auth with TRIVY_USERNAME/PASSWORD env vars."""
mock_subprocess.return_value = MagicMock(returncode=0, stderr="")
result = ImageProvider.test_connection(
image="private.registry.io/myapp:v1",
registry_username="myuser",
registry_password="mypass",
)
assert result.is_connected is True
assert mock_subprocess.call_count == 1
trivy_call = mock_subprocess.call_args
assert trivy_call.args[0][0] == "trivy"
env = trivy_call.kwargs.get("env") or trivy_call[1].get("env")
assert env["TRIVY_USERNAME"] == "myuser"
assert env["TRIVY_PASSWORD"] == "mypass"
@patch("subprocess.run")
def test_test_connection_with_token(self, mock_subprocess):
"""Test test_connection passes token via env."""
mock_subprocess.return_value = MagicMock(returncode=0, stderr="")
result = ImageProvider.test_connection(
image="private.registry.io/myapp:v1",
registry_token="my-token",
)
assert result.is_connected is True
call_kwargs = mock_subprocess.call_args
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env")
assert env["TRIVY_REGISTRY_TOKEN"] == "my-token"
def test_print_credentials_shows_auth_method(self):
"""Test that print_credentials outputs the auth method."""
provider = _make_provider(
registry_username="myuser",
registry_password="mypass",
)
with mock.patch("builtins.print") as mock_print:
provider.print_credentials()
output = " ".join(
str(call.args[0]) for call in mock_print.call_args_list if call.args
)
assert "Docker login" in output
class TestExtractRegistry:
def test_docker_hub_simple(self):
assert ImageProvider._extract_registry("alpine:3.18") is None
def test_docker_hub_with_namespace(self):
assert ImageProvider._extract_registry("andoniaf/test-private:tag") is None
def test_ghcr(self):
assert ImageProvider._extract_registry("ghcr.io/user/image:tag") == "ghcr.io"
def test_ecr(self):
assert (
ImageProvider._extract_registry(
"123456789012.dkr.ecr.us-east-1.amazonaws.com/repo:tag"
)
== "123456789012.dkr.ecr.us-east-1.amazonaws.com"
)
def test_localhost_with_port(self):
assert (
ImageProvider._extract_registry("localhost:5000/myimage:latest")
== "localhost:5000"
)
def test_custom_registry_with_port(self):
assert (
ImageProvider._extract_registry("myregistry.io:5000/image:tag")
== "myregistry.io:5000"
)
def test_digest_reference(self):
assert (
ImageProvider._extract_registry("ghcr.io/user/image@sha256:abc123")
== "ghcr.io"
)
def test_bare_image_name(self):
assert ImageProvider._extract_registry("nginx") is None
class TestCleanup:
def test_cleanup_idempotent(self):
"""Test cleanup is safe to call multiple times."""
provider = _make_provider()
provider.cleanup()
provider.cleanup()
class TestImageProviderInputValidation:
def test_invalid_timeout_format_raises_error(self):
"""Test that a non-matching timeout string raises ImageInvalidTimeoutError."""
with pytest.raises(ImageInvalidTimeoutError):
_make_provider(timeout="invalid")
def test_invalid_timeout_no_unit_raises_error(self):
"""Test that a numeric timeout without a unit raises ImageInvalidTimeoutError."""
with pytest.raises(ImageInvalidTimeoutError):
_make_provider(timeout="300")
def test_invalid_timeout_wrong_unit_raises_error(self):
"""Test that a timeout with an unsupported unit raises ImageInvalidTimeoutError."""
with pytest.raises(ImageInvalidTimeoutError):
_make_provider(timeout="5d")
def test_valid_timeout_seconds(self):
"""Test that a seconds-based timeout is accepted."""
provider = _make_provider(timeout="300s")
assert provider.timeout == "300s"
def test_valid_timeout_hours(self):
"""Test that an hours-based timeout is accepted."""
provider = _make_provider(timeout="1h")
assert provider.timeout == "1h"
def test_invalid_scanner_raises_error(self):
"""Test that an invalid scanner name raises ImageInvalidScannerError."""
with pytest.raises(ImageInvalidScannerError):
_make_provider(scanners=["vuln", "bad"])
def test_invalid_severity_raises_error(self):
"""Test that an invalid severity level raises ImageInvalidSeverityError."""
with pytest.raises(ImageInvalidSeverityError):
_make_provider(trivy_severity=["HIGH", "SUPER_HIGH"])
def test_valid_all_scanners(self):
"""Test that all valid scanner choices are accepted."""
provider = _make_provider(scanners=["vuln", "secret", "misconfig", "license"])
assert provider.scanners == ["vuln", "secret", "misconfig", "license"]
def test_valid_all_severities(self):
"""Test that all valid severity choices are accepted."""
provider = _make_provider(
trivy_severity=["CRITICAL", "HIGH", "MEDIUM", "LOW", "UNKNOWN"]
)
assert provider.trivy_severity == [
"CRITICAL",
"HIGH",
"MEDIUM",
"LOW",
"UNKNOWN",
]
def test_image_config_scanners_defaults_to_empty(self):
"""Test that image_config_scanners defaults to an empty list."""
provider = _make_provider()
assert provider.image_config_scanners == []
def test_valid_image_config_scanners(self):
"""Test that valid image config scanners are accepted."""
provider = _make_provider(image_config_scanners=["misconfig", "secret"])
assert provider.image_config_scanners == ["misconfig", "secret"]
def test_invalid_image_config_scanner_raises_error(self):
"""Test that an invalid image config scanner raises ImageInvalidConfigScannerError."""
with pytest.raises(ImageInvalidConfigScannerError):
_make_provider(image_config_scanners=["misconfig", "vuln"])
@patch("subprocess.run")
def test_trivy_command_includes_image_config_scanners(self, mock_subprocess):
"""Test that Trivy command includes --image-config-scanners when set."""
provider = _make_provider(image_config_scanners=["misconfig", "secret"])
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_empty_trivy_output(), stderr=""
)
for _ in provider._scan_single_image("alpine:3.18"):
pass
call_args = mock_subprocess.call_args[0][0]
assert "--image-config-scanners" in call_args
idx = call_args.index("--image-config-scanners")
assert call_args[idx + 1] == "misconfig,secret"
@patch("subprocess.run")
def test_trivy_command_omits_image_config_scanners_when_empty(
self, mock_subprocess
):
"""Test that Trivy command omits --image-config-scanners when empty."""
provider = _make_provider(image_config_scanners=[])
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_empty_trivy_output(), stderr=""
)
for _ in provider._scan_single_image("alpine:3.18"):
pass
call_args = mock_subprocess.call_args[0][0]
assert "--image-config-scanners" not in call_args
class TestImageProviderErrorCategorization:
def test_categorize_auth_failure(self):
"""Test that auth-related errors are categorized correctly."""
result = ImageProvider._categorize_trivy_error(
"401 unauthorized: access denied"
)
assert "Auth failure" in result
def test_categorize_not_found(self):
"""Test that not-found errors are categorized correctly."""
result = ImageProvider._categorize_trivy_error(
"manifest unknown: image not found"
)
assert "Image not found" in result
def test_categorize_rate_limit(self):
"""Test that rate-limit errors are categorized correctly."""
result = ImageProvider._categorize_trivy_error("429 too many requests")
assert "Rate limited" in result
def test_categorize_network_issue(self):
"""Test that network errors are categorized correctly."""
result = ImageProvider._categorize_trivy_error("connection refused to registry")
assert "Network issue" in result
def test_categorize_unknown_error(self):
"""Test that unrecognized errors are returned as-is."""
msg = "some unknown trivy error"
result = ImageProvider._categorize_trivy_error(msg)
assert result == msg
class TestImageProviderNameValidation:
@pytest.mark.parametrize(
"bad_name",
[
"alpine;rm -rf /",
"image|cat /etc/passwd",
"image&background",
"image$VAR",
"image`whoami`",
"image\ninjected",
"image\rinjected",
],
)
def test_image_provider_invalid_image_name_shell_chars(self, bad_name):
"""Test that image names with shell metacharacters raise ImageInvalidNameError."""
with pytest.raises(ImageInvalidNameError):
_make_provider(images=[bad_name])
def test_image_provider_invalid_image_name_empty(self):
"""Test that an empty string image name raises ImageInvalidNameError."""
with pytest.raises(ImageInvalidNameError):
_make_provider(images=[""])
@pytest.mark.parametrize(
"valid_name",
[
"alpine:3.18",
"nginx:latest",
"registry.example.com/repo/image:tag",
"ghcr.io/owner/image:v1.2.3",
"myimage@sha256:abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890",
"localhost:5000/myimage:latest",
],
)
def test_image_provider_valid_image_names(self, valid_name):
"""Test that various valid image name formats pass validation."""
provider = _make_provider(images=[valid_name])
assert valid_name in provider.images
def test_image_provider_image_name_too_long(self):
"""Test that a name exceeding 500 chars raises ImageInvalidNameError."""
long_name = "a" * 501
with pytest.raises(ImageInvalidNameError):
_make_provider(images=[long_name])
def test_image_provider_file_too_many_lines(self):
"""Test that a file with more than MAX_IMAGE_LIST_LINES raises ImageListFileReadError."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
for i in range(10_001):
f.write(f"image{i}:latest\n")
f.flush()
file_path = f.name
with pytest.raises(ImageListFileReadError):
_make_provider(images=None, image_list_file=file_path)
class TestScanPerImage:
@patch("subprocess.run")
def test_yields_per_image(self, mock_subprocess):
"""Test that scan_per_image yields (name, findings) per image."""
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_sample_trivy_json_output(), stderr=""
)
provider = _make_provider(images=["alpine:3.18", "nginx:latest"])
results = list(provider.scan_per_image())
assert len(results) == 2
for name, findings in results:
assert isinstance(name, str)
assert isinstance(findings, list)
assert all(isinstance(f, CheckReportImage) for f in findings)
@patch("subprocess.run")
def test_reraises_scan_error(self, mock_subprocess):
"""Test that ImageScanError propagates from scan_per_image."""
mock_subprocess.return_value = MagicMock(
returncode=1, stdout="", stderr="scan failed"
)
provider = _make_provider(images=["alpine:3.18"])
with pytest.raises(ImageScanError):
list(provider.scan_per_image())
@patch("subprocess.run")
def test_skips_generic_error(self, mock_subprocess):
"""Test that a generic RuntimeError in _scan_single_image yields empty findings and continues."""
def side_effect(cmd, **kwargs):
if "bad:image" in cmd:
raise RuntimeError("unexpected error")
return MagicMock(
returncode=0, stdout=get_sample_trivy_json_output(), stderr=""
)
mock_subprocess.side_effect = side_effect
provider = _make_provider(images=["bad:image", "alpine:3.18"])
results = list(provider.scan_per_image())
assert len(results) == 2
assert results[0][0] == "bad:image"
assert results[0][1] == []
assert results[1][0] == "alpine:3.18"
assert len(results[1][1]) > 0
@patch("subprocess.run")
def test_calls_cleanup(self, mock_subprocess):
"""Test that cleanup is called even after scan_per_image completes."""
mock_subprocess.return_value = MagicMock(
returncode=0, stdout=get_sample_trivy_json_output(), stderr=""
)
provider = _make_provider(images=["alpine:3.18"])
with mock.patch.object(provider, "cleanup") as mock_cleanup:
list(provider.scan_per_image())
mock_cleanup.assert_called_once()
@@ -0,0 +1,223 @@
from argparse import Namespace
from prowler.providers.image.lib.arguments.arguments import validate_arguments
class TestValidateArguments:
def test_no_source_fails(self):
args = Namespace(
images=[],
image_list_file=None,
registry=None,
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, msg = validate_arguments(args)
assert not ok
assert "--image" in msg
def test_image_only_passes(self):
args = Namespace(
images=["nginx:latest"],
image_list_file=None,
registry=None,
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, _ = validate_arguments(args)
assert ok
def test_image_list_only_passes(self):
args = Namespace(
images=[],
image_list_file="images.txt",
registry=None,
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, _ = validate_arguments(args)
assert ok
def test_registry_only_passes(self):
args = Namespace(
images=[],
image_list_file=None,
registry="myregistry.io",
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, _ = validate_arguments(args)
assert ok
def test_image_filter_without_registry_fails(self):
args = Namespace(
images=["nginx:latest"],
image_list_file=None,
registry=None,
image_filter="^prod",
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, msg = validate_arguments(args)
assert not ok
assert "--image-filter requires --registry" in msg
def test_tag_filter_without_registry_fails(self):
args = Namespace(
images=["nginx:latest"],
image_list_file=None,
registry=None,
image_filter=None,
tag_filter="^v",
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, msg = validate_arguments(args)
assert not ok
assert "--tag-filter requires --registry" in msg
def test_max_images_without_registry_fails(self):
args = Namespace(
images=["nginx:latest"],
image_list_file=None,
registry=None,
image_filter=None,
tag_filter=None,
max_images=50,
registry_insecure=False,
registry_list_images=False,
)
ok, msg = validate_arguments(args)
assert not ok
assert "--max-images requires --registry" in msg
def test_registry_insecure_without_registry_fails(self):
args = Namespace(
images=[],
image_list_file="i.txt",
registry=None,
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=True,
registry_list_images=False,
)
ok, msg = validate_arguments(args)
assert not ok
assert "--registry-insecure requires --registry" in msg
def test_docker_hub_no_namespace_fails(self):
args = Namespace(
images=[],
image_list_file=None,
registry="docker.io",
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, msg = validate_arguments(args)
assert not ok
assert "namespace" in msg.lower()
def test_docker_hub_with_namespace_passes(self):
args = Namespace(
images=[],
image_list_file=None,
registry="docker.io/myorg",
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, _ = validate_arguments(args)
assert ok
def test_docker_hub_https_no_namespace_fails(self):
args = Namespace(
images=[],
image_list_file=None,
registry="https://docker.io",
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, msg = validate_arguments(args)
assert not ok
assert "namespace" in msg.lower()
def test_registry_with_filters_passes(self):
args = Namespace(
images=[],
image_list_file=None,
registry="myregistry.io",
image_filter="^prod",
tag_filter="^v",
max_images=100,
registry_insecure=True,
registry_list_images=False,
)
ok, _ = validate_arguments(args)
assert ok
def test_registry_list_without_registry_fails(self):
args = Namespace(
images=["nginx:latest"],
image_list_file=None,
registry=None,
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=True,
)
ok, msg = validate_arguments(args)
assert not ok
assert "--registry-list requires --registry" in msg
def test_registry_list_with_registry_passes(self):
args = Namespace(
images=[],
image_list_file=None,
registry="myregistry.io",
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=True,
)
ok, _ = validate_arguments(args)
assert ok
def test_combined_registry_and_image_passes(self):
args = Namespace(
images=["nginx:latest"],
image_list_file=None,
registry="myregistry.io",
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
)
ok, _ = validate_arguments(args)
assert ok
@@ -0,0 +1,243 @@
from unittest.mock import MagicMock, patch
import pytest
import requests
from prowler.providers.image.exceptions.exceptions import (
ImageRegistryAuthError,
ImageRegistryCatalogError,
ImageRegistryNetworkError,
)
from prowler.providers.image.lib.registry.dockerhub_adapter import DockerHubAdapter
class TestDockerHubAdapterInit:
def test_extract_namespace_simple(self):
assert DockerHubAdapter._extract_namespace("docker.io/myorg") == "myorg"
def test_extract_namespace_https(self):
assert DockerHubAdapter._extract_namespace("https://docker.io/myorg") == "myorg"
def test_extract_namespace_registry1(self):
assert (
DockerHubAdapter._extract_namespace("registry-1.docker.io/myorg") == "myorg"
)
def test_extract_namespace_empty(self):
assert DockerHubAdapter._extract_namespace("docker.io") == ""
def test_extract_namespace_with_slash(self):
assert DockerHubAdapter._extract_namespace("docker.io/myorg/") == "myorg"
class TestDockerHubListRepositories:
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_repos(self, mock_request):
# Hub login (now goes through requests.request via _request_with_retry)
login_resp = MagicMock(status_code=200)
login_resp.json.return_value = {"token": "jwt"}
# Repo listing
repos_resp = MagicMock(status_code=200)
repos_resp.json.return_value = {
"results": [{"name": "app1"}, {"name": "app2"}],
"next": None,
}
mock_request.side_effect = [login_resp, repos_resp]
adapter = DockerHubAdapter("docker.io/myorg", username="u", password="p")
repos = adapter.list_repositories()
assert repos == ["myorg/app1", "myorg/app2"]
def test_list_repos_no_namespace_raises(self):
adapter = DockerHubAdapter("docker.io")
with pytest.raises(ImageRegistryCatalogError, match="namespace"):
adapter.list_repositories()
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_repos_public_no_credentials(self, mock_request):
"""When no credentials are provided, use the public /v2/repositories/{ns}/ endpoint."""
repos_resp = MagicMock(status_code=200)
repos_resp.json.return_value = {
"results": [{"name": "repo1"}, {"name": "repo2"}],
"next": None,
}
mock_request.return_value = repos_resp
adapter = DockerHubAdapter("docker.io/publicns")
repos = adapter.list_repositories()
assert repos == ["publicns/repo1", "publicns/repo2"]
called_url = mock_request.call_args[0][1]
assert "/v2/repositories/publicns/" in called_url
assert "/v2/namespaces/" not in called_url
class TestDockerHubListTags:
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_tags(self, mock_request):
# Token exchange (now goes through requests.request via _request_with_retry)
token_resp = MagicMock(status_code=200)
token_resp.json.return_value = {"token": "registry-token"}
# Tag listing
tags_resp = MagicMock(status_code=200, headers={})
tags_resp.json.return_value = {"tags": ["latest", "v1.0"]}
mock_request.side_effect = [token_resp, tags_resp]
adapter = DockerHubAdapter("docker.io/myorg", username="u", password="p")
tags = adapter.list_tags("myorg/myapp")
assert tags == ["latest", "v1.0"]
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_tags_auth_failure(self, mock_request):
# Token exchange
token_resp = MagicMock(status_code=200)
token_resp.json.return_value = {"token": "tok"}
# Tag listing returns 401
tags_resp = MagicMock(status_code=401)
mock_request.side_effect = [token_resp, tags_resp]
adapter = DockerHubAdapter("docker.io/myorg")
with pytest.raises(ImageRegistryAuthError):
adapter.list_tags("myorg/myapp")
class TestDockerHubLogin:
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_login_failure(self, mock_request):
resp = MagicMock(status_code=401, text="invalid credentials")
mock_request.return_value = resp
adapter = DockerHubAdapter("docker.io/myorg", username="bad", password="creds")
with pytest.raises(ImageRegistryAuthError, match="login failed"):
adapter._hub_login()
def test_login_skipped_without_credentials(self):
adapter = DockerHubAdapter("docker.io/myorg")
adapter._hub_login() # Should not raise
assert adapter._hub_jwt is None
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_login_401_includes_response_body(self, mock_request):
resp = MagicMock(
status_code=401, text='{"detail":"Incorrect authentication credentials"}'
)
mock_request.return_value = resp
adapter = DockerHubAdapter("docker.io/myorg", username="u", password="p")
with pytest.raises(
ImageRegistryAuthError, match="Incorrect authentication credentials"
):
adapter._hub_login()
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_login_500_retried_then_raises_network_error(
self, mock_request, mock_sleep
):
mock_request.return_value = MagicMock(status_code=500)
adapter = DockerHubAdapter("docker.io/myorg", username="u", password="p")
with pytest.raises(ImageRegistryNetworkError, match="Server error"):
adapter._hub_login()
assert mock_request.call_count == 3
class TestDockerHubRetry:
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_retry_on_429(self, mock_request, mock_sleep):
resp_429 = MagicMock(status_code=429)
resp_200 = MagicMock(status_code=200)
mock_request.side_effect = [resp_429, resp_200]
adapter = DockerHubAdapter("docker.io/myorg")
result = adapter._request_with_retry(
"GET", "https://hub.docker.com/v2/namespaces/myorg/repositories"
)
assert result.status_code == 200
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_connection_error_retries(self, mock_request, mock_sleep):
mock_request.side_effect = requests.exceptions.ConnectionError("fail")
adapter = DockerHubAdapter("docker.io/myorg")
with pytest.raises(ImageRegistryNetworkError):
adapter._request_with_retry("GET", "https://hub.docker.com")
assert mock_request.call_count == 3
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_retry_on_500(self, mock_request, mock_sleep):
resp_500 = MagicMock(status_code=500)
resp_200 = MagicMock(status_code=200)
mock_request.side_effect = [resp_500, resp_200]
adapter = DockerHubAdapter("docker.io/myorg")
result = adapter._request_with_retry("GET", "https://hub.docker.com")
assert result.status_code == 200
assert mock_request.call_count == 2
mock_sleep.assert_called_once()
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_retry_exhausted_on_500_raises_network_error(
self, mock_request, mock_sleep
):
mock_request.return_value = MagicMock(status_code=500)
adapter = DockerHubAdapter("docker.io/myorg")
with pytest.raises(
ImageRegistryNetworkError, match="Server error.*HTTP 500.*3 attempts"
):
adapter._request_with_retry("GET", "https://hub.docker.com")
assert mock_request.call_count == 3
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_4xx_not_retried(self, mock_request, mock_sleep):
mock_request.return_value = MagicMock(status_code=403)
adapter = DockerHubAdapter("docker.io/myorg")
result = adapter._request_with_retry("GET", "https://hub.docker.com")
assert result.status_code == 403
assert mock_request.call_count == 1
mock_sleep.assert_not_called()
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_request_sends_user_agent(self, mock_request):
mock_request.return_value = MagicMock(status_code=200)
adapter = DockerHubAdapter("docker.io/myorg")
adapter._request_with_retry("GET", "https://hub.docker.com")
_, kwargs = mock_request.call_args
from prowler.config.config import prowler_version
assert (
kwargs["headers"]["User-Agent"]
== f"Prowler/{prowler_version} (registry-adapter)"
)
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_retry_500_includes_response_body(self, mock_request, mock_sleep):
resp_500 = MagicMock(status_code=500, text="<html>Cloudflare error</html>")
mock_request.return_value = resp_500
adapter = DockerHubAdapter("docker.io/myorg")
with pytest.raises(ImageRegistryNetworkError, match="Cloudflare error"):
adapter._request_with_retry("GET", "https://hub.docker.com")
class TestDockerHubEmptyTokens:
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_empty_hub_jwt_raises(self, mock_request):
resp = MagicMock(status_code=200)
resp.json.return_value = {"token": ""}
mock_request.return_value = resp
adapter = DockerHubAdapter("docker.io/myorg", username="u", password="p")
with pytest.raises(ImageRegistryAuthError, match="empty JWT"):
adapter._hub_login()
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_none_hub_jwt_raises(self, mock_request):
resp = MagicMock(status_code=200)
resp.json.return_value = {}
mock_request.return_value = resp
adapter = DockerHubAdapter("docker.io/myorg", username="u", password="p")
with pytest.raises(ImageRegistryAuthError, match="empty JWT"):
adapter._hub_login()
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_empty_registry_token_raises(self, mock_request):
resp = MagicMock(status_code=200)
resp.json.return_value = {"token": ""}
mock_request.return_value = resp
adapter = DockerHubAdapter("docker.io/myorg", username="u", password="p")
with pytest.raises(ImageRegistryAuthError, match="empty token"):
adapter._get_registry_token("myorg/myapp")
@@ -0,0 +1,34 @@
from prowler.providers.image.lib.registry.dockerhub_adapter import DockerHubAdapter
from prowler.providers.image.lib.registry.factory import create_registry_adapter
from prowler.providers.image.lib.registry.oci_adapter import OciRegistryAdapter
class TestCreateRegistryAdapter:
def test_docker_hub_returns_dockerhub_adapter(self):
adapter = create_registry_adapter("docker.io/myorg")
assert isinstance(adapter, DockerHubAdapter)
def test_oci_returns_oci_adapter(self):
adapter = create_registry_adapter("myregistry.io")
assert isinstance(adapter, OciRegistryAdapter)
def test_ecr_returns_oci_adapter(self):
adapter = create_registry_adapter("123456789.dkr.ecr.us-east-1.amazonaws.com")
assert isinstance(adapter, OciRegistryAdapter)
def test_passes_credentials(self):
adapter = create_registry_adapter(
"myregistry.io",
username="user",
password="pass",
token="tok",
verify_ssl=False,
)
assert adapter.username == "user"
assert adapter.password == "pass"
assert adapter.token == "tok"
assert adapter.verify_ssl is False
def test_registry_1_docker_io(self):
adapter = create_registry_adapter("registry-1.docker.io/myorg")
assert isinstance(adapter, DockerHubAdapter)
@@ -0,0 +1,418 @@
import base64
from unittest.mock import MagicMock, patch
import pytest
import requests
from prowler.providers.image.exceptions.exceptions import (
ImageRegistryAuthError,
ImageRegistryCatalogError,
ImageRegistryNetworkError,
)
from prowler.providers.image.lib.registry.oci_adapter import OciRegistryAdapter
class TestOciAdapterInit:
def test_normalise_url_adds_https(self):
adapter = OciRegistryAdapter("myregistry.io")
assert adapter._base_url == "https://myregistry.io"
def test_normalise_url_keeps_http(self):
adapter = OciRegistryAdapter("http://myregistry.io")
assert adapter._base_url == "http://myregistry.io"
def test_normalise_url_strips_trailing_slash(self):
adapter = OciRegistryAdapter("https://myregistry.io/")
assert adapter._base_url == "https://myregistry.io"
def test_stores_credentials(self):
adapter = OciRegistryAdapter(
"reg.io", username="u", password="p", token="t", verify_ssl=False
)
assert adapter.username == "u"
assert adapter.password == "p"
assert adapter.token == "t"
assert adapter.verify_ssl is False
class TestOciAdapterAuth:
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_ensure_auth_with_token(self, mock_request):
adapter = OciRegistryAdapter("reg.io", token="my-token")
adapter._ensure_auth()
assert adapter._bearer_token == "my-token"
mock_request.assert_not_called()
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_ensure_auth_anonymous_ok(self, mock_request):
resp = MagicMock(status_code=200)
mock_request.return_value = resp
adapter = OciRegistryAdapter("reg.io")
adapter._ensure_auth()
assert adapter._bearer_token is None
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_ensure_auth_bearer_challenge(self, mock_request):
ping_resp = MagicMock(
status_code=401,
headers={
"Www-Authenticate": 'Bearer realm="https://auth.example.com/token",service="registry"'
},
)
token_resp = MagicMock(status_code=200)
token_resp.json.return_value = {"token": "bearer-tok"}
mock_request.side_effect = [ping_resp, token_resp]
adapter = OciRegistryAdapter("reg.io", username="u", password="p")
adapter._ensure_auth()
assert adapter._bearer_token == "bearer-tok"
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_ensure_auth_403_raises(self, mock_request):
resp = MagicMock(status_code=403)
mock_request.return_value = resp
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryAuthError):
adapter._ensure_auth()
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_ensure_auth_basic_challenge_with_creds(self, mock_request):
ping_resp = MagicMock(
status_code=401,
headers={"Www-Authenticate": 'Basic realm="https://ecr.aws"'},
)
mock_request.return_value = ping_resp
adapter = OciRegistryAdapter("ecr.aws", username="AWS", password="tok")
adapter._ensure_auth()
assert adapter._basic_auth_verified is True
assert adapter._bearer_token is None
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_ensure_auth_basic_challenge_no_creds(self, mock_request):
ping_resp = MagicMock(
status_code=401,
headers={"Www-Authenticate": 'Basic realm="https://ecr.aws"'},
)
mock_request.return_value = ping_resp
adapter = OciRegistryAdapter("ecr.aws")
with pytest.raises(ImageRegistryAuthError):
adapter._ensure_auth()
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_basic_auth_used_in_requests(self, mock_request):
ping_resp = MagicMock(
status_code=401,
headers={"Www-Authenticate": 'Basic realm="https://ecr.aws"'},
)
catalog_resp = MagicMock(status_code=200, headers={})
catalog_resp.json.return_value = {"repositories": ["myapp"]}
mock_request.side_effect = [ping_resp, catalog_resp]
adapter = OciRegistryAdapter("ecr.aws", username="AWS", password="tok")
adapter._ensure_auth()
adapter._authed_request("GET", "https://ecr.aws/v2/_catalog")
# The catalog request should use Basic auth (auth kwarg), not Bearer header
call_kwargs = mock_request.call_args_list[1][1]
assert call_kwargs.get("auth") == ("AWS", "tok")
assert "Authorization" not in call_kwargs.get("headers", {})
def test_resolve_basic_credentials_decodes_base64_token(self):
raw_password = "real-jwt-password"
encoded = base64.b64encode(f"AWS:{raw_password}".encode()).decode()
adapter = OciRegistryAdapter("ecr.aws", username="AWS", password=encoded)
user, pwd = adapter._resolve_basic_credentials()
assert user == "AWS"
assert pwd == raw_password
def test_resolve_basic_credentials_passthrough_raw_password(self):
adapter = OciRegistryAdapter("ecr.aws", username="AWS", password="plain-pass")
user, pwd = adapter._resolve_basic_credentials()
assert user == "AWS"
assert pwd == "plain-pass"
def test_resolve_basic_credentials_passthrough_invalid_base64(self):
adapter = OciRegistryAdapter(
"ecr.aws", username="AWS", password="not!valid~base64"
)
user, pwd = adapter._resolve_basic_credentials()
assert user == "AWS"
assert pwd == "not!valid~base64"
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_basic_auth_decodes_ecr_token_in_request(self, mock_request):
raw_password = "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.abc"
encoded = base64.b64encode(f"AWS:{raw_password}".encode()).decode()
ping_resp = MagicMock(
status_code=401,
headers={"Www-Authenticate": 'Basic realm="https://ecr.aws"'},
)
catalog_resp = MagicMock(status_code=200, headers={})
catalog_resp.json.return_value = {"repositories": ["myapp"]}
mock_request.side_effect = [ping_resp, catalog_resp]
adapter = OciRegistryAdapter("ecr.aws", username="AWS", password=encoded)
adapter._ensure_auth()
adapter._authed_request("GET", "https://ecr.aws/v2/_catalog")
call_kwargs = mock_request.call_args_list[1][1]
assert call_kwargs.get("auth") == ("AWS", raw_password)
def test_resolve_basic_credentials_none_password(self):
adapter = OciRegistryAdapter("ecr.aws", username="AWS", password=None)
user, pwd = adapter._resolve_basic_credentials()
assert user == "AWS"
assert pwd is None
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_authed_request_retries_on_401_with_bearer(self, mock_request):
adapter = OciRegistryAdapter("reg.io", username="u", password="p")
adapter._bearer_token = "expired-token"
# First request: 401 (expired token)
resp_401 = MagicMock(status_code=401)
# _ensure_auth ping: 401 with bearer challenge
ping_resp = MagicMock(
status_code=401,
headers={
"Www-Authenticate": 'Bearer realm="https://auth.reg.io/token",service="registry"'
},
)
# Token exchange: success
token_resp = MagicMock(status_code=200)
token_resp.json.return_value = {"token": "new-token"}
# Second request: 200 (new token works)
resp_200 = MagicMock(status_code=200)
mock_request.side_effect = [resp_401, ping_resp, token_resp, resp_200]
result = adapter._authed_request("GET", "https://reg.io/v2/myapp/tags/list")
assert result.status_code == 200
assert adapter._bearer_token == "new-token"
assert mock_request.call_count == 4
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_authed_request_no_retry_on_401_without_bearer(self, mock_request):
adapter = OciRegistryAdapter("reg.io", username="u", password="p")
adapter._basic_auth_verified = True
# No bearer token — using basic auth
resp_401 = MagicMock(status_code=401)
mock_request.return_value = resp_401
result = adapter._authed_request("GET", "https://reg.io/v2/_catalog")
assert result.status_code == 401
# Should only be called once (no retry for basic auth)
assert mock_request.call_count == 1
class TestOciAdapterListRepositories:
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_repos_single_page(self, mock_request):
ping_resp = MagicMock(status_code=200)
catalog_resp = MagicMock(status_code=200, headers={})
catalog_resp.json.return_value = {
"repositories": ["app/frontend", "app/backend"]
}
mock_request.side_effect = [ping_resp, catalog_resp]
adapter = OciRegistryAdapter("reg.io")
repos = adapter.list_repositories()
assert repos == ["app/frontend", "app/backend"]
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_repos_paginated(self, mock_request):
ping_resp = MagicMock(status_code=200)
page1_resp = MagicMock(
status_code=200,
headers={"Link": '<https://reg.io/v2/_catalog?n=200&last=b>; rel="next"'},
)
page1_resp.json.return_value = {"repositories": ["a"]}
page2_resp = MagicMock(status_code=200, headers={})
page2_resp.json.return_value = {"repositories": ["b"]}
mock_request.side_effect = [ping_resp, page1_resp, page2_resp]
adapter = OciRegistryAdapter("reg.io")
repos = adapter.list_repositories()
assert repos == ["a", "b"]
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_repos_404_raises(self, mock_request):
ping_resp = MagicMock(status_code=200)
catalog_resp = MagicMock(status_code=404)
mock_request.side_effect = [ping_resp, catalog_resp]
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryCatalogError):
adapter.list_repositories()
class TestOciAdapterListTags:
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_tags(self, mock_request):
ping_resp = MagicMock(status_code=200)
tags_resp = MagicMock(status_code=200, headers={})
tags_resp.json.return_value = {"tags": ["latest", "v1.0"]}
mock_request.side_effect = [ping_resp, tags_resp]
adapter = OciRegistryAdapter("reg.io")
tags = adapter.list_tags("myapp")
assert tags == ["latest", "v1.0"]
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_list_tags_null_tags(self, mock_request):
ping_resp = MagicMock(status_code=200)
tags_resp = MagicMock(status_code=200, headers={})
tags_resp.json.return_value = {"tags": None}
mock_request.side_effect = [ping_resp, tags_resp]
adapter = OciRegistryAdapter("reg.io")
tags = adapter.list_tags("myapp")
assert tags == []
class TestOciAdapterRetry:
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_retry_on_429(self, mock_request, mock_sleep):
resp_429 = MagicMock(status_code=429)
resp_200 = MagicMock(status_code=200)
mock_request.side_effect = [resp_429, resp_200]
adapter = OciRegistryAdapter("reg.io")
result = adapter._request_with_retry("GET", "https://reg.io/v2/")
assert result.status_code == 200
mock_sleep.assert_called_once()
@patch("prowler.providers.image.lib.registry.base.time.sleep")
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_connection_error_retries(self, mock_request, mock_sleep):
mock_request.side_effect = requests.exceptions.ConnectionError("failed")
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryNetworkError):
adapter._request_with_retry("GET", "https://reg.io/v2/")
assert mock_request.call_count == 3
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_timeout_raises_immediately(self, mock_request):
mock_request.side_effect = requests.exceptions.Timeout("timeout")
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryNetworkError):
adapter._request_with_retry("GET", "https://reg.io/v2/")
assert mock_request.call_count == 1
class TestOciAdapterNextPageUrl:
def test_no_link_header(self):
resp = MagicMock(headers={})
assert OciRegistryAdapter._next_page_url(resp) is None
def test_link_header_with_next(self):
resp = MagicMock(
headers={"Link": '<https://reg.io/v2/_catalog?n=200&last=b>; rel="next"'}
)
assert (
OciRegistryAdapter._next_page_url(resp)
== "https://reg.io/v2/_catalog?n=200&last=b"
)
def test_link_header_no_next(self):
resp = MagicMock(
headers={"Link": '<https://reg.io/v2/_catalog?n=200>; rel="prev"'}
)
assert OciRegistryAdapter._next_page_url(resp) is None
class TestOciAdapterSSRF:
def test_reject_file_scheme(self):
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryAuthError, match="disallowed scheme"):
adapter._validate_realm_url("file:///etc/passwd")
def test_reject_ftp_scheme(self):
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryAuthError, match="disallowed scheme"):
adapter._validate_realm_url("ftp://evil.com/token")
def test_reject_private_ip(self):
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryAuthError, match="private/loopback"):
adapter._validate_realm_url("https://10.0.0.1/token")
def test_reject_loopback(self):
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryAuthError, match="private/loopback"):
adapter._validate_realm_url("https://127.0.0.1/token")
def test_reject_link_local(self):
adapter = OciRegistryAdapter("reg.io")
with pytest.raises(ImageRegistryAuthError, match="private/loopback"):
adapter._validate_realm_url("https://169.254.169.254/latest/meta-data")
def test_accept_public_https(self):
adapter = OciRegistryAdapter("reg.io")
# Should not raise
adapter._validate_realm_url("https://auth.example.com/token")
def test_accept_hostname_not_ip(self):
adapter = OciRegistryAdapter("reg.io")
# Hostnames (not IPs) should pass even if they resolve to private IPs
adapter._validate_realm_url("https://internal.corp.com/token")
class TestOciAdapterEmptyToken:
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_empty_bearer_token_raises(self, mock_request):
ping_resp = MagicMock(
status_code=401,
headers={
"Www-Authenticate": 'Bearer realm="https://auth.example.com/token",service="registry"'
},
)
token_resp = MagicMock(status_code=200)
token_resp.json.return_value = {"token": "", "access_token": ""}
mock_request.side_effect = [ping_resp, token_resp]
adapter = OciRegistryAdapter("reg.io", username="u", password="p")
with pytest.raises(ImageRegistryAuthError, match="empty token"):
adapter._ensure_auth()
@patch("prowler.providers.image.lib.registry.base.requests.request")
def test_none_bearer_token_raises(self, mock_request):
ping_resp = MagicMock(
status_code=401,
headers={
"Www-Authenticate": 'Bearer realm="https://auth.example.com/token",service="registry"'
},
)
token_resp = MagicMock(status_code=200)
token_resp.json.return_value = {}
mock_request.side_effect = [ping_resp, token_resp]
adapter = OciRegistryAdapter("reg.io", username="u", password="p")
with pytest.raises(ImageRegistryAuthError, match="empty token"):
adapter._ensure_auth()
class TestOciAdapterNarrowExcept:
def test_invalid_utf8_base64_falls_through(self):
# Create a base64 string that decodes to invalid UTF-8
invalid_bytes = base64.b64encode(b"\xff\xfe").decode()
adapter = OciRegistryAdapter("ecr.aws", username="AWS", password=invalid_bytes)
user, pwd = adapter._resolve_basic_credentials()
assert user == "AWS"
assert pwd == invalid_bytes
class TestCredentialRedaction:
def test_getstate_redacts_credentials(self):
adapter = OciRegistryAdapter(
"reg.io", username="u", password="secret", token="tok"
)
state = adapter.__getstate__()
assert state["_password"] == "***"
assert state["_token"] == "***"
assert state["username"] == "u"
assert state["registry_url"] == "reg.io"
def test_getstate_none_credentials(self):
adapter = OciRegistryAdapter("reg.io")
state = adapter.__getstate__()
assert state["_password"] is None
assert state["_token"] is None
def test_repr_redacts_credentials(self):
adapter = OciRegistryAdapter(
"reg.io", username="u", password="s3cret_pw", token="s3cret_tk"
)
r = repr(adapter)
assert "s3cret_pw" not in r
assert "s3cret_tk" not in r
assert "<redacted>" in r
def test_properties_still_work(self):
adapter = OciRegistryAdapter("reg.io", password="secret", token="tok")
assert adapter.password == "secret"
assert adapter.token == "tok"
@@ -0,0 +1,234 @@
import os
from unittest.mock import MagicMock, patch
import pytest
from prowler.providers.image.exceptions.exceptions import (
ImageInvalidFilterError,
ImageMaxImagesExceededError,
)
from prowler.providers.image.image_provider import ImageProvider
from prowler.providers.image.lib.registry.dockerhub_adapter import DockerHubAdapter
_CLEAN_ENV = {
"PATH": os.environ.get("PATH", ""),
"HOME": os.environ.get("HOME", ""),
}
def _build_provider(**overrides):
defaults = dict(
images=[],
registry="myregistry.io",
image_filter=None,
tag_filter=None,
max_images=0,
registry_insecure=False,
registry_list_images=False,
config_content={"image": {}},
)
defaults.update(overrides)
with patch.dict(os.environ, _CLEAN_ENV, clear=True):
return ImageProvider(**defaults)
class TestRegistryEnumeration:
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_enumerate_oci_registry(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = ["app/frontend", "app/backend"]
adapter.list_tags.side_effect = [["latest", "v1.0"], ["latest"]]
mock_factory.return_value = adapter
provider = _build_provider()
assert "myregistry.io/app/frontend:latest" in provider.images
assert "myregistry.io/app/frontend:v1.0" in provider.images
assert "myregistry.io/app/backend:latest" in provider.images
assert len(provider.images) == 3
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_image_filter(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = ["prod/app", "dev/app", "staging/app"]
adapter.list_tags.return_value = ["latest"]
mock_factory.return_value = adapter
provider = _build_provider(image_filter="^prod/")
assert len(provider.images) == 1
assert "myregistry.io/prod/app:latest" in provider.images
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_tag_filter(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = ["myapp"]
adapter.list_tags.return_value = ["latest", "v1.0", "v2.0", "dev-abc123"]
mock_factory.return_value = adapter
provider = _build_provider(tag_filter=r"^v\d+\.\d+$")
assert len(provider.images) == 2
assert "myregistry.io/myapp:v1.0" in provider.images
assert "myregistry.io/myapp:v2.0" in provider.images
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_combined_filters(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = ["prod/app", "dev/app"]
adapter.list_tags.return_value = ["latest", "v1.0"]
mock_factory.return_value = adapter
provider = _build_provider(image_filter="^prod/", tag_filter="^v")
assert len(provider.images) == 1
assert "myregistry.io/prod/app:v1.0" in provider.images
class TestMaxImages:
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_max_images_exceeded(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = ["app1", "app2", "app3"]
adapter.list_tags.return_value = ["latest", "v1.0"]
mock_factory.return_value = adapter
with pytest.raises(ImageMaxImagesExceededError):
_build_provider(max_images=2)
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_max_images_not_exceeded(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = ["app1"]
adapter.list_tags.return_value = ["latest"]
mock_factory.return_value = adapter
provider = _build_provider(max_images=10)
assert len(provider.images) == 1
class TestDeduplication:
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_deduplication_with_explicit_images(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = ["myapp"]
adapter.list_tags.return_value = ["latest"]
mock_factory.return_value = adapter
provider = _build_provider(images=["myregistry.io/myapp:latest"])
assert provider.images.count("myregistry.io/myapp:latest") == 1
class TestInvalidFilters:
def test_invalid_image_filter_regex(self):
with pytest.raises(ImageInvalidFilterError):
_build_provider(image_filter="[invalid")
def test_invalid_tag_filter_regex(self):
with pytest.raises(ImageInvalidFilterError):
_build_provider(tag_filter="(unclosed")
class TestRegistryInsecure:
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_insecure_passes_verify_false(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = ["app"]
adapter.list_tags.return_value = ["latest"]
mock_factory.return_value = adapter
_build_provider(registry_insecure=True)
mock_factory.assert_called_once()
call_kwargs = mock_factory.call_args[1]
assert call_kwargs["verify_ssl"] is False
class TestEmptyRegistry:
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_empty_catalog_with_explicit_images(self, mock_factory):
adapter = MagicMock()
adapter.list_repositories.return_value = []
mock_factory.return_value = adapter
provider = _build_provider(images=["nginx:latest"])
assert provider.images == ["nginx:latest"]
class TestRegistryList:
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_registry_list_prints_and_returns(self, mock_factory, capsys):
adapter = MagicMock()
adapter.list_repositories.return_value = ["app/frontend", "app/backend"]
adapter.list_tags.side_effect = [["latest", "v1.0"], ["latest"]]
mock_factory.return_value = adapter
provider = _build_provider(registry_list_images=True)
assert provider._listing_only is True
captured = capsys.readouterr()
assert "app/frontend" in captured.out
assert "app/backend" in captured.out
assert "latest" in captured.out
assert "v1.0" in captured.out
assert "2 repositories" in captured.out
assert "3 images" in captured.out
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_registry_list_respects_image_filter(self, mock_factory, capsys):
adapter = MagicMock()
adapter.list_repositories.return_value = ["prod/app", "dev/app"]
adapter.list_tags.return_value = ["latest"]
mock_factory.return_value = adapter
provider = _build_provider(registry_list_images=True, image_filter="^prod/")
assert provider._listing_only is True
captured = capsys.readouterr()
assert "prod/app" in captured.out
assert "dev/app" not in captured.out
assert "1 repository" in captured.out
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_registry_list_respects_tag_filter(self, mock_factory, capsys):
adapter = MagicMock()
adapter.list_repositories.return_value = ["myapp"]
adapter.list_tags.return_value = ["latest", "v1.0", "dev-abc"]
mock_factory.return_value = adapter
provider = _build_provider(registry_list_images=True, tag_filter=r"^v\d+\.\d+$")
assert provider._listing_only is True
captured = capsys.readouterr()
assert "v1.0" in captured.out
assert "dev-abc" not in captured.out
assert "1 image)" in captured.out
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_registry_list_skips_max_images(self, mock_factory, capsys):
adapter = MagicMock()
adapter.list_repositories.return_value = ["app1", "app2", "app3"]
adapter.list_tags.return_value = ["latest", "v1.0"]
mock_factory.return_value = adapter
# max_images=1 would normally raise, but --registry-list skips it
provider = _build_provider(registry_list_images=True, max_images=1)
assert provider._listing_only is True
captured = capsys.readouterr()
assert "6 images" in captured.out
class TestDockerHubEnumeration:
@patch("prowler.providers.image.image_provider.create_registry_adapter")
def test_dockerhub_images_use_repo_tag_format(self, mock_factory):
"""Docker Hub images should use repo:tag format without host prefix."""
adapter = MagicMock(spec=DockerHubAdapter)
adapter.list_repositories.return_value = ["myorg/app1", "myorg/app2"]
adapter.list_tags.side_effect = [["latest", "v1.0"], ["latest"]]
mock_factory.return_value = adapter
provider = _build_provider(registry="docker.io/myorg")
# Docker Hub images should NOT have host prefix
assert "myorg/app1:latest" in provider.images
assert "myorg/app1:v1.0" in provider.images
assert "myorg/app2:latest" in provider.images
# Ensure no host prefix was added
for img in provider.images:
assert not img.startswith("docker.io/"), f"Unexpected host prefix in {img}"
assert len(provider.images) == 3