Speech vendor/cobalt (#463)

* initial changes for cobalt speech

* wip

* wip

* update to drachtio-fsmrf that supports cobalt

* update to verb-specifications with cobalt speech support

* more wip

* lint

* use node 18 with gh actions

* support for compiling cobalt hints

* fix bug in uuid_cobalt_compile_context

* update verb-specifications

* remove repeated code

* cobalt support for transcribe

* update to verb specs
This commit is contained in:
Dave Horton
2023-09-13 09:47:30 -04:00
committed by GitHub
parent d220733dea
commit a1793ac359
11 changed files with 204 additions and 24 deletions

View File

@@ -6,6 +6,7 @@ const {
AzureTranscriptionEvents,
DeepgramTranscriptionEvents,
SonioxTranscriptionEvents,
CobaltTranscriptionEvents,
IbmTranscriptionEvents,
NvidiaTranscriptionEvents,
JambonzTranscriptionEvents
@@ -187,12 +188,18 @@ class TaskGather extends SttTask {
}
}
/* when using cobalt model is required */
if (this.vendor === 'cobalt' && !this.data.recognizer.model) {
this.notifyError({ msg: 'ASR error', details:'Cobalt requires a model to be specified'});
throw new Error('Cobalt requires a model to be specified');
}
const startListening = async(cs, ep) => {
this._startTimer();
if (this.isContinuousAsr && 0 === this.timeout) this._startAsrTimer();
if (this.input.includes('speech') && !this.listenDuringPrompt) {
try {
await this._initSpeech(cs, ep);
await this._setSpeechHandlers(cs, ep);
if (this.killed) {
this.logger.info('Gather:exec - task was quickly killed so do not transcribe');
return;
@@ -260,7 +267,7 @@ class TaskGather extends SttTask {
}
if (this.input.includes('speech') && this.listenDuringPrompt) {
await this._initSpeech(cs, ep);
await this._setSpeechHandlers(cs, ep);
this._startTranscribing(ep);
updateSpeechCredentialLastUsed(this.sttCredentials.speech_credential_sid)
.catch(() => {/*already logged error */});
@@ -334,7 +341,9 @@ class TaskGather extends SttTask {
}
}
async _initSpeech(cs, ep) {
async _setSpeechHandlers(cs, ep) {
if (this._speechHandlersSet) return;
this._speechHandlersSet = true;
const opts = this.setChannelVarsForStt(this, this.sttCredentials, this.data.recognizer);
switch (this.vendor) {
case 'google':
@@ -387,6 +396,28 @@ class TaskGather extends SttTask {
ep.addCustomEventListener(SonioxTranscriptionEvents.Transcription, this._onTranscription.bind(this, cs, ep));
break;
case 'cobalt':
this.bugname = 'cobalt_speech';
ep.addCustomEventListener(CobaltTranscriptionEvents.Transcription, this._onTranscription.bind(this, cs, ep));
/* special case: if using hints with cobalt we need to compile them */
if (this.vendor === 'cobalt' && opts.COBALT_SPEECH_HINTS) {
try {
const context = await this.compileHintsForCobalt(
ep,
opts.COBALT_SERVER_URI,
this.data.recognizer.model,
opts.COBALT_CONTEXT_TOKEN,
opts.COBALT_SPEECH_HINTS
);
if (context) opts.COBALT_COMPILED_CONTEXT_DATA = context;
delete opts.COBALT_SPEECH_HINTS;
} catch (err) {
this.logger.error({err}, 'Error compiling hints for cobalt');
}
}
break;
case 'ibm':
this.bugname = 'ibm_transcribe';
ep.addCustomEventListener(IbmTranscriptionEvents.Transcription, this._onTranscription.bind(this, cs, ep));

View File

@@ -1,6 +1,7 @@
const Task = require('./task');
const assert = require('assert');
const { TaskPreconditions } = require('../utils/constants');
const crypto = require('crypto');
const { TaskPreconditions, CobaltTranscriptionEvents } = require('../utils/constants');
class SttTask extends Task {
@@ -95,6 +96,44 @@ class SttTask extends Task {
this.data.recognizer.label = this.label;
this.sttCredentials = await this._initSpeechCredentials(this.cs, this.vendor, this.label);
}
async compileHintsForCobalt(ep, hostport, model, token, hints) {
const {retrieveKey} = this.cs.srf.locals.dbHelpers;
const hash = crypto.createHash('sha1');
hash.update(`${model}:${hints}`);
const key = `cobalt:${hash.digest('hex')}`;
this.context = await retrieveKey(key);
if (this.context) {
this.logger.debug({model, hints}, 'found cached cobalt context for supplied hints');
return this.context;
}
this.logger.debug({model, hints}, 'compiling cobalt context for supplied hints');
return new Promise((resolve, reject) => {
this.cobaltCompileResolver = resolve;
ep.addCustomEventListener(CobaltTranscriptionEvents.CompileContext, this._onCompileContext.bind(this, ep, key));
ep.api('uuid_cobalt_compile_context', [ep.uuid, hostport, model, token, hints], (err, evt) => {
if (err || 0 !== evt.getBody().indexOf('+OK')) {
ep.removeCustomEventListener(CobaltTranscriptionEvents.CompileContext);
return reject(err);
}
});
});
}
_onCompileContext(ep, key, evt) {
const {addKey} = this.cs.srf.locals.dbHelpers;
this.logger.debug({evt}, `received cobalt compile context event, will cache under ${key}`);
this.cobaltCompileResolver(evt.compiled_context);
ep.removeCustomEventListener(CobaltTranscriptionEvents.CompileContext);
this.cobaltCompileResolver = null;
//cache the compiled context
addKey(key, evt.compiled_context, 3600 * 12)
.catch((err) => this.logger.info({err}, `Error caching cobalt context for ${key}`));
}
}
module.exports = SttTask;

View File

@@ -7,6 +7,7 @@ const {
AzureTranscriptionEvents,
DeepgramTranscriptionEvents,
SonioxTranscriptionEvents,
CobaltTranscriptionEvents,
IbmTranscriptionEvents,
NvidiaTranscriptionEvents,
JambonzTranscriptionEvents,
@@ -103,6 +104,12 @@ class TaskTranscribe extends SttTask {
}
}
/* when using cobalt model is required */
if (this.vendor === 'cobalt' && !this.data.recognizer.model) {
this.notifyError({ msg: 'ASR error', details:'Cobalt requires a model to be specified'});
throw new Error('Cobalt requires a model to be specified');
}
try {
await this._startTranscribing(cs, ep, 1);
if (this.separateRecognitionPerChannel && ep2) {
@@ -163,7 +170,9 @@ class TaskTranscribe extends SttTask {
}
}
async _startTranscribing(cs, ep, channel) {
async _setSpeechHandlers(cs, ep, channel) {
if (this._speechHandlersSet) return;
this._speechHandlersSet = true;
const opts = this.setChannelVarsForStt(this, this.sttCredentials, this.data.recognizer);
switch (this.vendor) {
case 'google':
@@ -216,6 +225,29 @@ class TaskTranscribe extends SttTask {
ep.addCustomEventListener(SonioxTranscriptionEvents.Transcription,
this._onTranscription.bind(this, cs, ep, channel));
break;
case 'cobalt':
this.bugname = 'cobalt_transcribe';
ep.addCustomEventListener(CobaltTranscriptionEvents.Transcription,
this._onTranscription.bind(this, cs, ep, channel));
/* special case: if using hints with cobalt we need to compile them */
if (this.vendor === 'cobalt' && opts.COBALT_SPEECH_HINTS) {
try {
const context = await this.compileHintsForCobalt(
ep,
opts.COBALT_SERVER_URI,
this.data.recognizer.model,
opts.COBALT_CONTEXT_TOKEN,
opts.COBALT_SPEECH_HINTS
);
if (context) opts.COBALT_COMPILED_CONTEXT_DATA = context;
delete opts.COBALT_SPEECH_HINTS;
} catch (err) {
this.logger.error({err}, 'Error compiling hints for cobalt');
}
}
break;
case 'ibm':
this.bugname = 'ibm_transcribe';
ep.addCustomEventListener(IbmTranscriptionEvents.Transcription,
@@ -237,6 +269,7 @@ class TaskTranscribe extends SttTask {
ep.addCustomEventListener(NvidiaTranscriptionEvents.VadDetected,
this._onVadDetected.bind(this, cs, ep));
break;
default:
if (this.vendor.startsWith('custom:')) {
this.bugname = `${this.vendor}_transcribe`;
@@ -258,7 +291,10 @@ class TaskTranscribe extends SttTask {
ep.addCustomEventListener(JambonzTranscriptionEvents.Error, this._onJambonzError.bind(this, cs, ep));
await ep.set(opts)
.catch((err) => this.logger.info(err, 'Error setting channel variables'));
}
async _startTranscribing(cs, ep, channel) {
await this._setSpeechHandlers(cs, ep, channel);
await this._transcribe(ep);
/* start child span for this channel */