From f8a69b4b7909812f8de0d7a4fc4f8e9dafadebb4 Mon Sep 17 00:00:00 2001 From: Quan HL Date: Mon, 9 Jun 2025 15:01:25 +0700 Subject: [PATCH] Support Aws S2S --- lib/tasks/llm/index.js | 5 + lib/tasks/llm/llms/aws_s2s.js | 362 ++++++++++++++++++++++++++++++++++ lib/utils/constants.json | 7 + 3 files changed, 374 insertions(+) create mode 100644 lib/tasks/llm/llms/aws_s2s.js diff --git a/lib/tasks/llm/index.js b/lib/tasks/llm/index.js index 60dcf12c..947144b4 100644 --- a/lib/tasks/llm/index.js +++ b/lib/tasks/llm/index.js @@ -6,6 +6,7 @@ const TaskLlmUltravox_S2S = require('./llms/ultravox_s2s'); const TaskLlmElevenlabs_S2S = require('./llms/elevenlabs_s2s'); const TaskLlmGoogle_S2S = require('./llms/google_s2s'); const LlmMcpService = require('../../utils/llm-mcp'); +const TaskLlmAws_S2S = require('./llms/aws_s2s'); class TaskLlm extends Task { constructor(logger, opts) { @@ -85,6 +86,10 @@ class TaskLlm extends Task { llm = new TaskLlmGoogle_S2S(this.logger, this.data, this); break; + case 'aws': + llm = new TaskLlmAws_S2S(this.logger, this.data, this); + break; + default: throw new Error(`Unsupported vendor ${this.vendor} for LLM`); } diff --git a/lib/tasks/llm/llms/aws_s2s.js b/lib/tasks/llm/llms/aws_s2s.js new file mode 100644 index 00000000..6d039617 --- /dev/null +++ b/lib/tasks/llm/llms/aws_s2s.js @@ -0,0 +1,362 @@ +const Task = require('../../task'); +const TaskName = 'Llm_Aws_s2s'; +const {LlmEvents_Aws} = require('../../../utils/constants'); +const ClientEvent = 'client.event'; +const SessionDelete = 'session.delete'; + +const aws_server_events = [ + 'error', + 'session.created', + 'session.updated', +]; + +const expandWildcards = (events) => { + const expandedEvents = []; + + events.forEach((evt) => { + if (evt.endsWith('.*')) { + const prefix = evt.slice(0, -2); // Remove the wildcard ".*" + const matchingEvents = aws_server_events.filter((e) => e.startsWith(prefix)); + expandedEvents.push(...matchingEvents); + } else { + expandedEvents.push(evt); + } + }); + + return expandedEvents; +}; + +class TaskLlmAws_S2S extends Task { + constructor(logger, opts, parentTask) { + super(logger, opts, parentTask); + this.parent = parentTask; + + this.vendor = this.parent.vendor; + this.model = this.parent.model; + this.auth = this.parent.auth; + this.connectionOptions = this.parent.connectOptions; + + const {access_key_id, secret_access_key, aws_region} = this.auth || {}; + if (!access_key_id) throw new Error('auth.access_key_id is required for Aws S2S'); + if (!secret_access_key) throw new Error('auth.secret_access_key is required for Aws S2S'); + if (!aws_region) throw new Error('auth.aws_region is required for Aws S2S'); + + this.access_key_id = access_key_id; + this.secret_access_key = secret_access_key; + this.aws_region = aws_region; + this.actionHook = this.data.actionHook; + this.eventHook = this.data.eventHook; + this.toolHook = this.data.toolHook; + const {inferenceConfiguration, toolConfiguration} = this.data.llmOptions; + + this.inferenceConfiguration = inferenceConfiguration; + this.toolConfiguration = toolConfiguration; + + this.results = { + completionReason: 'normal conversation end' + }; + + /** + * only one of these will have items, + * if includeEvents, then these are the events to include + * if excludeEvents, then these are the events to exclude + */ + this.includeEvents = []; + this.excludeEvents = []; + + /* default to all events if user did not specify */ + this._populateEvents(this.data.events || aws_server_events); + + this.addCustomEventListener = parentTask.addCustomEventListener.bind(parentTask); + this.removeCustomEventListeners = parentTask.removeCustomEventListeners.bind(parentTask); + } + + get name() { return TaskName; } + + get host() { + const {host} = this.connectionOptions || {}; + return host || (this.vendor === 'aws' ? 'api.aws.com' : void 0); + } + + get path() { + const {path} = this.connectionOptions || {}; + if (path) return path; + + switch (this.vendor) { + case 'aws': + return `v1/realtime?model=${this.model}`; + case 'microsoft': + return `aws/realtime?api-version=2024-10-01-preview&deployment=${this.model}`; + } + } + + async _api(ep, args) { + const res = await ep.api('uuid_aws_s2s', `^^|${args.join('|')}`); + if (!res.body?.startsWith('+OK')) { + throw new Error({args}, `Error calling uuid_aws_s2s: ${res.body}`); + } + } + + async exec(cs, {ep}) { + await super.exec(cs); + + await this._startListening(cs, ep); + + await this.awaitTaskDone(); + + /* note: the parent llm verb started the span, which is why this is necessary */ + await this.parent.performAction(this.results); + + this._unregisterHandlers(); + } + + async kill(cs) { + super.kill(cs); + + this._api(cs.ep, [cs.ep.uuid, SessionDelete]) + .catch((err) => this.logger.info({err}, 'TaskLlmAws_S2S:kill - error deleting session')); + + this.notifyTaskDone(); + } + + /** + * Send function call output to the Aws server in the form of conversation.item.create + * per https://platform.aws.com/docs/guides/realtime/function-calls + */ + async processToolOutput(ep, tool_call_id, data) { + try { + this.logger.debug({tool_call_id, data}, 'TaskLlmAws_S2S:processToolOutput'); + + if (!data.type || data.type !== 'conversation.item.create') { + this.logger.info({data}, + 'TaskLlmAws_S2S:processToolOutput - invalid tool output, must be conversation.item.create'); + } + else { + await this._api(ep, [ep.uuid, ClientEvent, JSON.stringify(data)]); + + // spec also recommends to send immediate response.create + await this._api(ep, [ep.uuid, ClientEvent, JSON.stringify({type: 'response.create'})]); + } + } catch (err) { + this.logger.info({err}, 'TaskLlmAws_S2S:processToolOutput'); + } + } + + /** + * Send a session.update to the Aws server + * Note: creating and deleting conversation items also supported as well as interrupting the assistant + */ + async processLlmUpdate(ep, data, _callSid) { + try { + this.logger.debug({data, _callSid}, 'TaskLlmAws_S2S:processLlmUpdate'); + + if (!data.type || ![ + 'session.update', + 'conversation.item.create', + 'conversation.item.delete', + 'response.cancel' + ].includes(data.type)) { + this.logger.info({data}, 'TaskLlmAws_S2S:processLlmUpdate - invalid mid-call request'); + } + else { + await this._api(ep, [ep.uuid, ClientEvent, JSON.stringify(data)]); + } + } catch (err) { + this.logger.info({err}, 'TaskLlmAws_S2S:processLlmUpdate'); + } + } + + async _startListening(cs, ep) { + this._registerHandlers(ep); + + try { + const args = [ep.uuid, 'session.create', this.host, this.path, this.authType, this.apiKey]; + await this._api(ep, args); + } catch (err) { + this.logger.error({err}, 'TaskLlmAws_S2S:_startListening'); + this.notifyTaskDone(); + } + } + + async _sendClientEvent(ep, obj) { + let ok = true; + this.logger.debug({obj}, 'TaskLlmAws_S2S:_sendClientEvent'); + try { + const args = [ep.uuid, ClientEvent, JSON.stringify(obj)]; + await this._api(ep, args); + } catch (err) { + ok = false; + this.logger.error({err}, 'TaskLlmAws_S2S:_sendClientEvent - Error'); + } + return ok; + } + + async _sendInitialMessage(ep) { + let obj = {type: 'response.create', response: this.response_create}; + if (!await this._sendClientEvent(ep, obj)) { + this.notifyTaskDone(); + } + + /* send immediate session.update if present */ + else if (this.session_update) { + if (this.parent.isMcpEnabled) { + this.logger.debug('TaskLlmAws_S2S:_sendInitialMessage - mcp enabled'); + const tools = await this.parent.mcpService.getAvailableMcpTools(); + if (tools && tools.length > 0 && this.session_update) { + const convertedTools = tools.map((tool) => ({ + name: tool.name, + type: 'function', + description: tool.description, + parameters: tool.inputSchema + })); + + this.session_update.tools = [ + ...convertedTools, + ...(this.session_update.tools || []) + ]; + } + } + obj = {type: 'session.update', session: this.session_update}; + this.logger.debug({obj}, 'TaskLlmAws_S2S:_sendInitialMessage - sending session.update'); + if (!await this._sendClientEvent(ep, obj)) { + this.notifyTaskDone(); + } + } + } + + _registerHandlers(ep) { + this.addCustomEventListener(ep, LlmEvents_Aws.Connect, this._onConnect.bind(this, ep)); + this.addCustomEventListener(ep, LlmEvents_Aws.ConnectFailure, this._onConnectFailure.bind(this, ep)); + this.addCustomEventListener(ep, LlmEvents_Aws.Disconnect, this._onDisconnect.bind(this, ep)); + this.addCustomEventListener(ep, LlmEvents_Aws.ServerEvent, this._onServerEvent.bind(this, ep)); + } + + _unregisterHandlers() { + this.removeCustomEventListeners(); + } + + _onError(ep, evt) { + this.logger.info({evt}, 'TaskLlmAws_S2S:_onError'); + this.notifyTaskDone(); + } + + _onConnect(ep) { + this.logger.debug('TaskLlmAws_S2S:_onConnect'); + this._sendInitialMessage(ep); + } + _onConnectFailure(_ep, evt) { + this.logger.info(evt, 'TaskLlmAws_S2S:_onConnectFailure'); + this.results = {completionReason: 'connection failure'}; + this.notifyTaskDone(); + } + _onDisconnect(_ep, evt) { + this.logger.info(evt, 'TaskLlmAws_S2S:_onConnectFailure'); + this.results = {completionReason: 'disconnect from remote end'}; + this.notifyTaskDone(); + } + async _onServerEvent(ep, evt) { + let endConversation = false; + const type = evt.type; + this.logger.info({evt}, 'TaskLlmAws_S2S:_onServerEvent'); + + /* check for failures, such as rate limit exceeded, that should terminate the conversation */ + if (type === 'response.done' && evt.response.status === 'failed') { + endConversation = true; + this.results = { + completionReason: 'server failure', + error: evt.response.status_details?.error + }; + } + + /* server errors of some sort */ + else if (type === 'error') { + endConversation = true; + this.results = { + completionReason: 'server error', + error: evt.error + }; + } + + /* tool calls */ + else if (type === 'response.output_item.done' && evt.item?.type === 'function_call') { + this.logger.debug({evt}, 'TaskLlmAws_S2S:_onServerEvent - function_call'); + const {name, call_id} = evt.item; + const args = JSON.parse(evt.item.arguments); + + const mcpTools = this.parent.isMcpEnabled ? await this.parent.mcpService.getAvailableMcpTools() : []; + if (mcpTools.some((tool) => tool.name === name)) { + this.logger.debug({call_id, name, args}, 'TaskLlmAws_S2S:_onServerEvent - calling mcp tool'); + try { + const res = await this.parent.mcpService.callMcpTool(name, args); + this.logger.debug({res}, 'TaskLlmAws_S2S:_onServerEvent - function_call - mcp result'); + this.processToolOutput(ep, call_id, { + type: 'conversation.item.create', + item: { + type: 'function_call_output', + call_id, + output: res.content[0]?.text || 'There is no output from the function call', + } + }); + return; + } catch (err) { + this.logger.info({err, evt}, 'TaskLlmAws_S2S - error calling function'); + this.results = { + completionReason: 'client error calling mcp function', + error: err + }; + endConversation = true; + } + } + else if (!this.toolHook) { + this.logger.warn({evt}, 'TaskLlmAws_S2S:_onServerEvent - no toolHook defined!'); + } + else { + try { + await this.parent.sendToolHook(call_id, {name, args}); + } catch (err) { + this.logger.info({err, evt}, 'TaskLlmAws - error calling function'); + this.results = { + completionReason: 'client error calling function', + error: err + }; + endConversation = true; + } + } + } + + /* check whether we should notify on this event */ + if (this.includeEvents.length > 0 ? this.includeEvents.includes(type) : !this.excludeEvents.includes(type)) { + this.parent.sendEventHook(evt) + .catch((err) => this.logger.info({err}, 'TaskLlmAws_S2S:_onServerEvent - error sending event hook')); + } + + if (endConversation) { + this.logger.info({results: this.results}, 'TaskLlmAws_S2S:_onServerEvent - ending conversation due to error'); + this.notifyTaskDone(); + } + } + + _populateEvents(events) { + if (events.includes('all')) { + /* work by excluding specific events */ + const exclude = events + .filter((evt) => evt.startsWith('-')) + .map((evt) => evt.slice(1)); + if (exclude.length === 0) this.includeEvents = aws_server_events; + else this.excludeEvents = expandWildcards(exclude); + } + else { + /* work by including specific events */ + const include = events + .filter((evt) => !evt.startsWith('-')); + this.includeEvents = expandWildcards(include); + } + + this.logger.debug({ + includeEvents: this.includeEvents, + excludeEvents: this.excludeEvents + }, 'TaskLlmAws_S2S:_populateEvents'); + } +} + +module.exports = TaskLlmAws_S2S; diff --git a/lib/utils/constants.json b/lib/utils/constants.json index 88b89282..4949c70f 100644 --- a/lib/utils/constants.json +++ b/lib/utils/constants.json @@ -194,6 +194,13 @@ "Disconnect": "openai_s2s::disconnect", "ServerEvent": "openai_s2s::server_event" }, + "LlmEvents_Aws": { + "Error": "error", + "Connect": "aws_s2s::connect", + "ConnectFailure": "aws_s2s::connect_failed", + "Disconnect": "aws_s2s::disconnect", + "ServerEvent": "aws_s2s::server_event" + }, "LlmEvents_Google": { "Error": "error", "Connect": "google_s2s::connect",