From 231bfd6f41b79f1ec6a59fb8c9302176124f42a0 Mon Sep 17 00:00:00 2001 From: Alejandro Bailo <59607668+alejandrobailo@users.noreply.github.com> Date: Wed, 25 Feb 2026 12:56:26 +0100 Subject: [PATCH] feat(ui): add organization server actions and scan launching (#10155) --- .../organizations.adapter.test.ts | 177 ++++++++++ .../organizations/organizations.test.ts | 140 ++++++++ ui/actions/organizations/organizations.ts | 326 ++++++++++++++++++ ui/actions/scans/scans.test.ts | 71 ++++ ui/actions/scans/scans.ts | 70 ++++ ui/lib/concurrency.test.ts | 40 +++ ui/lib/concurrency.ts | 38 ++ 7 files changed, 862 insertions(+) create mode 100644 ui/actions/organizations/organizations.adapter.test.ts create mode 100644 ui/actions/organizations/organizations.test.ts create mode 100644 ui/actions/organizations/organizations.ts create mode 100644 ui/actions/scans/scans.test.ts create mode 100644 ui/lib/concurrency.test.ts create mode 100644 ui/lib/concurrency.ts diff --git a/ui/actions/organizations/organizations.adapter.test.ts b/ui/actions/organizations/organizations.adapter.test.ts new file mode 100644 index 0000000000..07cff60c0c --- /dev/null +++ b/ui/actions/organizations/organizations.adapter.test.ts @@ -0,0 +1,177 @@ +import { describe, expect, it } from "vitest"; + +import { + APPLY_STATUS, + ApplyStatus, + DiscoveryResult, +} from "@/types/organizations"; + +import { + buildAccountLookup, + buildOrgTreeData, + getOuIdsForSelectedAccounts, + getSelectableAccountIds, +} from "./organizations.adapter"; + +const discoveryFixture: DiscoveryResult = { + roots: [ + { + id: "r-root", + arn: "arn:aws:organizations::123:root/o-example/r-root", + name: "Root", + policy_types: [], + }, + ], + organizational_units: [ + { + id: "ou-parent", + name: "Parent OU", + arn: "arn:aws:organizations::123:ou/o-example/ou-parent", + parent_id: "r-root", + }, + { + id: "ou-child", + name: "Child OU", + arn: "arn:aws:organizations::123:ou/o-example/ou-child", + parent_id: "ou-parent", + }, + ], + accounts: [ + { + id: "111111111111", + arn: "arn:aws:organizations::123:account/o-example/111111111111", + name: "App Account", + email: "app@example.com", + status: "ACTIVE", + joined_method: "CREATED", + joined_timestamp: "2024-01-01T00:00:00Z", + parent_id: "ou-child", + registration: { + provider_exists: false, + provider_id: null, + organization_relation: "link_required", + organizational_unit_relation: "link_required", + provider_secret_state: "will_create", + apply_status: APPLY_STATUS.READY, + blocked_reasons: [], + }, + }, + { + id: "222222222222", + arn: "arn:aws:organizations::123:account/o-example/222222222222", + name: "Security Account", + email: "security@example.com", + status: "ACTIVE", + joined_method: "CREATED", + joined_timestamp: "2024-01-01T00:00:00Z", + parent_id: "ou-parent", + registration: { + provider_exists: false, + provider_id: null, + organization_relation: "link_required", + organizational_unit_relation: "link_required", + provider_secret_state: "manual_required", + apply_status: APPLY_STATUS.BLOCKED, + blocked_reasons: ["role_missing"], + }, + }, + { + id: "333333333333", + arn: "arn:aws:organizations::123:account/o-example/333333333333", + name: "Legacy Account", + email: "legacy@example.com", + status: "ACTIVE", + joined_method: "INVITED", + joined_timestamp: "2024-01-01T00:00:00Z", + parent_id: "r-root", + }, + ], +}; + +describe("buildOrgTreeData", () => { + it("builds nested tree structure and marks blocked accounts as disabled", () => { + // Given / When + const treeData = buildOrgTreeData(discoveryFixture); + + // Then + expect(treeData).toHaveLength(2); + expect(treeData.map((node) => node.id)).toEqual( + expect.arrayContaining(["ou-parent", "333333333333"]), + ); + + const parentOuNode = treeData.find((node) => node.id === "ou-parent"); + expect(parentOuNode).toBeDefined(); + expect(parentOuNode?.children?.map((node) => node.id)).toEqual( + expect.arrayContaining(["ou-child", "222222222222"]), + ); + + const blockedAccount = parentOuNode?.children?.find( + (node) => node.id === "222222222222", + ); + expect(blockedAccount?.disabled).toBe(true); + }); +}); + +describe("getSelectableAccountIds", () => { + it("returns all accounts except explicitly blocked ones", () => { + const selectableIds = getSelectableAccountIds(discoveryFixture); + + expect(selectableIds).toEqual(["111111111111", "333333333333"]); + }); + + it("excludes accounts with explicit non-ready status values", () => { + const discoveryWithUnexpectedStatus = { + ...discoveryFixture, + accounts: [ + ...discoveryFixture.accounts, + { + id: "444444444444", + arn: "arn:aws:organizations::123:account/o-example/444444444444", + name: "Pending Account", + email: "pending@example.com", + status: "ACTIVE", + joined_method: "CREATED", + joined_timestamp: "2024-01-01T00:00:00Z", + parent_id: "r-root", + registration: { + provider_exists: false, + provider_id: null, + organization_relation: "link_required", + organizational_unit_relation: "link_required", + provider_secret_state: "will_create", + apply_status: "pending" as unknown as ApplyStatus, + blocked_reasons: [], + }, + }, + ], + } satisfies DiscoveryResult; + + const selectableIds = getSelectableAccountIds( + discoveryWithUnexpectedStatus, + ); + + expect(selectableIds).toEqual(["111111111111", "333333333333"]); + }); +}); + +describe("buildAccountLookup", () => { + it("creates a lookup map for all discovered accounts", () => { + const lookup = buildAccountLookup(discoveryFixture); + + expect(lookup.get("111111111111")?.name).toBe("App Account"); + expect(lookup.get("333333333333")?.name).toBe("Legacy Account"); + expect(lookup.size).toBe(3); + }); +}); + +describe("getOuIdsForSelectedAccounts", () => { + it("collects all ancestor OUs for selected accounts without duplicates", () => { + const ouIds = getOuIdsForSelectedAccounts(discoveryFixture, [ + "111111111111", + "222222222222", + ]); + + expect(ouIds).toEqual(expect.arrayContaining(["ou-parent", "ou-child"])); + expect(ouIds.length).toBe(2); + }); +}); diff --git a/ui/actions/organizations/organizations.test.ts b/ui/actions/organizations/organizations.test.ts new file mode 100644 index 0000000000..52a4b82eb3 --- /dev/null +++ b/ui/actions/organizations/organizations.test.ts @@ -0,0 +1,140 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const { + fetchMock, + getAuthHeadersMock, + handleApiErrorMock, + handleApiResponseMock, + revalidatePathMock, +} = vi.hoisted(() => ({ + fetchMock: vi.fn(), + getAuthHeadersMock: vi.fn(), + handleApiErrorMock: vi.fn(), + handleApiResponseMock: vi.fn(), + revalidatePathMock: vi.fn(), +})); + +vi.mock("next/cache", () => ({ + revalidatePath: revalidatePathMock, +})); + +vi.mock("@/lib", () => ({ + apiBaseUrl: "https://api.example.com/api/v1", + getAuthHeaders: getAuthHeadersMock, +})); + +vi.mock("@/lib/server-actions-helper", () => ({ + handleApiError: handleApiErrorMock, + handleApiResponse: handleApiResponseMock, +})); + +import { + applyDiscovery, + getDiscovery, + triggerDiscovery, + updateOrganizationSecret, +} from "./organizations"; + +describe("organizations actions", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.stubGlobal("fetch", fetchMock); + getAuthHeadersMock.mockResolvedValue({ Authorization: "Bearer token" }); + handleApiErrorMock.mockReturnValue({ error: "Unexpected error" }); + }); + + it("rejects invalid organization secret identifiers", async () => { + // Given + const formData = new FormData(); + formData.set("organizationSecretId", "../secret-id"); + formData.set("roleArn", "arn:aws:iam::123456789012:role/ProwlerOrgRole"); + formData.set("externalId", "o-abc123def4"); + + // When + const result = await updateOrganizationSecret(formData); + + // Then + expect(result).toEqual({ error: "Invalid organization secret ID" }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("rejects invalid discovery identifiers before building the request URL", async () => { + // When + const result = await getDiscovery( + "123e4567-e89b-12d3-a456-426614174000", + "discovery/../id", + ); + + // Then + expect(result).toEqual({ error: "Invalid discovery ID" }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("rejects invalid organization identifiers before triggering discovery", async () => { + // When + const result = await triggerDiscovery("org/id-with-slash"); + + // Then + expect(result).toEqual({ error: "Invalid organization ID" }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("revalidates providers only when apply discovery succeeds", async () => { + // Given + fetchMock.mockResolvedValue( + new Response(JSON.stringify({ data: { id: "apply-1" } }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }), + ); + handleApiResponseMock.mockResolvedValueOnce({ error: "Apply failed" }); + handleApiResponseMock.mockResolvedValueOnce({ data: { id: "apply-1" } }); + + // When + const failedResult = await applyDiscovery( + "123e4567-e89b-12d3-a456-426614174000", + "223e4567-e89b-12d3-a456-426614174111", + [], + [], + ); + const successfulResult = await applyDiscovery( + "123e4567-e89b-12d3-a456-426614174000", + "223e4567-e89b-12d3-a456-426614174111", + [], + [], + ); + + // Then + expect(failedResult).toEqual({ error: "Apply failed" }); + expect(successfulResult).toEqual({ data: { id: "apply-1" } }); + expect(revalidatePathMock).toHaveBeenCalledTimes(1); + expect(revalidatePathMock).toHaveBeenCalledWith("/providers"); + }); + + it("revalidates providers when response contains error set to null", async () => { + // Given + fetchMock.mockResolvedValue( + new Response(JSON.stringify({ data: { id: "apply-2" } }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }), + ); + handleApiResponseMock.mockResolvedValueOnce({ + data: { id: "apply-2" }, + error: null, + }); + + // When + const result = await applyDiscovery( + "123e4567-e89b-12d3-a456-426614174000", + "223e4567-e89b-12d3-a456-426614174111", + [], + [], + ); + + // Then + expect(result).toEqual({ data: { id: "apply-2" }, error: null }); + expect(revalidatePathMock).toHaveBeenCalledTimes(1); + expect(revalidatePathMock).toHaveBeenCalledWith("/providers"); + }); +}); diff --git a/ui/actions/organizations/organizations.ts b/ui/actions/organizations/organizations.ts new file mode 100644 index 0000000000..16ddd7593d --- /dev/null +++ b/ui/actions/organizations/organizations.ts @@ -0,0 +1,326 @@ +"use server"; + +import { revalidatePath } from "next/cache"; + +import { apiBaseUrl, getAuthHeaders } from "@/lib"; +import { handleApiError, handleApiResponse } from "@/lib/server-actions-helper"; + +const PATH_IDENTIFIER_PATTERN = /^[A-Za-z0-9_-]+$/; + +type PathIdentifierValidationResult = { value: string } | { error: string }; + +function validatePathIdentifier( + value: string | null | undefined, + requiredError: string, + invalidError: string, +): PathIdentifierValidationResult { + const normalizedValue = value?.trim(); + + if (!normalizedValue) { + return { error: requiredError }; + } + + if (!PATH_IDENTIFIER_PATTERN.test(normalizedValue)) { + return { error: invalidError }; + } + + return { value: normalizedValue }; +} + +function hasActionError(result: unknown): result is { error: unknown } { + return Boolean( + result && + typeof result === "object" && + "error" in (result as Record) && + (result as Record).error !== null && + (result as Record).error !== undefined, + ); +} + +/** + * Creates an AWS Organization resource. + * POST /api/v1/organizations + */ +export const createOrganization = async (formData: FormData) => { + const headers = await getAuthHeaders({ contentType: true }); + const url = new URL(`${apiBaseUrl}/organizations`); + + const name = formData.get("name") as string; + const externalId = formData.get("externalId") as string; + + try { + const response = await fetch(url.toString(), { + method: "POST", + headers, + body: JSON.stringify({ + data: { + type: "organizations", + attributes: { + name, + org_type: "aws", + external_id: externalId, + }, + }, + }), + }); + + return handleApiResponse(response); + } catch (error) { + return handleApiError(error); + } +}; + +/** + * Lists AWS Organizations filtered by external ID. + * GET /api/v1/organizations?filter[external_id]={externalId}&filter[org_type]=aws + */ +export const listOrganizationsByExternalId = async (externalId: string) => { + const headers = await getAuthHeaders({ contentType: false }); + const url = new URL(`${apiBaseUrl}/organizations`); + url.searchParams.set("filter[external_id]", externalId); + url.searchParams.set("filter[org_type]", "aws"); + + try { + const response = await fetch(url.toString(), { headers }); + return handleApiResponse(response); + } catch (error) { + return handleApiError(error); + } +}; + +/** + * Creates an organization secret (role-based credentials). + * POST /api/v1/organization-secrets + */ +export const createOrganizationSecret = async (formData: FormData) => { + const headers = await getAuthHeaders({ contentType: true }); + const url = new URL(`${apiBaseUrl}/organization-secrets`); + + const organizationId = formData.get("organizationId") as string; + const roleArn = formData.get("roleArn") as string; + const externalId = formData.get("externalId") as string; + + try { + const response = await fetch(url.toString(), { + method: "POST", + headers, + body: JSON.stringify({ + data: { + type: "organization-secrets", + attributes: { + secret_type: "role", + secret: { + role_arn: roleArn, + external_id: externalId, + }, + }, + relationships: { + organization: { + data: { + type: "organizations", + id: organizationId, + }, + }, + }, + }, + }), + }); + + return handleApiResponse(response); + } catch (error) { + return handleApiError(error); + } +}; + +/** + * Updates an organization secret (role-based credentials). + * PATCH /api/v1/organization-secrets/{id} + */ +export const updateOrganizationSecret = async (formData: FormData) => { + const headers = await getAuthHeaders({ contentType: true }); + const organizationSecretId = formData.get("organizationSecretId") as + | string + | null; + const roleArn = formData.get("roleArn") as string; + const externalId = formData.get("externalId") as string; + + const organizationSecretIdValidation = validatePathIdentifier( + organizationSecretId, + "Organization secret ID is required", + "Invalid organization secret ID", + ); + if ("error" in organizationSecretIdValidation) { + return organizationSecretIdValidation; + } + + const url = new URL( + `${apiBaseUrl}/organization-secrets/${encodeURIComponent(organizationSecretIdValidation.value)}`, + ); + + try { + const response = await fetch(url.toString(), { + method: "PATCH", + headers, + body: JSON.stringify({ + data: { + type: "organization-secrets", + id: organizationSecretIdValidation.value, + attributes: { + secret_type: "role", + secret: { + role_arn: roleArn, + external_id: externalId, + }, + }, + }, + }), + }); + + return handleApiResponse(response); + } catch (error) { + return handleApiError(error); + } +}; + +/** + * Lists organization secrets for an organization. + * GET /api/v1/organization-secrets?filter[organization_id]={organizationId} + */ +export const listOrganizationSecretsByOrganizationId = async ( + organizationId: string, +) => { + const headers = await getAuthHeaders({ contentType: false }); + const url = new URL(`${apiBaseUrl}/organization-secrets`); + url.searchParams.set("filter[organization_id]", organizationId); + + try { + const response = await fetch(url.toString(), { headers }); + return handleApiResponse(response); + } catch (error) { + return handleApiError(error); + } +}; + +/** + * Triggers an async discovery of the AWS Organization. + * POST /api/v1/organizations/{id}/discover + */ +export const triggerDiscovery = async (organizationId: string) => { + const headers = await getAuthHeaders({ contentType: false }); + const organizationIdValidation = validatePathIdentifier( + organizationId, + "Organization ID is required", + "Invalid organization ID", + ); + if ("error" in organizationIdValidation) { + return organizationIdValidation; + } + const url = new URL( + `${apiBaseUrl}/organizations/${encodeURIComponent(organizationIdValidation.value)}/discover`, + ); + + try { + const response = await fetch(url.toString(), { + method: "POST", + headers, + }); + + return handleApiResponse(response); + } catch (error) { + return handleApiError(error); + } +}; + +/** + * Polls the discovery status. + * GET /api/v1/organizations/{orgId}/discoveries/{discoveryId} + */ +export const getDiscovery = async ( + organizationId: string, + discoveryId: string, +) => { + const headers = await getAuthHeaders({ contentType: false }); + const organizationIdValidation = validatePathIdentifier( + organizationId, + "Organization ID is required", + "Invalid organization ID", + ); + if ("error" in organizationIdValidation) { + return organizationIdValidation; + } + const discoveryIdValidation = validatePathIdentifier( + discoveryId, + "Discovery ID is required", + "Invalid discovery ID", + ); + if ("error" in discoveryIdValidation) { + return discoveryIdValidation; + } + const url = new URL( + `${apiBaseUrl}/organizations/${encodeURIComponent(organizationIdValidation.value)}/discoveries/${encodeURIComponent(discoveryIdValidation.value)}`, + ); + + try { + const response = await fetch(url.toString(), { headers }); + + return handleApiResponse(response); + } catch (error) { + return handleApiError(error); + } +}; + +/** + * Applies discovery results — creates providers, links to org/OUs, auto-generates secrets. + * POST /api/v1/organizations/{orgId}/discoveries/{discoveryId}/apply + */ +export const applyDiscovery = async ( + organizationId: string, + discoveryId: string, + accounts: Array<{ id: string; alias?: string }>, + organizationalUnits: Array<{ id: string }>, +) => { + const headers = await getAuthHeaders({ contentType: true }); + const organizationIdValidation = validatePathIdentifier( + organizationId, + "Organization ID is required", + "Invalid organization ID", + ); + if ("error" in organizationIdValidation) { + return organizationIdValidation; + } + const discoveryIdValidation = validatePathIdentifier( + discoveryId, + "Discovery ID is required", + "Invalid discovery ID", + ); + if ("error" in discoveryIdValidation) { + return discoveryIdValidation; + } + const url = new URL( + `${apiBaseUrl}/organizations/${encodeURIComponent(organizationIdValidation.value)}/discoveries/${encodeURIComponent(discoveryIdValidation.value)}/apply`, + ); + + try { + const response = await fetch(url.toString(), { + method: "POST", + headers, + body: JSON.stringify({ + data: { + type: "organization-discoveries", + attributes: { + accounts, + organizational_units: organizationalUnits, + }, + }, + }), + }); + + const result = await handleApiResponse(response); + if (!hasActionError(result)) { + revalidatePath("/providers"); + } + return result; + } catch (error) { + return handleApiError(error); + } +}; diff --git a/ui/actions/scans/scans.test.ts b/ui/actions/scans/scans.test.ts new file mode 100644 index 0000000000..ae82ba860d --- /dev/null +++ b/ui/actions/scans/scans.test.ts @@ -0,0 +1,71 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const { + fetchMock, + getAuthHeadersMock, + handleApiErrorMock, + handleApiResponseMock, +} = vi.hoisted(() => ({ + fetchMock: vi.fn(), + getAuthHeadersMock: vi.fn(), + handleApiErrorMock: vi.fn(), + handleApiResponseMock: vi.fn(), +})); + +vi.mock("@/lib", () => ({ + apiBaseUrl: "https://api.example.com/api/v1", + getAuthHeaders: getAuthHeadersMock, + getErrorMessage: (error: unknown) => + error instanceof Error ? error.message : String(error), +})); + +vi.mock("@/lib/server-actions-helper", () => ({ + handleApiError: handleApiErrorMock, + handleApiResponse: handleApiResponseMock, +})); + +vi.mock("@/lib/sentry-breadcrumbs", () => ({ + addScanOperation: vi.fn(), +})); + +import { launchOrganizationScans } from "./scans"; + +describe("launchOrganizationScans", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.stubGlobal("fetch", fetchMock); + getAuthHeadersMock.mockResolvedValue({ Authorization: "Bearer token" }); + handleApiResponseMock.mockResolvedValue({ data: { id: "scan-id" } }); + handleApiErrorMock.mockReturnValue({ error: "Scan launch failed." }); + }); + + it("limits concurrent launch requests to avoid overwhelming the backend", async () => { + // Given + const providerIds = Array.from( + { length: 12 }, + (_, index) => `provider-${index + 1}`, + ); + let activeRequests = 0; + let maxActiveRequests = 0; + + fetchMock.mockImplementation(async () => { + activeRequests += 1; + maxActiveRequests = Math.max(maxActiveRequests, activeRequests); + await new Promise((resolve) => setTimeout(resolve, 5)); + activeRequests -= 1; + + return new Response(JSON.stringify({ data: { id: "scan-id" } }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + }); + + // When + const result = await launchOrganizationScans(providerIds, "daily"); + + // Then + expect(maxActiveRequests).toBeLessThanOrEqual(5); + expect(result.successCount).toBe(providerIds.length); + expect(result.failureCount).toBe(0); + }); +}); diff --git a/ui/actions/scans/scans.ts b/ui/actions/scans/scans.ts index 52da71bb3b..8716c5c1c6 100644 --- a/ui/actions/scans/scans.ts +++ b/ui/actions/scans/scans.ts @@ -7,12 +7,15 @@ import { COMPLIANCE_REPORT_DISPLAY_NAMES, type ComplianceReportType, } from "@/lib/compliance/compliance-report-types"; +import { runWithConcurrencyLimit } from "@/lib/concurrency"; import { appendSanitizedProviderTypeFilters, sanitizeProviderTypesCsv, } from "@/lib/provider-filters"; import { addScanOperation } from "@/lib/sentry-breadcrumbs"; import { handleApiError, handleApiResponse } from "@/lib/server-actions-helper"; + +const ORGANIZATION_SCAN_CONCURRENCY_LIMIT = 5; export const getScans = async ({ page = 1, query = "", @@ -165,6 +168,73 @@ export const scheduleDaily = async (formData: FormData) => { } }; +export const launchOrganizationScans = async ( + providerIds: string[], + scheduleOption: "daily" | "single", +) => { + const validProviderIds = providerIds.filter(Boolean); + if (validProviderIds.length === 0) { + return { + successCount: 0, + failureCount: 0, + totalCount: 0, + }; + } + + const launchResults = await runWithConcurrencyLimit( + validProviderIds, + ORGANIZATION_SCAN_CONCURRENCY_LIMIT, + async (providerId) => { + try { + const formData = new FormData(); + formData.set("providerId", providerId); + + const result = + scheduleOption === "daily" + ? await scheduleDaily(formData) + : await scanOnDemand(formData); + + return { + providerId, + ok: !result?.error, + error: result?.error ? String(result.error) : null, + }; + } catch (error) { + return { + providerId, + ok: false, + error: + error instanceof Error ? error.message : "Failed to launch scan.", + }; + } + }, + ); + + const summary = launchResults.reduce( + (acc, item) => { + if (item.ok) { + acc.successCount += 1; + return acc; + } + + acc.failureCount += 1; + acc.errors.push({ + providerId: item.providerId, + error: item.error || "Failed to launch scan.", + }); + return acc; + }, + { + successCount: 0, + failureCount: 0, + totalCount: validProviderIds.length, + errors: [] as Array<{ providerId: string; error: string }>, + }, + ); + + return summary; +}; + export const updateScan = async (formData: FormData) => { const headers = await getAuthHeaders({ contentType: true }); diff --git a/ui/lib/concurrency.test.ts b/ui/lib/concurrency.test.ts new file mode 100644 index 0000000000..45f39ea772 --- /dev/null +++ b/ui/lib/concurrency.test.ts @@ -0,0 +1,40 @@ +import { describe, expect, it } from "vitest"; + +import { runWithConcurrencyLimit } from "./concurrency"; + +describe("runWithConcurrencyLimit", () => { + it("should process items without exceeding the configured concurrency", async () => { + // Given + const items = Array.from({ length: 12 }, (_, index) => index + 1); + let activeTasks = 0; + let maxActiveTasks = 0; + + // When + const results = await runWithConcurrencyLimit(items, 4, async (item) => { + activeTasks += 1; + maxActiveTasks = Math.max(maxActiveTasks, activeTasks); + await new Promise((resolve) => setTimeout(resolve, 5)); + activeTasks -= 1; + return item * 2; + }); + + // Then + expect(maxActiveTasks).toBeLessThanOrEqual(4); + expect(results).toEqual(items.map((item) => item * 2)); + }); + + it("should reject when worker throws an uncaught error", async () => { + // Given + const items = [1, 2, 3]; + + // When / Then + await expect( + runWithConcurrencyLimit(items, 2, async (item) => { + if (item === 2) { + throw new Error("boom"); + } + return item; + }), + ).rejects.toThrow("boom"); + }); +}); diff --git a/ui/lib/concurrency.ts b/ui/lib/concurrency.ts new file mode 100644 index 0000000000..035ce2d063 --- /dev/null +++ b/ui/lib/concurrency.ts @@ -0,0 +1,38 @@ +/** + * Runs async work over items with a fixed concurrency limit. + * + * Note: if `worker` throws, this function rejects. Callers should handle + * expected per-item errors inside the worker and return a typed result. + */ +export async function runWithConcurrencyLimit( + items: T[], + concurrencyLimit: number, + worker: (item: T, index: number) => Promise, +): Promise { + if (items.length === 0) { + return []; + } + + const normalizedConcurrency = Math.max(1, Math.floor(concurrencyLimit)); + const results = new Array(items.length); + let currentIndex = 0; + + const runWorker = async () => { + while (currentIndex < items.length) { + const assignedIndex = currentIndex; + currentIndex += 1; + results[assignedIndex] = await worker( + items[assignedIndex], + assignedIndex, + ); + } + }; + + const workers = Array.from( + { length: Math.min(normalizedConcurrency, items.length) }, + () => runWorker(), + ); + + await Promise.all(workers); + return results; +}