Support Aws S2S

This commit is contained in:
Quan HL
2025-06-09 15:01:25 +07:00
parent e686a11808
commit f8a69b4b79
3 changed files with 374 additions and 0 deletions

View File

@@ -6,6 +6,7 @@ const TaskLlmUltravox_S2S = require('./llms/ultravox_s2s');
const TaskLlmElevenlabs_S2S = require('./llms/elevenlabs_s2s'); const TaskLlmElevenlabs_S2S = require('./llms/elevenlabs_s2s');
const TaskLlmGoogle_S2S = require('./llms/google_s2s'); const TaskLlmGoogle_S2S = require('./llms/google_s2s');
const LlmMcpService = require('../../utils/llm-mcp'); const LlmMcpService = require('../../utils/llm-mcp');
const TaskLlmAws_S2S = require('./llms/aws_s2s');
class TaskLlm extends Task { class TaskLlm extends Task {
constructor(logger, opts) { constructor(logger, opts) {
@@ -85,6 +86,10 @@ class TaskLlm extends Task {
llm = new TaskLlmGoogle_S2S(this.logger, this.data, this); llm = new TaskLlmGoogle_S2S(this.logger, this.data, this);
break; break;
case 'aws':
llm = new TaskLlmAws_S2S(this.logger, this.data, this);
break;
default: default:
throw new Error(`Unsupported vendor ${this.vendor} for LLM`); throw new Error(`Unsupported vendor ${this.vendor} for LLM`);
} }

View File

@@ -0,0 +1,362 @@
const Task = require('../../task');
const TaskName = 'Llm_Aws_s2s';
const {LlmEvents_Aws} = require('../../../utils/constants');
const ClientEvent = 'client.event';
const SessionDelete = 'session.delete';
const aws_server_events = [
'error',
'session.created',
'session.updated',
];
const expandWildcards = (events) => {
const expandedEvents = [];
events.forEach((evt) => {
if (evt.endsWith('.*')) {
const prefix = evt.slice(0, -2); // Remove the wildcard ".*"
const matchingEvents = aws_server_events.filter((e) => e.startsWith(prefix));
expandedEvents.push(...matchingEvents);
} else {
expandedEvents.push(evt);
}
});
return expandedEvents;
};
class TaskLlmAws_S2S extends Task {
constructor(logger, opts, parentTask) {
super(logger, opts, parentTask);
this.parent = parentTask;
this.vendor = this.parent.vendor;
this.model = this.parent.model;
this.auth = this.parent.auth;
this.connectionOptions = this.parent.connectOptions;
const {access_key_id, secret_access_key, aws_region} = this.auth || {};
if (!access_key_id) throw new Error('auth.access_key_id is required for Aws S2S');
if (!secret_access_key) throw new Error('auth.secret_access_key is required for Aws S2S');
if (!aws_region) throw new Error('auth.aws_region is required for Aws S2S');
this.access_key_id = access_key_id;
this.secret_access_key = secret_access_key;
this.aws_region = aws_region;
this.actionHook = this.data.actionHook;
this.eventHook = this.data.eventHook;
this.toolHook = this.data.toolHook;
const {inferenceConfiguration, toolConfiguration} = this.data.llmOptions;
this.inferenceConfiguration = inferenceConfiguration;
this.toolConfiguration = toolConfiguration;
this.results = {
completionReason: 'normal conversation end'
};
/**
* only one of these will have items,
* if includeEvents, then these are the events to include
* if excludeEvents, then these are the events to exclude
*/
this.includeEvents = [];
this.excludeEvents = [];
/* default to all events if user did not specify */
this._populateEvents(this.data.events || aws_server_events);
this.addCustomEventListener = parentTask.addCustomEventListener.bind(parentTask);
this.removeCustomEventListeners = parentTask.removeCustomEventListeners.bind(parentTask);
}
get name() { return TaskName; }
get host() {
const {host} = this.connectionOptions || {};
return host || (this.vendor === 'aws' ? 'api.aws.com' : void 0);
}
get path() {
const {path} = this.connectionOptions || {};
if (path) return path;
switch (this.vendor) {
case 'aws':
return `v1/realtime?model=${this.model}`;
case 'microsoft':
return `aws/realtime?api-version=2024-10-01-preview&deployment=${this.model}`;
}
}
async _api(ep, args) {
const res = await ep.api('uuid_aws_s2s', `^^|${args.join('|')}`);
if (!res.body?.startsWith('+OK')) {
throw new Error({args}, `Error calling uuid_aws_s2s: ${res.body}`);
}
}
async exec(cs, {ep}) {
await super.exec(cs);
await this._startListening(cs, ep);
await this.awaitTaskDone();
/* note: the parent llm verb started the span, which is why this is necessary */
await this.parent.performAction(this.results);
this._unregisterHandlers();
}
async kill(cs) {
super.kill(cs);
this._api(cs.ep, [cs.ep.uuid, SessionDelete])
.catch((err) => this.logger.info({err}, 'TaskLlmAws_S2S:kill - error deleting session'));
this.notifyTaskDone();
}
/**
* Send function call output to the Aws server in the form of conversation.item.create
* per https://platform.aws.com/docs/guides/realtime/function-calls
*/
async processToolOutput(ep, tool_call_id, data) {
try {
this.logger.debug({tool_call_id, data}, 'TaskLlmAws_S2S:processToolOutput');
if (!data.type || data.type !== 'conversation.item.create') {
this.logger.info({data},
'TaskLlmAws_S2S:processToolOutput - invalid tool output, must be conversation.item.create');
}
else {
await this._api(ep, [ep.uuid, ClientEvent, JSON.stringify(data)]);
// spec also recommends to send immediate response.create
await this._api(ep, [ep.uuid, ClientEvent, JSON.stringify({type: 'response.create'})]);
}
} catch (err) {
this.logger.info({err}, 'TaskLlmAws_S2S:processToolOutput');
}
}
/**
* Send a session.update to the Aws server
* Note: creating and deleting conversation items also supported as well as interrupting the assistant
*/
async processLlmUpdate(ep, data, _callSid) {
try {
this.logger.debug({data, _callSid}, 'TaskLlmAws_S2S:processLlmUpdate');
if (!data.type || ![
'session.update',
'conversation.item.create',
'conversation.item.delete',
'response.cancel'
].includes(data.type)) {
this.logger.info({data}, 'TaskLlmAws_S2S:processLlmUpdate - invalid mid-call request');
}
else {
await this._api(ep, [ep.uuid, ClientEvent, JSON.stringify(data)]);
}
} catch (err) {
this.logger.info({err}, 'TaskLlmAws_S2S:processLlmUpdate');
}
}
async _startListening(cs, ep) {
this._registerHandlers(ep);
try {
const args = [ep.uuid, 'session.create', this.host, this.path, this.authType, this.apiKey];
await this._api(ep, args);
} catch (err) {
this.logger.error({err}, 'TaskLlmAws_S2S:_startListening');
this.notifyTaskDone();
}
}
async _sendClientEvent(ep, obj) {
let ok = true;
this.logger.debug({obj}, 'TaskLlmAws_S2S:_sendClientEvent');
try {
const args = [ep.uuid, ClientEvent, JSON.stringify(obj)];
await this._api(ep, args);
} catch (err) {
ok = false;
this.logger.error({err}, 'TaskLlmAws_S2S:_sendClientEvent - Error');
}
return ok;
}
async _sendInitialMessage(ep) {
let obj = {type: 'response.create', response: this.response_create};
if (!await this._sendClientEvent(ep, obj)) {
this.notifyTaskDone();
}
/* send immediate session.update if present */
else if (this.session_update) {
if (this.parent.isMcpEnabled) {
this.logger.debug('TaskLlmAws_S2S:_sendInitialMessage - mcp enabled');
const tools = await this.parent.mcpService.getAvailableMcpTools();
if (tools && tools.length > 0 && this.session_update) {
const convertedTools = tools.map((tool) => ({
name: tool.name,
type: 'function',
description: tool.description,
parameters: tool.inputSchema
}));
this.session_update.tools = [
...convertedTools,
...(this.session_update.tools || [])
];
}
}
obj = {type: 'session.update', session: this.session_update};
this.logger.debug({obj}, 'TaskLlmAws_S2S:_sendInitialMessage - sending session.update');
if (!await this._sendClientEvent(ep, obj)) {
this.notifyTaskDone();
}
}
}
_registerHandlers(ep) {
this.addCustomEventListener(ep, LlmEvents_Aws.Connect, this._onConnect.bind(this, ep));
this.addCustomEventListener(ep, LlmEvents_Aws.ConnectFailure, this._onConnectFailure.bind(this, ep));
this.addCustomEventListener(ep, LlmEvents_Aws.Disconnect, this._onDisconnect.bind(this, ep));
this.addCustomEventListener(ep, LlmEvents_Aws.ServerEvent, this._onServerEvent.bind(this, ep));
}
_unregisterHandlers() {
this.removeCustomEventListeners();
}
_onError(ep, evt) {
this.logger.info({evt}, 'TaskLlmAws_S2S:_onError');
this.notifyTaskDone();
}
_onConnect(ep) {
this.logger.debug('TaskLlmAws_S2S:_onConnect');
this._sendInitialMessage(ep);
}
_onConnectFailure(_ep, evt) {
this.logger.info(evt, 'TaskLlmAws_S2S:_onConnectFailure');
this.results = {completionReason: 'connection failure'};
this.notifyTaskDone();
}
_onDisconnect(_ep, evt) {
this.logger.info(evt, 'TaskLlmAws_S2S:_onConnectFailure');
this.results = {completionReason: 'disconnect from remote end'};
this.notifyTaskDone();
}
async _onServerEvent(ep, evt) {
let endConversation = false;
const type = evt.type;
this.logger.info({evt}, 'TaskLlmAws_S2S:_onServerEvent');
/* check for failures, such as rate limit exceeded, that should terminate the conversation */
if (type === 'response.done' && evt.response.status === 'failed') {
endConversation = true;
this.results = {
completionReason: 'server failure',
error: evt.response.status_details?.error
};
}
/* server errors of some sort */
else if (type === 'error') {
endConversation = true;
this.results = {
completionReason: 'server error',
error: evt.error
};
}
/* tool calls */
else if (type === 'response.output_item.done' && evt.item?.type === 'function_call') {
this.logger.debug({evt}, 'TaskLlmAws_S2S:_onServerEvent - function_call');
const {name, call_id} = evt.item;
const args = JSON.parse(evt.item.arguments);
const mcpTools = this.parent.isMcpEnabled ? await this.parent.mcpService.getAvailableMcpTools() : [];
if (mcpTools.some((tool) => tool.name === name)) {
this.logger.debug({call_id, name, args}, 'TaskLlmAws_S2S:_onServerEvent - calling mcp tool');
try {
const res = await this.parent.mcpService.callMcpTool(name, args);
this.logger.debug({res}, 'TaskLlmAws_S2S:_onServerEvent - function_call - mcp result');
this.processToolOutput(ep, call_id, {
type: 'conversation.item.create',
item: {
type: 'function_call_output',
call_id,
output: res.content[0]?.text || 'There is no output from the function call',
}
});
return;
} catch (err) {
this.logger.info({err, evt}, 'TaskLlmAws_S2S - error calling function');
this.results = {
completionReason: 'client error calling mcp function',
error: err
};
endConversation = true;
}
}
else if (!this.toolHook) {
this.logger.warn({evt}, 'TaskLlmAws_S2S:_onServerEvent - no toolHook defined!');
}
else {
try {
await this.parent.sendToolHook(call_id, {name, args});
} catch (err) {
this.logger.info({err, evt}, 'TaskLlmAws - error calling function');
this.results = {
completionReason: 'client error calling function',
error: err
};
endConversation = true;
}
}
}
/* check whether we should notify on this event */
if (this.includeEvents.length > 0 ? this.includeEvents.includes(type) : !this.excludeEvents.includes(type)) {
this.parent.sendEventHook(evt)
.catch((err) => this.logger.info({err}, 'TaskLlmAws_S2S:_onServerEvent - error sending event hook'));
}
if (endConversation) {
this.logger.info({results: this.results}, 'TaskLlmAws_S2S:_onServerEvent - ending conversation due to error');
this.notifyTaskDone();
}
}
_populateEvents(events) {
if (events.includes('all')) {
/* work by excluding specific events */
const exclude = events
.filter((evt) => evt.startsWith('-'))
.map((evt) => evt.slice(1));
if (exclude.length === 0) this.includeEvents = aws_server_events;
else this.excludeEvents = expandWildcards(exclude);
}
else {
/* work by including specific events */
const include = events
.filter((evt) => !evt.startsWith('-'));
this.includeEvents = expandWildcards(include);
}
this.logger.debug({
includeEvents: this.includeEvents,
excludeEvents: this.excludeEvents
}, 'TaskLlmAws_S2S:_populateEvents');
}
}
module.exports = TaskLlmAws_S2S;

View File

@@ -194,6 +194,13 @@
"Disconnect": "openai_s2s::disconnect", "Disconnect": "openai_s2s::disconnect",
"ServerEvent": "openai_s2s::server_event" "ServerEvent": "openai_s2s::server_event"
}, },
"LlmEvents_Aws": {
"Error": "error",
"Connect": "aws_s2s::connect",
"ConnectFailure": "aws_s2s::connect_failed",
"Disconnect": "aws_s2s::disconnect",
"ServerEvent": "aws_s2s::server_event"
},
"LlmEvents_Google": { "LlmEvents_Google": {
"Error": "error", "Error": "error",
"Connect": "google_s2s::connect", "Connect": "google_s2s::connect",