feat: add vault parallelization

This commit is contained in:
HugoPBrito
2026-01-23 12:57:36 +01:00
parent f11f71bc42
commit 4baf8e8873
3 changed files with 280 additions and 104 deletions

View File

@@ -1,6 +1,11 @@
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from prowler.lib.logger import logger
from prowler.providers.azure.azure_provider import AzureProvider
MAX_WORKERS = 10
class AzureService:
def __init__(
@@ -20,6 +25,40 @@ class AzureService:
self.audit_config = provider.audit_config
self.fixer_config = provider.fixer_config
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
def __threading_call__(self, call, iterator):
"""Execute a function across multiple items using threading."""
items = list(iterator) if not isinstance(iterator, list) else iterator
item_count = len(items)
call_name = getattr(call, "__name__", str(call)).strip("_")
call_name = " ".join(word.capitalize() for word in call_name.split("_"))
logger.info(
f"Azure - Starting threads for '{call_name}' to process {item_count} items..."
)
start_time = time.perf_counter()
futures = {self.thread_pool.submit(call, item): item for item in items}
results = []
for future in as_completed(futures):
try:
result = future.result()
if result is not None:
results.append(result)
except Exception:
pass
elapsed = time.perf_counter() - start_time
logger.info(
f"Azure - Completed '{call_name}' for {item_count} items in {elapsed:.2f}s"
)
return results
def __set_clients__(self, identity, session, service, region_config):
clients = {}
try:

View File

@@ -1,26 +1,37 @@
import time
from prowler.lib.check.models import Check, Check_Report_Azure
from prowler.lib.logger import logger
from prowler.providers.azure.services.keyvault.keyvault_client import keyvault_client
class keyvault_rbac_secret_expiration_set(Check):
def execute(self) -> Check_Report_Azure:
start_time = time.perf_counter()
findings = []
total_secrets = 0
for subscription, key_vaults in keyvault_client.key_vaults.items():
for keyvault in key_vaults:
if keyvault.properties.enable_rbac_authorization and keyvault.secrets:
report = Check_Report_Azure(
metadata=self.metadata(), resource=keyvault
)
report.subscription = subscription
report.status = "PASS"
report.status_extended = f"Keyvault {keyvault.name} from subscription {subscription} has all the secrets with expiration date set."
has_secret_without_expiration = False
for secret in keyvault.secrets:
total_secrets += 1
report = Check_Report_Azure(
metadata=self.metadata(), resource=secret
)
report.subscription = subscription
if not secret.attributes.expires and secret.enabled:
report.status = "FAIL"
report.status_extended = f"Keyvault {keyvault.name} from subscription {subscription} has the secret {secret.name} without expiration date set."
has_secret_without_expiration = True
findings.append(report)
if not has_secret_without_expiration:
report.status_extended = f"Secret '{secret.name}' in KeyVault '{keyvault.name}' does not have expiration date set."
else:
report.status = "PASS"
report.status_extended = f"Secret '{secret.name}' in KeyVault '{keyvault.name}' has expiration date set."
findings.append(report)
elapsed = time.perf_counter() - start_time
logger.info(
f"Check keyvault_rbac_secret_expiration_set: "
f"processed {total_secrets} secrets, created {len(findings)} findings in {elapsed:.2f}s"
)
return findings

View File

@@ -1,3 +1,6 @@
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional, Union
@@ -16,103 +19,180 @@ from prowler.providers.azure.services.monitor.monitor_service import DiagnosticS
class KeyVault(AzureService):
def __init__(self, provider: AzureProvider):
super().__init__(KeyVaultManagementClient, provider)
# TODO: review this credentials assignment
self.key_vaults = self._get_key_vaults(provider)
self._provider = provider
self.key_vaults = self._get_key_vaults()
def _get_key_vaults(self, provider):
def _get_key_vaults(self):
"""
Get all KeyVaults with parallel processing.
Optimizations:
1. Uses list_by_subscription() for full Vault objects
2. Processes vaults in parallel using __threading_call__
3. Each vault's keys/secrets/monitor fetched in parallel
"""
logger.info("KeyVault - Getting key_vaults...")
total_start = time.perf_counter()
key_vaults = {}
for subscription, client in self.clients.items():
try:
key_vaults.update({subscription: []})
key_vaults_list = client.vaults.list()
for keyvault in key_vaults_list:
resource_group = keyvault.id.split("/")[4]
keyvault_name = keyvault.name
keyvault_properties = client.vaults.get(
resource_group, keyvault_name
).properties
keys = self._get_keys(
subscription, resource_group, keyvault_name, provider
)
secrets = self._get_secrets(
subscription, resource_group, keyvault_name
)
key_vaults[subscription].append(
KeyVaultInfo(
id=getattr(keyvault, "id", ""),
name=getattr(keyvault, "name", ""),
location=getattr(keyvault, "location", ""),
resource_group=resource_group,
properties=VaultProperties(
tenant_id=getattr(keyvault_properties, "tenant_id", ""),
enable_rbac_authorization=getattr(
keyvault_properties,
"enable_rbac_authorization",
False,
),
private_endpoint_connections=[
PrivateEndpointConnection(id=conn.id)
for conn in (
getattr(
keyvault_properties,
"private_endpoint_connections",
[],
)
or []
)
],
enable_soft_delete=getattr(
keyvault_properties, "enable_soft_delete", False
),
enable_purge_protection=getattr(
keyvault_properties,
"enable_purge_protection",
False,
),
public_network_access_disabled=(
getattr(
keyvault_properties,
"public_network_access",
"Enabled",
)
== "Disabled"
),
),
keys=keys,
secrets=secrets,
monitor_diagnostic_settings=self._get_vault_monitor_settings(
keyvault_name, resource_group, subscription
),
)
)
key_vaults[subscription] = []
list_start = time.perf_counter()
vaults_list = list(client.vaults.list_by_subscription())
list_elapsed = time.perf_counter() - list_start
logger.info(f"KeyVault - list_by_subscription took {list_elapsed:.2f}s")
if not vaults_list:
continue
logger.info(
f"KeyVault - Found {len(vaults_list)} vaults in subscription {subscription}"
)
# Prepare items for parallel processing
items = [
{"subscription": subscription, "keyvault": vault}
for vault in vaults_list
]
# Process all KeyVaults in parallel
results = self.__threading_call__(self._process_single_keyvault, items)
key_vaults[subscription] = results
except Exception as error:
logger.error(
f"Subscription name: {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
total_elapsed = time.perf_counter() - total_start
logger.info(f"KeyVault - _get_key_vaults TOTAL took {total_elapsed:.2f}s")
return key_vaults
def _get_keys(self, subscription, resource_group, keyvault_name, provider):
logger.info(f"KeyVault - Getting keys for {keyvault_name}...")
def _process_single_keyvault(self, item: dict) -> Optional["KeyVaultInfo"]:
"""Process a single KeyVault in parallel."""
subscription = item["subscription"]
keyvault = item["keyvault"]
thread_id = threading.current_thread().name
try:
start_time = time.perf_counter()
resource_group = keyvault.id.split("/")[4]
keyvault_name = keyvault.name
logger.info(
f"KeyVault - [{thread_id}] Processing vault {keyvault_name} START"
)
keyvault_properties = keyvault.properties
# Fetch keys, secrets, and monitor in parallel
parallel_start = time.perf_counter()
with ThreadPoolExecutor(max_workers=3) as executor:
keys_future = executor.submit(
self._get_keys, subscription, resource_group, keyvault_name
)
secrets_future = executor.submit(
self._get_secrets, subscription, resource_group, keyvault_name
)
monitor_future = executor.submit(
self._get_vault_monitor_settings,
keyvault_name,
resource_group,
subscription,
)
keys = keys_future.result()
secrets = secrets_future.result()
monitor_settings = monitor_future.result()
parallel_elapsed = time.perf_counter() - parallel_start
total_elapsed = time.perf_counter() - start_time
logger.info(
f"KeyVault - [{thread_id}] Vault {keyvault_name} DONE: "
f"parallel={parallel_elapsed:.2f}s, total={total_elapsed:.2f}s, "
f"keys={len(keys)}, secrets={len(secrets)}"
)
return KeyVaultInfo(
id=getattr(keyvault, "id", ""),
name=getattr(keyvault, "name", ""),
location=getattr(keyvault, "location", ""),
resource_group=resource_group,
properties=VaultProperties(
tenant_id=getattr(keyvault_properties, "tenant_id", ""),
enable_rbac_authorization=getattr(
keyvault_properties,
"enable_rbac_authorization",
False,
),
private_endpoint_connections=[
PrivateEndpointConnection(id=conn.id)
for conn in (
getattr(
keyvault_properties,
"private_endpoint_connections",
[],
)
or []
)
],
enable_soft_delete=getattr(
keyvault_properties, "enable_soft_delete", False
),
enable_purge_protection=getattr(
keyvault_properties,
"enable_purge_protection",
False,
),
public_network_access_disabled=(
getattr(
keyvault_properties,
"public_network_access",
"Enabled",
)
== "Disabled"
),
),
keys=keys,
secrets=secrets,
monitor_diagnostic_settings=monitor_settings,
)
except Exception as error:
logger.error(
f"KeyVault {keyvault.name} in {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return None
def _get_keys(self, subscription, resource_group, keyvault_name):
thread_id = threading.current_thread().name
start_time = time.perf_counter()
logger.info(f"KeyVault - [{thread_id}] _get_keys({keyvault_name}) START")
keys = []
keys_dict = {}
try:
client = self.clients[subscription]
keys_list = client.keys.list(resource_group, keyvault_name)
for key in keys_list:
keys.append(
Key(
id=getattr(key, "id", ""),
name=getattr(key, "name", ""),
key_obj = Key(
id=getattr(key, "id", ""),
name=getattr(key, "name", ""),
enabled=getattr(key.attributes, "enabled", False),
location=getattr(key, "location", ""),
attributes=KeyAttributes(
enabled=getattr(key.attributes, "enabled", False),
location=getattr(key, "location", ""),
attributes=KeyAttributes(
enabled=getattr(key.attributes, "enabled", False),
created=getattr(key.attributes, "created", 0),
updated=getattr(key.attributes, "updated", 0),
expires=getattr(key.attributes, "expires", 0),
),
)
created=getattr(key.attributes, "created", 0),
updated=getattr(key.attributes, "updated", 0),
expires=getattr(key.attributes, "expires", 0),
),
)
keys.append(key_obj)
keys_dict[key_obj.name] = key_obj
except Exception as error:
logger.error(
f"Subscription name: {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -121,15 +201,21 @@ class KeyVault(AzureService):
try:
key_client = KeyClient(
vault_url=f"https://{keyvault_name}.vault.azure.net/",
# TODO: review the following line
credential=provider.session,
credential=self._provider.session,
)
properties = key_client.list_properties_of_keys()
for prop in properties:
policy = key_client.get_key_rotation_policy(prop.name)
for key in keys:
if key.name == prop.name:
key.rotation_policy = KeyRotationPolicy(
properties = list(key_client.list_properties_of_keys())
if properties:
items = [
{"key_client": key_client, "prop": prop} for prop in properties
]
rotation_results = self.__threading_call__(
self._get_single_rotation_policy, items
)
for name, policy in rotation_results:
if policy and name in keys_dict:
keys_dict[name].rotation_policy = KeyRotationPolicy(
id=getattr(policy, "id", ""),
lifetime_actions=[
KeyRotationLifetimeAction(action=action.action)
@@ -137,15 +223,40 @@ class KeyVault(AzureService):
],
)
# TODO: handle different errors here since we are catching all HTTP Errors here
except HttpResponseError:
logger.warning(
f"Subscription name: {subscription} -- has no access policy configured for keyvault {keyvault_name}"
)
elapsed = time.perf_counter() - start_time
logger.info(
f"KeyVault - [{thread_id}] _get_keys({keyvault_name}) DONE: "
f"{len(keys)} keys in {elapsed:.2f}s"
)
return keys
def _get_single_rotation_policy(self, item: dict) -> tuple:
"""Thread-safe rotation policy retrieval."""
key_client = item["key_client"]
prop = item["prop"]
try:
policy = key_client.get_key_rotation_policy(prop.name)
return (prop.name, policy)
except HttpResponseError:
return (prop.name, None)
except Exception as error:
logger.warning(
f"KeyVault - Failed to get rotation policy for key {prop.name}: {error}"
)
return (prop.name, None)
def _get_secrets(self, subscription, resource_group, keyvault_name):
logger.info(f"KeyVault - Getting secrets for {keyvault_name}...")
thread_id = threading.current_thread().name
start_time = time.perf_counter()
logger.info(f"KeyVault - [{thread_id}] _get_secrets({keyvault_name}) START")
secrets = []
try:
client = self.clients[subscription]
@@ -177,12 +288,20 @@ class KeyVault(AzureService):
logger.error(
f"Subscription name: {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
elapsed = time.perf_counter() - start_time
logger.info(
f"KeyVault - [{thread_id}] _get_secrets({keyvault_name}) DONE: "
f"{len(secrets)} secrets in {elapsed:.2f}s"
)
return secrets
def _get_vault_monitor_settings(self, keyvault_name, resource_group, subscription):
logger.info(
f"KeyVault - Getting monitor diagnostics settings for {keyvault_name}..."
)
thread_id = threading.current_thread().name
start_time = time.perf_counter()
logger.info(f"KeyVault - [{thread_id}] _get_monitor({keyvault_name}) START")
monitor_diagnostics_settings = []
try:
monitor_diagnostics_settings = monitor_client.diagnostic_settings_with_uri(
@@ -192,8 +311,15 @@ class KeyVault(AzureService):
)
except Exception as error:
logger.error(
f"Subscription name: {self.subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"Subscription name: {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
elapsed = time.perf_counter() - start_time
logger.info(
f"KeyVault - [{thread_id}] _get_monitor({keyvault_name}) DONE: "
f"{len(monitor_diagnostics_settings)} settings in {elapsed:.2f}s"
)
return monitor_diagnostics_settings