Compare commits

...

1 Commits

Author SHA1 Message Date
Andoni A. d3daeb1d75 wip: ratdoni' 2026-04-28 10:38:47 +02:00
6 changed files with 139 additions and 9 deletions
+23 -2
View File
@@ -1,7 +1,15 @@
import sentry_sdk
from config.env import env
_SENTRY_TAG_FIELDS = {
"prowler_provider": "provider",
"prowler_region": "region",
"prowler_service": "service",
"prowler_tenant_id": "tenant_id",
"prowler_scan_id": "scan_id",
"prowler_provider_uid": "provider_uid",
}
IGNORED_EXCEPTIONS = [
# Provider is not connected due to credentials errors
"is not connected",
@@ -81,7 +89,10 @@ IGNORED_EXCEPTIONS = [
def before_send(event, hint):
"""
before_send handles the Sentry events in order to send them or not
before_send handles the Sentry events in order to send them or not.
It also promotes prowler context fields (injected by ProwlerContextFilter)
from the LogRecord into Sentry event tags so they become searchable.
"""
# Ignore logs with the ignored_exceptions
# https://docs.python.org/3/library/logging.html#logrecord-objects
@@ -105,6 +116,16 @@ def before_send(event, hint):
if log_lvl <= 40 and any(ignored in log_msg for ignored in IGNORED_EXCEPTIONS):
return None # Explicitly return None to drop the event
# Promote prowler context fields to Sentry tags
for record_attr, tag_name in _SENTRY_TAG_FIELDS.items():
value = getattr(log_record, record_attr, None)
if value:
event.setdefault("tags", {})
if isinstance(event["tags"], dict):
event["tags"][tag_name] = str(value)
elif isinstance(event["tags"], list):
event["tags"].append([tag_name, str(value)])
# Ignore exceptions with the ignored_exceptions
if "exc_info" in hint and hint["exc_info"]:
exc_value = str(hint["exc_info"][1])
+16
View File
@@ -853,6 +853,22 @@ def perform_prowler_scan(
scan_instance.started_at = datetime.now(tz=timezone.utc)
scan_instance.save()
# Enrich Sentry context for all downstream errors (Layer 2: app-only tags)
from prowler.lib.logger import (
prowler_provider_uid_var,
prowler_scan_id_var,
prowler_tenant_id_var,
)
prowler_tenant_id_var.set(str(tenant_id))
prowler_scan_id_var.set(str(scan_id))
prowler_provider_uid_var.set(str(provider_instance.uid))
sentry_sdk.set_tag("provider", str(provider_instance.provider))
sentry_sdk.set_tag("tenant_id", str(tenant_id))
sentry_sdk.set_tag("scan_id", str(scan_id))
sentry_sdk.set_tag("provider_uid", str(provider_instance.uid))
# Find the mutelist processor if it exists
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
try:
+47
View File
@@ -1,6 +1,51 @@
import contextvars
import logging
from os import environ
# Core context — set by provider service base classes (Layer 1)
prowler_provider_var: contextvars.ContextVar[str] = contextvars.ContextVar(
"prowler_provider", default=""
)
prowler_region_var: contextvars.ContextVar[str] = contextvars.ContextVar(
"prowler_region", default=""
)
prowler_service_var: contextvars.ContextVar[str] = contextvars.ContextVar(
"prowler_service", default=""
)
# App context — set by API layer only (Layer 2)
prowler_tenant_id_var: contextvars.ContextVar[str] = contextvars.ContextVar(
"prowler_tenant_id", default=""
)
prowler_scan_id_var: contextvars.ContextVar[str] = contextvars.ContextVar(
"prowler_scan_id", default=""
)
prowler_provider_uid_var: contextvars.ContextVar[str] = contextvars.ContextVar(
"prowler_provider_uid", default=""
)
_PROWLER_CONTEXT_VARS = {
"prowler_provider": prowler_provider_var,
"prowler_region": prowler_region_var,
"prowler_service": prowler_service_var,
"prowler_tenant_id": prowler_tenant_id_var,
"prowler_scan_id": prowler_scan_id_var,
"prowler_provider_uid": prowler_provider_uid_var,
}
class ProwlerContextFilter(logging.Filter):
"""Injects prowler context from contextvars into every LogRecord."""
def filter(self, record: logging.LogRecord) -> bool:
for attr, var in _PROWLER_CONTEXT_VARS.items():
if not hasattr(record, attr):
value = var.get()
if value:
setattr(record, attr, value)
return True
# Logging levels
logging_levels = {
"CRITICAL": logging.CRITICAL,
@@ -54,6 +99,8 @@ def set_logging_config(log_level: str, log_file: str = None, only_logs: bool = F
datefmt="%m/%d/%Y %I:%M:%S %p",
)
logging.getLogger().addFilter(ProwlerContextFilter())
# Retrieve the logger instance
logger = logging.getLogger()
+31 -3
View File
@@ -1,6 +1,12 @@
import contextvars
from concurrent.futures import ThreadPoolExecutor, as_completed
from prowler.lib.logger import logger
from prowler.lib.logger import (
logger,
prowler_provider_var,
prowler_region_var,
prowler_service_var,
)
from prowler.providers.aws.aws_provider import AwsProvider
# TODO: review the following code
@@ -66,6 +72,10 @@ class AWSService:
)
self.client = self.session.client(self.service, self.region)
# Set Sentry context for this provider/service
prowler_provider_var.set("aws")
prowler_service_var.set(self.service)
# Thread pool for __threading_call__
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
@@ -93,8 +103,26 @@ class AWSService:
f"{self.service.upper()} - Starting threads for '{call_name}' function to process {item_count} items..."
)
# Submit tasks to the thread pool
futures = [self.thread_pool.submit(call, item) for item in items]
# Submit tasks to the thread pool with context propagation.
# copy_context() gives each thread an isolated snapshot of the
# current contextvars so prowler_region_var can be set per-thread
# without races (required for Python <3.12).
futures = []
for item in items:
ctx = contextvars.copy_context()
region = getattr(item, "region", None) or (
getattr(item, "_client_config", None)
and item._client_config.region_name
)
def _call_with_region(fn, arg, rgn):
if rgn:
prowler_region_var.set(rgn)
return fn(arg)
futures.append(
self.thread_pool.submit(ctx.run, _call_with_region, call, item, region)
)
# Wait for all tasks to complete
for future in as_completed(futures):
@@ -1,6 +1,7 @@
import contextvars
from concurrent.futures import ThreadPoolExecutor, as_completed
from prowler.lib.logger import logger
from prowler.lib.logger import logger, prowler_provider_var, prowler_service_var
from prowler.providers.azure.azure_provider import AzureProvider
MAX_WORKERS = 10
@@ -24,13 +25,19 @@ class AzureService:
self.audit_config = provider.audit_config
self.fixer_config = provider.fixer_config
prowler_provider_var.set("azure")
prowler_service_var.set(self.__class__.__name__.lower())
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
def __threading_call__(self, call, iterator):
"""Execute a function across multiple items using threading."""
items = list(iterator) if not isinstance(iterator, list) else iterator
futures = {self.thread_pool.submit(call, item): item for item in items}
futures = {}
for item in items:
ctx = contextvars.copy_context()
futures[self.thread_pool.submit(ctx.run, call, item)] = item
results = []
for future in as_completed(futures):
+13 -2
View File
@@ -1,3 +1,4 @@
import contextvars
import threading
import google_auth_httplib2
@@ -6,7 +7,12 @@ from google.oauth2.credentials import Credentials
from googleapiclient import discovery
from googleapiclient.discovery import Resource
from prowler.lib.logger import logger
from prowler.lib.logger import (
logger,
prowler_provider_var,
prowler_region_var,
prowler_service_var,
)
from prowler.providers.gcp.config import DEFAULT_RETRY_ATTEMPTS
from prowler.providers.gcp.gcp_provider import GcpProvider
@@ -38,13 +44,18 @@ class GCPService:
self.audit_config = provider.audit_config
self.fixer_config = provider.fixer_config
prowler_provider_var.set("gcp")
prowler_service_var.set(self.service)
prowler_region_var.set(self.region)
def _get_client(self):
return self.client
def __threading_call__(self, call, iterator):
threads = []
for value in iterator:
threads.append(threading.Thread(target=call, args=(value,)))
ctx = contextvars.copy_context()
threads.append(threading.Thread(target=ctx.run, args=(call, value)))
for t in threads:
t.start()
for t in threads: