Added AWS_LANGUAGE_MODEL_NAME (#99)

Co-authored-by: ajukes <ajukes@callable.io>
This commit is contained in:
Antony Jukes
2024-08-12 16:03:26 +01:00
committed by GitHub
parent 110a12d5a5
commit 81ceddf3d2

View File

@@ -43,16 +43,16 @@ public:
const char *sessionId, const char *sessionId,
const char *bugname, const char *bugname,
u_int16_t channels, u_int16_t channels,
char *lang, char *lang,
int interim, int interim,
uint32_t samples_per_second, uint32_t samples_per_second,
const char* region, const char* region,
const char* awsAccessKeyId, const char* awsAccessKeyId,
const char* awsSecretAccessKey, const char* awsSecretAccessKey,
const char* awsSessionToken, const char* awsSessionToken,
responseHandler_t responseHandler responseHandler_t responseHandler
) : m_sessionId(sessionId), m_bugname(bugname), m_finished(false), m_interim(interim), m_finishing(false), m_connected(false), m_connecting(false), ) : m_sessionId(sessionId), m_bugname(bugname), m_finished(false), m_interim(interim), m_finishing(false), m_connected(false), m_connecting(false),
m_packets(0), m_responseHandler(responseHandler), m_pStream(nullptr), m_packets(0), m_responseHandler(responseHandler), m_pStream(nullptr),
m_audioBuffer(320 * (samples_per_second == 8000 ? 1 : 2), 15) { m_audioBuffer(320 * (samples_per_second == 8000 ? 1 : 2), 15) {
Aws::Client::ClientConfiguration config; Aws::Client::ClientConfiguration config;
if (region != nullptr && strlen(region) > 0) config.region = region; if (region != nullptr && strlen(region) > 0) config.region = region;
@@ -71,7 +71,7 @@ public:
else { else {
m_client = Aws::MakeUnique<TranscribeStreamingServiceClient>(ALLOC_TAG, config); m_client = Aws::MakeUnique<TranscribeStreamingServiceClient>(ALLOC_TAG, config);
} }
m_handler.SetTranscriptEventCallback([this](const TranscriptEvent& ev) m_handler.SetTranscriptEventCallback([this](const TranscriptEvent& ev)
{ {
switch_core_session_t* psession = switch_core_session_locate(m_sessionId.c_str()); switch_core_session_t* psession = switch_core_session_locate(m_sessionId.c_str());
@@ -111,6 +111,9 @@ public:
if (var = switch_channel_get_variable(channel, "AWS_VOCABULARY_FILTER_METHOD")) { if (var = switch_channel_get_variable(channel, "AWS_VOCABULARY_FILTER_METHOD")) {
m_request.SetVocabularyFilterMethod(VocabularyFilterMethodMapper::GetVocabularyFilterMethodForName(var)); m_request.SetVocabularyFilterMethod(VocabularyFilterMethodMapper::GetVocabularyFilterMethodForName(var));
} }
if (var = switch_channel_get_variable(channel, "AWS_LANGUAGE_MODEL_NAME")) {
m_request.SetLanguageModelName(var);
}
switch_core_session_rwunlock(session); switch_core_session_rwunlock(session);
} }
@@ -132,7 +135,7 @@ public:
// send any buffered audio // send any buffered audio
int nFrames = m_audioBuffer.getNumItems(); int nFrames = m_audioBuffer.getNumItems();
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "GStreamer %p got stream ready, %d buffered frames\n", this, nFrames); switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "GStreamer %p got stream ready, %d buffered frames\n", this, nFrames);
if (nFrames) { if (nFrames) {
char *p; char *p;
do { do {
@@ -142,19 +145,19 @@ public:
} }
} while (p); } while (p);
} }
switch_core_session_rwunlock(psession); switch_core_session_rwunlock(psession);
} }
}; };
auto OnResponseCallback = [this](const TranscribeStreamingServiceClient* pClient, auto OnResponseCallback = [this](const TranscribeStreamingServiceClient* pClient,
const Model::StartStreamTranscriptionRequest& request, const Model::StartStreamTranscriptionRequest& request,
const Model::StartStreamTranscriptionOutcome& outcome, const Model::StartStreamTranscriptionOutcome& outcome,
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context) const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context)
{ {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "GStreamer %p stream got final response\n", this); switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "GStreamer %p stream got final response\n", this);
switch_core_session_t* psession = switch_core_session_locate(m_sessionId.c_str()); switch_core_session_t* psession = switch_core_session_locate(m_sessionId.c_str());
if (psession) { if (psession) {
if (!outcome.IsSuccess()) { if (!outcome.IsSuccess()) {
const TranscribeStreamingServiceError& err = outcome.GetError(); const TranscribeStreamingServiceError& err = outcome.GetError();
auto message = err.GetMessage(); auto message = err.GetMessage();
auto exception = err.GetExceptionName(); auto exception = err.GetExceptionName();
@@ -186,7 +189,7 @@ public:
~GStreamer() { ~GStreamer() {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "GStreamer::~GStreamer wrote %u packets %p\n", m_packets, this); switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "GStreamer::~GStreamer wrote %u packets %p\n", m_packets, this);
} }
bool write(void* data, uint32_t datalen) { bool write(void* data, uint32_t datalen) {
@@ -228,7 +231,7 @@ public:
bool shutdownInitiated = false; bool shutdownInitiated = false;
while (true) { while (true) {
std::unique_lock<std::mutex> lk(m_mutex); std::unique_lock<std::mutex> lk(m_mutex);
m_cond.wait(lk, [&, this] { m_cond.wait(lk, [&, this] {
return (!m_deqAudio.empty() && !m_finishing) || m_transcript.TranscriptHasBeenSet() || m_finished || (m_finishing && !shutdownInitiated); return (!m_deqAudio.empty() && !m_finishing) || m_transcript.TranscriptHasBeenSet() || m_finished || (m_finishing && !shutdownInitiated);
}); });
@@ -264,7 +267,7 @@ public:
m_responseHandler(psession, s.str().c_str(), m_bugname.c_str()); m_responseHandler(psession, s.str().c_str(), m_bugname.c_str());
} }
TranscriptEvent empty; TranscriptEvent empty;
m_transcript = empty; m_transcript = empty;
switch_core_session_rwunlock(psession); switch_core_session_rwunlock(psession);
} }
@@ -362,7 +365,7 @@ extern "C" {
const char* secretAccessKey = std::getenv("AWS_SECRET_ACCESS_KEY"); const char* secretAccessKey = std::getenv("AWS_SECRET_ACCESS_KEY");
const char* region = std::getenv("AWS_REGION"); const char* region = std::getenv("AWS_REGION");
if (NULL == accessKeyId && NULL == secretAccessKey) { if (NULL == accessKeyId && NULL == secretAccessKey) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE, switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE,
"\"AWS_ACCESS_KEY_ID\" and/or \"AWS_SECRET_ACCESS_KEY\" env var not set; authentication will expect channel variables of same names to be set\n"); "\"AWS_ACCESS_KEY_ID\" and/or \"AWS_SECRET_ACCESS_KEY\" env var not set; authentication will expect channel variables of same names to be set\n");
} }
else { else {
@@ -370,7 +373,7 @@ extern "C" {
} }
Aws::SDKOptions options; Aws::SDKOptions options;
/* /*
options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Trace; options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Trace;
Aws::Utils::Logging::InitializeAWSLogging( Aws::Utils::Logging::InitializeAWSLogging(
@@ -381,7 +384,7 @@ extern "C" {
return SWITCH_STATUS_SUCCESS; return SWITCH_STATUS_SUCCESS;
} }
switch_status_t aws_transcribe_cleanup() { switch_status_t aws_transcribe_cleanup() {
Aws::SDKOptions options; Aws::SDKOptions options;
/* /*
@@ -394,7 +397,7 @@ extern "C" {
} }
// start transcribe on a channel // start transcribe on a channel
switch_status_t aws_transcribe_session_init(switch_core_session_t *session, responseHandler_t responseHandler, switch_status_t aws_transcribe_session_init(switch_core_session_t *session, responseHandler_t responseHandler,
uint32_t samples_per_second, uint32_t channels, char* lang, int interim, char* bugname, void **ppUserData uint32_t samples_per_second, uint32_t channels, char* lang, int interim, char* bugname, void **ppUserData
) { ) {
switch_status_t status = SWITCH_STATUS_SUCCESS; switch_status_t status = SWITCH_STATUS_SUCCESS;
@@ -433,7 +436,7 @@ extern "C" {
std::getenv("AWS_REGION")) { std::getenv("AWS_REGION")) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "Using env vars for aws authentication\n"); switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "Using env vars for aws authentication\n");
strncpy(cb->awsAccessKeyId, std::getenv("AWS_ACCESS_KEY_ID"), 128); strncpy(cb->awsAccessKeyId, std::getenv("AWS_ACCESS_KEY_ID"), 128);
strncpy(cb->awsSecretAccessKey, std::getenv("AWS_SECRET_ACCESS_KEY"), 128); strncpy(cb->awsSecretAccessKey, std::getenv("AWS_SECRET_ACCESS_KEY"), 128);
strncpy(cb->region, std::getenv("AWS_REGION"), MAX_REGION); strncpy(cb->region, std::getenv("AWS_REGION"), MAX_REGION);
} }
else { else {
@@ -445,7 +448,7 @@ extern "C" {
if (switch_mutex_init(&cb->mutex, SWITCH_MUTEX_NESTED, pool) != SWITCH_STATUS_SUCCESS) { if (switch_mutex_init(&cb->mutex, SWITCH_MUTEX_NESTED, pool) != SWITCH_STATUS_SUCCESS) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "Error initializing mutex\n"); switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "Error initializing mutex\n");
status = SWITCH_STATUS_FALSE; status = SWITCH_STATUS_FALSE;
goto done; goto done;
} }
cb->interim = interim; cb->interim = interim;
@@ -455,7 +458,7 @@ extern "C" {
if (sampleRate != 8000) { if (sampleRate != 8000) {
cb->resampler = speex_resampler_init(1, sampleRate, 16000, SWITCH_RESAMPLE_QUALITY, &err); cb->resampler = speex_resampler_init(1, sampleRate, 16000, SWITCH_RESAMPLE_QUALITY, &err);
if (0 != err) { if (0 != err) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "%s: Error initializing resampler: %s.\n", switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "%s: Error initializing resampler: %s.\n",
switch_channel_get_name(channel), speex_resampler_strerror(err)); switch_channel_get_name(channel), speex_resampler_strerror(err));
status = SWITCH_STATUS_FALSE; status = SWITCH_STATUS_FALSE;
goto done; goto done;
@@ -488,7 +491,7 @@ extern "C" {
switch_vad_set_param(cb->vad, "silence_ms", silence_ms); switch_vad_set_param(cb->vad, "silence_ms", silence_ms);
switch_vad_set_param(cb->vad, "voice_ms", voice_ms); switch_vad_set_param(cb->vad, "voice_ms", voice_ms);
switch_vad_set_param(cb->vad, "debug", debug); switch_vad_set_param(cb->vad, "debug", debug);
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "%s: delaying connection until vad, voice_ms %d, mode %d\n", switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "%s: delaying connection until vad, voice_ms %d, mode %d\n",
switch_channel_get_name(channel), voice_ms, mode); switch_channel_get_name(channel), voice_ms, mode);
} }
} }
@@ -499,7 +502,7 @@ extern "C" {
switch_thread_create(&cb->thread, thd_attr, aws_transcribe_thread, cb, pool); switch_thread_create(&cb->thread, thd_attr, aws_transcribe_thread, cb, pool);
*ppUserData = cb; *ppUserData = cb;
done: done:
return status; return status;
} }
@@ -521,7 +524,7 @@ extern "C" {
do { do {
streamer = (GStreamer *) cb->streamer; streamer = (GStreamer *) cb->streamer;
if (streamer) break; if (streamer) break;
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG,
"aws_transcribe_session_stop: waiting for streamer to come online..%s\n", bugname); "aws_transcribe_session_stop: waiting for streamer to come online..%s\n", bugname);
switch_yield(10000); // wait 10ms switch_yield(10000); // wait 10ms
} while (i++ < 3); } while (i++ < 3);
@@ -557,7 +560,7 @@ extern "C" {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_INFO, "%s Bug is not attached.\n", switch_channel_get_name(channel)); switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_INFO, "%s Bug is not attached.\n", switch_channel_get_name(channel));
return SWITCH_STATUS_FALSE; return SWITCH_STATUS_FALSE;
} }
switch_bool_t aws_transcribe_frame(switch_media_bug_t *bug, void* user_data) { switch_bool_t aws_transcribe_frame(switch_media_bug_t *bug, void* user_data) {
switch_core_session_t *session = switch_core_media_bug_get_session(bug); switch_core_session_t *session = switch_core_media_bug_get_session(bug);
uint8_t data[SWITCH_RECOMMENDED_BUFFER_SIZE]; uint8_t data[SWITCH_RECOMMENDED_BUFFER_SIZE];
@@ -587,7 +590,7 @@ extern "C" {
} }
if (cb->resampler) { if (cb->resampler) {
speex_resampler_process_interleaved_int(cb->resampler, (const spx_int16_t *) frame.data, (spx_uint32_t *) &in_len, &out[0], &out_len); speex_resampler_process_interleaved_int(cb->resampler, (const spx_int16_t *) frame.data, (spx_uint32_t *) &in_len, &out[0], &out_len);
streamer->write( &out[0], sizeof(spx_int16_t) * out_len); streamer->write( &out[0], sizeof(spx_int16_t) * out_len);
} }
else { else {
@@ -597,11 +600,11 @@ extern "C" {
} }
} }
else { else {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG,
"aws_transcribe_frame: not sending audio because aws channel has been closed\n"); "aws_transcribe_frame: not sending audio because aws channel has been closed\n");
} }
switch_mutex_unlock(cb->mutex); switch_mutex_unlock(cb->mutex);
} }
return SWITCH_TRUE; return SWITCH_TRUE;
} }
} }