diff --git a/prowler/CHANGELOG.md b/prowler/CHANGELOG.md index 92ac4df2f2..255357511c 100644 --- a/prowler/CHANGELOG.md +++ b/prowler/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to the **Prowler SDK** are documented in this file. ### Added - Add Prowler ThreatScore for the Alibaba Cloud provider [(#9511)](https://github.com/prowler-cloud/prowler/pull/9511) - `compute_instance_group_multiple_zones` check for GCP provider [(#9566)](https://github.com/prowler-cloud/prowler/pull/9566) +- Bedrock service pagination [(#9606)](https://github.com/prowler-cloud/prowler/pull/9606) ### Changed - Update AWS Step Functions service metadata to new format [(#9432)](https://github.com/prowler-cloud/prowler/pull/9432) diff --git a/prowler/providers/aws/services/bedrock/bedrock_service.py b/prowler/providers/aws/services/bedrock/bedrock_service.py index c00fc61ac0..c0e3c6717a 100644 --- a/prowler/providers/aws/services/bedrock/bedrock_service.py +++ b/prowler/providers/aws/services/bedrock/bedrock_service.py @@ -55,16 +55,18 @@ class Bedrock(AWSService): def _list_guardrails(self, regional_client): logger.info("Bedrock - Listing Guardrails...") try: - for guardrail in regional_client.list_guardrails().get("guardrails", []): - if not self.audit_resources or ( - is_resource_filtered(guardrail["arn"], self.audit_resources) - ): - self.guardrails[guardrail["arn"]] = Guardrail( - id=guardrail["id"], - name=guardrail["name"], - arn=guardrail["arn"], - region=regional_client.region, - ) + paginator = regional_client.get_paginator("list_guardrails") + for page in paginator.paginate(): + for guardrail in page.get("guardrails", []): + if not self.audit_resources or ( + is_resource_filtered(guardrail["arn"], self.audit_resources) + ): + self.guardrails[guardrail["arn"]] = Guardrail( + id=guardrail["id"], + name=guardrail["name"], + arn=guardrail["arn"], + region=regional_client.region, + ) except Exception as error: logger.error( f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" @@ -130,20 +132,22 @@ class BedrockAgent(AWSService): def _list_agents(self, regional_client): logger.info("Bedrock Agent - Listing Agents...") try: - for agent in regional_client.list_agents().get("agentSummaries", []): - agent_arn = f"arn:aws:bedrock:{regional_client.region}:{self.audited_account}:agent/{agent['agentId']}" - if not self.audit_resources or ( - is_resource_filtered(agent_arn, self.audit_resources) - ): - self.agents[agent_arn] = Agent( - id=agent["agentId"], - name=agent["agentName"], - arn=agent_arn, - guardrail_id=agent.get("guardrailConfiguration", {}).get( - "guardrailIdentifier" - ), - region=regional_client.region, - ) + paginator = regional_client.get_paginator("list_agents") + for page in paginator.paginate(): + for agent in page.get("agentSummaries", []): + agent_arn = f"arn:aws:bedrock:{regional_client.region}:{self.audited_account}:agent/{agent['agentId']}" + if not self.audit_resources or ( + is_resource_filtered(agent_arn, self.audit_resources) + ): + self.agents[agent_arn] = Agent( + id=agent["agentId"], + name=agent["agentName"], + arn=agent_arn, + guardrail_id=agent.get("guardrailConfiguration", {}).get( + "guardrailIdentifier" + ), + region=regional_client.region, + ) except Exception as error: logger.error( f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" diff --git a/tests/providers/aws/services/bedrock/bedrock_service_test.py b/tests/providers/aws/services/bedrock/bedrock_service_test.py index 95901f20f2..ed39ac865b 100644 --- a/tests/providers/aws/services/bedrock/bedrock_service_test.py +++ b/tests/providers/aws/services/bedrock/bedrock_service_test.py @@ -1,4 +1,5 @@ from unittest import mock +from unittest.mock import MagicMock import botocore from boto3 import client @@ -215,3 +216,128 @@ class Test_Bedrock_Agent_Service: "Key": "test-tag-key", } ] + + +class TestBedrockPagination: + """Test suite for Bedrock Guardrail pagination logic.""" + + def test_list_guardrails_pagination(self): + """Test that list_guardrails iterates through all pages.""" + # Mock the audit_info + audit_info = MagicMock() + audit_info.audited_partition = "aws" + audit_info.audited_account = "123456789012" + audit_info.audit_resources = None + + # Mock the regional client + regional_client = MagicMock() + regional_client.region = "us-east-1" + + # Mock paginator + paginator = MagicMock() + page1 = { + "guardrails": [ + { + "id": "g-1", + "name": "guardrail-1", + "arn": "arn:aws:bedrock:us-east-1:123456789012:guardrail/g-1", + } + ] + } + page2 = { + "guardrails": [ + { + "id": "g-2", + "name": "guardrail-2", + "arn": "arn:aws:bedrock:us-east-1:123456789012:guardrail/g-2", + } + ] + } + paginator.paginate.return_value = [page1, page2] + regional_client.get_paginator.return_value = paginator + + # Initialize service and inject mock client + bedrock_service = Bedrock(audit_info) + bedrock_service.regional_clients = {"us-east-1": regional_client} + bedrock_service.guardrails = {} # Clear any init side effects + + # Run the method under test + bedrock_service._list_guardrails(regional_client) + + # Assertions + assert len(bedrock_service.guardrails) == 2 + assert ( + "arn:aws:bedrock:us-east-1:123456789012:guardrail/g-1" + in bedrock_service.guardrails + ) + assert ( + "arn:aws:bedrock:us-east-1:123456789012:guardrail/g-2" + in bedrock_service.guardrails + ) + + # Verify paginator was used + regional_client.get_paginator.assert_called_once_with("list_guardrails") + paginator.paginate.assert_called_once() + + +class TestBedrockAgentPagination: + """Test suite for Bedrock Agent pagination logic.""" + + def test_list_agents_pagination(self): + """Test that list_agents iterates through all pages.""" + # Mock the audit_info + audit_info = MagicMock() + audit_info.audited_partition = "aws" + audit_info.audited_account = "123456789012" + audit_info.audit_resources = None + + # Mock the regional client + regional_client = MagicMock() + regional_client.region = "us-east-1" + + # Mock paginator + paginator = MagicMock() + page1 = { + "agentSummaries": [ + { + "agentId": "agent-1", + "agentName": "agent-name-1", + "agentStatus": "PREPARED", + } + ] + } + page2 = { + "agentSummaries": [ + { + "agentId": "agent-2", + "agentName": "agent-name-2", + "agentStatus": "PREPARED", + } + ] + } + paginator.paginate.return_value = [page1, page2] + regional_client.get_paginator.return_value = paginator + + # Initialize service and inject mock client + bedrock_agent_service = BedrockAgent(audit_info) + bedrock_agent_service.regional_clients = {"us-east-1": regional_client} + bedrock_agent_service.agents = {} # Clear init side effects + bedrock_agent_service.audited_account = "123456789012" + + # Run method + bedrock_agent_service._list_agents(regional_client) + + # Assertions + assert len(bedrock_agent_service.agents) == 2 + assert ( + "arn:aws:bedrock:us-east-1:123456789012:agent/agent-1" + in bedrock_agent_service.agents + ) + assert ( + "arn:aws:bedrock:us-east-1:123456789012:agent/agent-2" + in bedrock_agent_service.agents + ) + + # Verify paginator was used + regional_client.get_paginator.assert_called_once_with("list_agents") + paginator.paginate.assert_called_once()