mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-28 02:49:53 +00:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ba2dee434 | ||
|
|
7ad152f5fe | ||
|
|
98d27ecfa0 | ||
|
|
15e4107065 | ||
|
|
53a72a3c7c | ||
|
|
f44b20ff4c | ||
|
|
231fcf98d0 | ||
|
|
c7093013f9 | ||
|
|
6e96cb0874 | ||
|
|
296fa0f984 | ||
|
|
9a46fca8dd | ||
|
|
66e5a03f9f | ||
|
|
ccd561f0f1 | ||
|
|
9e1b78e64f | ||
|
|
536f90ced3 | ||
|
|
5453c02fd4 | ||
|
|
230e11be8a | ||
|
|
20625954a3 |
@@ -2,6 +2,20 @@
|
||||
|
||||
All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.10.2] (Prowler v5.9.2)
|
||||
|
||||
### Changed
|
||||
- Optimized queries for resources views [(#8336)](https://github.com/prowler-cloud/prowler/pull/8336)
|
||||
|
||||
---
|
||||
|
||||
## [v1.10.1] (Prowler v5.9.1)
|
||||
|
||||
### Fixed
|
||||
- Calculate failed findings during scans to prevent heavy database queries [(#8322)](https://github.com/prowler-cloud/prowler/pull/8322)
|
||||
|
||||
---
|
||||
|
||||
## [v1.10.0] (Prowler v5.9.0)
|
||||
|
||||
### Added
|
||||
@@ -12,7 +26,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- `/processors` endpoints to post-process findings. Currently, only the Mutelist processor is supported to allow to mute findings.
|
||||
- Optimized the underlying queries for resources endpoints [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
|
||||
- Optimized include parameters for resources view [(#8229)](https://github.com/prowler-cloud/prowler/pull/8229)
|
||||
- Optimized overview background tasks [(#8300)](https://github.com/prowler-cloud/prowler/pull/8300)
|
||||
- Optimized overview background tasks [(#8300)](https://github.com/prowler-cloud/prowler/pull/8300)
|
||||
|
||||
### Fixed
|
||||
- Search filter for findings and resources [(#8112)](https://github.com/prowler-cloud/prowler/pull/8112)
|
||||
|
||||
3374
api/poetry.lock
generated
3374
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ dependencies = [
|
||||
"drf-spectacular-jsonapi==0.5.1",
|
||||
"gunicorn==23.0.0",
|
||||
"lxml==5.3.2",
|
||||
"prowler @ git+https://github.com/prowler-cloud/prowler.git@master",
|
||||
"prowler @ git+https://github.com/prowler-cloud/prowler.git@v5.9",
|
||||
"psycopg2-binary==2.9.9",
|
||||
"pytest-celery[redis] (>=1.0.1,<2.0.0)",
|
||||
"sentry-sdk[django] (>=2.20.0,<3.0.0)",
|
||||
@@ -38,7 +38,7 @@ name = "prowler-api"
|
||||
package-mode = false
|
||||
# Needed for the SDK compatibility
|
||||
requires-python = ">=3.11,<3.13"
|
||||
version = "1.10.0"
|
||||
version = "1.10.2"
|
||||
|
||||
[project.scripts]
|
||||
celery = "src.backend.config.settings.celery"
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
atomic = False
|
||||
|
||||
dependencies = [
|
||||
("api", "0039_resource_resources_failed_findings_idx"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(
|
||||
partial(
|
||||
create_index_on_partitions,
|
||||
parent_table="resource_finding_mappings",
|
||||
index_name="rfm_tenant_resource_idx",
|
||||
columns="tenant_id, resource_id",
|
||||
method="BTREE",
|
||||
),
|
||||
reverse_code=partial(
|
||||
drop_index_on_partitions,
|
||||
parent_table="resource_finding_mappings",
|
||||
index_name="rfm_tenant_resource_idx",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0040_rfm_tenant_resource_index_partitions"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddIndex(
|
||||
model_name="resourcefindingmapping",
|
||||
index=models.Index(
|
||||
fields=["tenant_id", "resource_id"],
|
||||
name="rfm_tenant_resource_idx",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from django.contrib.postgres.operations import AddIndexConcurrently
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
atomic = False
|
||||
|
||||
dependencies = [
|
||||
("api", "0041_rfm_tenant_resource_parent_partitions"),
|
||||
("django_celery_beat", "0019_alter_periodictasks_options"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
AddIndexConcurrently(
|
||||
model_name="scan",
|
||||
index=models.Index(
|
||||
condition=models.Q(("state", "completed")),
|
||||
fields=["tenant_id", "provider_id", "-inserted_at"],
|
||||
include=("id",),
|
||||
name="scans_prov_ins_desc_idx",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -476,6 +476,13 @@ class Scan(RowLevelSecurityProtectedModel):
|
||||
condition=Q(state=StateChoices.COMPLETED),
|
||||
name="scans_prov_state_ins_desc_idx",
|
||||
),
|
||||
# TODO This might replace `scans_prov_state_ins_desc_idx` completely. Review usage
|
||||
models.Index(
|
||||
fields=["tenant_id", "provider_id", "-inserted_at"],
|
||||
condition=Q(state=StateChoices.COMPLETED),
|
||||
include=["id"],
|
||||
name="scans_prov_ins_desc_idx",
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
@@ -860,6 +867,10 @@ class ResourceFindingMapping(PostgresPartitionedModel, RowLevelSecurityProtected
|
||||
fields=["tenant_id", "finding_id"],
|
||||
name="rfm_tenant_finding_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "resource_id"],
|
||||
name="rfm_tenant_resource_idx",
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: Prowler API
|
||||
version: 1.10.0
|
||||
version: 1.10.2
|
||||
description: |-
|
||||
Prowler API specification.
|
||||
|
||||
|
||||
@@ -5188,6 +5188,8 @@ class TestComplianceOverviewViewSet:
|
||||
assert "description" in attributes
|
||||
assert "status" in attributes
|
||||
|
||||
# TODO: This test may fail randomly because requirements are not ordered
|
||||
@pytest.mark.xfail
|
||||
def test_compliance_overview_requirements_manual(
|
||||
self, authenticated_client, compliance_requirements_overviews_fixture
|
||||
):
|
||||
|
||||
@@ -22,7 +22,7 @@ from django.conf import settings as django_settings
|
||||
from django.contrib.postgres.aggregates import ArrayAgg
|
||||
from django.contrib.postgres.search import SearchQuery
|
||||
from django.db import transaction
|
||||
from django.db.models import Count, F, Prefetch, Q, Sum
|
||||
from django.db.models import Count, F, Prefetch, Q, Subquery, Sum
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.http import HttpResponse
|
||||
from django.shortcuts import redirect
|
||||
@@ -292,7 +292,7 @@ class SchemaView(SpectacularAPIView):
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
spectacular_settings.TITLE = "Prowler API"
|
||||
spectacular_settings.VERSION = "1.10.0"
|
||||
spectacular_settings.VERSION = "1.10.2"
|
||||
spectacular_settings.DESCRIPTION = (
|
||||
"Prowler API specification.\n\nThis file is auto-generated."
|
||||
)
|
||||
@@ -1994,6 +1994,21 @@ class ResourceViewSet(PaginateByPkMixin, BaseRLSViewSet):
|
||||
)
|
||||
)
|
||||
|
||||
def _should_prefetch_findings(self) -> bool:
|
||||
fields_param = self.request.query_params.get("fields[resources]", "")
|
||||
include_param = self.request.query_params.get("include", "")
|
||||
return (
|
||||
fields_param == ""
|
||||
or "findings" in fields_param.split(",")
|
||||
or "findings" in include_param.split(",")
|
||||
)
|
||||
|
||||
def _get_findings_prefetch(self):
|
||||
findings_queryset = Finding.all_objects.defer("scan", "resources").filter(
|
||||
tenant_id=self.request.tenant_id
|
||||
)
|
||||
return [Prefetch("findings", queryset=findings_queryset)]
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action in ["metadata", "metadata_latest"]:
|
||||
return ResourceMetadataSerializer
|
||||
@@ -2017,7 +2032,11 @@ class ResourceViewSet(PaginateByPkMixin, BaseRLSViewSet):
|
||||
filtered_queryset,
|
||||
manager=Resource.all_objects,
|
||||
select_related=["provider"],
|
||||
prefetch_related=["findings"],
|
||||
prefetch_related=(
|
||||
self._get_findings_prefetch()
|
||||
if self._should_prefetch_findings()
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
@@ -2042,14 +2061,18 @@ class ResourceViewSet(PaginateByPkMixin, BaseRLSViewSet):
|
||||
tenant_id = request.tenant_id
|
||||
filtered_queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
latest_scan_ids = (
|
||||
Scan.all_objects.filter(tenant_id=tenant_id, state=StateChoices.COMPLETED)
|
||||
latest_scans = (
|
||||
Scan.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
state=StateChoices.COMPLETED,
|
||||
)
|
||||
.order_by("provider_id", "-inserted_at")
|
||||
.distinct("provider_id")
|
||||
.values_list("id", flat=True)
|
||||
.values("provider_id")
|
||||
)
|
||||
|
||||
filtered_queryset = filtered_queryset.filter(
|
||||
tenant_id=tenant_id, provider__scan__in=latest_scan_ids
|
||||
provider_id__in=Subquery(latest_scans)
|
||||
)
|
||||
|
||||
return self.paginate_by_pk(
|
||||
@@ -2057,7 +2080,11 @@ class ResourceViewSet(PaginateByPkMixin, BaseRLSViewSet):
|
||||
filtered_queryset,
|
||||
manager=Resource.all_objects,
|
||||
select_related=["provider"],
|
||||
prefetch_related=["findings"],
|
||||
prefetch_related=(
|
||||
self._get_findings_prefetch()
|
||||
if self._should_prefetch_findings()
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="metadata")
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
|
||||
from django.db import IntegrityError, OperationalError, connection
|
||||
from django.db import IntegrityError, OperationalError
|
||||
from django.db.models import Case, Count, IntegerField, Prefetch, Sum, When
|
||||
from tasks.utils import CustomEncoder
|
||||
|
||||
@@ -13,7 +14,11 @@ from api.compliance import (
|
||||
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
|
||||
generate_scan_compliance,
|
||||
)
|
||||
from api.db_utils import create_objects_in_batches, rls_transaction
|
||||
from api.db_utils import (
|
||||
create_objects_in_batches,
|
||||
rls_transaction,
|
||||
update_objects_in_batches,
|
||||
)
|
||||
from api.exceptions import ProviderConnectionError
|
||||
from api.models import (
|
||||
ComplianceRequirementOverview,
|
||||
@@ -103,7 +108,10 @@ def _store_resources(
|
||||
|
||||
|
||||
def perform_prowler_scan(
|
||||
tenant_id: str, scan_id: str, provider_id: str, checks_to_execute: list[str] = None
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
provider_id: str,
|
||||
checks_to_execute: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Perform a scan using Prowler and store the findings and resources in the database.
|
||||
@@ -175,6 +183,7 @@ def perform_prowler_scan(
|
||||
resource_cache = {}
|
||||
tag_cache = {}
|
||||
last_status_cache = {}
|
||||
resource_failed_findings_cache = defaultdict(int)
|
||||
|
||||
for progress, findings in prowler_scan.scan():
|
||||
for finding in findings:
|
||||
@@ -200,6 +209,9 @@ def perform_prowler_scan(
|
||||
},
|
||||
)
|
||||
resource_cache[resource_uid] = resource_instance
|
||||
|
||||
# Initialize all processed resources in the cache
|
||||
resource_failed_findings_cache[resource_uid] = 0
|
||||
else:
|
||||
resource_instance = resource_cache[resource_uid]
|
||||
|
||||
@@ -313,6 +325,11 @@ def perform_prowler_scan(
|
||||
)
|
||||
finding_instance.add_resources([resource_instance])
|
||||
|
||||
# Increment failed_findings_count cache if the finding status is FAIL and not muted
|
||||
if status == FindingStatus.FAIL and not finding.muted:
|
||||
resource_uid = finding.resource_uid
|
||||
resource_failed_findings_cache[resource_uid] += 1
|
||||
|
||||
# Update scan resource summaries
|
||||
scan_resource_cache.add(
|
||||
(
|
||||
@@ -330,6 +347,24 @@ def perform_prowler_scan(
|
||||
|
||||
scan_instance.state = StateChoices.COMPLETED
|
||||
|
||||
# Update failed_findings_count for all resources in batches if scan completed successfully
|
||||
if resource_failed_findings_cache:
|
||||
resources_to_update = []
|
||||
for resource_uid, failed_count in resource_failed_findings_cache.items():
|
||||
if resource_uid in resource_cache:
|
||||
resource_instance = resource_cache[resource_uid]
|
||||
resource_instance.failed_findings_count = failed_count
|
||||
resources_to_update.append(resource_instance)
|
||||
|
||||
if resources_to_update:
|
||||
update_objects_in_batches(
|
||||
tenant_id=tenant_id,
|
||||
model=Resource,
|
||||
objects=resources_to_update,
|
||||
fields=["failed_findings_count"],
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing scan {scan_id}: {e}")
|
||||
exception = e
|
||||
@@ -376,7 +411,6 @@ def perform_prowler_scan(
|
||||
def aggregate_findings(tenant_id: str, scan_id: str):
|
||||
"""
|
||||
Aggregates findings for a given scan and stores the results in the ScanSummary table.
|
||||
Also updates the failed_findings_count for each resource based on the latest findings.
|
||||
|
||||
This function retrieves all findings associated with a given `scan_id` and calculates various
|
||||
metrics such as counts of failed, passed, and muted findings, as well as their deltas (new,
|
||||
@@ -405,8 +439,6 @@ def aggregate_findings(tenant_id: str, scan_id: str):
|
||||
- muted_new: Muted findings with a delta of 'new'.
|
||||
- muted_changed: Muted findings with a delta of 'changed'.
|
||||
"""
|
||||
_update_resource_failed_findings_count(tenant_id, scan_id)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
findings = Finding.objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
|
||||
@@ -531,48 +563,6 @@ def aggregate_findings(tenant_id: str, scan_id: str):
|
||||
ScanSummary.objects.bulk_create(scan_aggregations, batch_size=3000)
|
||||
|
||||
|
||||
def _update_resource_failed_findings_count(tenant_id: str, scan_id: str):
|
||||
"""
|
||||
Update the failed_findings_count field for resources based on the latest findings.
|
||||
|
||||
This function calculates the number of failed findings for each resource by:
|
||||
1. Getting the latest finding for each finding.uid
|
||||
2. Counting failed findings per resource
|
||||
3. Updating the failed_findings_count field for each resource
|
||||
|
||||
Args:
|
||||
tenant_id (str): The ID of the tenant to which the scan belongs.
|
||||
scan_id (str): The ID of the scan for which to update resource counts.
|
||||
"""
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
scan = Scan.objects.get(pk=scan_id)
|
||||
provider_id = str(scan.provider_id)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE resources AS r
|
||||
SET failed_findings_count = COALESCE((
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT DISTINCT ON (f.uid) f.uid
|
||||
FROM findings AS f
|
||||
JOIN resource_finding_mappings AS rfm
|
||||
ON rfm.finding_id = f.id
|
||||
WHERE f.tenant_id = %s
|
||||
AND f.status = %s
|
||||
AND f.muted = FALSE
|
||||
AND rfm.resource_id = r.id
|
||||
ORDER BY f.uid, f.inserted_at DESC
|
||||
) AS latest_uids
|
||||
), 0)
|
||||
WHERE r.tenant_id = %s
|
||||
AND r.provider_id = %s
|
||||
""",
|
||||
[tenant_id, FindingStatus.FAIL, tenant_id, provider_id],
|
||||
)
|
||||
|
||||
|
||||
def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
"""
|
||||
Create detailed compliance requirement overview records for a scan.
|
||||
|
||||
@@ -7,22 +7,14 @@ import pytest
|
||||
from tasks.jobs.scan import (
|
||||
_create_finding_delta,
|
||||
_store_resources,
|
||||
_update_resource_failed_findings_count,
|
||||
create_compliance_requirements,
|
||||
perform_prowler_scan,
|
||||
)
|
||||
from tasks.utils import CustomEncoder
|
||||
|
||||
from api.exceptions import ProviderConnectionError
|
||||
from api.models import (
|
||||
Finding,
|
||||
Provider,
|
||||
Resource,
|
||||
Scan,
|
||||
Severity,
|
||||
StateChoices,
|
||||
StatusChoices,
|
||||
)
|
||||
from api.models import Finding, Provider, Resource, Scan, StateChoices, StatusChoices
|
||||
from prowler.lib.check.models import Severity
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -182,6 +174,9 @@ class TestPerformScan:
|
||||
assert tag_keys == set(finding.resource_tags.keys())
|
||||
assert tag_values == set(finding.resource_tags.values())
|
||||
|
||||
# Assert that failed_findings_count is 0 (finding is PASS and muted)
|
||||
assert scan_resource.failed_findings_count == 0
|
||||
|
||||
@patch("tasks.jobs.scan.ProwlerScan")
|
||||
@patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider",
|
||||
@@ -386,6 +381,359 @@ class TestPerformScan:
|
||||
assert resource == resource_instance
|
||||
assert resource_uid_tuple == (resource_instance.uid, resource_instance.region)
|
||||
|
||||
def test_perform_prowler_scan_with_failed_findings(
|
||||
self,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""Test that failed findings increment the failed_findings_count"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
|
||||
patch(
|
||||
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE",
|
||||
new_callable=dict,
|
||||
),
|
||||
patch("api.compliance.PROWLER_CHECKS", new_callable=dict),
|
||||
):
|
||||
# Ensure the database is empty
|
||||
assert Finding.objects.count() == 0
|
||||
assert Resource.objects.count() == 0
|
||||
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
# Ensure the provider type is 'aws'
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
tenant_id = str(tenant.id)
|
||||
scan_id = str(scan.id)
|
||||
provider_id = str(provider.id)
|
||||
|
||||
# Mock a FAIL finding that is not muted
|
||||
fail_finding = MagicMock()
|
||||
fail_finding.uid = "fail_finding_uid"
|
||||
fail_finding.status = StatusChoices.FAIL
|
||||
fail_finding.status_extended = "test fail status"
|
||||
fail_finding.severity = Severity.high
|
||||
fail_finding.check_id = "fail_check"
|
||||
fail_finding.get_metadata.return_value = {"key": "value"}
|
||||
fail_finding.resource_uid = "resource_uid_fail"
|
||||
fail_finding.resource_name = "fail_resource"
|
||||
fail_finding.region = "us-east-1"
|
||||
fail_finding.service_name = "ec2"
|
||||
fail_finding.resource_type = "instance"
|
||||
fail_finding.resource_tags = {"env": "test"}
|
||||
fail_finding.muted = False
|
||||
fail_finding.raw = {}
|
||||
fail_finding.resource_metadata = {"test": "metadata"}
|
||||
fail_finding.resource_details = {"details": "test"}
|
||||
fail_finding.partition = "aws"
|
||||
fail_finding.compliance = {"compliance1": "FAIL"}
|
||||
|
||||
# Mock the ProwlerScan instance
|
||||
mock_prowler_scan_instance = MagicMock()
|
||||
mock_prowler_scan_instance.scan.return_value = [(100, [fail_finding])]
|
||||
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
|
||||
|
||||
# Mock prowler_provider
|
||||
mock_prowler_provider_instance = MagicMock()
|
||||
mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"]
|
||||
mock_initialize_prowler_provider.return_value = (
|
||||
mock_prowler_provider_instance
|
||||
)
|
||||
|
||||
# Call the function under test
|
||||
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
|
||||
|
||||
# Refresh instances from the database
|
||||
scan.refresh_from_db()
|
||||
scan_resource = Resource.objects.get(provider=provider)
|
||||
|
||||
# Assert that failed_findings_count is 1 (one FAIL finding not muted)
|
||||
assert scan_resource.failed_findings_count == 1
|
||||
|
||||
def test_perform_prowler_scan_multiple_findings_same_resource(
|
||||
self,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""Test that multiple FAIL findings on the same resource increment the counter correctly"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
|
||||
patch(
|
||||
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE",
|
||||
new_callable=dict,
|
||||
),
|
||||
patch("api.compliance.PROWLER_CHECKS", new_callable=dict),
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
tenant_id = str(tenant.id)
|
||||
scan_id = str(scan.id)
|
||||
provider_id = str(provider.id)
|
||||
|
||||
# Create multiple findings for the same resource
|
||||
# Two FAIL findings (not muted) and one PASS finding
|
||||
resource_uid = "shared_resource_uid"
|
||||
|
||||
fail_finding_1 = MagicMock()
|
||||
fail_finding_1.uid = "fail_finding_1"
|
||||
fail_finding_1.status = StatusChoices.FAIL
|
||||
fail_finding_1.status_extended = "fail 1"
|
||||
fail_finding_1.severity = Severity.high
|
||||
fail_finding_1.check_id = "fail_check_1"
|
||||
fail_finding_1.get_metadata.return_value = {"key": "value1"}
|
||||
fail_finding_1.resource_uid = resource_uid
|
||||
fail_finding_1.resource_name = "shared_resource"
|
||||
fail_finding_1.region = "us-east-1"
|
||||
fail_finding_1.service_name = "ec2"
|
||||
fail_finding_1.resource_type = "instance"
|
||||
fail_finding_1.resource_tags = {}
|
||||
fail_finding_1.muted = False
|
||||
fail_finding_1.raw = {}
|
||||
fail_finding_1.resource_metadata = {}
|
||||
fail_finding_1.resource_details = {}
|
||||
fail_finding_1.partition = "aws"
|
||||
fail_finding_1.compliance = {}
|
||||
|
||||
fail_finding_2 = MagicMock()
|
||||
fail_finding_2.uid = "fail_finding_2"
|
||||
fail_finding_2.status = StatusChoices.FAIL
|
||||
fail_finding_2.status_extended = "fail 2"
|
||||
fail_finding_2.severity = Severity.medium
|
||||
fail_finding_2.check_id = "fail_check_2"
|
||||
fail_finding_2.get_metadata.return_value = {"key": "value2"}
|
||||
fail_finding_2.resource_uid = resource_uid
|
||||
fail_finding_2.resource_name = "shared_resource"
|
||||
fail_finding_2.region = "us-east-1"
|
||||
fail_finding_2.service_name = "ec2"
|
||||
fail_finding_2.resource_type = "instance"
|
||||
fail_finding_2.resource_tags = {}
|
||||
fail_finding_2.muted = False
|
||||
fail_finding_2.raw = {}
|
||||
fail_finding_2.resource_metadata = {}
|
||||
fail_finding_2.resource_details = {}
|
||||
fail_finding_2.partition = "aws"
|
||||
fail_finding_2.compliance = {}
|
||||
|
||||
pass_finding = MagicMock()
|
||||
pass_finding.uid = "pass_finding"
|
||||
pass_finding.status = StatusChoices.PASS
|
||||
pass_finding.status_extended = "pass"
|
||||
pass_finding.severity = Severity.low
|
||||
pass_finding.check_id = "pass_check"
|
||||
pass_finding.get_metadata.return_value = {"key": "value3"}
|
||||
pass_finding.resource_uid = resource_uid
|
||||
pass_finding.resource_name = "shared_resource"
|
||||
pass_finding.region = "us-east-1"
|
||||
pass_finding.service_name = "ec2"
|
||||
pass_finding.resource_type = "instance"
|
||||
pass_finding.resource_tags = {}
|
||||
pass_finding.muted = False
|
||||
pass_finding.raw = {}
|
||||
pass_finding.resource_metadata = {}
|
||||
pass_finding.resource_details = {}
|
||||
pass_finding.partition = "aws"
|
||||
pass_finding.compliance = {}
|
||||
|
||||
# Mock the ProwlerScan instance
|
||||
mock_prowler_scan_instance = MagicMock()
|
||||
mock_prowler_scan_instance.scan.return_value = [
|
||||
(100, [fail_finding_1, fail_finding_2, pass_finding])
|
||||
]
|
||||
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
|
||||
|
||||
# Mock prowler_provider
|
||||
mock_prowler_provider_instance = MagicMock()
|
||||
mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"]
|
||||
mock_initialize_prowler_provider.return_value = (
|
||||
mock_prowler_provider_instance
|
||||
)
|
||||
|
||||
# Call the function under test
|
||||
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
|
||||
|
||||
# Refresh instances from the database
|
||||
scan_resource = Resource.objects.get(provider=provider, uid=resource_uid)
|
||||
|
||||
# Assert that failed_findings_count is 2 (two FAIL findings, one PASS)
|
||||
assert scan_resource.failed_findings_count == 2
|
||||
|
||||
def test_perform_prowler_scan_with_muted_findings(
|
||||
self,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""Test that muted FAIL findings do not increment the failed_findings_count"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
|
||||
patch(
|
||||
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE",
|
||||
new_callable=dict,
|
||||
),
|
||||
patch("api.compliance.PROWLER_CHECKS", new_callable=dict),
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
tenant_id = str(tenant.id)
|
||||
scan_id = str(scan.id)
|
||||
provider_id = str(provider.id)
|
||||
|
||||
# Mock a FAIL finding that is muted
|
||||
muted_fail_finding = MagicMock()
|
||||
muted_fail_finding.uid = "muted_fail_finding"
|
||||
muted_fail_finding.status = StatusChoices.FAIL
|
||||
muted_fail_finding.status_extended = "muted fail"
|
||||
muted_fail_finding.severity = Severity.high
|
||||
muted_fail_finding.check_id = "muted_fail_check"
|
||||
muted_fail_finding.get_metadata.return_value = {"key": "value"}
|
||||
muted_fail_finding.resource_uid = "muted_resource_uid"
|
||||
muted_fail_finding.resource_name = "muted_resource"
|
||||
muted_fail_finding.region = "us-east-1"
|
||||
muted_fail_finding.service_name = "ec2"
|
||||
muted_fail_finding.resource_type = "instance"
|
||||
muted_fail_finding.resource_tags = {}
|
||||
muted_fail_finding.muted = True
|
||||
muted_fail_finding.raw = {}
|
||||
muted_fail_finding.resource_metadata = {}
|
||||
muted_fail_finding.resource_details = {}
|
||||
muted_fail_finding.partition = "aws"
|
||||
muted_fail_finding.compliance = {}
|
||||
|
||||
# Mock the ProwlerScan instance
|
||||
mock_prowler_scan_instance = MagicMock()
|
||||
mock_prowler_scan_instance.scan.return_value = [(100, [muted_fail_finding])]
|
||||
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
|
||||
|
||||
# Mock prowler_provider
|
||||
mock_prowler_provider_instance = MagicMock()
|
||||
mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"]
|
||||
mock_initialize_prowler_provider.return_value = (
|
||||
mock_prowler_provider_instance
|
||||
)
|
||||
|
||||
# Call the function under test
|
||||
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
|
||||
|
||||
# Refresh instances from the database
|
||||
scan_resource = Resource.objects.get(provider=provider)
|
||||
|
||||
# Assert that failed_findings_count is 0 (FAIL finding is muted)
|
||||
assert scan_resource.failed_findings_count == 0
|
||||
|
||||
def test_perform_prowler_scan_reset_failed_findings_count(
|
||||
self,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
resources_fixture,
|
||||
):
|
||||
"""Test that failed_findings_count is reset to 0 at the beginning of each scan"""
|
||||
# Use existing resource from fixture and set initial failed_findings_count
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
resource = resources_fixture[0]
|
||||
|
||||
# Set a non-zero failed_findings_count initially
|
||||
resource.failed_findings_count = 5
|
||||
resource.save()
|
||||
|
||||
# Create a new scan
|
||||
scan = Scan.objects.create(
|
||||
name="Reset Test Scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.AVAILABLE,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
|
||||
patch(
|
||||
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE",
|
||||
new_callable=dict,
|
||||
),
|
||||
patch("api.compliance.PROWLER_CHECKS", new_callable=dict),
|
||||
):
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
tenant_id = str(tenant.id)
|
||||
scan_id = str(scan.id)
|
||||
provider_id = str(provider.id)
|
||||
|
||||
# Mock a PASS finding for the existing resource
|
||||
pass_finding = MagicMock()
|
||||
pass_finding.uid = "reset_test_finding"
|
||||
pass_finding.status = StatusChoices.PASS
|
||||
pass_finding.status_extended = "reset test pass"
|
||||
pass_finding.severity = Severity.low
|
||||
pass_finding.check_id = "reset_test_check"
|
||||
pass_finding.get_metadata.return_value = {"key": "value"}
|
||||
pass_finding.resource_uid = resource.uid
|
||||
pass_finding.resource_name = resource.name
|
||||
pass_finding.region = resource.region
|
||||
pass_finding.service_name = resource.service
|
||||
pass_finding.resource_type = resource.type
|
||||
pass_finding.resource_tags = {}
|
||||
pass_finding.muted = False
|
||||
pass_finding.raw = {}
|
||||
pass_finding.resource_metadata = {}
|
||||
pass_finding.resource_details = {}
|
||||
pass_finding.partition = "aws"
|
||||
pass_finding.compliance = {}
|
||||
|
||||
# Mock the ProwlerScan instance
|
||||
mock_prowler_scan_instance = MagicMock()
|
||||
mock_prowler_scan_instance.scan.return_value = [(100, [pass_finding])]
|
||||
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
|
||||
|
||||
# Mock prowler_provider
|
||||
mock_prowler_provider_instance = MagicMock()
|
||||
mock_prowler_provider_instance.get_regions.return_value = [resource.region]
|
||||
mock_initialize_prowler_provider.return_value = (
|
||||
mock_prowler_provider_instance
|
||||
)
|
||||
|
||||
# Call the function under test
|
||||
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
|
||||
|
||||
# Refresh resource from the database
|
||||
resource.refresh_from_db()
|
||||
|
||||
# Assert that failed_findings_count was reset to 0 during the scan
|
||||
assert resource.failed_findings_count == 0
|
||||
|
||||
|
||||
# TODO Add tests for aggregations
|
||||
|
||||
@@ -697,68 +1045,3 @@ class TestCreateComplianceRequirements:
|
||||
|
||||
assert "requirements_created" in result
|
||||
assert result["requirements_created"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestUpdateResourceFailedFindingsCount:
|
||||
def test_execute_sql_update(
|
||||
self, tenants_fixture, scans_fixture, providers_fixture, resources_fixture
|
||||
):
|
||||
resource = resources_fixture[0]
|
||||
tenant_id = resource.tenant_id
|
||||
scan_id = resource.provider.scans.first().id
|
||||
|
||||
# Common kwargs for all failing findings
|
||||
base_kwargs = {
|
||||
"tenant_id": tenant_id,
|
||||
"scan_id": scan_id,
|
||||
"delta": None,
|
||||
"status": StatusChoices.FAIL,
|
||||
"status_extended": "test status extended",
|
||||
"impact": Severity.critical,
|
||||
"impact_extended": "test impact extended",
|
||||
"severity": Severity.critical,
|
||||
"raw_result": {
|
||||
"status": StatusChoices.FAIL,
|
||||
"impact": Severity.critical,
|
||||
"severity": Severity.critical,
|
||||
},
|
||||
"tags": {"test": "dev-qa"},
|
||||
"check_id": "test_check_id",
|
||||
"check_metadata": {
|
||||
"CheckId": "test_check_id",
|
||||
"Description": "test description apple sauce",
|
||||
"servicename": "ec2",
|
||||
},
|
||||
"first_seen_at": "2024-01-02T00:00:00Z",
|
||||
}
|
||||
|
||||
# UIDs to create (two with same UID, one unique)
|
||||
uids = ["test_finding_uid_1", "test_finding_uid_1", "test_finding_uid_2"]
|
||||
|
||||
# Create findings and associate with the resource
|
||||
for uid in uids:
|
||||
finding = Finding.objects.create(uid=uid, **base_kwargs)
|
||||
finding.add_resources([resource])
|
||||
|
||||
resource.refresh_from_db()
|
||||
assert resource.failed_findings_count == 0
|
||||
|
||||
_update_resource_failed_findings_count(tenant_id=tenant_id, scan_id=scan_id)
|
||||
resource.refresh_from_db()
|
||||
|
||||
# Only two since two findings share the same UID
|
||||
assert resource.failed_findings_count == 2
|
||||
|
||||
@patch("tasks.jobs.scan.Scan.objects.get")
|
||||
def test_scan_not_found(
|
||||
self,
|
||||
mock_scan_get,
|
||||
):
|
||||
mock_scan_get.side_effect = Scan.DoesNotExist
|
||||
|
||||
with pytest.raises(Scan.DoesNotExist):
|
||||
_update_resource_failed_findings_count(
|
||||
"8614ca97-8370-4183-a7f7-e96a6c7d2c93",
|
||||
"4705bed5-8782-4e8b-bab6-55e8043edaa6",
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
@@ -145,11 +146,11 @@ def _get_script_arguments():
|
||||
|
||||
def _run_prowler(prowler_args):
|
||||
_debug("Running prowler with args: {0}".format(prowler_args), 1)
|
||||
_prowler_command = "{prowler}/prowler {args}".format(
|
||||
prowler=PATH_TO_PROWLER, args=prowler_args
|
||||
_prowler_command = shlex.split(
|
||||
"{prowler}/prowler {args}".format(prowler=PATH_TO_PROWLER, args=prowler_args)
|
||||
)
|
||||
_debug("Running command: {0}".format(_prowler_command), 2)
|
||||
_process = subprocess.Popen(_prowler_command, stdout=subprocess.PIPE, shell=True)
|
||||
_debug("Running command: {0}".format(" ".join(_prowler_command)), 2)
|
||||
_process = subprocess.Popen(_prowler_command, stdout=subprocess.PIPE)
|
||||
_output, _error = _process.communicate()
|
||||
_debug("Raw prowler output: {0}".format(_output), 3)
|
||||
_debug("Raw prowler error: {0}".format(_error), 3)
|
||||
|
||||
791
poetry.lock
generated
791
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,32 @@
|
||||
|
||||
All notable changes to the **Prowler SDK** are documented in this file.
|
||||
|
||||
## [v5.10.0] (Prowler UNRELEASED)
|
||||
|
||||
### Added
|
||||
- Add `bedrock_api_key_no_administrative_privileges` check for AWS provider [(#8321)](https://github.com/prowler-cloud/prowler/pull/8321)
|
||||
- Support App Key Content in GitHub provider [(#8271)](https://github.com/prowler-cloud/prowler/pull/8271)
|
||||
|
||||
---
|
||||
|
||||
## [v5.9.3] (Prowler UNRELEASED)
|
||||
|
||||
### Fixed
|
||||
- Add more validations to Azure Storage models when some values are None to avoid serialization issues [(#8325)](https://github.com/prowler-cloud/prowler/pull/8325)
|
||||
- `sns_topics_not_publicly_accessible` false positive with `aws:SourceArn` conditions [(#8326)](https://github.com/prowler-cloud/prowler/issues/8326)
|
||||
- Remove typo from description req 1.2.3 - Prowler ThreatScore m365 [(#8384)](https://github.com/prowler-cloud/prowler/pull/8384)
|
||||
- Way of counting FAILED/PASS reqs from `kisa_isms_p_2023_aws` table [(#8382)](https://github.com/prowler-cloud/prowler/pull/8382)
|
||||
- Avoid multiple module error calls in M365 provider [(#8353)](https://github.com/prowler-cloud/prowler/pull/8353)
|
||||
- Tweaks from Prowler ThreatScore in order to handle the correct reqs [(#8401)](https://github.com/prowler-cloud/prowler/pull/8401)
|
||||
---
|
||||
|
||||
## [v5.9.2] (Prowler v5.9.2)
|
||||
|
||||
### Fixed
|
||||
- Use the correct resource name in `defender_domain_dkim_enabled` check [(#8334)](https://github.com/prowler-cloud/prowler/pull/8334)
|
||||
|
||||
---
|
||||
|
||||
## [v5.9.0] (Prowler v5.9.0)
|
||||
|
||||
### Added
|
||||
|
||||
@@ -6,24 +6,6 @@
|
||||
"Requirements": [
|
||||
{
|
||||
"Id": "1.1.1",
|
||||
"Description": "Ensure Security Defaults is enabled on Microsoft Entra ID",
|
||||
"Checks": [
|
||||
"entra_security_defaults_enabled"
|
||||
],
|
||||
"Attributes": [
|
||||
{
|
||||
"Title": "Security Defaults enabled on Entra ID",
|
||||
"Section": "1. IAM",
|
||||
"SubSection": "1.1 Authentication",
|
||||
"AttributeDescription": "Microsoft Entra ID Security Defaults offer preconfigured security settings designed to protect organizations from common identity attacks at no additional cost. These settings enforce basic security measures such as MFA registration, risk-based authentication prompts, and blocking legacy authentication clients that do not support MFA. Security defaults are available to all organizations and can be enabled via the Azure portal to strengthen authentication security.",
|
||||
"AdditionalInformation": "Security defaults provide built-in protections to reduce the risk of unauthorized access until organizations configure their own identity security policies. By requiring MFA, blocking weak authentication methods, and adapting authentication challenges based on risk factors, these settings create a stronger security foundation without additional licensing requirements.",
|
||||
"LevelOfRisk": 4,
|
||||
"Weight": 100
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"Id": "1.1.2",
|
||||
"Description": "Ensure that 'Multi-Factor Auth Status' is 'Enabled' for all Privileged Users",
|
||||
"Checks": [
|
||||
"entra_privileged_user_has_mfa"
|
||||
@@ -41,7 +23,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"Id": "1.1.3",
|
||||
"Id": "1.1.2",
|
||||
"Description": "Ensure that 'Multi-Factor Auth Status' is 'Enabled' for all Non-Privileged Users",
|
||||
"Checks": [
|
||||
"entra_non_privileged_user_has_mfa"
|
||||
@@ -59,7 +41,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"Id": "1.1.4",
|
||||
"Id": "1.1.3",
|
||||
"Description": "Ensure Multi-factor Authentication is Required for Windows Azure Service Management API",
|
||||
"Checks": [
|
||||
"entra_conditional_access_policy_require_mfa_for_management_api"
|
||||
@@ -77,7 +59,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"Id": "1.1.5",
|
||||
"Id": "1.1.4",
|
||||
"Description": "Ensure Multi-factor Authentication is Required to access Microsoft Admin Portals",
|
||||
"Checks": [
|
||||
"defender_ensure_defender_for_server_is_on"
|
||||
@@ -95,7 +77,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"Id": "1.1.6",
|
||||
"Id": "1.1.5",
|
||||
"Description": "Ensure only MFA enabled identities can access privileged Virtual Machine",
|
||||
"Checks": [
|
||||
"entra_user_with_vm_access_has_mfa"
|
||||
|
||||
@@ -310,6 +310,24 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"Id": "1.1.18",
|
||||
"Description": "Ensure that only administrative roles have access to Microsoft Admin Portals",
|
||||
"Checks": [
|
||||
"entra_admin_portals_access_restriction"
|
||||
],
|
||||
"Attributes": [
|
||||
{
|
||||
"Title": "Only administrative roles have access to Microsoft Admin Portals",
|
||||
"Section": "1. IAM",
|
||||
"SubSection": "1.1 Authentication",
|
||||
"AttributeDescription": "Restrict access to Microsoft Admin Portals exclusively to administrative roles to prevent unauthorized modifications, privilege escalation, and security misconfigurations",
|
||||
"AdditionalInformation": "Granting non-administrative users access to Microsoft Admin Portals exposes the environment to unauthorized changes, potential elevation of privileges, and misconfigured security settings. This could allow attackers to alter configurations, disable protections, or gain access to sensitive information.",
|
||||
"LevelOfRisk": 4,
|
||||
"Weight": 100
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"Id": "1.2.1",
|
||||
"Description": "Ensure that only organizationally managed/approved public groups exist",
|
||||
@@ -348,7 +366,7 @@
|
||||
},
|
||||
{
|
||||
"Id": "1.2.3",
|
||||
"Description": "entra_managed_device_required_for_mfa_registration",
|
||||
"Description": "Ensure the admin consent workflow is enabled.",
|
||||
"Checks": [
|
||||
"entra_admin_consent_workflow_enabled"
|
||||
],
|
||||
|
||||
@@ -12,7 +12,7 @@ from prowler.lib.logger import logger
|
||||
|
||||
timestamp = datetime.today()
|
||||
timestamp_utc = datetime.now(timezone.utc).replace(tzinfo=timezone.utc)
|
||||
prowler_version = "5.9.0"
|
||||
prowler_version = "5.9.3"
|
||||
html_logo_url = "https://github.com/prowler-cloud/prowler/"
|
||||
square_logo_img = "https://prowler.com/wp-content/uploads/logo-html.png"
|
||||
aws_logo = "https://user-images.githubusercontent.com/38561120/235953920-3e3fba08-0795-41dc-b480-9bea57db9f2e.png"
|
||||
|
||||
@@ -13,6 +13,7 @@ def get_kisa_ismsp_table(
|
||||
compliance_overview: bool,
|
||||
):
|
||||
sections = {}
|
||||
sections_status = {}
|
||||
kisa_ismsp_compliance_table = {
|
||||
"Provider": [],
|
||||
"Section": [],
|
||||
@@ -36,7 +37,10 @@ def get_kisa_ismsp_table(
|
||||
# Check if Section exists
|
||||
if section not in sections:
|
||||
sections[section] = {
|
||||
"Status": f"{Fore.GREEN}PASS{Style.RESET_ALL}",
|
||||
"Status": {
|
||||
"PASS": 0,
|
||||
"FAIL": 0,
|
||||
},
|
||||
"Muted": 0,
|
||||
}
|
||||
if finding.muted:
|
||||
@@ -46,14 +50,29 @@ def get_kisa_ismsp_table(
|
||||
else:
|
||||
if finding.status == "FAIL" and index not in fail_count:
|
||||
fail_count.append(index)
|
||||
sections[section]["Status"]["FAIL"] += 1
|
||||
elif finding.status == "PASS" and index not in pass_count:
|
||||
pass_count.append(index)
|
||||
sections[section]["Status"]["PASS"] += 1
|
||||
|
||||
# Add results to table
|
||||
sections = dict(sorted(sections.items()))
|
||||
for section in sections:
|
||||
if sections[section]["Status"]["FAIL"] > 0:
|
||||
sections_status[section] = (
|
||||
f"{Fore.RED}FAIL({sections[section]['Status']['FAIL']}){Style.RESET_ALL}"
|
||||
)
|
||||
else:
|
||||
if sections[section]["Status"]["PASS"] > 0:
|
||||
sections_status[section] = (
|
||||
f"{Fore.GREEN}PASS({sections[section]['Status']['PASS']}){Style.RESET_ALL}"
|
||||
)
|
||||
else:
|
||||
sections_status[section] = f"{Fore.GREEN}PASS{Style.RESET_ALL}"
|
||||
for section in sections:
|
||||
kisa_ismsp_compliance_table["Provider"].append(compliance.Provider)
|
||||
kisa_ismsp_compliance_table["Section"].append(section)
|
||||
kisa_ismsp_compliance_table["Status"].append(sections_status[section])
|
||||
kisa_ismsp_compliance_table["Muted"].append(
|
||||
f"{orange_color}{sections[section]['Muted']}{Style.RESET_ALL}"
|
||||
)
|
||||
|
||||
@@ -223,6 +223,108 @@ def check_full_service_access(service: str, policy: dict) -> bool:
|
||||
return all_target_service_actions.issubset(actions_allowed_on_all_resources)
|
||||
|
||||
|
||||
def has_public_principal(statement: dict) -> bool:
|
||||
"""
|
||||
Check if a policy statement has a public principal.
|
||||
|
||||
Args:
|
||||
statement (dict): IAM policy statement
|
||||
|
||||
Returns:
|
||||
bool: True if the statement has a public principal, False otherwise
|
||||
"""
|
||||
principal = statement.get("Principal", "")
|
||||
return (
|
||||
"*" in principal
|
||||
or "arn:aws:iam::*:root" in principal
|
||||
or (
|
||||
isinstance(principal, dict)
|
||||
and (
|
||||
"*" in principal.get("AWS", "")
|
||||
or "arn:aws:iam::*:root" in principal.get("AWS", "")
|
||||
or (
|
||||
isinstance(principal.get("AWS"), list)
|
||||
and (
|
||||
"*" in principal["AWS"]
|
||||
or "arn:aws:iam::*:root" in principal["AWS"]
|
||||
)
|
||||
)
|
||||
or "*" in principal.get("CanonicalUser", "")
|
||||
or "arn:aws:iam::*:root" in principal.get("CanonicalUser", "")
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def has_restrictive_source_arn_condition(
|
||||
statement: dict, source_account: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a policy statement has a restrictive aws:SourceArn condition.
|
||||
|
||||
A SourceArn condition is considered restrictive if:
|
||||
1. It doesn't contain overly permissive wildcards (like "*" or "arn:aws:s3:::*")
|
||||
2. When source_account is provided, the ARN either contains no account field (like S3 buckets)
|
||||
or contains the source_account
|
||||
|
||||
Args:
|
||||
statement (dict): IAM policy statement
|
||||
source_account (str): The account to check restrictions for (optional)
|
||||
|
||||
Returns:
|
||||
bool: True if the statement has a restrictive aws:SourceArn condition, False otherwise
|
||||
"""
|
||||
if "Condition" not in statement:
|
||||
return False
|
||||
|
||||
for condition_operator in statement["Condition"]:
|
||||
for condition_key, condition_value in statement["Condition"][
|
||||
condition_operator
|
||||
].items():
|
||||
if condition_key.lower() == "aws:sourcearn":
|
||||
arn_values = (
|
||||
condition_value
|
||||
if isinstance(condition_value, list)
|
||||
else [condition_value]
|
||||
)
|
||||
|
||||
for arn_value in arn_values:
|
||||
if (
|
||||
arn_value == "*" # Global wildcard
|
||||
or arn_value.count("*")
|
||||
>= 3 # Too many wildcards (e.g., arn:aws:*:*:*:*)
|
||||
or (
|
||||
isinstance(arn_value, str)
|
||||
and (
|
||||
arn_value.endswith(
|
||||
":::*"
|
||||
) # Service-wide wildcard (e.g., arn:aws:s3:::*)
|
||||
or arn_value.endswith(
|
||||
":*"
|
||||
) # Resource wildcard (e.g., arn:aws:sns:us-east-1:123456789012:*)
|
||||
)
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
if source_account:
|
||||
arn_parts = arn_value.split(":")
|
||||
if len(arn_parts) > 4 and arn_parts[4] and arn_parts[4] != "*":
|
||||
if arn_parts[4].isdigit():
|
||||
if source_account not in arn_value:
|
||||
return False
|
||||
else:
|
||||
if arn_parts[4] != source_account:
|
||||
return False
|
||||
elif len(arn_parts) > 4 and arn_parts[4] == "*":
|
||||
return False
|
||||
# else: ARN doesn't contain account field (like S3 bucket), so it's restrictive
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_condition_restricting_from_private_ip(condition_statement: dict) -> bool:
|
||||
"""Check if the policy condition is coming from a private IP address.
|
||||
|
||||
@@ -303,61 +405,49 @@ def is_policy_public(
|
||||
for statement in policy.get("Statement", []):
|
||||
# Only check allow statements
|
||||
if statement["Effect"] == "Allow":
|
||||
has_public_access = has_public_principal(statement)
|
||||
|
||||
principal = statement.get("Principal", "")
|
||||
if (
|
||||
"*" in principal
|
||||
or "arn:aws:iam::*:root" in principal
|
||||
or (
|
||||
isinstance(principal, dict)
|
||||
and (
|
||||
"*" in principal.get("AWS", "")
|
||||
or "arn:aws:iam::*:root" in principal.get("AWS", "")
|
||||
or (
|
||||
isinstance(principal.get("AWS"), str)
|
||||
and source_account
|
||||
and not is_cross_account_allowed
|
||||
and source_account not in principal.get("AWS", "")
|
||||
)
|
||||
or (
|
||||
isinstance(principal.get("AWS"), list)
|
||||
and (
|
||||
"*" in principal["AWS"]
|
||||
or "arn:aws:iam::*:root" in principal["AWS"]
|
||||
or (
|
||||
source_account
|
||||
and not is_cross_account_allowed
|
||||
and not any(
|
||||
source_account in principal_aws
|
||||
for principal_aws in principal["AWS"]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
or "*" in principal.get("CanonicalUser", "")
|
||||
or "arn:aws:iam::*:root"
|
||||
in principal.get("CanonicalUser", "")
|
||||
or check_cross_service_confused_deputy
|
||||
and (
|
||||
# Check if function can be invoked by other AWS services if check_cross_service_confused_deputy is True
|
||||
(
|
||||
".amazonaws.com" in principal.get("Service", "")
|
||||
or ".amazon.com" in principal.get("Service", "")
|
||||
or "*" in principal.get("Service", "")
|
||||
)
|
||||
and (
|
||||
"secretsmanager.amazonaws.com"
|
||||
not in principal.get(
|
||||
"Service", ""
|
||||
) # AWS ensures that resources called by SecretsManager are executed in the same AWS account
|
||||
or "eks.amazonaws.com"
|
||||
not in principal.get(
|
||||
"Service", ""
|
||||
) # AWS ensures that resources called by EKS are executed in the same AWS account
|
||||
)
|
||||
)
|
||||
if not has_public_access and isinstance(principal, dict):
|
||||
# Check for cross-account access when not allowed
|
||||
if (
|
||||
isinstance(principal.get("AWS"), str)
|
||||
and source_account
|
||||
and not is_cross_account_allowed
|
||||
and source_account not in principal.get("AWS", "")
|
||||
) or (
|
||||
isinstance(principal.get("AWS"), list)
|
||||
and source_account
|
||||
and not is_cross_account_allowed
|
||||
and not any(
|
||||
source_account in principal_aws
|
||||
for principal_aws in principal["AWS"]
|
||||
)
|
||||
)
|
||||
) and (
|
||||
):
|
||||
has_public_access = True
|
||||
|
||||
# Check for cross-service confused deputy
|
||||
if check_cross_service_confused_deputy and (
|
||||
# Check if function can be invoked by other AWS services if check_cross_service_confused_deputy is True
|
||||
(
|
||||
".amazonaws.com" in principal.get("Service", "")
|
||||
or ".amazon.com" in principal.get("Service", "")
|
||||
or "*" in principal.get("Service", "")
|
||||
)
|
||||
and (
|
||||
"secretsmanager.amazonaws.com"
|
||||
not in principal.get(
|
||||
"Service", ""
|
||||
) # AWS ensures that resources called by SecretsManager are executed in the same AWS account
|
||||
or "eks.amazonaws.com"
|
||||
not in principal.get(
|
||||
"Service", ""
|
||||
) # AWS ensures that resources called by EKS are executed in the same AWS account
|
||||
)
|
||||
):
|
||||
has_public_access = True
|
||||
|
||||
if has_public_access and (
|
||||
not not_allowed_actions # If not_allowed_actions is empty, the function will not consider the actions in the policy
|
||||
or (
|
||||
statement.get(
|
||||
@@ -498,9 +588,29 @@ def is_condition_block_restrictive(
|
||||
"aws:sourcevpc" != value
|
||||
and "aws:sourcevpce" != value
|
||||
):
|
||||
if source_account not in item:
|
||||
is_condition_key_restrictive = False
|
||||
break
|
||||
if value == "aws:sourcearn":
|
||||
# Use the specialized function to properly validate SourceArn restrictions
|
||||
# Create a minimal statement to test with our function
|
||||
test_statement = {
|
||||
"Condition": {
|
||||
condition_operator: {
|
||||
value: condition_statement[
|
||||
condition_operator
|
||||
][value]
|
||||
}
|
||||
}
|
||||
}
|
||||
is_condition_key_restrictive = (
|
||||
has_restrictive_source_arn_condition(
|
||||
test_statement, source_account
|
||||
)
|
||||
)
|
||||
if not is_condition_key_restrictive:
|
||||
break
|
||||
else:
|
||||
if source_account not in item:
|
||||
is_condition_key_restrictive = False
|
||||
break
|
||||
|
||||
if is_condition_key_restrictive:
|
||||
is_condition_valid = True
|
||||
@@ -516,11 +626,31 @@ def is_condition_block_restrictive(
|
||||
if is_cross_account_allowed:
|
||||
is_condition_valid = True
|
||||
else:
|
||||
if (
|
||||
source_account
|
||||
in condition_statement[condition_operator][value]
|
||||
):
|
||||
is_condition_valid = True
|
||||
if value == "aws:sourcearn":
|
||||
# Use the specialized function to properly validate SourceArn restrictions
|
||||
# Create a minimal statement to test with our function
|
||||
test_statement = {
|
||||
"Condition": {
|
||||
condition_operator: {
|
||||
value: condition_statement[
|
||||
condition_operator
|
||||
][value]
|
||||
}
|
||||
}
|
||||
}
|
||||
is_condition_valid = (
|
||||
has_restrictive_source_arn_condition(
|
||||
test_statement, source_account
|
||||
)
|
||||
)
|
||||
else:
|
||||
if (
|
||||
source_account
|
||||
in condition_statement[condition_operator][
|
||||
value
|
||||
]
|
||||
):
|
||||
is_condition_valid = True
|
||||
|
||||
return is_condition_valid
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from prowler.lib.check.models import Check, Check_Report_AWS
|
||||
from prowler.providers.aws.services.iam.lib.policy import (
|
||||
has_public_principal,
|
||||
has_restrictive_source_arn_condition,
|
||||
is_condition_block_restrictive,
|
||||
is_condition_block_restrictive_organization,
|
||||
is_condition_block_restrictive_sns_endpoint,
|
||||
@@ -16,46 +18,26 @@ class sns_topics_not_publicly_accessible(Check):
|
||||
report.status_extended = (
|
||||
f"SNS topic {topic.name} is not publicly accessible."
|
||||
)
|
||||
|
||||
if topic.policy:
|
||||
for statement in topic.policy["Statement"]:
|
||||
# Only check allow statements
|
||||
if statement["Effect"] == "Allow":
|
||||
if (
|
||||
"*" in statement["Principal"]
|
||||
or (
|
||||
"AWS" in statement["Principal"]
|
||||
and "*" in statement["Principal"]["AWS"]
|
||||
if statement["Effect"] == "Allow" and has_public_principal(
|
||||
statement
|
||||
):
|
||||
if has_restrictive_source_arn_condition(statement):
|
||||
break
|
||||
elif "Condition" in statement:
|
||||
condition_account = is_condition_block_restrictive(
|
||||
statement["Condition"], sns_client.audited_account
|
||||
)
|
||||
or (
|
||||
"CanonicalUser" in statement["Principal"]
|
||||
and "*" in statement["Principal"]["CanonicalUser"]
|
||||
condition_org = is_condition_block_restrictive_organization(
|
||||
statement["Condition"]
|
||||
)
|
||||
):
|
||||
condition_account = False
|
||||
condition_org = False
|
||||
condition_endpoint = False
|
||||
if (
|
||||
"Condition" in statement
|
||||
and is_condition_block_restrictive(
|
||||
statement["Condition"],
|
||||
sns_client.audited_account,
|
||||
condition_endpoint = (
|
||||
is_condition_block_restrictive_sns_endpoint(
|
||||
statement["Condition"]
|
||||
)
|
||||
):
|
||||
condition_account = True
|
||||
if (
|
||||
"Condition" in statement
|
||||
and is_condition_block_restrictive_organization(
|
||||
statement["Condition"],
|
||||
)
|
||||
):
|
||||
condition_org = True
|
||||
if (
|
||||
"Condition" in statement
|
||||
and is_condition_block_restrictive_sns_endpoint(
|
||||
statement["Condition"],
|
||||
)
|
||||
):
|
||||
condition_endpoint = True
|
||||
)
|
||||
|
||||
if condition_account and condition_org:
|
||||
report.status_extended = f"SNS topic {topic.name} is not public because its policy only allows access from the account {sns_client.audited_account} and an organization."
|
||||
@@ -69,7 +51,11 @@ class sns_topics_not_publicly_accessible(Check):
|
||||
report.status = "FAIL"
|
||||
report.status_extended = f"SNS topic {topic.name} is public because its policy allows public access."
|
||||
break
|
||||
else:
|
||||
# Public principal with no conditions = public
|
||||
report.status = "FAIL"
|
||||
report.status_extended = f"SNS topic {topic.name} is public because its policy allows public access."
|
||||
break
|
||||
|
||||
findings.append(report)
|
||||
|
||||
return findings
|
||||
|
||||
@@ -70,17 +70,44 @@ class Storage(AzureService):
|
||||
],
|
||||
key_expiration_period_in_days=key_expiration_period_in_days,
|
||||
location=storage_account.location,
|
||||
default_to_entra_authorization=getattr(
|
||||
storage_account,
|
||||
"default_to_o_auth_authentication",
|
||||
False,
|
||||
default_to_entra_authorization=(
|
||||
False
|
||||
if getattr(
|
||||
storage_account,
|
||||
"default_to_o_auth_authentication",
|
||||
False,
|
||||
)
|
||||
is None
|
||||
else getattr(
|
||||
storage_account,
|
||||
"default_to_o_auth_authentication",
|
||||
False,
|
||||
)
|
||||
),
|
||||
replication_settings=replication_settings,
|
||||
allow_cross_tenant_replication=getattr(
|
||||
storage_account, "allow_cross_tenant_replication", True
|
||||
allow_cross_tenant_replication=(
|
||||
True
|
||||
if getattr(
|
||||
storage_account,
|
||||
"allow_cross_tenant_replication",
|
||||
True,
|
||||
)
|
||||
is None
|
||||
else getattr(
|
||||
storage_account,
|
||||
"allow_cross_tenant_replication",
|
||||
True,
|
||||
)
|
||||
),
|
||||
allow_shared_key_access=getattr(
|
||||
storage_account, "allow_shared_key_access", True
|
||||
allow_shared_key_access=(
|
||||
True
|
||||
if getattr(
|
||||
storage_account, "allow_shared_key_access", True
|
||||
)
|
||||
is None
|
||||
else getattr(
|
||||
storage_account, "allow_shared_key_access", True
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
68
prowler/providers/m365/lib/jwt/jwt_decoder.py
Normal file
68
prowler/providers/m365/lib/jwt/jwt_decoder.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def decode_jwt(token: str) -> dict:
|
||||
"""
|
||||
Decodes the payload of a JWT without verifying its signature.
|
||||
|
||||
Args:
|
||||
token (str): JWT string in the format 'header.payload.signature'
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the decoded payload (claims), or an empty dict on failure.
|
||||
"""
|
||||
try:
|
||||
# Split the JWT into its 3 parts
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
raise ValueError(
|
||||
"The token does not have the expected three-part structure."
|
||||
)
|
||||
|
||||
# Extract and decode the payload (second part)
|
||||
payload_b64 = parts[1]
|
||||
|
||||
# Add padding if necessary for base64 decoding
|
||||
padding = "=" * (-len(payload_b64) % 4)
|
||||
payload_b64 += padding
|
||||
|
||||
payload_bytes = base64.urlsafe_b64decode(payload_b64)
|
||||
payload_json = payload_bytes.decode("utf-8")
|
||||
payload = json.loads(payload_json)
|
||||
|
||||
return payload
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to decode the token: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def decode_msal_token(text: str) -> dict:
|
||||
"""
|
||||
Extracts and decodes the payload of a MSAL token from a given string.
|
||||
|
||||
Args:
|
||||
text (str): A string that contains the MSAL token, possibly over multiple lines.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the decoded payload (claims), or an empty dict on failure.
|
||||
"""
|
||||
try:
|
||||
# Join all lines and remove whitespace
|
||||
flattened = "".join(text.split())
|
||||
|
||||
# Search for a valid JWT pattern (three base64url parts separated by dots)
|
||||
match = re.search(
|
||||
r"([A-Za-z0-9-_]+\.[A-Za-z0-9-_]+\.[A-Za-z0-9-_]+)", flattened
|
||||
)
|
||||
if not match:
|
||||
raise ValueError("No valid JWT found in the input.")
|
||||
|
||||
token = match.group(1)
|
||||
return decode_jwt(token)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to extract and decode the token: {e}")
|
||||
return {}
|
||||
@@ -1,16 +1,14 @@
|
||||
import os
|
||||
import platform
|
||||
|
||||
import msal
|
||||
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.powershell.powershell import PowerShellSession
|
||||
from prowler.providers.m365.exceptions.exceptions import (
|
||||
M365ExchangeConnectionError,
|
||||
M365GraphConnectionError,
|
||||
M365TeamsConnectionError,
|
||||
M365UserCredentialsError,
|
||||
M365UserNotBelongingToTenantError,
|
||||
)
|
||||
from prowler.providers.m365.lib.jwt.jwt_decoder import decode_jwt, decode_msal_token
|
||||
from prowler.providers.m365.models import M365Credentials, M365IdentityInfo
|
||||
|
||||
|
||||
@@ -162,26 +160,29 @@ class M365PowerShell(PowerShellSession):
|
||||
message=f"The user domain {user_domain} does not match any of the tenant domains: {', '.join(self.tenant_identity.tenant_domains)}",
|
||||
)
|
||||
|
||||
app = msal.ConfidentialClientApplication(
|
||||
client_id=credentials.client_id,
|
||||
client_credential=credentials.client_secret,
|
||||
authority=f"https://login.microsoftonline.com/{credentials.tenant_id}",
|
||||
)
|
||||
|
||||
# Validate credentials
|
||||
result = app.acquire_token_by_username_password(
|
||||
username=credentials.user,
|
||||
password=credentials.passwd,
|
||||
scopes=["https://graph.microsoft.com/.default"],
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise Exception(
|
||||
"Unexpected error: Acquiring token in behalf of user did not return a result."
|
||||
)
|
||||
|
||||
if "access_token" not in result:
|
||||
raise Exception(f"MsGraph Error {result.get('error_description')}")
|
||||
result = self.execute("Connect-ExchangeOnline -Credential $credential")
|
||||
if "https://aka.ms/exov3-module" not in result:
|
||||
if "AADSTS" in result: # Entra Security Token Service Error
|
||||
raise M365UserCredentialsError(
|
||||
file=os.path.basename(__file__),
|
||||
message=result,
|
||||
)
|
||||
else: # Could not connect to Exchange Online, try Microsoft Teams
|
||||
result = self.execute(
|
||||
"Connect-MicrosoftTeams -Credential $credential"
|
||||
)
|
||||
if self.tenant_identity.tenant_id not in result:
|
||||
if "AADSTS" in result: # Entra Security Token Service Error
|
||||
raise M365UserCredentialsError(
|
||||
file=os.path.basename(__file__),
|
||||
message=result,
|
||||
)
|
||||
else: # Unknown error, could be a permission issue or modules not installed
|
||||
raise Exception(
|
||||
file=os.path.basename(__file__),
|
||||
message=f"Error connecting to PowerShell modules: {result}",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -191,6 +192,7 @@ class M365PowerShell(PowerShellSession):
|
||||
logger.info("Testing Microsoft Graph connection...")
|
||||
self.test_graph_connection()
|
||||
logger.info("Microsoft Graph connection successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Microsoft Graph connection failed: {e}")
|
||||
raise M365GraphConnectionError(
|
||||
@@ -199,34 +201,6 @@ class M365PowerShell(PowerShellSession):
|
||||
message="Check your Microsoft Application credentials and ensure the app has proper permissions",
|
||||
)
|
||||
|
||||
# Test Microsoft Teams connection
|
||||
try:
|
||||
logger.info("Testing Microsoft Teams connection...")
|
||||
self.test_teams_connection()
|
||||
logger.info("Microsoft Teams connection successful")
|
||||
except Exception as e:
|
||||
logger.error(f"Microsoft Teams connection failed: {e}")
|
||||
raise M365TeamsConnectionError(
|
||||
file=os.path.basename(__file__),
|
||||
original_exception=e,
|
||||
message="Ensure the application has proper permission granted to access Microsoft Teams.",
|
||||
)
|
||||
|
||||
# Test Exchange Online connection
|
||||
try:
|
||||
logger.info("Testing Exchange Online connection...")
|
||||
self.test_exchange_connection()
|
||||
logger.info("Exchange Online connection successful")
|
||||
except Exception as e:
|
||||
logger.error(f"Exchange Online connection failed: {e}")
|
||||
raise M365ExchangeConnectionError(
|
||||
file=os.path.basename(__file__),
|
||||
original_exception=e,
|
||||
message="Ensure the application has proper permission granted to access Exchange Online.",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def test_graph_connection(self) -> bool:
|
||||
"""Test Microsoft Graph API connection and raise exception if it fails."""
|
||||
try:
|
||||
@@ -253,19 +227,20 @@ class M365PowerShell(PowerShellSession):
|
||||
self.execute(
|
||||
'$teamsToken = Invoke-RestMethod -Uri "https://login.microsoftonline.com/$tenantID/oauth2/v2.0/token" -Method POST -Body $teamstokenBody | Select-Object -ExpandProperty Access_Token'
|
||||
)
|
||||
if self.execute("Write-Output $teamsToken") == "":
|
||||
raise M365TeamsConnectionError(
|
||||
file=os.path.basename(__file__),
|
||||
message="Microsoft Teams token is empty or invalid.",
|
||||
permissions = decode_jwt(self.execute("Write-Output $teamsToken")).get(
|
||||
"roles", []
|
||||
)
|
||||
if "application_access" not in permissions:
|
||||
logger.error(
|
||||
"Microsoft Teams connection failed: Please check your permissions and try again."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Microsoft Teams connection failed: {e}")
|
||||
raise M365TeamsConnectionError(
|
||||
file=os.path.basename(__file__),
|
||||
original_exception=e,
|
||||
message=f"Failed to connect to Microsoft Teams API: {str(e)}",
|
||||
logger.error(
|
||||
f"Microsoft Teams connection failed: {e}. Please check your permissions and try again."
|
||||
)
|
||||
return False
|
||||
|
||||
def test_exchange_connection(self) -> bool:
|
||||
"""Test Exchange Online API connection and raise exception if it fails."""
|
||||
@@ -276,19 +251,19 @@ class M365PowerShell(PowerShellSession):
|
||||
self.execute(
|
||||
'$exchangeToken = Get-MsalToken -clientID "$clientID" -tenantID "$tenantID" -clientSecret $SecureSecret -Scopes "https://outlook.office365.com/.default"'
|
||||
)
|
||||
if self.execute("Write-Output $exchangeToken") == "":
|
||||
raise M365ExchangeConnectionError(
|
||||
file=os.path.basename(__file__),
|
||||
message="Exchange Online token is empty or invalid.",
|
||||
token = decode_msal_token(self.execute("Write-Output $exchangeToken"))
|
||||
permissions = token.get("roles", [])
|
||||
if "Exchange.ManageAsApp" not in permissions:
|
||||
logger.error(
|
||||
"Exchange Online connection failed: Please check your permissions and try again."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Exchange Online connection failed: {e}")
|
||||
raise M365ExchangeConnectionError(
|
||||
file=os.path.basename(__file__),
|
||||
original_exception=e,
|
||||
message=f"Failed to connect to Exchange Online API: {str(e)}",
|
||||
logger.error(
|
||||
f"Exchange Online connection failed: {e}. Please check your permissions and try again."
|
||||
)
|
||||
return False
|
||||
|
||||
def connect_microsoft_teams(self) -> dict:
|
||||
"""
|
||||
@@ -302,18 +277,26 @@ class M365PowerShell(PowerShellSession):
|
||||
Note:
|
||||
This method requires the Microsoft Teams PowerShell module to be installed.
|
||||
"""
|
||||
if self.execute("Write-Output $credential") != "": # User Auth
|
||||
return self.execute("Connect-MicrosoftTeams -Credential $credential")
|
||||
else: # Application Auth
|
||||
self.execute(
|
||||
'$teamstokenBody = @{ Grant_Type = "client_credentials"; Scope = "48ac35b8-9aa8-4d74-927d-1f4a14a0b239/.default"; Client_Id = $clientID; Client_Secret = $clientSecret }'
|
||||
)
|
||||
self.execute(
|
||||
'$teamsToken = Invoke-RestMethod -Uri "https://login.microsoftonline.com/$tenantID/oauth2/v2.0/token" -Method POST -Body $teamstokenBody | Select-Object -ExpandProperty Access_Token'
|
||||
)
|
||||
return self.execute(
|
||||
'Connect-MicrosoftTeams -AccessTokens @("$graphToken","$teamsToken")'
|
||||
)
|
||||
# User Auth
|
||||
if self.execute("Write-Output $credential") != "":
|
||||
self.execute("Connect-MicrosoftTeams -Credential $credential")
|
||||
# Test connection with a simple call
|
||||
connection = self.execute("Get-CsTeamsClientConfiguration")
|
||||
if connection:
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
"Microsoft Teams connection failed: Please check your permissions and try again."
|
||||
)
|
||||
return connection
|
||||
# Application Auth
|
||||
else:
|
||||
connection = self.test_teams_connection()
|
||||
if connection:
|
||||
self.execute(
|
||||
'Connect-MicrosoftTeams -AccessTokens @("$graphToken","$teamsToken")'
|
||||
)
|
||||
return connection
|
||||
|
||||
def get_teams_settings(self) -> dict:
|
||||
"""
|
||||
@@ -407,18 +390,25 @@ class M365PowerShell(PowerShellSession):
|
||||
Note:
|
||||
This method requires the Exchange Online PowerShell module to be installed.
|
||||
"""
|
||||
if self.execute("Write-Output $credential") != "": # User Auth
|
||||
return self.execute("Connect-ExchangeOnline -Credential $credential")
|
||||
else: # Application Auth
|
||||
self.execute(
|
||||
'$SecureSecret = ConvertTo-SecureString "$clientSecret" -AsPlainText -Force'
|
||||
)
|
||||
self.execute(
|
||||
'$exchangeToken = Get-MsalToken -clientID "$clientID" -tenantID "$tenantID" -clientSecret $SecureSecret -Scopes "https://outlook.office365.com/.default"'
|
||||
)
|
||||
return self.execute(
|
||||
'Connect-ExchangeOnline -AccessToken $exchangeToken.AccessToken -Organization "$tenantID"'
|
||||
)
|
||||
# User Auth
|
||||
if self.execute("Write-Output $credential") != "":
|
||||
self.execute("Connect-ExchangeOnline -Credential $credential")
|
||||
connection = self.execute("Get-OrganizationConfig")
|
||||
if connection:
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
"Exchange Online connection failed: Please check your permissions and try again."
|
||||
)
|
||||
return False
|
||||
# Application Auth
|
||||
else:
|
||||
connection = self.test_exchange_connection()
|
||||
if connection:
|
||||
self.execute(
|
||||
'Connect-ExchangeOnline -AccessToken $exchangeToken.AccessToken -Organization "$tenantID"'
|
||||
)
|
||||
return connection
|
||||
|
||||
def get_audit_log_config(self) -> dict:
|
||||
"""
|
||||
|
||||
@@ -15,9 +15,9 @@ class AdminCenter(M365Service):
|
||||
self.organization_config = None
|
||||
self.sharing_policy = None
|
||||
if self.powershell:
|
||||
self.powershell.connect_exchange_online()
|
||||
self.organization_config = self._get_organization_config()
|
||||
self.sharing_policy = self._get_sharing_policy()
|
||||
if self.powershell.connect_exchange_online():
|
||||
self.organization_config = self._get_organization_config()
|
||||
self.sharing_policy = self._get_sharing_policy()
|
||||
self.powershell.close()
|
||||
|
||||
loop = get_event_loop()
|
||||
|
||||
@@ -26,7 +26,7 @@ class defender_domain_dkim_enabled(Check):
|
||||
report = CheckReportM365(
|
||||
metadata=self.metadata(),
|
||||
resource=config,
|
||||
resource_name="DKIM Configuration",
|
||||
resource_name=config.id,
|
||||
resource_id=config.id,
|
||||
)
|
||||
report.status = "FAIL"
|
||||
|
||||
@@ -21,18 +21,18 @@ class Defender(M365Service):
|
||||
self.inbound_spam_rules = {}
|
||||
self.report_submission_policy = None
|
||||
if self.powershell:
|
||||
self.powershell.connect_exchange_online()
|
||||
self.malware_policies = self._get_malware_filter_policy()
|
||||
self.malware_rules = self._get_malware_filter_rule()
|
||||
self.outbound_spam_policies = self._get_outbound_spam_filter_policy()
|
||||
self.outbound_spam_rules = self._get_outbound_spam_filter_rule()
|
||||
self.antiphishing_policies = self._get_antiphishing_policy()
|
||||
self.antiphishing_rules = self._get_antiphishing_rules()
|
||||
self.connection_filter_policy = self._get_connection_filter_policy()
|
||||
self.dkim_configurations = self._get_dkim_config()
|
||||
self.inbound_spam_policies = self._get_inbound_spam_filter_policy()
|
||||
self.inbound_spam_rules = self._get_inbound_spam_filter_rule()
|
||||
self.report_submission_policy = self._get_report_submission_policy()
|
||||
if self.powershell.connect_exchange_online():
|
||||
self.malware_policies = self._get_malware_filter_policy()
|
||||
self.malware_rules = self._get_malware_filter_rule()
|
||||
self.outbound_spam_policies = self._get_outbound_spam_filter_policy()
|
||||
self.outbound_spam_rules = self._get_outbound_spam_filter_rule()
|
||||
self.antiphishing_policies = self._get_antiphishing_policy()
|
||||
self.antiphishing_rules = self._get_antiphishing_rules()
|
||||
self.connection_filter_policy = self._get_connection_filter_policy()
|
||||
self.dkim_configurations = self._get_dkim_config()
|
||||
self.inbound_spam_policies = self._get_inbound_spam_filter_policy()
|
||||
self.inbound_spam_rules = self._get_inbound_spam_filter_rule()
|
||||
self.report_submission_policy = self._get_report_submission_policy()
|
||||
self.powershell.close()
|
||||
|
||||
def _get_malware_filter_policy(self):
|
||||
|
||||
@@ -21,15 +21,15 @@ class Exchange(M365Service):
|
||||
self.mailbox_audit_properties = []
|
||||
|
||||
if self.powershell:
|
||||
self.powershell.connect_exchange_online()
|
||||
self.organization_config = self._get_organization_config()
|
||||
self.mailboxes_config = self._get_mailbox_audit_config()
|
||||
self.external_mail_config = self._get_external_mail_config()
|
||||
self.transport_rules = self._get_transport_rules()
|
||||
self.transport_config = self._get_transport_config()
|
||||
self.mailbox_policy = self._get_mailbox_policy()
|
||||
self.role_assignment_policies = self._get_role_assignment_policies()
|
||||
self.mailbox_audit_properties = self._get_mailbox_audit_properties()
|
||||
if self.powershell.connect_exchange_online():
|
||||
self.organization_config = self._get_organization_config()
|
||||
self.mailboxes_config = self._get_mailbox_audit_config()
|
||||
self.external_mail_config = self._get_external_mail_config()
|
||||
self.transport_rules = self._get_transport_rules()
|
||||
self.transport_config = self._get_transport_config()
|
||||
self.mailbox_policy = self._get_mailbox_policy()
|
||||
self.role_assignment_policies = self._get_role_assignment_policies()
|
||||
self.mailbox_audit_properties = self._get_mailbox_audit_properties()
|
||||
self.powershell.close()
|
||||
|
||||
def _get_organization_config(self):
|
||||
|
||||
@@ -11,8 +11,8 @@ class Purview(M365Service):
|
||||
self.audit_log_config = None
|
||||
|
||||
if self.powershell:
|
||||
self.powershell.connect_exchange_online()
|
||||
self.audit_log_config = self._get_audit_log_config()
|
||||
if self.powershell.connect_exchange_online():
|
||||
self.audit_log_config = self._get_audit_log_config()
|
||||
self.powershell.close()
|
||||
|
||||
def _get_audit_log_config(self):
|
||||
|
||||
@@ -14,11 +14,11 @@ class Teams(M365Service):
|
||||
self.user_settings = None
|
||||
|
||||
if self.powershell:
|
||||
self.powershell.connect_microsoft_teams()
|
||||
self.teams_settings = self._get_teams_client_configuration()
|
||||
self.global_meeting_policy = self._get_global_meeting_policy()
|
||||
self.global_messaging_policy = self._get_global_messaging_policy()
|
||||
self.user_settings = self._get_user_settings()
|
||||
if self.powershell.connect_microsoft_teams():
|
||||
self.teams_settings = self._get_teams_client_configuration()
|
||||
self.global_meeting_policy = self._get_global_meeting_policy()
|
||||
self.global_messaging_policy = self._get_global_messaging_policy()
|
||||
self.user_settings = self._get_user_settings()
|
||||
self.powershell.close()
|
||||
|
||||
def _get_teams_client_configuration(self):
|
||||
|
||||
@@ -71,7 +71,7 @@ maintainers = [{name = "Prowler Engineering", email = "engineering@prowler.com"}
|
||||
name = "prowler"
|
||||
readme = "README.md"
|
||||
requires-python = ">3.9.1,<3.13"
|
||||
version = "5.9.0"
|
||||
version = "5.9.3"
|
||||
|
||||
[project.scripts]
|
||||
prowler = "prowler.__main__:prowler"
|
||||
|
||||
256
tests/contrib/wazuh/prowler_wrapper_security_test.py
Normal file
256
tests/contrib/wazuh/prowler_wrapper_security_test.py
Normal file
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Security test for prowler-wrapper.py command injection vulnerability
|
||||
This test demonstrates the command injection vulnerability and validates the fix
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestProwlerWrapperSecurity(unittest.TestCase):
|
||||
"""Test cases for command injection vulnerability in prowler-wrapper.py"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment"""
|
||||
# Create a temporary directory for testing
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.prowler_wrapper_path = os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
),
|
||||
"contrib",
|
||||
"wazuh",
|
||||
"prowler-wrapper.py",
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment"""
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
|
||||
def _import_prowler_wrapper(self):
|
||||
"""Helper to import prowler_wrapper with mocked WAZUH_PATH"""
|
||||
sys.path.insert(0, os.path.dirname(self.prowler_wrapper_path))
|
||||
|
||||
# Mock the WAZUH_PATH that's read at module level
|
||||
with patch("builtins.open", create=True) as mock_open:
|
||||
mock_open.return_value.readline.return_value = 'DIRECTORY="/opt/wazuh"'
|
||||
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"prowler_wrapper", self.prowler_wrapper_path
|
||||
)
|
||||
prowler_wrapper = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(prowler_wrapper)
|
||||
return prowler_wrapper._run_prowler
|
||||
|
||||
def test_command_injection_semicolon(self):
|
||||
"""Test command injection using semicolon"""
|
||||
# Create a test file that should not be created if injection is prevented
|
||||
test_file = os.path.join(self.test_dir, "pwned.txt")
|
||||
|
||||
# Malicious profile that attempts to create a file
|
||||
malicious_profile = f"test; touch {test_file}"
|
||||
|
||||
# Mock the subprocess.Popen to capture the command
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate.return_value = (b"test output", None)
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
# Import and run the vulnerable function
|
||||
_run_prowler = self._import_prowler_wrapper()
|
||||
|
||||
# Run with malicious input
|
||||
_run_prowler(f'-p "{malicious_profile}" -V')
|
||||
|
||||
# Check that Popen was called
|
||||
self.assertTrue(mock_popen.called)
|
||||
|
||||
# Get the actual command that was passed to Popen
|
||||
actual_command = mock_popen.call_args[0][0]
|
||||
|
||||
# With the fix, the command should be a list (from shlex.split)
|
||||
# and should NOT have shell=True
|
||||
self.assertIsInstance(
|
||||
actual_command, list, "Command should be a list after shlex.split"
|
||||
)
|
||||
|
||||
# Check that shell=True is not in the call
|
||||
call_kwargs = mock_popen.call_args[1]
|
||||
self.assertNotIn(
|
||||
"shell",
|
||||
call_kwargs,
|
||||
"shell parameter should not be present (defaults to False)",
|
||||
)
|
||||
|
||||
def test_command_injection_ampersand(self):
|
||||
"""Test command injection using ampersand"""
|
||||
# Create a test file that should not be created if injection is prevented
|
||||
test_file = os.path.join(self.test_dir, "pwned2.txt")
|
||||
|
||||
# Malicious profile that attempts to create a file
|
||||
malicious_profile = f"test && touch {test_file}"
|
||||
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate.return_value = (b"test output", None)
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
# Import and run the function
|
||||
_run_prowler = self._import_prowler_wrapper()
|
||||
|
||||
# Run with malicious input
|
||||
_run_prowler(f'-p "{malicious_profile}" -V')
|
||||
|
||||
# Get the actual command
|
||||
actual_command = mock_popen.call_args[0][0]
|
||||
|
||||
# Verify it's a list (safe execution)
|
||||
self.assertIsInstance(actual_command, list)
|
||||
|
||||
# The malicious characters should be preserved as part of the argument
|
||||
# not interpreted as shell commands
|
||||
command_str = " ".join(actual_command)
|
||||
self.assertIn(
|
||||
"&&",
|
||||
command_str,
|
||||
"Shell metacharacters should be preserved as literals",
|
||||
)
|
||||
|
||||
def test_command_injection_pipe(self):
|
||||
"""Test command injection using pipe"""
|
||||
malicious_profile = 'test | echo "injected"'
|
||||
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate.return_value = (b"test output", None)
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
# Import and run the function
|
||||
_run_prowler = self._import_prowler_wrapper()
|
||||
|
||||
# Run with malicious input
|
||||
_run_prowler(f'-p "{malicious_profile}" -V')
|
||||
|
||||
# Get the actual command
|
||||
actual_command = mock_popen.call_args[0][0]
|
||||
|
||||
# Verify safe execution
|
||||
self.assertIsInstance(actual_command, list)
|
||||
|
||||
# Pipe should be preserved as literal
|
||||
command_str = " ".join(actual_command)
|
||||
self.assertIn("|", command_str)
|
||||
|
||||
def test_command_injection_backticks(self):
|
||||
"""Test command injection using backticks"""
|
||||
malicious_profile = "test `echo injected`"
|
||||
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate.return_value = (b"test output", None)
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
# Import and run the function
|
||||
_run_prowler = self._import_prowler_wrapper()
|
||||
|
||||
# Run with malicious input
|
||||
_run_prowler(f'-p "{malicious_profile}" -V')
|
||||
|
||||
# Get the actual command
|
||||
actual_command = mock_popen.call_args[0][0]
|
||||
|
||||
# Verify safe execution
|
||||
self.assertIsInstance(actual_command, list)
|
||||
|
||||
# Backticks should be preserved as literals
|
||||
command_str = " ".join(actual_command)
|
||||
self.assertIn("`", command_str)
|
||||
|
||||
def test_command_injection_dollar_parentheses(self):
|
||||
"""Test command injection using $() syntax"""
|
||||
malicious_profile = "test $(echo injected)"
|
||||
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate.return_value = (b"test output", None)
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
# Import and run the function
|
||||
_run_prowler = self._import_prowler_wrapper()
|
||||
|
||||
# Run with malicious input
|
||||
_run_prowler(f'-p "{malicious_profile}" -V')
|
||||
|
||||
# Get the actual command
|
||||
actual_command = mock_popen.call_args[0][0]
|
||||
|
||||
# Verify safe execution
|
||||
self.assertIsInstance(actual_command, list)
|
||||
|
||||
# $() should be preserved as literals
|
||||
command_str = " ".join(actual_command)
|
||||
self.assertIn("$(", command_str)
|
||||
|
||||
def test_legitimate_profile_name(self):
|
||||
"""Test that legitimate profile names still work correctly"""
|
||||
legitimate_profile = "production-aws-profile"
|
||||
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate.return_value = (b"test output", None)
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
# Import and run the function
|
||||
_run_prowler = self._import_prowler_wrapper()
|
||||
|
||||
# Run with legitimate input
|
||||
result = _run_prowler(f"-p {legitimate_profile} -V")
|
||||
|
||||
# Verify the function returns output
|
||||
self.assertEqual(result, b"test output")
|
||||
|
||||
# Verify Popen was called correctly
|
||||
actual_command = mock_popen.call_args[0][0]
|
||||
self.assertIsInstance(actual_command, list)
|
||||
|
||||
# Check the profile is passed correctly
|
||||
command_str = " ".join(actual_command)
|
||||
self.assertIn(legitimate_profile, command_str)
|
||||
|
||||
def test_shlex_split_behavior(self):
|
||||
"""Test that shlex properly handles quoted arguments"""
|
||||
profile_with_spaces = "my profile name"
|
||||
|
||||
with patch("subprocess.Popen") as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate.return_value = (b"test output", None)
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
# Import and run the function
|
||||
_run_prowler = self._import_prowler_wrapper()
|
||||
|
||||
# Run with profile containing spaces
|
||||
_run_prowler(f'-p "{profile_with_spaces}" -V')
|
||||
|
||||
# Get the actual command
|
||||
actual_command = mock_popen.call_args[0][0]
|
||||
|
||||
# Verify it's properly split
|
||||
self.assertIsInstance(actual_command, list)
|
||||
|
||||
# The profile name should be preserved as a single argument
|
||||
# despite containing spaces
|
||||
self.assertIn("my profile name", actual_command)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -404,7 +404,7 @@ class Test_ec2_securitygroup_allow_ingress_from_internet_to_all_ports:
|
||||
new=EC2(aws_provider),
|
||||
),
|
||||
mock.patch(
|
||||
"prowler.providers.aws.services.vpc.vpc_service.VPC",
|
||||
"prowler.providers.aws.services.ec2.ec2_securitygroup_allow_ingress_from_internet_to_all_ports.ec2_securitygroup_allow_ingress_from_internet_to_all_ports.vpc_client",
|
||||
new=VPC(aws_provider),
|
||||
),
|
||||
mock.patch(
|
||||
|
||||
@@ -6,6 +6,8 @@ from prowler.providers.aws.services.iam.lib.policy import (
|
||||
check_full_service_access,
|
||||
get_effective_actions,
|
||||
has_codebuild_trusted_principal,
|
||||
has_public_principal,
|
||||
has_restrictive_source_arn_condition,
|
||||
is_codebuild_using_allowed_github_org,
|
||||
is_condition_block_restrictive,
|
||||
is_condition_block_restrictive_organization,
|
||||
@@ -2451,3 +2453,266 @@ def test_has_codebuild_trusted_principal_list():
|
||||
],
|
||||
}
|
||||
assert has_codebuild_trusted_principal(trust_policy) is True
|
||||
|
||||
|
||||
class Test_has_public_principal:
|
||||
"""Tests for the has_public_principal function"""
|
||||
|
||||
def test_has_public_principal_wildcard_string(self):
|
||||
"""Test public principal detection with wildcard string"""
|
||||
statement = {"Principal": "*"}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
def test_has_public_principal_root_arn_string(self):
|
||||
"""Test public principal detection with root ARN string"""
|
||||
statement = {"Principal": "arn:aws:iam::*:root"}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
def test_has_public_principal_aws_dict_wildcard(self):
|
||||
"""Test public principal detection with AWS dict containing wildcard"""
|
||||
statement = {"Principal": {"AWS": "*"}}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
def test_has_public_principal_aws_dict_root_arn(self):
|
||||
"""Test public principal detection with AWS dict containing root ARN"""
|
||||
statement = {"Principal": {"AWS": "arn:aws:iam::*:root"}}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
def test_has_public_principal_aws_list_wildcard(self):
|
||||
"""Test public principal detection with AWS list containing wildcard"""
|
||||
statement = {"Principal": {"AWS": ["arn:aws:iam::123456789012:user/test", "*"]}}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
def test_has_public_principal_aws_list_root_arn(self):
|
||||
"""Test public principal detection with AWS list containing root ARN"""
|
||||
statement = {
|
||||
"Principal": {
|
||||
"AWS": ["arn:aws:iam::123456789012:user/test", "arn:aws:iam::*:root"]
|
||||
}
|
||||
}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
def test_has_public_principal_canonical_user_wildcard(self):
|
||||
"""Test public principal detection with CanonicalUser wildcard"""
|
||||
statement = {"Principal": {"CanonicalUser": "*"}}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
def test_has_public_principal_canonical_user_root_arn(self):
|
||||
"""Test public principal detection with CanonicalUser root ARN"""
|
||||
statement = {"Principal": {"CanonicalUser": "arn:aws:iam::*:root"}}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
def test_has_public_principal_no_principal(self):
|
||||
"""Test with statement that has no Principal field"""
|
||||
statement = {"Effect": "Allow", "Action": "s3:GetObject"}
|
||||
assert has_public_principal(statement) is False
|
||||
|
||||
def test_has_public_principal_empty_principal(self):
|
||||
"""Test with empty principal"""
|
||||
statement = {"Principal": ""}
|
||||
assert has_public_principal(statement) is False
|
||||
|
||||
def test_has_public_principal_specific_account(self):
|
||||
"""Test with specific account principal (not public)"""
|
||||
statement = {"Principal": {"AWS": "arn:aws:iam::123456789012:root"}}
|
||||
assert has_public_principal(statement) is False
|
||||
|
||||
def test_has_public_principal_service_principal(self):
|
||||
"""Test with service principal (not public)"""
|
||||
statement = {"Principal": {"Service": "lambda.amazonaws.com"}}
|
||||
assert has_public_principal(statement) is False
|
||||
|
||||
def test_has_public_principal_mixed_principals(self):
|
||||
"""Test with mixed principals including public one"""
|
||||
statement = {
|
||||
"Principal": {
|
||||
"AWS": ["arn:aws:iam::123456789012:user/test"],
|
||||
"Service": "lambda.amazonaws.com",
|
||||
"CanonicalUser": "*",
|
||||
}
|
||||
}
|
||||
assert has_public_principal(statement) is True
|
||||
|
||||
|
||||
class Test_has_restrictive_source_arn_condition:
|
||||
"""Tests for the has_restrictive_source_arn_condition function"""
|
||||
|
||||
def test_no_condition_block(self):
|
||||
"""Test statement without Condition block"""
|
||||
statement = {"Effect": "Allow", "Principal": "*", "Action": "s3:GetObject"}
|
||||
assert has_restrictive_source_arn_condition(statement) is False
|
||||
|
||||
def test_no_source_arn_condition(self):
|
||||
"""Test with condition block but no aws:SourceArn"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "s3:GetObject",
|
||||
"Condition": {"StringEquals": {"aws:SourceAccount": "123456789012"}},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is False
|
||||
|
||||
def test_restrictive_source_arn_s3_bucket(self):
|
||||
"""Test restrictive SourceArn condition with S3 bucket"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "arn:aws:s3:::my-bucket"}},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is True
|
||||
|
||||
def test_restrictive_source_arn_lambda_function(self):
|
||||
"""Test restrictive SourceArn condition with Lambda function"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {
|
||||
"ArnEquals": {
|
||||
"aws:SourceArn": "arn:aws:lambda:us-east-1:123456789012:function:MyFunction"
|
||||
}
|
||||
},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is True
|
||||
|
||||
def test_non_restrictive_global_wildcard(self):
|
||||
"""Test non-restrictive SourceArn with global wildcard"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "*"}},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is False
|
||||
|
||||
def test_non_restrictive_service_wildcard(self):
|
||||
"""Test non-restrictive SourceArn with service wildcard"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "arn:aws:s3:::*"}},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is False
|
||||
|
||||
def test_non_restrictive_multi_wildcard(self):
|
||||
"""Test non-restrictive SourceArn with multiple wildcards"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "arn:aws:*:*:*:*"}},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is False
|
||||
|
||||
def test_non_restrictive_resource_wildcard(self):
|
||||
"""Test non-restrictive SourceArn with resource wildcard"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {
|
||||
"ArnLike": {"aws:SourceArn": "arn:aws:lambda:us-east-1:123456789012:*"}
|
||||
},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is False
|
||||
|
||||
def test_source_arn_list_with_valid_arn(self):
|
||||
"""Test SourceArn condition with list containing valid ARN"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {
|
||||
"ArnLike": {
|
||||
"aws:SourceArn": ["arn:aws:s3:::bucket1", "arn:aws:s3:::bucket2"]
|
||||
}
|
||||
},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is True
|
||||
|
||||
def test_source_arn_list_with_wildcard(self):
|
||||
"""Test SourceArn condition with list containing wildcard"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": ["arn:aws:s3:::bucket1", "*"]}},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is False
|
||||
|
||||
def test_source_arn_with_account_validation_match(self):
|
||||
"""Test SourceArn with account validation - matching account"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {
|
||||
"ArnLike": {
|
||||
"aws:SourceArn": "arn:aws:lambda:us-east-1:123456789012:function:MyFunction"
|
||||
}
|
||||
},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement, "123456789012") is True
|
||||
|
||||
def test_source_arn_with_account_validation_mismatch(self):
|
||||
"""Test SourceArn with account validation - non-matching account"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {
|
||||
"ArnLike": {
|
||||
"aws:SourceArn": "arn:aws:lambda:us-east-1:123456789012:function:MyFunction"
|
||||
}
|
||||
},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement, "987654321098") is False
|
||||
|
||||
def test_source_arn_with_account_wildcard(self):
|
||||
"""Test SourceArn with account wildcard"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {
|
||||
"ArnLike": {
|
||||
"aws:SourceArn": "arn:aws:lambda:us-east-1:*:function:MyFunction"
|
||||
}
|
||||
},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement, "123456789012") is False
|
||||
|
||||
def test_source_arn_s3_bucket_no_account_field(self):
|
||||
"""Test SourceArn with S3 bucket (no account field) - should be restrictive"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "arn:aws:s3:::my-bucket"}},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement, "123456789012") is True
|
||||
|
||||
def test_source_arn_case_insensitive(self):
|
||||
"""Test SourceArn condition key is case insensitive"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {"ArnLike": {"AWS:SourceArn": "arn:aws:s3:::my-bucket"}},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is True
|
||||
|
||||
def test_source_arn_mixed_operators(self):
|
||||
"""Test SourceArn with multiple condition operators"""
|
||||
statement = {
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": "sns:Publish",
|
||||
"Condition": {
|
||||
"ArnLike": {"aws:SourceArn": "arn:aws:s3:::my-bucket"},
|
||||
"StringEquals": {"aws:SourceAccount": "123456789012"},
|
||||
},
|
||||
}
|
||||
assert has_restrictive_source_arn_condition(statement) is True
|
||||
|
||||
@@ -2,9 +2,10 @@ from typing import Any, Dict
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from prowler.providers.aws.services.sns.sns_service import Topic
|
||||
from tests.providers.aws.utils import AWS_ACCOUNT_NUMBER, AWS_REGION_EU_WEST_1
|
||||
import pytest
|
||||
|
||||
kms_key_id = str(uuid4())
|
||||
topic_name = "test-topic"
|
||||
@@ -98,6 +99,73 @@ test_policy_restricted_principal_account_organization = {
|
||||
]
|
||||
}
|
||||
|
||||
test_policy_restricted_source_arn = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"AWS": "*"},
|
||||
"Action": "SNS:Publish",
|
||||
"Resource": f"arn:aws:sns:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:{topic_name}",
|
||||
"Condition": {
|
||||
"ArnLike": {"aws:SourceArn": "arn:aws:s3:::test-bucket-name"}
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
test_policy_invalid_source_arn = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"AWS": "*"},
|
||||
"Action": "SNS:Publish",
|
||||
"Resource": f"arn:aws:sns:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:{topic_name}",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "invalid-arn-format"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
test_policy_unrestricted_source_arn_wildcard = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"AWS": "*"},
|
||||
"Action": "SNS:Publish",
|
||||
"Resource": f"arn:aws:sns:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:{topic_name}",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "*"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
test_policy_unrestricted_source_arn_service_wildcard = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"AWS": "*"},
|
||||
"Action": "SNS:Publish",
|
||||
"Resource": f"arn:aws:sns:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:{topic_name}",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "arn:aws:s3:::*"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
test_policy_unrestricted_source_arn_multi_wildcard = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"AWS": "*"},
|
||||
"Action": "SNS:Publish",
|
||||
"Resource": f"arn:aws:sns:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:{topic_name}",
|
||||
"Condition": {"ArnLike": {"aws:SourceArn": "arn:aws:*:*:*:*"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def generate_policy_restricted_on_sns_endpoint(endpoint: str) -> Dict[str, Any]:
|
||||
return {
|
||||
@@ -396,6 +464,78 @@ class Test_sns_topics_not_publicly_accessible:
|
||||
assert result[0].region == AWS_REGION_EU_WEST_1
|
||||
assert result[0].resource_tags == []
|
||||
|
||||
def test_topic_public_with_source_arn_restriction(self):
|
||||
sns_client = mock.MagicMock
|
||||
sns_client.audited_account = AWS_ACCOUNT_NUMBER
|
||||
sns_client.topics = []
|
||||
sns_client.topics.append(
|
||||
Topic(
|
||||
arn=topic_arn,
|
||||
name=topic_name,
|
||||
policy=test_policy_restricted_source_arn,
|
||||
region=AWS_REGION_EU_WEST_1,
|
||||
)
|
||||
)
|
||||
sns_client.provider = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata.organization_id = org_id
|
||||
with mock.patch(
|
||||
"prowler.providers.aws.services.sns.sns_service.SNS",
|
||||
sns_client,
|
||||
):
|
||||
from prowler.providers.aws.services.sns.sns_topics_not_publicly_accessible.sns_topics_not_publicly_accessible import (
|
||||
sns_topics_not_publicly_accessible,
|
||||
)
|
||||
|
||||
check = sns_topics_not_publicly_accessible()
|
||||
result = check.execute()
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "PASS"
|
||||
assert (
|
||||
result[0].status_extended
|
||||
== f"SNS topic {topic_name} is not publicly accessible."
|
||||
)
|
||||
assert result[0].resource_id == topic_name
|
||||
assert result[0].resource_arn == topic_arn
|
||||
assert result[0].region == AWS_REGION_EU_WEST_1
|
||||
assert result[0].resource_tags == []
|
||||
|
||||
def test_topic_public_with_invalid_source_arn(self):
|
||||
sns_client = mock.MagicMock
|
||||
sns_client.audited_account = AWS_ACCOUNT_NUMBER
|
||||
sns_client.topics = []
|
||||
sns_client.topics.append(
|
||||
Topic(
|
||||
arn=topic_arn,
|
||||
name=topic_name,
|
||||
policy=test_policy_invalid_source_arn,
|
||||
region=AWS_REGION_EU_WEST_1,
|
||||
)
|
||||
)
|
||||
sns_client.provider = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata.organization_id = org_id
|
||||
with mock.patch(
|
||||
"prowler.providers.aws.services.sns.sns_service.SNS",
|
||||
sns_client,
|
||||
):
|
||||
from prowler.providers.aws.services.sns.sns_topics_not_publicly_accessible.sns_topics_not_publicly_accessible import (
|
||||
sns_topics_not_publicly_accessible,
|
||||
)
|
||||
|
||||
check = sns_topics_not_publicly_accessible()
|
||||
result = check.execute()
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "PASS"
|
||||
assert (
|
||||
result[0].status_extended
|
||||
== f"SNS topic {topic_name} is not publicly accessible."
|
||||
)
|
||||
assert result[0].resource_id == topic_name
|
||||
assert result[0].resource_arn == topic_arn
|
||||
assert result[0].region == AWS_REGION_EU_WEST_1
|
||||
assert result[0].resource_tags == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
@@ -443,6 +583,114 @@ class Test_sns_topics_not_publicly_accessible:
|
||||
assert result[0].region == AWS_REGION_EU_WEST_1
|
||||
assert result[0].resource_tags == []
|
||||
|
||||
def test_topic_public_with_unrestricted_source_arn_wildcard(self):
|
||||
sns_client = mock.MagicMock
|
||||
sns_client.audited_account = AWS_ACCOUNT_NUMBER
|
||||
sns_client.topics = []
|
||||
sns_client.topics.append(
|
||||
Topic(
|
||||
arn=topic_arn,
|
||||
name=topic_name,
|
||||
policy=test_policy_unrestricted_source_arn_wildcard,
|
||||
region=AWS_REGION_EU_WEST_1,
|
||||
)
|
||||
)
|
||||
sns_client.provider = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata.organization_id = org_id
|
||||
with mock.patch(
|
||||
"prowler.providers.aws.services.sns.sns_service.SNS",
|
||||
sns_client,
|
||||
):
|
||||
from prowler.providers.aws.services.sns.sns_topics_not_publicly_accessible.sns_topics_not_publicly_accessible import (
|
||||
sns_topics_not_publicly_accessible,
|
||||
)
|
||||
|
||||
check = sns_topics_not_publicly_accessible()
|
||||
result = check.execute()
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "FAIL"
|
||||
assert (
|
||||
result[0].status_extended
|
||||
== f"SNS topic {topic_name} is public because its policy allows public access."
|
||||
)
|
||||
assert result[0].resource_id == topic_name
|
||||
assert result[0].resource_arn == topic_arn
|
||||
assert result[0].region == AWS_REGION_EU_WEST_1
|
||||
assert result[0].resource_tags == []
|
||||
|
||||
def test_topic_public_with_unrestricted_source_arn_service_wildcard(self):
|
||||
sns_client = mock.MagicMock
|
||||
sns_client.audited_account = AWS_ACCOUNT_NUMBER
|
||||
sns_client.topics = []
|
||||
sns_client.topics.append(
|
||||
Topic(
|
||||
arn=topic_arn,
|
||||
name=topic_name,
|
||||
policy=test_policy_unrestricted_source_arn_service_wildcard,
|
||||
region=AWS_REGION_EU_WEST_1,
|
||||
)
|
||||
)
|
||||
sns_client.provider = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata.organization_id = org_id
|
||||
with mock.patch(
|
||||
"prowler.providers.aws.services.sns.sns_service.SNS",
|
||||
sns_client,
|
||||
):
|
||||
from prowler.providers.aws.services.sns.sns_topics_not_publicly_accessible.sns_topics_not_publicly_accessible import (
|
||||
sns_topics_not_publicly_accessible,
|
||||
)
|
||||
|
||||
check = sns_topics_not_publicly_accessible()
|
||||
result = check.execute()
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "FAIL"
|
||||
assert (
|
||||
result[0].status_extended
|
||||
== f"SNS topic {topic_name} is public because its policy allows public access."
|
||||
)
|
||||
assert result[0].resource_id == topic_name
|
||||
assert result[0].resource_arn == topic_arn
|
||||
assert result[0].region == AWS_REGION_EU_WEST_1
|
||||
assert result[0].resource_tags == []
|
||||
|
||||
def test_topic_public_with_unrestricted_source_arn_multi_wildcard(self):
|
||||
sns_client = mock.MagicMock
|
||||
sns_client.audited_account = AWS_ACCOUNT_NUMBER
|
||||
sns_client.topics = []
|
||||
sns_client.topics.append(
|
||||
Topic(
|
||||
arn=topic_arn,
|
||||
name=topic_name,
|
||||
policy=test_policy_unrestricted_source_arn_multi_wildcard,
|
||||
region=AWS_REGION_EU_WEST_1,
|
||||
)
|
||||
)
|
||||
sns_client.provider = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata = mock.MagicMock()
|
||||
sns_client.provider.organizations_metadata.organization_id = org_id
|
||||
with mock.patch(
|
||||
"prowler.providers.aws.services.sns.sns_service.SNS",
|
||||
sns_client,
|
||||
):
|
||||
from prowler.providers.aws.services.sns.sns_topics_not_publicly_accessible.sns_topics_not_publicly_accessible import (
|
||||
sns_topics_not_publicly_accessible,
|
||||
)
|
||||
|
||||
check = sns_topics_not_publicly_accessible()
|
||||
result = check.execute()
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "FAIL"
|
||||
assert (
|
||||
result[0].status_extended
|
||||
== f"SNS topic {topic_name} is public because its policy allows public access."
|
||||
)
|
||||
assert result[0].resource_id == topic_name
|
||||
assert result[0].resource_arn == topic_arn
|
||||
assert result[0].region == AWS_REGION_EU_WEST_1
|
||||
assert result[0].resource_tags == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
|
||||
259
tests/providers/m365/lib/jwt/jwt_decoder_test.py
Normal file
259
tests/providers/m365/lib/jwt/jwt_decoder_test.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from prowler.providers.m365.lib.jwt.jwt_decoder import decode_jwt, decode_msal_token
|
||||
|
||||
|
||||
class TestJwtDecoder:
|
||||
def test_decode_jwt_valid_token(self):
|
||||
"""Test decode_jwt with a valid JWT token"""
|
||||
# Create a mock JWT token
|
||||
header = {"alg": "HS256", "typ": "JWT"}
|
||||
payload = {
|
||||
"sub": "1234567890",
|
||||
"name": "John Doe",
|
||||
"iat": 1516239022,
|
||||
"roles": ["application_access", "user_read"],
|
||||
}
|
||||
|
||||
# Encode header and payload
|
||||
header_b64 = (
|
||||
base64.urlsafe_b64encode(json.dumps(header).encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
payload_b64 = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
# Create JWT with dummy signature
|
||||
token = f"{header_b64}.{payload_b64}.dummy_signature"
|
||||
|
||||
result = decode_jwt(token)
|
||||
|
||||
assert result == payload
|
||||
assert result["sub"] == "1234567890"
|
||||
assert result["name"] == "John Doe"
|
||||
assert result["roles"] == ["application_access", "user_read"]
|
||||
|
||||
def test_decode_jwt_valid_token_with_padding(self):
|
||||
"""Test decode_jwt with a token that needs base64 padding"""
|
||||
# Create mock payload that will need padding
|
||||
payload = {"test": "data"}
|
||||
payload_json = json.dumps(payload)
|
||||
|
||||
# Encode mock payload without padding
|
||||
payload_b64 = (
|
||||
base64.urlsafe_b64encode(payload_json.encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
token = f"header.{payload_b64}.signature"
|
||||
|
||||
result = decode_jwt(token)
|
||||
|
||||
assert result == payload
|
||||
|
||||
def test_decode_jwt_invalid_structure_two_parts(self):
|
||||
"""Test decode_jwt with token that has only 2 parts"""
|
||||
token = "header.payload" # Missing signature
|
||||
|
||||
result = decode_jwt(token)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_decode_jwt_invalid_structure_four_parts(self):
|
||||
"""Test decode_jwt with token that has 4 parts"""
|
||||
token = "header.payload.signature.extra"
|
||||
|
||||
result = decode_jwt(token)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_decode_jwt_invalid_base64(self):
|
||||
"""Test decode_jwt with invalid base64 in payload"""
|
||||
token = "header.invalid_base64!@#.signature"
|
||||
|
||||
result = decode_jwt(token)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_decode_jwt_invalid_json(self):
|
||||
"""Test decode_jwt with invalid JSON in payload"""
|
||||
# Create invalid JSON base64
|
||||
invalid_json = "{'invalid': json,}"
|
||||
payload_b64 = (
|
||||
base64.urlsafe_b64encode(invalid_json.encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
token = f"header.{payload_b64}.signature"
|
||||
|
||||
result = decode_jwt(token)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_decode_jwt_empty_token(self):
|
||||
"""Test decode_jwt with empty token"""
|
||||
result = decode_jwt("")
|
||||
assert result == {}
|
||||
|
||||
def test_decode_jwt_none_token(self):
|
||||
"""Test decode_jwt with None token"""
|
||||
assert decode_jwt(None) == {}
|
||||
|
||||
@patch("builtins.print")
|
||||
def test_decode_jwt_prints_error_on_failure(self, mock_print):
|
||||
"""Test that decode_jwt prints error message on failure"""
|
||||
token = "invalid.token"
|
||||
|
||||
result = decode_jwt(token)
|
||||
|
||||
assert result == {}
|
||||
mock_print.assert_called_once()
|
||||
assert "Failed to decode the token:" in mock_print.call_args[0][0]
|
||||
|
||||
def test_decode_msal_token_valid_single_line(self):
|
||||
"""Test decode_msal_token with valid JWT in single line"""
|
||||
# Create a valid JWT
|
||||
payload = {"roles": ["Exchange.ManageAsApp"], "tenant": "test-tenant"}
|
||||
payload_b64 = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
jwt_token = f"header.{payload_b64}.signature"
|
||||
text = f"Some text before {jwt_token} some text after"
|
||||
|
||||
result = decode_msal_token(text)
|
||||
|
||||
assert result == payload
|
||||
assert result["roles"] == ["Exchange.ManageAsApp"]
|
||||
|
||||
def test_decode_msal_token_valid_multiline(self):
|
||||
"""Test decode_msal_token with valid JWT across multiple lines"""
|
||||
payload = {"roles": ["application_access"], "user": "test@contoso.com"}
|
||||
payload_b64 = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
jwt_token = f"header.{payload_b64}.signature"
|
||||
text = f"""Line 1
|
||||
Line 2 with {jwt_token}
|
||||
Line 3"""
|
||||
|
||||
result = decode_msal_token(text)
|
||||
|
||||
assert result == payload
|
||||
assert result["user"] == "test@contoso.com"
|
||||
|
||||
def test_decode_msal_token_with_whitespace(self):
|
||||
"""Test decode_msal_token with JWT containing whitespace"""
|
||||
payload = {"test": "data"}
|
||||
payload_b64 = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
jwt_token = f"header.{payload_b64}.signature"
|
||||
text = f" Token: {jwt_token} "
|
||||
|
||||
result = decode_msal_token(text)
|
||||
|
||||
assert result == payload
|
||||
|
||||
def test_decode_msal_token_no_jwt_found(self):
|
||||
"""Test decode_msal_token when no JWT pattern is found"""
|
||||
text = "This text contains no JWT tokens at all"
|
||||
|
||||
result = decode_msal_token(text)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_decode_msal_token_invalid_jwt_pattern(self):
|
||||
"""Test decode_msal_token with text that looks like JWT but isn't"""
|
||||
text = "header.payload" # Only 2 parts, not valid JWT
|
||||
|
||||
result = decode_msal_token(text)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_decode_msal_token_empty_text(self):
|
||||
"""Test decode_msal_token with empty text"""
|
||||
result = decode_msal_token("")
|
||||
assert result == {}
|
||||
|
||||
def test_decode_msal_token_none_text(self):
|
||||
"""Test decode_msal_token with None text"""
|
||||
assert decode_msal_token(None) == {}
|
||||
|
||||
@patch("builtins.print")
|
||||
def test_decode_msal_token_prints_error_on_failure(self, mock_print):
|
||||
"""Test that decode_msal_token prints error message on failure"""
|
||||
text = "No JWT here"
|
||||
|
||||
result = decode_msal_token(text)
|
||||
|
||||
assert result == {}
|
||||
mock_print.assert_called_once()
|
||||
assert "Failed to extract and decode the token:" in mock_print.call_args[0][0]
|
||||
|
||||
def test_decode_msal_token_real_world_scenario(self):
|
||||
"""Test decode_msal_token with a realistic PowerShell output scenario"""
|
||||
# Simulate output from Get-MsalToken or similar
|
||||
payload = {
|
||||
"aud": "https://graph.microsoft.com",
|
||||
"iss": "https://sts.windows.net/tenant-id/",
|
||||
"iat": 1640995200,
|
||||
"exp": 1641081600,
|
||||
"roles": ["Application.ReadWrite.All"],
|
||||
"sub": "app-subject-id",
|
||||
}
|
||||
payload_b64 = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
jwt_token = f"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.{payload_b64}.signature123"
|
||||
|
||||
# Simulate PowerShell output format
|
||||
powershell_output = f"""
|
||||
AccessToken : {jwt_token}
|
||||
TokenType : Bearer
|
||||
ExpiresOn : 1/2/2022 12:00:00 AM +00:00
|
||||
ExtendedExpiresOn : 1/2/2022 12:00:00 AM +00:00
|
||||
"""
|
||||
|
||||
result = decode_msal_token(powershell_output)
|
||||
|
||||
assert result == payload
|
||||
assert result["roles"] == ["Application.ReadWrite.All"]
|
||||
assert result["aud"] == "https://graph.microsoft.com"
|
||||
|
||||
def test_decode_msal_token_with_jwt_in_json(self):
|
||||
"""Test decode_msal_token with JWT embedded in JSON-like structure"""
|
||||
payload = {"tenant": "test", "scope": "https://graph.microsoft.com/.default"}
|
||||
payload_b64 = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
jwt_token = f"header.{payload_b64}.signature"
|
||||
|
||||
json_like_text = f'{{"access_token": "{jwt_token}", "token_type": "Bearer"}}'
|
||||
|
||||
result = decode_msal_token(json_like_text)
|
||||
|
||||
assert result == payload
|
||||
@@ -4,9 +4,8 @@ import pytest
|
||||
|
||||
from prowler.lib.powershell.powershell import PowerShellSession
|
||||
from prowler.providers.m365.exceptions.exceptions import (
|
||||
M365ExchangeConnectionError,
|
||||
M365GraphConnectionError,
|
||||
M365TeamsConnectionError,
|
||||
M365UserCredentialsError,
|
||||
M365UserNotBelongingToTenantError,
|
||||
)
|
||||
from prowler.providers.m365.lib.powershell.m365_powershell import M365PowerShell
|
||||
@@ -113,15 +112,9 @@ class Testm365PowerShell:
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
@patch("msal.ConfidentialClientApplication")
|
||||
def test_test_credentials(self, mock_msal, mock_popen):
|
||||
def test_test_credentials(self, mock_popen):
|
||||
mock_process = MagicMock()
|
||||
mock_popen.return_value = mock_process
|
||||
mock_msal_instance = MagicMock()
|
||||
mock_msal.return_value = mock_msal_instance
|
||||
mock_msal_instance.acquire_token_by_username_password.return_value = {
|
||||
"access_token": "test_token"
|
||||
}
|
||||
|
||||
credentials = M365Credentials(
|
||||
user="test@contoso.onmicrosoft.com",
|
||||
@@ -143,7 +136,11 @@ class Testm365PowerShell:
|
||||
|
||||
# Mock encrypt_password to return a known value
|
||||
session.encrypt_password = MagicMock(return_value="encrypted_password")
|
||||
session.execute = MagicMock()
|
||||
|
||||
# Mock execute to simulate successful Connect-ExchangeOnline
|
||||
session.execute = MagicMock(
|
||||
return_value="Connected successfully https://aka.ms/exov3-module"
|
||||
)
|
||||
|
||||
# Execute the test
|
||||
result = session.test_credentials(credentials)
|
||||
@@ -156,18 +153,10 @@ class Testm365PowerShell:
|
||||
session.execute.assert_any_call(
|
||||
f'$credential = New-Object System.Management.Automation.PSCredential("{session.sanitize(credentials.user)}", $securePassword)'
|
||||
)
|
||||
session.execute.assert_any_call(
|
||||
"Connect-ExchangeOnline -Credential $credential"
|
||||
)
|
||||
|
||||
# Verify MSAL was called with the correct parameters
|
||||
mock_msal.assert_called_once_with(
|
||||
client_id="test_client_id",
|
||||
client_credential="test_client_secret",
|
||||
authority="https://login.microsoftonline.com/test_tenant_id",
|
||||
)
|
||||
mock_msal_instance.acquire_token_by_username_password.assert_called_once_with(
|
||||
username="test@contoso.onmicrosoft.com",
|
||||
password="test_password", # Original password, not encrypted
|
||||
scopes=["https://graph.microsoft.com/.default"],
|
||||
)
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
@@ -255,13 +244,9 @@ class Testm365PowerShell:
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
@patch("msal.ConfidentialClientApplication")
|
||||
def test_test_credentials_auth_failure(self, mock_msal, mock_popen):
|
||||
def test_test_credentials_auth_failure_aadsts_error(self, mock_popen):
|
||||
mock_process = MagicMock()
|
||||
mock_popen.return_value = mock_process
|
||||
mock_msal_instance = MagicMock()
|
||||
mock_msal.return_value = mock_msal_instance
|
||||
mock_msal_instance.acquire_token_by_username_password.return_value = None
|
||||
|
||||
credentials = M365Credentials(
|
||||
user="test@contoso.onmicrosoft.com",
|
||||
@@ -281,46 +266,37 @@ class Testm365PowerShell:
|
||||
)
|
||||
session = M365PowerShell(credentials, identity)
|
||||
|
||||
# Mock the execute method to return the decrypted password
|
||||
def mock_execute(command, *args, **kwargs):
|
||||
if "Write-Output" in command:
|
||||
return "decrypted_password"
|
||||
return None
|
||||
# Mock encrypt_password and execute to simulate AADSTS error
|
||||
session.encrypt_password = MagicMock(return_value="encrypted_password")
|
||||
session.execute = MagicMock(
|
||||
return_value="AADSTS50126: Error validating credentials due to invalid username or password"
|
||||
)
|
||||
|
||||
session.execute = MagicMock(side_effect=mock_execute)
|
||||
session.process.stdin.write = MagicMock()
|
||||
session.read_output = MagicMock(return_value="decrypted_password")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(M365UserCredentialsError) as exc_info:
|
||||
session.test_credentials(credentials)
|
||||
|
||||
assert (
|
||||
"Unexpected error: Acquiring token in behalf of user did not return a result."
|
||||
"AADSTS50126: Error validating credentials due to invalid username or password"
|
||||
in str(exc_info.value)
|
||||
)
|
||||
|
||||
mock_msal.assert_called_once_with(
|
||||
client_id="test_client_id",
|
||||
client_credential="test_client_secret",
|
||||
authority="https://login.microsoftonline.com/test_tenant_id",
|
||||
# Verify execute was called with the correct commands
|
||||
session.execute.assert_any_call(
|
||||
f'$securePassword = "{credentials.encrypted_passwd}" | ConvertTo-SecureString'
|
||||
)
|
||||
mock_msal_instance.acquire_token_by_username_password.assert_called_once_with(
|
||||
username="test@contoso.onmicrosoft.com",
|
||||
password="test_password",
|
||||
scopes=["https://graph.microsoft.com/.default"],
|
||||
session.execute.assert_any_call(
|
||||
f'$credential = New-Object System.Management.Automation.PSCredential("{session.sanitize(credentials.user)}", $securePassword)'
|
||||
)
|
||||
session.execute.assert_any_call(
|
||||
"Connect-ExchangeOnline -Credential $credential"
|
||||
)
|
||||
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
@patch("msal.ConfidentialClientApplication")
|
||||
def test_test_credentials_auth_failure_no_access_token(self, mock_msal, mock_popen):
|
||||
def test_test_credentials_auth_failure_no_access_token(self, mock_popen):
|
||||
mock_process = MagicMock()
|
||||
mock_popen.return_value = mock_process
|
||||
mock_msal_instance = MagicMock()
|
||||
mock_msal.return_value = mock_msal_instance
|
||||
mock_msal_instance.acquire_token_by_username_password.return_value = {
|
||||
"error_description": "invalid_grant: authentication failed"
|
||||
}
|
||||
|
||||
credentials = M365Credentials(
|
||||
user="test@contoso.onmicrosoft.com",
|
||||
@@ -340,31 +316,29 @@ class Testm365PowerShell:
|
||||
)
|
||||
session = M365PowerShell(credentials, identity)
|
||||
|
||||
# Mock the execute method to return the decrypted password
|
||||
def mock_execute(command, *args, **kwargs):
|
||||
if "Write-Output" in command:
|
||||
return "decrypted_password"
|
||||
return None
|
||||
# Mock encrypt_password and execute to simulate AADSTS invalid grant error
|
||||
session.encrypt_password = MagicMock(return_value="encrypted_password")
|
||||
session.execute = MagicMock(
|
||||
return_value="AADSTS70002: The request body must contain the following parameter: 'client_secret' or 'client_assertion'."
|
||||
)
|
||||
|
||||
session.execute = MagicMock(side_effect=mock_execute)
|
||||
session.process.stdin.write = MagicMock()
|
||||
session.read_output = MagicMock(return_value="decrypted_password")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(M365UserCredentialsError) as exc_info:
|
||||
session.test_credentials(credentials)
|
||||
assert "MsGraph Error invalid_grant: authentication failed" in str(
|
||||
exc_info.value
|
||||
|
||||
assert (
|
||||
"AADSTS70002: The request body must contain the following parameter: 'client_secret' or 'client_assertion'."
|
||||
in str(exc_info.value)
|
||||
)
|
||||
|
||||
mock_msal.assert_called_once_with(
|
||||
client_id="test_client_id",
|
||||
client_credential="test_client_secret",
|
||||
authority="https://login.microsoftonline.com/test_tenant_id",
|
||||
# Verify execute was called with the correct commands
|
||||
session.execute.assert_any_call(
|
||||
f'$securePassword = "{credentials.encrypted_passwd}" | ConvertTo-SecureString'
|
||||
)
|
||||
mock_msal_instance.acquire_token_by_username_password.assert_called_once_with(
|
||||
username="test@contoso.onmicrosoft.com",
|
||||
password="test_password",
|
||||
scopes=["https://graph.microsoft.com/.default"],
|
||||
session.execute.assert_any_call(
|
||||
f'$credential = New-Object System.Management.Automation.PSCredential("{session.sanitize(credentials.user)}", $securePassword)'
|
||||
)
|
||||
session.execute.assert_any_call(
|
||||
"Connect-ExchangeOnline -Credential $credential"
|
||||
)
|
||||
|
||||
session.close()
|
||||
@@ -744,7 +718,8 @@ class Testm365PowerShell:
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
def test_test_teams_connection_success(self, mock_popen):
|
||||
@patch("prowler.providers.m365.lib.powershell.m365_powershell.decode_jwt")
|
||||
def test_test_teams_connection_success(self, mock_decode_jwt, mock_popen):
|
||||
"""Test test_teams_connection when token is valid"""
|
||||
mock_process = MagicMock()
|
||||
mock_popen.return_value = mock_process
|
||||
@@ -766,17 +741,23 @@ class Testm365PowerShell:
|
||||
return None
|
||||
|
||||
session.execute = MagicMock(side_effect=mock_execute)
|
||||
# Mock JWT decode to return proper permissions
|
||||
mock_decode_jwt.return_value = {"roles": ["application_access"]}
|
||||
|
||||
result = session.test_teams_connection()
|
||||
|
||||
assert result is True
|
||||
# Verify all expected PowerShell commands were called
|
||||
assert session.execute.call_count == 3
|
||||
mock_decode_jwt.assert_called_once_with("valid_teams_token")
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
def test_test_teams_connection_empty_token(self, mock_popen):
|
||||
"""Test test_teams_connection when token is empty"""
|
||||
@patch("prowler.providers.m365.lib.powershell.m365_powershell.decode_jwt")
|
||||
def test_test_teams_connection_missing_permissions(
|
||||
self, mock_decode_jwt, mock_popen
|
||||
):
|
||||
"""Test test_teams_connection when token lacks required permissions"""
|
||||
mock_process = MagicMock()
|
||||
mock_popen.return_value = mock_process
|
||||
credentials = M365Credentials(user="test@example.com", passwd="test_password")
|
||||
@@ -790,18 +771,23 @@ class Testm365PowerShell:
|
||||
)
|
||||
session = M365PowerShell(credentials, identity)
|
||||
|
||||
# Mock execute to return empty token when checking
|
||||
# Mock execute to return valid token but decode returns no permissions
|
||||
def mock_execute(command, *args, **kwargs):
|
||||
if "Write-Output $teamsToken" in command:
|
||||
return ""
|
||||
return "valid_teams_token"
|
||||
return None
|
||||
|
||||
session.execute = MagicMock(side_effect=mock_execute)
|
||||
# Mock JWT decode to return missing required permission
|
||||
mock_decode_jwt.return_value = {"roles": ["other_permission"]}
|
||||
|
||||
with pytest.raises(M365TeamsConnectionError) as exc_info:
|
||||
session.test_teams_connection()
|
||||
with patch("prowler.lib.logger.logger.error") as mock_error:
|
||||
result = session.test_teams_connection()
|
||||
|
||||
assert "Microsoft Teams token is empty or invalid" in str(exc_info.value)
|
||||
assert result is False
|
||||
mock_error.assert_called_once_with(
|
||||
"Microsoft Teams connection failed: Please check your permissions and try again."
|
||||
)
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
@@ -823,16 +809,18 @@ class Testm365PowerShell:
|
||||
# Mock execute to raise an exception
|
||||
session.execute = MagicMock(side_effect=Exception("Teams API error"))
|
||||
|
||||
with pytest.raises(M365TeamsConnectionError) as exc_info:
|
||||
session.test_teams_connection()
|
||||
with patch("prowler.lib.logger.logger.error") as mock_error:
|
||||
result = session.test_teams_connection()
|
||||
|
||||
assert "Failed to connect to Microsoft Teams API: Teams API error" in str(
|
||||
exc_info.value
|
||||
assert result is False
|
||||
mock_error.assert_called_once_with(
|
||||
"Microsoft Teams connection failed: Teams API error. Please check your permissions and try again."
|
||||
)
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
def test_test_exchange_connection_success(self, mock_popen):
|
||||
@patch("prowler.providers.m365.lib.powershell.m365_powershell.decode_msal_token")
|
||||
def test_test_exchange_connection_success(self, mock_decode_msal_token, mock_popen):
|
||||
"""Test test_exchange_connection when token is valid"""
|
||||
mock_process = MagicMock()
|
||||
mock_popen.return_value = mock_process
|
||||
@@ -854,17 +842,23 @@ class Testm365PowerShell:
|
||||
return None
|
||||
|
||||
session.execute = MagicMock(side_effect=mock_execute)
|
||||
# Mock MSAL token decode to return proper permissions
|
||||
mock_decode_msal_token.return_value = {"roles": ["Exchange.ManageAsApp"]}
|
||||
|
||||
result = session.test_exchange_connection()
|
||||
|
||||
assert result is True
|
||||
# Verify all expected PowerShell commands were called
|
||||
assert session.execute.call_count == 3
|
||||
mock_decode_msal_token.assert_called_once_with("valid_exchange_token")
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
def test_test_exchange_connection_empty_token(self, mock_popen):
|
||||
"""Test test_exchange_connection when token is empty"""
|
||||
@patch("prowler.providers.m365.lib.powershell.m365_powershell.decode_msal_token")
|
||||
def test_test_exchange_connection_missing_permissions(
|
||||
self, mock_decode_msal_token, mock_popen
|
||||
):
|
||||
"""Test test_exchange_connection when token lacks required permissions"""
|
||||
mock_process = MagicMock()
|
||||
mock_popen.return_value = mock_process
|
||||
credentials = M365Credentials(user="test@example.com", passwd="test_password")
|
||||
@@ -878,18 +872,23 @@ class Testm365PowerShell:
|
||||
)
|
||||
session = M365PowerShell(credentials, identity)
|
||||
|
||||
# Mock execute to return empty token when checking
|
||||
# Mock execute to return valid token but decode returns no permissions
|
||||
def mock_execute(command, *args, **kwargs):
|
||||
if "Write-Output $exchangeToken" in command:
|
||||
return ""
|
||||
return "valid_exchange_token"
|
||||
return None
|
||||
|
||||
session.execute = MagicMock(side_effect=mock_execute)
|
||||
# Mock MSAL token decode to return missing required permission
|
||||
mock_decode_msal_token.return_value = {"roles": ["other_permission"]}
|
||||
|
||||
with pytest.raises(M365ExchangeConnectionError) as exc_info:
|
||||
session.test_exchange_connection()
|
||||
with patch("prowler.lib.logger.logger.error") as mock_error:
|
||||
result = session.test_exchange_connection()
|
||||
|
||||
assert "Exchange Online token is empty or invalid" in str(exc_info.value)
|
||||
assert result is False
|
||||
mock_error.assert_called_once_with(
|
||||
"Exchange Online connection failed: Please check your permissions and try again."
|
||||
)
|
||||
session.close()
|
||||
|
||||
@patch("subprocess.Popen")
|
||||
@@ -911,11 +910,12 @@ class Testm365PowerShell:
|
||||
# Mock execute to raise an exception
|
||||
session.execute = MagicMock(side_effect=Exception("Exchange API error"))
|
||||
|
||||
with pytest.raises(M365ExchangeConnectionError) as exc_info:
|
||||
session.test_exchange_connection()
|
||||
with patch("prowler.lib.logger.logger.error") as mock_error:
|
||||
result = session.test_exchange_connection()
|
||||
|
||||
assert "Failed to connect to Exchange Online API: Exchange API error" in str(
|
||||
exc_info.value
|
||||
assert result is False
|
||||
mock_error.assert_called_once_with(
|
||||
"Exchange Online connection failed: Exchange API error. Please check your permissions and try again."
|
||||
)
|
||||
session.close()
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class Test_defender_domain_dkim_enabled:
|
||||
== "DKIM is enabled for domain with ID domain1."
|
||||
)
|
||||
assert result[0].resource == defender_client.dkim_configurations[0].dict()
|
||||
assert result[0].resource_name == "DKIM Configuration"
|
||||
assert result[0].resource_name == "domain1"
|
||||
assert result[0].resource_id == "domain1"
|
||||
assert result[0].location == "global"
|
||||
|
||||
@@ -86,7 +86,7 @@ class Test_defender_domain_dkim_enabled:
|
||||
== "DKIM is not enabled for domain with ID domain2."
|
||||
)
|
||||
assert result[0].resource == defender_client.dkim_configurations[0].dict()
|
||||
assert result[0].resource_name == "DKIM Configuration"
|
||||
assert result[0].resource_name == "domain2"
|
||||
assert result[0].resource_id == "domain2"
|
||||
assert result[0].location == "global"
|
||||
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
|
||||
All notable changes to the **Prowler UI** are documented in this file.
|
||||
|
||||
## [1.10.0] (Prowler v5.10.0 - UNRELEASED)
|
||||
|
||||
### Added
|
||||
|
||||
- Lighthouse banner [(#8259)](https://github.com/prowler-cloud/prowler/pull/8259)
|
||||
- Integration with Amazon S3, enabling storage and retrieval of scan data via S3 buckets [(#8056)](https://github.com/prowler-cloud/prowler/pull/8056)
|
||||
|
||||
___
|
||||
|
||||
## [v1.9.3] (Prowler v5.9.3)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- Display error messages and allow editing last message in Lighthouse [(#8358)](https://github.com/prowler-cloud/prowler/pull/8358)
|
||||
|
||||
## [v1.9.0] (Prowler v5.9.0)
|
||||
|
||||
### 🚀 Added
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { LangChainAdapter, Message } from "ai";
|
||||
|
||||
import { getLighthouseConfig } from "@/actions/lighthouse/lighthouse";
|
||||
import { getErrorMessage } from "@/lib/helper";
|
||||
import { getCurrentDataSection } from "@/lib/lighthouse/data";
|
||||
import {
|
||||
convertLangChainMessageToVercelMessage,
|
||||
@@ -73,22 +74,38 @@ export async function POST(req: Request) {
|
||||
|
||||
const stream = new ReadableStream({
|
||||
async start(controller) {
|
||||
for await (const { event, data, tags } of agentStream) {
|
||||
if (event === "on_chat_model_stream") {
|
||||
if (data.chunk.content && !!tags && tags.includes("supervisor")) {
|
||||
const chunk = data.chunk;
|
||||
const aiMessage = convertLangChainMessageToVercelMessage(chunk);
|
||||
controller.enqueue(aiMessage);
|
||||
try {
|
||||
for await (const { event, data, tags } of agentStream) {
|
||||
if (event === "on_chat_model_stream") {
|
||||
if (data.chunk.content && !!tags && tags.includes("supervisor")) {
|
||||
const chunk = data.chunk;
|
||||
const aiMessage = convertLangChainMessageToVercelMessage(chunk);
|
||||
controller.enqueue(aiMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
controller.close();
|
||||
} catch (error) {
|
||||
const errorName =
|
||||
error instanceof Error ? error.constructor.name : "UnknownError";
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
controller.enqueue({
|
||||
id: "error-" + Date.now(),
|
||||
role: "assistant",
|
||||
content: `[LIGHTHOUSE_ANALYST_ERROR]: ${errorName}: ${errorMessage}`,
|
||||
});
|
||||
controller.close();
|
||||
}
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
|
||||
return LangChainAdapter.toDataStreamResponse(stream);
|
||||
} catch (error) {
|
||||
console.error("Error in POST request:", error);
|
||||
return Response.json({ error: "An error occurred" }, { status: 500 });
|
||||
return Response.json(
|
||||
{ error: await getErrorMessage(error) },
|
||||
{ status: 500 },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { useChat } from "@ai-sdk/react";
|
||||
import Link from "next/link";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
|
||||
import { MemoizedMarkdown } from "@/components/lighthouse/memoized-markdown";
|
||||
@@ -25,20 +25,54 @@ interface ChatFormData {
|
||||
}
|
||||
|
||||
export const Chat = ({ hasConfig, isActive }: ChatProps) => {
|
||||
const { messages, handleSubmit, handleInputChange, append, status } = useChat(
|
||||
{
|
||||
api: "/api/lighthouse/analyst",
|
||||
credentials: "same-origin",
|
||||
experimental_throttle: 100,
|
||||
sendExtraMessageFields: true,
|
||||
onFinish: () => {
|
||||
// Handle chat completion
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error("Chat error:", error);
|
||||
},
|
||||
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
messages,
|
||||
handleSubmit,
|
||||
handleInputChange,
|
||||
append,
|
||||
status,
|
||||
error,
|
||||
setMessages,
|
||||
} = useChat({
|
||||
api: "/api/lighthouse/analyst",
|
||||
credentials: "same-origin",
|
||||
experimental_throttle: 100,
|
||||
sendExtraMessageFields: true,
|
||||
onFinish: (message) => {
|
||||
// There is no specific way to output the error message from langgraph supervisor
|
||||
// Hence, all error messages are sent as normal messages with the prefix [LIGHTHOUSE_ANALYST_ERROR]:
|
||||
// Detect error messages sent from backend using specific prefix and display the error
|
||||
if (message.content?.startsWith("[LIGHTHOUSE_ANALYST_ERROR]:")) {
|
||||
const errorText = message.content
|
||||
.replace("[LIGHTHOUSE_ANALYST_ERROR]:", "")
|
||||
.trim();
|
||||
setErrorMessage(errorText);
|
||||
// Remove error message from chat history
|
||||
setMessages((prev) =>
|
||||
prev.filter(
|
||||
(m) => !m.content?.startsWith("[LIGHTHOUSE_ANALYST_ERROR]:"),
|
||||
),
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
onError: (error) => {
|
||||
console.error("Chat error:", error);
|
||||
|
||||
if (
|
||||
error?.message?.includes("<html>") &&
|
||||
error?.message?.includes("<title>403 Forbidden</title>")
|
||||
) {
|
||||
setErrorMessage("403 Forbidden");
|
||||
return;
|
||||
}
|
||||
|
||||
setErrorMessage(
|
||||
error?.message || "An error occurred. Please retry your message.",
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
const form = useForm<ChatFormData>({
|
||||
defaultValues: {
|
||||
@@ -49,6 +83,30 @@ export const Chat = ({ hasConfig, isActive }: ChatProps) => {
|
||||
const messageValue = form.watch("message");
|
||||
const messagesContainerRef = useRef<HTMLDivElement | null>(null);
|
||||
const latestUserMsgRef = useRef<HTMLDivElement | null>(null);
|
||||
const messageValueRef = useRef<string>("");
|
||||
|
||||
// Keep ref in sync with current value
|
||||
messageValueRef.current = messageValue;
|
||||
|
||||
// Restore last user message to input when any error occurs
|
||||
useEffect(() => {
|
||||
if (errorMessage) {
|
||||
// Capture current messages to avoid dependency issues
|
||||
setMessages((currentMessages) => {
|
||||
const lastUserMessage = currentMessages
|
||||
.filter((m) => m.role === "user")
|
||||
.pop();
|
||||
|
||||
if (lastUserMessage) {
|
||||
form.setValue("message", lastUserMessage.content);
|
||||
// Remove the last user message from history since it's now in the input
|
||||
return currentMessages.slice(0, -1);
|
||||
}
|
||||
|
||||
return currentMessages;
|
||||
});
|
||||
}
|
||||
}, [errorMessage, form, setMessages]);
|
||||
|
||||
// Sync form value with chat input
|
||||
useEffect(() => {
|
||||
@@ -67,6 +125,8 @@ export const Chat = ({ hasConfig, isActive }: ChatProps) => {
|
||||
|
||||
const onFormSubmit = form.handleSubmit((data) => {
|
||||
if (data.message.trim()) {
|
||||
// Clear error on new submission
|
||||
setErrorMessage(null);
|
||||
handleSubmit();
|
||||
}
|
||||
});
|
||||
@@ -148,7 +208,54 @@ export const Chat = ({ hasConfig, isActive }: ChatProps) => {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{messages.length === 0 ? (
|
||||
{/* Error Banner */}
|
||||
{(error || errorMessage) && (
|
||||
<div className="mx-4 mt-4 rounded-lg border border-red-200 bg-red-50 p-4 dark:border-red-800 dark:bg-red-900/20">
|
||||
<div className="flex items-start">
|
||||
<div className="flex-shrink-0">
|
||||
<svg
|
||||
className="h-5 w-5 text-red-400"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
fillRule="evenodd"
|
||||
d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.28 7.22a.75.75 0 00-1.06 1.06L8.94 10l-1.72 1.72a.75.75 0 101.06 1.06L10 11.06l1.72 1.72a.75.75 0 101.06-1.06L11.06 10l1.72-1.72a.75.75 0 00-1.06-1.06L10 8.94 8.28 7.22z"
|
||||
clipRule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
<div className="ml-3">
|
||||
<h3 className="text-sm font-medium text-red-800 dark:text-red-200">
|
||||
Error
|
||||
</h3>
|
||||
<p className="mt-1 text-sm text-red-700 dark:text-red-300">
|
||||
{errorMessage ||
|
||||
error?.message ||
|
||||
"An error occurred. Please retry your message."}
|
||||
</p>
|
||||
{/* Original error details for native errors */}
|
||||
{error && (error as any).status && (
|
||||
<p className="mt-1 text-xs text-red-600 dark:text-red-400">
|
||||
Status: {(error as any).status}
|
||||
</p>
|
||||
)}
|
||||
{error && (error as any).body && (
|
||||
<details className="mt-2">
|
||||
<summary className="cursor-pointer text-xs text-red-600 hover:text-red-800 dark:text-red-400 dark:hover:text-red-300">
|
||||
Show details
|
||||
</summary>
|
||||
<pre className="mt-1 max-h-20 overflow-auto rounded bg-red-100 p-2 text-xs text-red-800 dark:bg-red-900/30 dark:text-red-200">
|
||||
{JSON.stringify((error as any).body, null, 2)}
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{messages.length === 0 && !errorMessage && !error ? (
|
||||
<div className="flex flex-1 items-center justify-center p-4">
|
||||
<div className="w-full max-w-2xl">
|
||||
<h2 className="mb-4 text-center font-sans text-xl">Suggestions</h2>
|
||||
@@ -232,7 +339,11 @@ export const Chat = ({ hasConfig, isActive }: ChatProps) => {
|
||||
control={form.control}
|
||||
name="message"
|
||||
label=""
|
||||
placeholder="Type your message..."
|
||||
placeholder={
|
||||
error || errorMessage
|
||||
? "Edit your message and try again..."
|
||||
: "Type your message..."
|
||||
}
|
||||
variant="bordered"
|
||||
minRows={1}
|
||||
maxRows={6}
|
||||
|
||||
Reference in New Issue
Block a user