feat(mcp_server): add provider management tools (#9350)

This commit is contained in:
Rubén De la Torre Vico
2025-12-10 17:31:21 +01:00
committed by GitHub
parent 0c3ba0b737
commit 978e2c82af
6 changed files with 826 additions and 5 deletions

View File

@@ -7,6 +7,7 @@ All notable changes to the **Prowler MCP Server** are documented in this file.
### Added
- Remove all Prowler App MCP tools; and add new MCP Server tools for Prowler Findings and Compliance [(#9300)](https://github.com/prowler-cloud/prowler/pull/9300)
- Add new MCP Server tools for Prowler Providers Management [(#9350)](https://github.com/prowler-cloud/prowler/pull/9350)
---

View File

@@ -1,7 +1,6 @@
"""Pydantic models for Prowler App MCP Server."""
from prowler_mcp_server.prowler_app.models.base import MinimalSerializerMixin
from prowler_mcp_server.prowler_app.models.findings import (
CheckMetadata,
CheckRemediation,

View File

@@ -27,18 +27,19 @@ class MinimalSerializerMixin(BaseModel):
Dictionary with non-empty values only
"""
data = handler(self)
return {k: v for k, v in data.items() if not self._should_exclude(v)}
return {k: v for k, v in data.items() if not self._should_exclude(k, v)}
def _should_exclude(self, value: Any) -> bool:
"""Determine if a value should be excluded from serialization.
def _should_exclude(self, key: str, value: Any) -> bool:
"""Determine if a key-value pair should be excluded from serialization.
Override this method in subclasses for custom exclusion logic.
Args:
key: Field name
value: Field value
Returns:
True if the value should be excluded, False otherwise
True if the field should be excluded, False otherwise
"""
# None values
if value is None:

View File

@@ -0,0 +1,134 @@
"""Pydantic models for simplified provider responses."""
from typing import Any, Literal
from prowler_mcp_server.prowler_app.models.base import MinimalSerializerMixin
from pydantic import BaseModel
class SimplifiedProvider(MinimalSerializerMixin, BaseModel):
"""Simplified provider for list/search operations."""
id: str
uid: str
alias: str | None = None
provider: str
connected: bool | None = None
secret_type: Literal["role", "service_account", "static"] | None = None
def _should_exclude(self, key: str, value: Any) -> bool:
"""Override to always include connected and secret_type fields even when None."""
# Always include these fields regardless of value (None has semantic meaning)
if key == "connected" or key == "secret_type":
return False
# Use parent class logic for other fields
return super()._should_exclude(key, value)
@classmethod
def from_api_response(cls, data: dict[str, Any]) -> "SimplifiedProvider":
"""Transform JSON:API provider response to simplified format."""
attributes = data["attributes"]
connection_data = attributes.get("connection", {})
return cls(
id=data["id"],
uid=attributes["uid"],
alias=attributes.get("alias"),
provider=attributes["provider"],
connected=connection_data.get("connected"),
secret_type=None, # Will be populated separately via secret endpoint
)
class DetailedProvider(SimplifiedProvider):
"""Detailed provider with complete information for deep analysis.
Extends SimplifiedProvider with temporal metadata and relationships.
Use this when you need complete context about a specific provider.
"""
inserted_at: str | None = None
updated_at: str | None = None
last_checked_at: str | None = None
provider_group_ids: list[str] | None = None
@classmethod
def from_api_response(cls, data: dict[str, Any]) -> "DetailedProvider":
"""Transform JSON:API provider response to detailed format."""
attributes = data["attributes"]
connection_data = attributes.get("connection", {})
relationships = data.get("relationships", {})
# Extract provider groups relationship
provider_group_ids = None
groups_data = relationships.get("provider_groups", {}).get("data", [])
if groups_data:
provider_group_ids = [group["id"] for group in groups_data]
return cls(
id=data["id"],
uid=attributes["uid"],
alias=attributes.get("alias"),
provider=attributes["provider"],
connected=connection_data.get("connected"),
inserted_at=attributes.get("inserted_at"),
updated_at=attributes.get("updated_at"),
last_checked_at=connection_data.get("last_checked_at"),
provider_group_ids=provider_group_ids,
)
class ProvidersListResponse(BaseModel):
"""Simplified response for providers list queries."""
providers: list[SimplifiedProvider]
total_num_providers: int
total_num_pages: int
current_page: int
@classmethod
def from_api_response(cls, response: dict[str, Any]) -> "ProvidersListResponse":
"""Transform JSON:API response to simplified format."""
data = response["data"]
meta = response["meta"]
pagination = meta["pagination"]
providers = [SimplifiedProvider.from_api_response(item) for item in data]
return cls(
providers=providers,
total_num_providers=pagination["count"],
total_num_pages=pagination["pages"],
current_page=pagination["page"],
)
class ProviderConnectionStatus(MinimalSerializerMixin, BaseModel):
"""Result of provider connection operation."""
provider: DetailedProvider
connected: Literal["connected", "failed", "not_tested"]
error: str | None = None
@classmethod
def create(
cls,
provider_data: dict[str, Any],
connection_status: dict[str, Any],
) -> "ProviderConnectionStatus":
"""Create connection status from provider data and connection test result."""
connected: str | None = connection_status.get("connected", None)
if connected is None:
connected = "not_tested"
elif connected:
connected = "connected"
else:
connected = "failed"
return cls(
provider=DetailedProvider.from_api_response(provider_data),
connected=connected,
error=connection_status.get("error", None),
)

View File

@@ -0,0 +1,623 @@
"""Provider Management tools for Prowler App MCP Server.
This module provides tools for managing provider connections,
including searching, connecting, and deleting providers.
"""
from typing import Any
from prowler_mcp_server.prowler_app.models.providers import (
ProviderConnectionStatus,
ProvidersListResponse,
)
from prowler_mcp_server.prowler_app.tools.base import BaseTool
from pydantic import Field
class ProvidersTools(BaseTool):
"""Tools for provider management operations
Provides tools for:
- prowler_app_search_providers: Search and view configured providers with their connection status
- prowler_app_connect_provider: Connect or register a provider for security scanning in Prowler
- prowler_app_delete_provider: Permanently remove a provider from Prowler
"""
async def search_providers(
self,
provider_id: list[str] = Field(
default=[],
description="Filter by Prowler's internal UUID(s) (v4) for the provider(s), generated when the provider is registered in the system.",
),
provider_uid: list[str] = Field(
default=[],
description="Filter by provider's unique identifier(s), this ID is the one provided by the provider itself. Format varies by provider type: AWS Account ID (12 digits), Azure Subscription ID (UUID), GCP Project ID (string), Kubernetes namespace, GitHub username/organization, M365 domain ID, etc. All supported provider types are listed in the Prowler Hub/Prowler Documentation that you can also find in form of tools in this MCP Server",
),
provider_type: list[str] = Field(
default=[],
description="Filter by provider type. Valid values include: 'aws', 'azure', 'gcp', 'kubernetes'... For more valid values, please refer to Prowler Hub/Prowler Documentation that you can also find in form of tools in this MCP Server.",
),
alias: str | None = Field(
default=None,
description="Search by provider alias/friendly name. Partial match supported (case-insensitive). Use this to find providers by their human-readable name (e.g., 'Production', 'Dev', 'AWS Main')",
),
connected: (
bool | str | None
) = Field( # Wrong `str` hint type due to bad MCP Clients implementation
default=None,
description="Filter by connection status. True returns only successfully connected providers (credentials work), False returns only providers with failed connections (credentials invalid). If not specified, returns all connected, failed and not tested providers. Strings 'true' and 'false' are also accepted.",
),
page_size: int = Field(
default=50, description="Number of results to return per page"
),
page_number: int = Field(
default=1,
description="Page number to retrieve (1-indexed)",
),
) -> dict[str, Any]:
"""Search and view configured providers to be scanned with Prowler.
This tool returns a unified view of all providers configured in Prowler.
For getting more details about what types of providers are available to be scanned with Prowler or
what are the UIDs are accepted for each provider type, please refer to Prowler Hub/Prowler Documentation
that you can also find in form of tools in this MCP Server.
Each provider includes:
- Provider identification: Prowler Internal ID, External Provider UID, Provider Alias
- Provider context: Provider Type
- Connection status: Connected (true), Failed (false), Not Tested (null)
"""
self.api_client.validate_page_size(page_size)
params = {
"fields[providers]": "uid,alias,provider,connection,secret",
"page[number]": page_number,
"page[size]": page_size,
}
# Build filter parameters
if provider_id:
params["filter[id__in]"] = provider_id
if provider_uid:
params["filter[uid__in]"] = provider_uid
if provider_type:
params["filter[provider__in]"] = provider_type
if alias:
params["filter[alias__icontains]"] = alias
if connected is not None:
if isinstance(connected, bool):
params["filter[connected]"] = connected
else:
if connected.lower() == "true":
params["filter[connected]"] = True
elif connected.lower() == "false":
params["filter[connected]"] = False
else:
raise ValueError(
f"Invalid connected value: {connected}. Valid values are True, False, 'true', 'false' or None."
)
clean_params = self.api_client.build_filter_params(params)
api_response = await self.api_client.get(
"/api/v1/providers", params=clean_params
)
simplified_response = ProvidersListResponse.from_api_response(api_response)
# Fetch secret_type for each provider that has a secret
for provider in simplified_response.providers:
# Get the provider data from the API response to access relationships
provider_data = next(
(
provider_api_response
for provider_api_response in api_response["data"]
if provider_api_response["id"] == provider.id
),
None,
)
if provider_data:
secret_relationship = provider_data.get("relationships", {}).get(
"secret", {}
)
secret_data = secret_relationship.get("data")
if secret_data:
secret_id = secret_data["id"]
provider.secret_type = await self._get_secret_type(secret_id)
return simplified_response.model_dump()
async def connect_provider(
self,
provider_uid: str = Field(
description="Provider's unique identifier. For supported UID provider formats, please refer to Prowler Hub/Prowler Documentation that you can also find in form of tools in this MCP Server"
),
provider_type: str = Field(
description="Type of provider to be scanned with Prowler. Valid values include: 'aws', 'azure', 'gcp', 'kubernetes'... For more valid values, please refer to Prowler Hub/Prowler Documentation that you can also find in form of tools in this MCP Server."
),
alias: str | None = Field(
default=None,
description="Human-friendly name for this provider. Optional but recommended for easy identification. Use descriptive names to distinguish multiple accounts of the same type.",
),
credentials: dict[str, Any] | None = Field(
default=None,
description="Provider-specific credentials for authentication. Optional - if not provided, provider is created but not connected. Structure varies by provider type. For supported provider types, please refer to Prowler Hub/Prowler Documentation that you can also find in form of tools in this MCP Server",
),
) -> dict[str, Any]:
"""Register a provider to be scanned with Prowler.
This tool will register a provider in Prowler App, even if the UID is wrong.
If the provider is already registered, it will be updated with the new provided alias or credentials if provided.
If credentials are provided, they will be added to the indicated provider, if the provider does not exist, it will be created and the credentials will be added to it.
If the connection test is successful, the provider will be connected.
If the connection test fails, the provider will be created but not connected.
The tool always returns the provider details after its registration or update.
Example Input:
- AWS Static Credentials:
```json
{
"provider_uid": "123456789012",
"provider_type": "aws",
"alias": "production-aws-account",
"credentials": {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "...",
"aws_session_token": "..."
}
}
```
- AWS Assume Role:
```json
{
"provider_uid": "987654321098",
"provider_type": "aws",
"alias": "staging-aws-account",
"credentials": {
"role_arn": "arn:aws:iam::987654321098:role/ProwlerScanRole",
"external_id": "...",
"aws_access_key_id": "AKIA...", # Optional
"aws_secret_access_key": "...", # Optional
"aws_session_token": "...", # Optional
"session_duration": 3600, # Optional
"role_session_name": "..." # Optional
}
}
```
- Azure/M365 Static Credentials:
```json
{
"provider_uid": "a1b2c3d4-e5f6-4a5b-8c9d-0e1f2a3b4c5d",
"provider_type": "azure",
"alias": "production-azure-subscription",
"credentials": {
"client_id": "...",
"client_secret": "...",
"tenant_id": "..."
}
}
```
- GCP Service Account Account Key:
```json
{
"provider_uid": "my-gcp-project-prod",
"provider_type": "gcp",
"alias": "production-gcp-project",
"credentials": {
"service_account_key": {
"type": "service_account",
"project_id": "...",
"private_key_id": "...",
"private_key": "...",
"client_email": "...",
}
}
}
```
- Kubernetes Static Credentials:
```json
{
"provider_uid": "prod-k8s-cluster",
"provider_type": "kubernetes",
"alias": "production-kubernetes-cluster",
"credentials": {
"kubeconfig_content": "..."
}
}
```
- GitHub OAuth App Token:
```json
{
"provider_uid": "my-organization",
"provider_type": "github",
"alias": "my-github-organization",
"credentials": {
"oauth_app_token": "..."
}
}
NOTE: THERE ARE MORE PROVIDER TYPES AND CREDENTIAL TYPES AVAILABLE, PLEASE REFER TO THE Prowler Hub/Prowler Documentation that you can also find in form of tools in this MCP Server.
"""
# Step 1: Check if provider already exists
prowler_provider_id = await self._check_provider_exists(provider_uid)
# Step 2: Create or update provider
if prowler_provider_id is None:
prowler_provider_id = await self._create_provider(
provider_uid, provider_type, alias
)
elif alias:
await self._update_provider_alias(prowler_provider_id, alias)
# Step 3: Handle credentials if provided and capture secret response
secret_response = None
if credentials:
secret_response = await self._store_credentials(
prowler_provider_id, credentials
)
# Step 4: Test connection
connection_status = await self._test_connection(prowler_provider_id)
# Step 5: Get final provider state with relationships
final_provider = await self._get_final_provider_state(prowler_provider_id)
# Transform to structured response using model
connection_result = ProviderConnectionStatus.create(
provider_data=final_provider["data"],
connection_status=connection_status,
)
if secret_response:
# We just stored credentials, use the secret_type from the response
connection_result.provider.secret_type = (
secret_response.get("data", {}).get("attributes", {}).get("secret_type")
)
else:
# No new credentials provided, check if provider has an existing secret
secret_data = (
final_provider.get("data", {})
.get("relationships", {})
.get("secret", {})
.get("data")
)
if secret_data:
# Provider has existing secret, fetch its type
secret_id = secret_data["id"]
connection_result.provider.secret_type = await self._get_secret_type(
secret_id
)
return connection_result.model_dump()
async def delete_provider(
self,
provider_id: str = Field(
description="Prowler's internal UUID (v4) for the provider to permanently remove, generated when the provider was registered in the system. Use `prowler_app_search_providers` tool to find the provider_id if you only know the alias or the provider's own identifier (provider_uid)"
),
) -> dict[str, Any]:
"""Permanently remove a registered provider from Prowler.
WARNING: This is a destructive operation that cannot be undone. The provider will need to be
re-added with prowler_app_connect_provider if you want to scan it again.
The tool always returns the deletion status and message.
"""
self.logger.info(f"Deleting provider {provider_id}...")
try:
# Initiate the deletion task
task_response = await self.api_client.delete(
f"/api/v1/providers/{provider_id}"
)
task_id = task_response.get("data", {}).get("id")
# Poll until task completes (with 60 second timeout)
await self.api_client.poll_task_until_complete(
task_id=task_id, timeout=60, poll_interval=1.0
)
# If we reach here, the task completed successfully
return {
"deleted": True,
"message": f"Provider {provider_id} deleted successfully",
}
except Exception as e:
self.logger.error(f"Provider deletion failed: {e}")
return {
"deleted": False,
"message": f"Provider {provider_id} deletion failed: {str(e)}",
}
# Private helper methods
async def _check_provider_exists(self, provider_uid: str) -> str | None:
"""Check if a provider already exists by its UID.
Args:
provider_uid: The provider's unique identifier (e.g., AWS account ID)
Returns:
The Prowler-generated provider ID if exists, None otherwise
Raises:
Exception: If multiple providers with the same UID are found (data integrity issue)
Exception: If API request fails
"""
self.logger.info(f"Checking if provider {provider_uid} exists...")
response = await self.api_client.get(
"/api/v1/providers", params={"filter[uid]": provider_uid}
)
providers = response.get("data", [])
if len(providers) == 0:
self.logger.info(f"Provider {provider_uid} does not exist")
return None
elif len(providers) == 1:
prowler_provider_id = providers[0].get("id")
self.logger.info(
f"Provider {provider_uid} exists with ID {prowler_provider_id}"
)
return prowler_provider_id
else:
# Multiple providers with the same UID is a data integrity issue
raise Exception(
f"Data integrity error: Found {len(providers)} providers with UID '{provider_uid}'. "
f"Each provider UID should be unique. Please contact support or manually clean up duplicate providers."
)
async def _create_provider(
self, provider_uid: str, provider_type: str, alias: str | None
) -> str:
"""Create a new provider.
Args:
provider_uid: The provider's unique identifier
provider_type: Type of provider to be scanned with Prowler (aws, azure, gcp, etc.)
alias: Optional human-friendly name for the provider
Returns:
The provider UID (which is used as the ID)
"""
self.logger.info(f"Creating provider {provider_uid} (type: {provider_type})...")
provider_body = {
"data": {
"type": "providers",
"attributes": {
"uid": provider_uid,
"provider": provider_type,
},
}
}
if alias:
provider_body["data"]["attributes"]["alias"] = alias
await self.api_client.post("/api/v1/providers", json_data=provider_body)
provider_id = await self._check_provider_exists(provider_uid)
if provider_id is None:
raise Exception(f"Provider {provider_uid} creation failed")
return provider_id
async def _update_provider_alias(
self, prowler_provider_id: str, alias: str
) -> None:
"""Update the alias of an existing provider.
Args:
prowler_provider_id: The Prowler-generated provider ID
alias: New human-friendly name for the provider
"""
self.logger.info(f"Updating provider {prowler_provider_id} alias...")
update_body = {
"data": {
"type": "providers",
"id": prowler_provider_id,
"attributes": {
"alias": alias,
},
}
}
result = await self.api_client.patch(
f"/api/v1/providers/{prowler_provider_id}", json_data=update_body
)
if result.get("data", {}).get("attributes", {}).get("alias") != alias:
raise Exception(f"Provider {prowler_provider_id} alias update failed")
def _determine_secret_type(self, credentials: dict[str, Any]) -> str:
"""Determine the secret type from credentials structure.
Args:
credentials: The credentials dictionary
Returns:
Secret type: "role", "service_account", or "static"
"""
if "role_arn" in credentials:
return "role"
elif "service_account_key" in credentials:
return "service_account"
else:
return "static"
async def _get_provider_secret_id(self, prowler_provider_id: str) -> str | None:
"""Get the secret ID for a provider if it exists.
Args:
prowler_provider_id: The Prowler-generated provider ID
Returns:
The secret ID if exists, None otherwise
"""
try:
response = await self.api_client.get(
"/api/v1/providers/secrets",
params={"filter[provider]": prowler_provider_id},
)
secrets = response.get("data", [])
if len(secrets) > 0:
secret_id = secrets[0].get("id")
self.logger.info(
f"Found existing secret {secret_id} for provider {prowler_provider_id}"
)
return secret_id
else:
self.logger.info(
f"No existing secret found for provider {prowler_provider_id}"
)
return None
except Exception as e:
self.logger.error(f"Error checking for existing secret: {e}")
return None
async def _get_secret_type(self, secret_id: str) -> str | None:
"""Get the secret type for a given secret ID.
Args:
secret_id: The secret ID from provider relationships
Returns:
The secret type ("role", "service_account", or "static") if found, None otherwise
"""
try:
response = await self.api_client.get(
f"/api/v1/providers/secrets/{secret_id}",
params={"fields[provider-secrets]": "secret_type"},
)
secret_type = (
response.get("data", {}).get("attributes", {}).get("secret_type")
)
return secret_type
except Exception as e:
self.logger.error(f"Error fetching secret type for {secret_id}: {e}")
return None
async def _store_credentials(
self, prowler_provider_id: str, credentials: dict[str, Any]
) -> dict[str, Any]:
"""Store or update credentials for a provider.
Args:
prowler_provider_id: The Prowler-generated provider ID
credentials: The credentials to store
Returns:
The API response with the secret data
"""
self.logger.info(
f"Adding/updating credentials for provider {prowler_provider_id}..."
)
secret_type = self._determine_secret_type(credentials)
# Check if a secret already exists for this provider
existing_secret_id = await self._get_provider_secret_id(prowler_provider_id)
if existing_secret_id:
# Update existing secret
self.logger.info(f"Updating existing secret {existing_secret_id}...")
update_body = {
"data": {
"type": "provider-secrets",
"id": existing_secret_id,
"attributes": {
"secret_type": secret_type,
"secret": credentials,
},
"relationships": {
"provider": {
"data": {
"type": "providers",
"id": prowler_provider_id,
}
}
},
}
}
try:
response = await self.api_client.patch(
f"/api/v1/providers/secrets/{existing_secret_id}",
json_data=update_body,
)
self.logger.info("Credentials updated successfully")
return response
except Exception as e:
self.logger.error(f"Error updating credentials: {e}")
raise
else:
# Create new secret
self.logger.info("Creating new secret...")
secret_body = {
"data": {
"type": "provider-secrets",
"attributes": {
"secret_type": secret_type,
"secret": credentials,
},
"relationships": {
"provider": {
"data": {
"type": "providers",
"id": prowler_provider_id,
}
}
},
}
}
try:
response = await self.api_client.post(
"/api/v1/providers/secrets", json_data=secret_body
)
self.logger.info("Credentials added successfully")
return response
except Exception as e:
self.logger.error(f"Error adding credentials: {e}")
raise
async def _test_connection(self, prowler_provider_id: str) -> dict[str, Any]:
"""Test connection to a provider.
Args:
prowler_provider_id: The Prowler-generated provider ID
Returns:
Connection status dictionary with 'connected' boolean and optional 'error' message
"""
self.logger.info(f"Testing connection for provider {prowler_provider_id}...")
try:
# Initiate the connection test task
task_response = await self.api_client.post(
f"/api/v1/providers/{prowler_provider_id}/connection", json_data={}
)
task_id = task_response.get("data", {}).get("id")
# Poll until task completes (with 60 second timeout)
completed_task = await self.api_client.poll_task_until_complete(
task_id=task_id, timeout=60, poll_interval=1.0
)
# Extract the result from the completed task
task_result = (
completed_task.get("data", {}).get("attributes", {}).get("result", {})
)
return task_result
except Exception as e:
self.logger.error(f"Connection test failed: {e}")
return {"connected": False, "error": str(e)}
async def _get_final_provider_state(
self, prowler_provider_id: str
) -> dict[str, Any]:
"""Get final provider state with relationships.
Args:
prowler_provider_id: The Prowler-generated provider ID
Returns:
Provider data dictionary
"""
return await self.api_client.get(
f"/api/v1/providers/{prowler_provider_id}",
)

View File

@@ -1,5 +1,6 @@
"""Shared API client utilities for Prowler App tools."""
import asyncio
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict
@@ -180,6 +181,68 @@ class ProwlerAPIClient(metaclass=SingletonMeta):
"""
return await self._make_request(HTTPMethod.DELETE, path, params=params)
async def poll_task_until_complete(
self,
task_id: str,
timeout: int = 60,
poll_interval: float = 1.0,
) -> dict[str, any]:
"""Poll a task until it reaches a terminal state.
This method polls the task endpoint at regular intervals until the task
completes, fails, or times out. It's designed for async operations like
provider connection tests and deletions that return task IDs.
Args:
task_id: The UUID of the task to poll (UUID object or string)
timeout: Maximum time to wait in seconds (default: 60)
poll_interval: Time between polls in seconds (default: 1.0)
Returns:
The complete task response when terminal state is reached
Raises:
Exception: If task fails, is cancelled, or timeout is exceeded
"""
terminal_states = {"completed", "failed", "cancelled"}
start_time = asyncio.get_event_loop().time()
max_time = start_time + timeout
logger.info(
f"Polling task {task_id} (timeout: {timeout}s, interval: {poll_interval}s)"
)
while True:
# Check if we've exceeded the timeout
current_time = asyncio.get_event_loop().time()
if current_time >= max_time:
raise Exception(
f"Task {task_id} polling timed out after {timeout} seconds. "
f"The task may still be running. Try increasing the timeout or check task status manually."
)
# Fetch current task state
response = await self.get(f"/api/v1/tasks/{task_id}")
task_data = response.get("data", {})
task_attrs = task_data.get("attributes", {})
state = task_attrs.get("state")
logger.debug(f"Task {task_id} state: {state}")
# Check if we've reached a terminal state
if state in terminal_states:
if state == "completed":
logger.info(f"Task {task_id} completed successfully")
return response
elif state == "failed":
error_msg = task_attrs.get("error", "Unknown error")
raise Exception(f"Task {task_id} failed: {error_msg}")
elif state == "cancelled":
raise Exception(f"Task {task_id} was cancelled")
# Wait before next poll
await asyncio.sleep(poll_interval)
def _validate_date_format(self, date_str: str, param_name: str) -> datetime:
"""Validate date string format.