mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-24 04:28:02 +00:00
chore(azure): add vault parallelization in keyvault service (#9876)
This commit is contained in:
committed by
GitHub
parent
35f263dea6
commit
5784592437
@@ -16,6 +16,7 @@ All notable changes to the **Prowler SDK** are documented in this file.
|
||||
- Update Azure Container Registry service metadata to new format [(#9615)](https://github.com/prowler-cloud/prowler/pull/9615)
|
||||
- Update Azure Cosmos DB service metadata to new format [(#9616)](https://github.com/prowler-cloud/prowler/pull/9616)
|
||||
- Update Azure Databricks service metadata to new format [(#9617)](https://github.com/prowler-cloud/prowler/pull/9617)
|
||||
- Parallelize Azure Key Vault vaults and vaults contents retrieval to improve performance [(#9876)](https://github.com/prowler-cloud/prowler/pull/9876)
|
||||
- Update Azure IAM service metadata to new format [(#9620)](https://github.com/prowler-cloud/prowler/pull/9620)
|
||||
- Update Azure Policy service metadata to new format [(#9625)](https://github.com/prowler-cloud/prowler/pull/9625)
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.azure.azure_provider import AzureProvider
|
||||
|
||||
MAX_WORKERS = 10
|
||||
|
||||
|
||||
class AzureService:
|
||||
def __init__(
|
||||
@@ -20,6 +24,25 @@ class AzureService:
|
||||
self.audit_config = provider.audit_config
|
||||
self.fixer_config = provider.fixer_config
|
||||
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
def __threading_call__(self, call, iterator):
|
||||
"""Execute a function across multiple items using threading."""
|
||||
items = list(iterator) if not isinstance(iterator, list) else iterator
|
||||
|
||||
futures = {self.thread_pool.submit(call, item): item for item in items}
|
||||
results = []
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
def __set_clients__(self, identity, session, service, region_config):
|
||||
clients = {}
|
||||
try:
|
||||
|
||||
@@ -5,22 +5,21 @@ from prowler.providers.azure.services.keyvault.keyvault_client import keyvault_c
|
||||
class keyvault_rbac_secret_expiration_set(Check):
|
||||
def execute(self) -> Check_Report_Azure:
|
||||
findings = []
|
||||
|
||||
for subscription, key_vaults in keyvault_client.key_vaults.items():
|
||||
for keyvault in key_vaults:
|
||||
if keyvault.properties.enable_rbac_authorization and keyvault.secrets:
|
||||
report = Check_Report_Azure(
|
||||
metadata=self.metadata(), resource=keyvault
|
||||
)
|
||||
report.subscription = subscription
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"Keyvault {keyvault.name} from subscription {subscription} has all the secrets with expiration date set."
|
||||
has_secret_without_expiration = False
|
||||
for secret in keyvault.secrets:
|
||||
report = Check_Report_Azure(
|
||||
metadata=self.metadata(), resource=secret
|
||||
)
|
||||
report.subscription = subscription
|
||||
if not secret.attributes.expires and secret.enabled:
|
||||
report.status = "FAIL"
|
||||
report.status_extended = f"Keyvault {keyvault.name} from subscription {subscription} has the secret {secret.name} without expiration date set."
|
||||
has_secret_without_expiration = True
|
||||
findings.append(report)
|
||||
if not has_secret_without_expiration:
|
||||
report.status_extended = f"Secret '{secret.name}' in KeyVault '{keyvault.name}' does not have expiration date set."
|
||||
else:
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"Secret '{secret.name}' in KeyVault '{keyvault.name}' has expiration date set."
|
||||
findings.append(report)
|
||||
|
||||
return findings
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
@@ -20,99 +21,155 @@ class KeyVault(AzureService):
|
||||
self.key_vaults = self._get_key_vaults(provider)
|
||||
|
||||
def _get_key_vaults(self, provider):
|
||||
"""
|
||||
Get all KeyVaults with parallel processing.
|
||||
|
||||
Optimizations:
|
||||
1. Uses list_by_subscription() for full Vault objects
|
||||
2. Processes vaults in parallel using __threading_call__
|
||||
3. Each vault's keys/secrets/monitor fetched in parallel
|
||||
"""
|
||||
logger.info("KeyVault - Getting key_vaults...")
|
||||
key_vaults = {}
|
||||
|
||||
for subscription, client in self.clients.items():
|
||||
try:
|
||||
key_vaults.update({subscription: []})
|
||||
key_vaults_list = client.vaults.list()
|
||||
for keyvault in key_vaults_list:
|
||||
resource_group = keyvault.id.split("/")[4]
|
||||
keyvault_name = keyvault.name
|
||||
keyvault_properties = client.vaults.get(
|
||||
resource_group, keyvault_name
|
||||
).properties
|
||||
keys = self._get_keys(
|
||||
subscription, resource_group, keyvault_name, provider
|
||||
)
|
||||
secrets = self._get_secrets(
|
||||
subscription, resource_group, keyvault_name
|
||||
)
|
||||
key_vaults[subscription].append(
|
||||
KeyVaultInfo(
|
||||
id=getattr(keyvault, "id", ""),
|
||||
name=getattr(keyvault, "name", ""),
|
||||
location=getattr(keyvault, "location", ""),
|
||||
resource_group=resource_group,
|
||||
properties=VaultProperties(
|
||||
tenant_id=getattr(keyvault_properties, "tenant_id", ""),
|
||||
enable_rbac_authorization=getattr(
|
||||
keyvault_properties,
|
||||
"enable_rbac_authorization",
|
||||
False,
|
||||
),
|
||||
private_endpoint_connections=[
|
||||
PrivateEndpointConnection(id=conn.id)
|
||||
for conn in (
|
||||
getattr(
|
||||
keyvault_properties,
|
||||
"private_endpoint_connections",
|
||||
[],
|
||||
)
|
||||
or []
|
||||
)
|
||||
],
|
||||
enable_soft_delete=getattr(
|
||||
keyvault_properties, "enable_soft_delete", False
|
||||
),
|
||||
enable_purge_protection=getattr(
|
||||
keyvault_properties,
|
||||
"enable_purge_protection",
|
||||
False,
|
||||
),
|
||||
public_network_access_disabled=(
|
||||
getattr(
|
||||
keyvault_properties,
|
||||
"public_network_access",
|
||||
"Enabled",
|
||||
)
|
||||
== "Disabled"
|
||||
),
|
||||
),
|
||||
keys=keys,
|
||||
secrets=secrets,
|
||||
monitor_diagnostic_settings=self._get_vault_monitor_settings(
|
||||
keyvault_name, resource_group, subscription
|
||||
),
|
||||
)
|
||||
)
|
||||
key_vaults[subscription] = []
|
||||
vaults_list = list(client.vaults.list_by_subscription())
|
||||
|
||||
if not vaults_list:
|
||||
continue
|
||||
|
||||
# Prepare items for parallel processing
|
||||
items = [
|
||||
{
|
||||
"subscription": subscription,
|
||||
"keyvault": vault,
|
||||
"provider": provider,
|
||||
}
|
||||
for vault in vaults_list
|
||||
]
|
||||
|
||||
# Process all KeyVaults in parallel
|
||||
results = self.__threading_call__(self._process_single_keyvault, items)
|
||||
key_vaults[subscription] = results
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"Subscription name: {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
return key_vaults
|
||||
|
||||
def _process_single_keyvault(self, item: dict) -> Optional["KeyVaultInfo"]:
|
||||
"""Process a single KeyVault in parallel."""
|
||||
subscription = item["subscription"]
|
||||
keyvault = item["keyvault"]
|
||||
provider = item["provider"]
|
||||
|
||||
try:
|
||||
resource_group = keyvault.id.split("/")[4]
|
||||
keyvault_name = keyvault.name
|
||||
keyvault_properties = keyvault.properties
|
||||
|
||||
# Fetch keys, secrets, and monitor in parallel
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
keys_future = executor.submit(
|
||||
self._get_keys,
|
||||
subscription,
|
||||
resource_group,
|
||||
keyvault_name,
|
||||
provider,
|
||||
)
|
||||
secrets_future = executor.submit(
|
||||
self._get_secrets, subscription, resource_group, keyvault_name
|
||||
)
|
||||
monitor_future = executor.submit(
|
||||
self._get_vault_monitor_settings,
|
||||
keyvault_name,
|
||||
resource_group,
|
||||
subscription,
|
||||
)
|
||||
|
||||
keys = keys_future.result()
|
||||
secrets = secrets_future.result()
|
||||
monitor_settings = monitor_future.result()
|
||||
|
||||
return KeyVaultInfo(
|
||||
id=getattr(keyvault, "id", ""),
|
||||
name=getattr(keyvault, "name", ""),
|
||||
location=getattr(keyvault, "location", ""),
|
||||
resource_group=resource_group,
|
||||
properties=VaultProperties(
|
||||
tenant_id=getattr(keyvault_properties, "tenant_id", ""),
|
||||
enable_rbac_authorization=getattr(
|
||||
keyvault_properties,
|
||||
"enable_rbac_authorization",
|
||||
False,
|
||||
),
|
||||
private_endpoint_connections=[
|
||||
PrivateEndpointConnection(id=conn.id)
|
||||
for conn in (
|
||||
getattr(
|
||||
keyvault_properties,
|
||||
"private_endpoint_connections",
|
||||
[],
|
||||
)
|
||||
or []
|
||||
)
|
||||
],
|
||||
enable_soft_delete=getattr(
|
||||
keyvault_properties, "enable_soft_delete", False
|
||||
),
|
||||
enable_purge_protection=getattr(
|
||||
keyvault_properties,
|
||||
"enable_purge_protection",
|
||||
False,
|
||||
),
|
||||
public_network_access_disabled=(
|
||||
getattr(
|
||||
keyvault_properties,
|
||||
"public_network_access",
|
||||
"Enabled",
|
||||
)
|
||||
== "Disabled"
|
||||
),
|
||||
),
|
||||
keys=keys,
|
||||
secrets=secrets,
|
||||
monitor_diagnostic_settings=monitor_settings,
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"KeyVault {keyvault.name} in {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_keys(self, subscription, resource_group, keyvault_name, provider):
|
||||
logger.info(f"KeyVault - Getting keys for {keyvault_name}...")
|
||||
keys = []
|
||||
keys_dict = {}
|
||||
|
||||
try:
|
||||
client = self.clients[subscription]
|
||||
keys_list = client.keys.list(resource_group, keyvault_name)
|
||||
for key in keys_list:
|
||||
keys.append(
|
||||
Key(
|
||||
id=getattr(key, "id", ""),
|
||||
name=getattr(key, "name", ""),
|
||||
key_obj = Key(
|
||||
id=getattr(key, "id", ""),
|
||||
name=getattr(key, "name", ""),
|
||||
enabled=getattr(key.attributes, "enabled", False),
|
||||
location=getattr(key, "location", ""),
|
||||
attributes=KeyAttributes(
|
||||
enabled=getattr(key.attributes, "enabled", False),
|
||||
location=getattr(key, "location", ""),
|
||||
attributes=KeyAttributes(
|
||||
enabled=getattr(key.attributes, "enabled", False),
|
||||
created=getattr(key.attributes, "created", 0),
|
||||
updated=getattr(key.attributes, "updated", 0),
|
||||
expires=getattr(key.attributes, "expires", 0),
|
||||
),
|
||||
)
|
||||
created=getattr(key.attributes, "created", 0),
|
||||
updated=getattr(key.attributes, "updated", 0),
|
||||
expires=getattr(key.attributes, "expires", 0),
|
||||
),
|
||||
)
|
||||
keys.append(key_obj)
|
||||
keys_dict[key_obj.name] = key_obj
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"Subscription name: {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
@@ -124,12 +181,19 @@ class KeyVault(AzureService):
|
||||
# TODO: review the following line
|
||||
credential=provider.session,
|
||||
)
|
||||
properties = key_client.list_properties_of_keys()
|
||||
for prop in properties:
|
||||
policy = key_client.get_key_rotation_policy(prop.name)
|
||||
for key in keys:
|
||||
if key.name == prop.name:
|
||||
key.rotation_policy = KeyRotationPolicy(
|
||||
properties = list(key_client.list_properties_of_keys())
|
||||
|
||||
if properties:
|
||||
items = [
|
||||
{"key_client": key_client, "prop": prop} for prop in properties
|
||||
]
|
||||
rotation_results = self.__threading_call__(
|
||||
self._get_single_rotation_policy, items
|
||||
)
|
||||
|
||||
for name, policy in rotation_results:
|
||||
if policy and name in keys_dict:
|
||||
keys_dict[name].rotation_policy = KeyRotationPolicy(
|
||||
id=getattr(policy, "id", ""),
|
||||
lifetime_actions=[
|
||||
KeyRotationLifetimeAction(action=action.action)
|
||||
@@ -142,8 +206,25 @@ class KeyVault(AzureService):
|
||||
logger.warning(
|
||||
f"Subscription name: {subscription} -- has no access policy configured for keyvault {keyvault_name}"
|
||||
)
|
||||
|
||||
return keys
|
||||
|
||||
def _get_single_rotation_policy(self, item: dict) -> tuple:
|
||||
"""Thread-safe rotation policy retrieval."""
|
||||
key_client = item["key_client"]
|
||||
prop = item["prop"]
|
||||
|
||||
try:
|
||||
policy = key_client.get_key_rotation_policy(prop.name)
|
||||
return (prop.name, policy)
|
||||
except HttpResponseError:
|
||||
return (prop.name, None)
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
f"KeyVault - Failed to get rotation policy for key {prop.name}: {error}"
|
||||
)
|
||||
return (prop.name, None)
|
||||
|
||||
def _get_secrets(self, subscription, resource_group, keyvault_name):
|
||||
logger.info(f"KeyVault - Getting secrets for {keyvault_name}...")
|
||||
secrets = []
|
||||
@@ -177,6 +258,7 @@ class KeyVault(AzureService):
|
||||
logger.error(
|
||||
f"Subscription name: {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
return secrets
|
||||
|
||||
def _get_vault_monitor_settings(self, keyvault_name, resource_group, subscription):
|
||||
@@ -192,8 +274,9 @@ class KeyVault(AzureService):
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"Subscription name: {self.subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"Subscription name: {subscription} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
return monitor_diagnostics_settings
|
||||
|
||||
|
||||
|
||||
@@ -97,11 +97,12 @@ class Test_keyvault_rbac_secret_expiration_set:
|
||||
Secret,
|
||||
)
|
||||
|
||||
secret_id = str(uuid4())
|
||||
secret = Secret(
|
||||
id="id",
|
||||
id=secret_id,
|
||||
name=secret_name,
|
||||
enabled=True,
|
||||
location="location",
|
||||
location="westeurope",
|
||||
attributes=SecretAttributes(expires=None),
|
||||
)
|
||||
keyvault_client.key_vaults = {
|
||||
@@ -127,11 +128,11 @@ class Test_keyvault_rbac_secret_expiration_set:
|
||||
assert result[0].status == "FAIL"
|
||||
assert (
|
||||
result[0].status_extended
|
||||
== f"Keyvault {keyvault_name} from subscription {AZURE_SUBSCRIPTION_ID} has the secret {secret_name} without expiration date set."
|
||||
== f"Secret '{secret_name}' in KeyVault '{keyvault_name}' does not have expiration date set."
|
||||
)
|
||||
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
|
||||
assert result[0].resource_name == keyvault_name
|
||||
assert result[0].resource_id == keyvault_id
|
||||
assert result[0].resource_name == secret_name
|
||||
assert result[0].resource_id == secret_id
|
||||
assert result[0].location == "westeurope"
|
||||
|
||||
def test_key_vaults_invalid_multiple_secrets(self):
|
||||
@@ -159,18 +160,20 @@ class Test_keyvault_rbac_secret_expiration_set:
|
||||
Secret,
|
||||
)
|
||||
|
||||
secret1_id = str(uuid4())
|
||||
secret2_id = str(uuid4())
|
||||
secret1 = Secret(
|
||||
id="id",
|
||||
id=secret1_id,
|
||||
name=secret1_name,
|
||||
enabled=True,
|
||||
location="location",
|
||||
location="westeurope",
|
||||
attributes=SecretAttributes(expires=None),
|
||||
)
|
||||
secret2 = Secret(
|
||||
id="id",
|
||||
id=secret2_id,
|
||||
name=secret2_name,
|
||||
enabled=True,
|
||||
location="location",
|
||||
location="westeurope",
|
||||
attributes=SecretAttributes(expires=84934),
|
||||
)
|
||||
keyvault_client.key_vaults = {
|
||||
@@ -192,16 +195,35 @@ class Test_keyvault_rbac_secret_expiration_set:
|
||||
}
|
||||
check = keyvault_rbac_secret_expiration_set()
|
||||
result = check.execute()
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "FAIL"
|
||||
# Now we get 1 finding per secret (2 total)
|
||||
assert len(result) == 2
|
||||
|
||||
# Find the FAIL and PASS results by status
|
||||
fail_results = [r for r in result if r.status == "FAIL"]
|
||||
pass_results = [r for r in result if r.status == "PASS"]
|
||||
|
||||
assert len(fail_results) == 1
|
||||
assert len(pass_results) == 1
|
||||
|
||||
# Verify FAIL finding (secret1 without expiration)
|
||||
assert (
|
||||
result[0].status_extended
|
||||
== f"Keyvault {keyvault_name} from subscription {AZURE_SUBSCRIPTION_ID} has the secret {secret1_name} without expiration date set."
|
||||
fail_results[0].status_extended
|
||||
== f"Secret '{secret1_name}' in KeyVault '{keyvault_name}' does not have expiration date set."
|
||||
)
|
||||
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
|
||||
assert result[0].resource_name == keyvault_name
|
||||
assert result[0].resource_id == keyvault_id
|
||||
assert result[0].location == "westeurope"
|
||||
assert fail_results[0].subscription == AZURE_SUBSCRIPTION_ID
|
||||
assert fail_results[0].resource_name == secret1_name
|
||||
assert fail_results[0].resource_id == secret1_id
|
||||
assert fail_results[0].location == "westeurope"
|
||||
|
||||
# Verify PASS finding (secret2 with expiration)
|
||||
assert (
|
||||
pass_results[0].status_extended
|
||||
== f"Secret '{secret2_name}' in KeyVault '{keyvault_name}' has expiration date set."
|
||||
)
|
||||
assert pass_results[0].subscription == AZURE_SUBSCRIPTION_ID
|
||||
assert pass_results[0].resource_name == secret2_name
|
||||
assert pass_results[0].resource_id == secret2_id
|
||||
assert pass_results[0].location == "westeurope"
|
||||
|
||||
def test_key_vaults_valid_keys(self):
|
||||
keyvault_client = mock.MagicMock
|
||||
@@ -226,11 +248,13 @@ class Test_keyvault_rbac_secret_expiration_set:
|
||||
Secret,
|
||||
)
|
||||
|
||||
secret_name = "secret-name"
|
||||
secret_id = str(uuid4())
|
||||
secret = Secret(
|
||||
id="id",
|
||||
name="name",
|
||||
id=secret_id,
|
||||
name=secret_name,
|
||||
enabled=False,
|
||||
location="location",
|
||||
location="westeurope",
|
||||
attributes=SecretAttributes(expires=None),
|
||||
)
|
||||
keyvault_client.key_vaults = {
|
||||
@@ -256,9 +280,9 @@ class Test_keyvault_rbac_secret_expiration_set:
|
||||
assert result[0].status == "PASS"
|
||||
assert (
|
||||
result[0].status_extended
|
||||
== f"Keyvault {keyvault_name} from subscription {AZURE_SUBSCRIPTION_ID} has all the secrets with expiration date set."
|
||||
== f"Secret '{secret_name}' in KeyVault '{keyvault_name}' has expiration date set."
|
||||
)
|
||||
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
|
||||
assert result[0].resource_name == keyvault_name
|
||||
assert result[0].resource_id == keyvault_id
|
||||
assert result[0].resource_name == secret_name
|
||||
assert result[0].resource_id == secret_id
|
||||
assert result[0].location == "westeurope"
|
||||
|
||||
Reference in New Issue
Block a user