* wip

* #103 wip

* wip

* wip

* support both grpc (legacy) and websockets api for aws transcribe

* renaming
This commit is contained in:
Dave Horton
2024-09-24 09:51:06 -04:00
committed by GitHub
parent d17a2aa9be
commit 8a3c001b59
12 changed files with 4061 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
include $(top_srcdir)/build/modmake.rulesam
MODNAME=mod_aws_transcribe_ws
mod_LTLIBRARIES = mod_aws_transcribe_ws.la
mod_aws_transcribe_ws_la_SOURCES = mod_aws_transcribe_ws.c aws_transcribe_glue.cpp transcribe_manager.cpp audio_pipe.cpp
mod_aws_transcribe_ws_la_CFLAGS = $(AM_CFLAGS)
mod_aws_transcribe_ws_la_CXXFLAGS = $(AM_CXXFLAGS) -std=c++11 -I${switch_srcdir}/libs/aws-sdk-cpp/aws-cpp-sdk-core/include -I${switch_srcdir}/libs/aws-sdk-cpp/aws-cpp-sdk-transcribestreaming/include -I${switch_srcdir}/libs/aws-sdk-cpp/build/.deps/install/include
mod_aws_transcribe_ws_la_LIBADD = $(switch_builddir)/libfreeswitch.la
mod_aws_transcribe_ws_la_LDFLAGS = -avoid-version -module -no-undefined -shared `pkg-config --libs libwebsockets`

View File

@@ -0,0 +1,58 @@
# mod_aws_transcribe
A Freeswitch module that generates real-time transcriptions on a Freeswitch channel by using AWS streaming transcription API
## API
### Commands
The freeswitch module exposes the following API commands:
```
aws_transcribe <uuid> start <lang-code> [interim]
```
Attaches media bug to channel and performs streaming recognize request.
- `uuid` - unique identifier of Freeswitch channel
- `lang-code` - a valid AWS [language code](https://docs.aws.amazon.com/transcribe/latest/dg/what-is-transcribe.html) that is supported for streaming transcription
- `interim` - If the 'interim' keyword is present then both interim and final transcription results will be returned; otherwise only final transcriptions will be returned
```
aws_transcribe <uuid> stop
```
Stop transcription on the channel.
### Authentication
The plugin will first look for channel variables, then environment variables. If neither are found, then the default AWS profile on the server will be used.
The names of the channel variables and environment variables are:
| variable | Description |
| --- | ----------- |
| AWS_ACCESS_KEY_ID | The Aws access key ID |
| AWS_SECRET_ACCESS_KEY | The Aws secret access key |
| AWS_REGION | The Aws region |
### Events
`aws_transcribe::transcription` - returns an interim or final transcription. The event contains a JSON body describing the transcription result:
```js
[
{
"is_final": true,
"alternatives": [{
"transcript": "Hello. Can you hear me?"
}]
}
]
```
## Usage
When using [drachtio-fsrmf](https://www.npmjs.com/package/drachtio-fsmrf), you can access this API command via the api method on the 'endpoint' object.
```js
ep.api('aws_transcribe', `${ep.uuid} start en-US interim`);
```
## Building
This uses the AWS websocket api.
## Examples
[aws_transcribe.js](../../examples/aws_transcribe.js)

View File

@@ -0,0 +1,612 @@
#include "audio_pipe.hpp"
#include "transcribe_manager.hpp"
#include "crc.h"
#include <switch.h>
#include <cassert>
#include <iostream>
#include <netinet/in.h>
#include <fstream>
#include <string>
/* discard incoming text messages over the socket that are longer than this */
#define MAX_RECV_BUF_SIZE (65 * 1024 * 10)
#define RECV_BUF_REALLOC_SIZE (8 * 1024)
#define AWS_PRELUDE_PLUS_HDRS_LEN (100)
using namespace aws;
namespace {
static const char *requestedTcpKeepaliveSecs = std::getenv("MOD_AUDIO_FORK_TCP_KEEPALIVE_SECS");
static int nTcpKeepaliveSecs = requestedTcpKeepaliveSecs ? ::atoi(requestedTcpKeepaliveSecs) : 55;
static uint8_t aws_prelude_and_headers[AWS_PRELUDE_PLUS_HDRS_LEN];
void writeToFile(const char* buffer, size_t bufferSize) {
static int writeCounter = 0; // Static variable to keep track of write count
// Write only the first three times
if (writeCounter >= 4) {
return;
}
// Generate a unique file name using the writeCounter
std::string filename = "/tmp/audio_data_" + std::to_string(writeCounter) + ".bin";
// Open a file in binary mode
std::ofstream outFile(filename, std::ios::binary);
// Check if the file is open
if (outFile.is_open()) {
// Write the buffer to the file
outFile.write(buffer, bufferSize);
outFile.close();
// Increment the write counter
writeCounter++;
} else {
// Handle error in file opening
std::cerr << "Unable to open file: " << filename << std::endl;
}
}
}
int AudioPipe::aws_lws_callback(struct lws *wsi,
enum lws_callback_reasons reason,
void *user, void *in, size_t len) {
struct AudioPipe::lws_per_vhost_data *vhd =
(struct AudioPipe::lws_per_vhost_data *) lws_protocol_vh_priv_get(lws_get_vhost(wsi), lws_get_protocol(wsi));
struct lws_vhost* vhost = lws_get_vhost(wsi);
AudioPipe ** ppAp = (AudioPipe **) user;
switch (reason) {
case LWS_CALLBACK_PROTOCOL_INIT:
vhd = (struct AudioPipe::lws_per_vhost_data *) lws_protocol_vh_priv_zalloc(lws_get_vhost(wsi), lws_get_protocol(wsi), sizeof(struct AudioPipe::lws_per_vhost_data));
vhd->context = lws_get_context(wsi);
vhd->protocol = lws_get_protocol(wsi);
vhd->vhost = lws_get_vhost(wsi);
break;
case LWS_CALLBACK_EVENT_WAIT_CANCELLED:
processPendingConnects(vhd);
processPendingDisconnects(vhd);
processPendingWrites();
break;
case LWS_CALLBACK_CLIENT_CONNECTION_ERROR:
{
AudioPipe* ap = findAndRemovePendingConnect(wsi);
int rc = lws_http_client_http_response(wsi);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR,"AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_CONNECTION_ERROR: %s, response status %d\n", in ? (char *)in : "(null)", rc);
if (ap) {
ap->m_state = LWS_CLIENT_FAILED;
ap->m_callback(ap->m_uuid.c_str(), ap->m_bugname.c_str(), AudioPipe::CONNECT_FAIL, (char *) in, ap->isFinished());
}
else {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR,"AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_ESTABLISHED %s unable to find wsi %p..\n", ap->m_uuid.c_str(), wsi);
}
}
break;
case LWS_CALLBACK_CLIENT_ESTABLISHED:
{
AudioPipe* ap = findAndRemovePendingConnect(wsi);
if (ap) {
*ppAp = ap;
ap->m_vhd = vhd;
ap->m_state = LWS_CLIENT_CONNECTED;
ap->m_callback(ap->m_uuid.c_str(), ap->m_bugname.c_str(), AudioPipe::CONNECT_SUCCESS, NULL, ap->isFinished());
//switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_INFO,"%s connected\n", ap->m_uuid.c_str());
}
else {
lwsl_err("AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_ESTABLISHED %s unable to find wsi %p..\n", ap->m_uuid.c_str(), wsi);
}
}
break;
case LWS_CALLBACK_CLIENT_CLOSED:
{
AudioPipe* ap = *ppAp;
if (!ap) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR,"AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_CLOSED %s unable to find wsi %p..\n", ap->m_uuid.c_str(), wsi);
return 0;
}
if (ap->m_state == LWS_CLIENT_DISCONNECTING) {
// closed by us
lwsl_debug("%s socket closed by us\n", ap->m_uuid.c_str());
ap->m_callback(ap->m_uuid.c_str(), ap->m_bugname.c_str(), AudioPipe::CONNECTION_CLOSED_GRACEFULLY, NULL, ap->isFinished());
}
else if (ap->m_state == LWS_CLIENT_CONNECTED) {
// closed by far end
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_INFO,"%s socket closed by far end\n", ap->m_uuid.c_str());
ap->m_callback(ap->m_uuid.c_str(), ap->m_bugname.c_str(), AudioPipe::CONNECTION_DROPPED, NULL, ap->isFinished());
}
ap->m_state = LWS_CLIENT_DISCONNECTED;
ap->setClosed();
//NB: after receiving any of the events above, any holder of a
//pointer or reference to this object must treat is as no longer valid
//*ppAp = NULL;
//delete ap;
}
break;
case LWS_CALLBACK_CLIENT_RECEIVE:
{
AudioPipe* ap = *ppAp;
if (!ap) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR,"AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_RECEIVE %s unable to find wsi %p..\n", ap->m_uuid.c_str(), wsi);
return 0;
}
if (lws_is_first_fragment(wsi)) {
// allocate a buffer for the entire chunk of memory needed
assert(nullptr == ap->m_recv_buf);
ap->m_recv_buf_len = len + lws_remaining_packet_payload(wsi);
ap->m_recv_buf = (uint8_t*) malloc(ap->m_recv_buf_len);
ap->m_recv_buf_ptr = ap->m_recv_buf;
}
size_t write_offset = ap->m_recv_buf_ptr - ap->m_recv_buf;
size_t remaining_space = ap->m_recv_buf_len - write_offset;
if (remaining_space < len) {
lwsl_err("AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_RECEIVE buffer realloc needed.\n");
size_t newlen = ap->m_recv_buf_len + RECV_BUF_REALLOC_SIZE;
if (newlen > MAX_RECV_BUF_SIZE) {
free(ap->m_recv_buf);
ap->m_recv_buf = ap->m_recv_buf_ptr = nullptr;
ap->m_recv_buf_len = 0;
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR, "AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_RECEIVE max buffer exceeded, truncating message.\n");
}
else {
ap->m_recv_buf = (uint8_t*) realloc(ap->m_recv_buf, newlen);
if (nullptr != ap->m_recv_buf) {
ap->m_recv_buf_len = newlen;
ap->m_recv_buf_ptr = ap->m_recv_buf + write_offset;
}
}
}
if (nullptr != ap->m_recv_buf) {
if (len > 0) {
memcpy(ap->m_recv_buf_ptr, in, len);
ap->m_recv_buf_ptr += len;
}
if (lws_is_final_fragment(wsi)) {
//switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR, "AudioPipe::lws_service_thread - LWS_CALLBACK_CLIENT_RECEIVE received %d bytes.\n", len);
if (nullptr != ap->m_recv_buf) {
bool isError = false;
std::string payload;
std::string msg((char *)ap->m_recv_buf, ap->m_recv_buf_ptr - ap->m_recv_buf);
TranscribeManager::parseResponse(msg, payload, isError, true);
//switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR, "AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_RECEIVE payload: %s.\n", payload.c_str());
//switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR, "AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_RECEIVE response %s.\n", msg.c_str());
if (0 != payload.compare("{\"Transcript\":{\"Results\":[]}}")) {
ap->m_callback(ap->m_uuid.c_str(), ap->m_bugname.c_str(), AudioPipe::MESSAGE, payload.c_str(), ap->isFinished());
}
if (nullptr != ap->m_recv_buf) free(ap->m_recv_buf);
}
ap->m_recv_buf = ap->m_recv_buf_ptr = nullptr;
ap->m_recv_buf_len = 0;
}
}
}
break;
case LWS_CALLBACK_CLIENT_WRITEABLE:
{
AudioPipe* ap = *ppAp;
if (!ap) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR,"AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_WRITEABLE %s unable to find wsi %p..\n", ap->m_uuid.c_str(), wsi);
return 0;
}
if (ap->m_state == LWS_CLIENT_DISCONNECTING) {
lws_close_reason(wsi, LWS_CLOSE_STATUS_NORMAL, NULL, 0);
return -1;
}
// check for audio packets
{
std::lock_guard<std::mutex> lk(ap->m_audio_mutex);
if (ap->m_audio_buffer_write_offset > LWS_PRE + AWS_PRELUDE_PLUS_HDRS_LEN || ap->isFinished()) {
// send a zero length audio packet to indicate end of stream
if (ap->isFinished()) {
ap->m_audio_buffer_write_offset = LWS_PRE + AWS_PRELUDE_PLUS_HDRS_LEN;
}
/**
* fill in
* [0..3] = total byte length
* [8..11] = prelude crc
* following the audio data: 4 bytes of Message CRC
*
*/
// copy in the prelude and headers
memcpy(ap->m_audio_buffer + LWS_PRE, aws_prelude_and_headers, AWS_PRELUDE_PLUS_HDRS_LEN);
// fill in the total byte length
uint32_t totalLen = ap->m_audio_buffer_write_offset - LWS_PRE + 4; // for the trailing Message CRC which is 4 bytes
//lwsl_err("AudioPipe - total length %u (decimal), 0x%X (hex)\n", totalLen, totalLen);
totalLen = htonl(totalLen);
//lwsl_err("AudioPipe - total length in network byte order %u (decimal), 0x%X (hex)\n", totalLen, totalLen);
memcpy(ap->m_audio_buffer + LWS_PRE, &totalLen, sizeof(uint32_t));
// fill in the prelude crc
uint32_t preludeCRC = CRC::Calculate(ap->m_audio_buffer + LWS_PRE, 8, CRC::CRC_32());
//lwsl_err("AudioPipe - prelude CRC %u (decimal), 0x%X (hex)\n", preludeCRC, preludeCRC);
preludeCRC = htonl(preludeCRC);
//lwsl_err("AudioPipe - prelude CRC in network order %u (decimal), 0x%X (hex)\n", preludeCRC, preludeCRC);
memcpy(ap->m_audio_buffer + LWS_PRE + 8, &preludeCRC, sizeof(uint32_t));
// fill in the message crc
uint32_t messageCRC = CRC::Calculate(ap->m_audio_buffer + LWS_PRE, ap->m_audio_buffer_write_offset - LWS_PRE, CRC::CRC_32());
messageCRC = htonl(messageCRC);
memcpy(ap->m_audio_buffer + ap->m_audio_buffer_write_offset, &messageCRC, sizeof(uint32_t));
ap->m_audio_buffer_write_offset + sizeof(uint32_t);
size_t datalen = ap->m_audio_buffer_write_offset - LWS_PRE + 4;
//switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_INFO,"%s writing data length %lu\n", ap->m_uuid.c_str(), datalen);
// TMP: write data to a file
//writeToFile((const char *) ap->m_audio_buffer + LWS_PRE, datalen);
int sent = lws_write(wsi, (unsigned char *) ap->m_audio_buffer + LWS_PRE, datalen, LWS_WRITE_BINARY);
if (sent < datalen) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR, "AudioPipe::lws_service_thread LWS_CALLBACK_CLIENT_WRITEABLE %s attemped to send %lu only sent %d wsi %p..\n",
ap->m_uuid.c_str(), datalen, sent, wsi);
}
ap->binaryWritePtrResetToZero();
}
}
return 0;
}
break;
default:
break;
}
return lws_callback_http_dummy(wsi, reason, user, in, len);
}
// static members
static const lws_retry_bo_t retry = {
nullptr, // retry_ms_table
0, // retry_ms_table_count
0, // conceal_count
UINT16_MAX, // secs_since_valid_ping
UINT16_MAX, // secs_since_valid_hangup
0 // jitter_percent
};
struct lws_context *AudioPipe::contexts[] = {
nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr
};
unsigned int AudioPipe::numContexts = 0;
unsigned int AudioPipe::nchild = 0;
std::string AudioPipe::protocolName;
std::mutex AudioPipe::mutex_connects;
std::mutex AudioPipe::mutex_disconnects;
std::mutex AudioPipe::mutex_writes;
std::list<AudioPipe*> AudioPipe::pendingConnects;
std::list<AudioPipe*> AudioPipe::pendingDisconnects;
std::list<AudioPipe*> AudioPipe::pendingWrites;
AudioPipe::log_emit_function AudioPipe::logger;
std::mutex AudioPipe::mapMutex;
std::unordered_map<std::thread::id, bool> AudioPipe::stopFlags;
std::queue<std::thread::id> AudioPipe::threadIds;
void AudioPipe::processPendingConnects(lws_per_vhost_data *vhd) {
std::list<AudioPipe*> connects;
{
std::lock_guard<std::mutex> guard(mutex_connects);
for (auto it = pendingConnects.begin(); it != pendingConnects.end(); ++it) {
if ((*it)->m_state == LWS_CLIENT_IDLE) {
connects.push_back(*it);
(*it)->m_state = LWS_CLIENT_CONNECTING;
}
}
}
for (auto it = connects.begin(); it != connects.end(); ++it) {
AudioPipe* ap = *it;
ap->connect_client(vhd);
}
}
void AudioPipe::processPendingDisconnects(lws_per_vhost_data *vhd) {
std::list<AudioPipe*> disconnects;
{
std::lock_guard<std::mutex> guard(mutex_disconnects);
for (auto it = pendingDisconnects.begin(); it != pendingDisconnects.end(); ++it) {
if ((*it)->m_state == LWS_CLIENT_DISCONNECTING) disconnects.push_back(*it);
}
pendingDisconnects.clear();
}
for (auto it = disconnects.begin(); it != disconnects.end(); ++it) {
AudioPipe* ap = *it;
lws_callback_on_writable(ap->m_wsi);
}
}
void AudioPipe::processPendingWrites() {
std::list<AudioPipe*> writes;
{
std::lock_guard<std::mutex> guard(mutex_writes);
for (auto it = pendingWrites.begin(); it != pendingWrites.end(); ++it) {
if ((*it)->m_state == LWS_CLIENT_CONNECTED) writes.push_back(*it);
}
pendingWrites.clear();
}
for (auto it = writes.begin(); it != writes.end(); ++it) {
AudioPipe* ap = *it;
lws_callback_on_writable(ap->m_wsi);
}
}
AudioPipe* AudioPipe::findAndRemovePendingConnect(struct lws *wsi) {
AudioPipe* ap = NULL;
std::lock_guard<std::mutex> guard(mutex_connects);
std::list<AudioPipe* > toRemove;
for (auto it = pendingConnects.begin(); it != pendingConnects.end() && !ap; ++it) {
int state = (*it)->m_state;
if ((*it)->m_wsi == nullptr)
toRemove.push_back(*it);
if ((state == LWS_CLIENT_CONNECTING) &&
(*it)->m_wsi == wsi) ap = *it;
}
for (auto it = toRemove.begin(); it != toRemove.end(); ++it)
pendingConnects.remove(*it);
if (ap) {
pendingConnects.remove(ap);
}
return ap;
}
AudioPipe* AudioPipe::findPendingConnect(struct lws *wsi) {
AudioPipe* ap = NULL;
std::lock_guard<std::mutex> guard(mutex_connects);
for (auto it = pendingConnects.begin(); it != pendingConnects.end() && !ap; ++it) {
int state = (*it)->m_state;
if ((state == LWS_CLIENT_CONNECTING) &&
(*it)->m_wsi == wsi) ap = *it;
}
return ap;
}
void AudioPipe::addPendingConnect(AudioPipe* ap) {
{
std::lock_guard<std::mutex> guard(mutex_connects);
pendingConnects.push_back(ap);
lwsl_debug("%s after adding connect there are %lu pending connects\n",
ap->m_uuid.c_str(), pendingConnects.size());
}
lws_cancel_service(contexts[nchild++ % numContexts]);
}
void AudioPipe::addPendingDisconnect(AudioPipe* ap) {
ap->m_state = LWS_CLIENT_DISCONNECTING;
{
std::lock_guard<std::mutex> guard(mutex_disconnects);
pendingDisconnects.push_back(ap);
lwsl_debug("%s after adding disconnect there are %lu pending disconnects\n",
ap->m_uuid.c_str(), pendingDisconnects.size());
}
lws_cancel_service(ap->m_vhd->context);
}
void AudioPipe::addPendingWrite(AudioPipe* ap) {
{
std::lock_guard<std::mutex> guard(mutex_writes);
pendingWrites.push_back(ap);
}
lws_cancel_service(ap->m_vhd->context);
}
void AudioPipe::binaryWritePtrResetToZero(void) {
m_audio_buffer_write_offset = LWS_PRE + AWS_PRELUDE_PLUS_HDRS_LEN;
}
bool AudioPipe::lws_service_thread(unsigned int nServiceThread) {
struct lws_context_creation_info info;
std::thread::id this_id = std::this_thread::get_id();
const struct lws_protocols protocols[] = {
{
"",
AudioPipe::aws_lws_callback,
sizeof(void *),
1024,
},
{ NULL, NULL, 0, 0 }
};
memset(&info, 0, sizeof info);
info.port = CONTEXT_PORT_NO_LISTEN;
info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
info.protocols = protocols;
info.ka_time = nTcpKeepaliveSecs; // tcp keep-alive timer
info.ka_probes = 4; // number of times to try ka before closing connection
info.ka_interval = 5; // time between ka's
info.timeout_secs = 10; // doc says timeout for "various processes involving network roundtrips"
info.keepalive_timeout = 5; // seconds to allow remote client to hold on to an idle HTTP/1.1 connection
info.timeout_secs_ah_idle = 10; // secs to allow a client to hold an ah without using it
info.retry_and_idle_policy = &retry;
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_INFO,"AudioPipe::lws_service_thread creating context\n");
contexts[nServiceThread] = lws_create_context(&info);
if (!contexts[nServiceThread]) {
lwsl_err("AudioPipe::lws_service_thread failed creating context in service thread %d..\n", nServiceThread);
return false;
}
int n;
do {
n = lws_service(contexts[nServiceThread], 0);
} while (n >= 0 && !stopFlags[this_id]);
// Cleanup once work is done or stopped
{
std::lock_guard<std::mutex> lock(mapMutex);
stopFlags.erase(this_id);
}
lwsl_notice("AudioPipe::lws_service_thread ending in service thread %d\n", nServiceThread);
return true;
}
void AudioPipe::initialize(unsigned int nThreads, int loglevel, log_emit_function logger) {
assert(nThreads > 0 && nThreads <= 10);
numContexts = nThreads;
lws_set_log_level(loglevel, logger);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE,"AudioPipe::initialize starting\n");
for (unsigned int i = 0; i < numContexts; i++) {
std::lock_guard<std::mutex> lock(mapMutex);
std::thread t(&AudioPipe::lws_service_thread, i);
stopFlags[t.get_id()] = false;
threadIds.push(t.get_id());
t.detach();
}
}
bool AudioPipe::deinitialize() {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE,"AudioPipe::deinitialize\n");
std::lock_guard<std::mutex> lock(mapMutex);
if (!threadIds.empty()) {
std::thread::id id = threadIds.front();
threadIds.pop();
stopFlags[id] = true;
}
for (unsigned int i = 0; i < numContexts; i++)
{
lwsl_notice("AudioPipe::deinitialize destroying context %d of %d\n", i + 1, numContexts);
lws_context_destroy(contexts[i]);
}
std::this_thread::sleep_for(std::chrono::seconds(2));
return true;
}
// instance members
AudioPipe::AudioPipe(const char* uuid, const char* bugname, const char* host, unsigned int port, const char* path,
size_t bufLen, size_t minFreespace, notifyHandler_t callback) :
m_uuid(uuid), m_host(host), m_port(port), m_path(path), m_finished(false), m_bugname(bugname),
m_audio_buffer_min_freespace(minFreespace), m_audio_buffer_max_len(bufLen), m_gracefulShutdown(false),
m_recv_buf(nullptr), m_recv_buf_ptr(nullptr),
m_state(LWS_CLIENT_IDLE), m_wsi(nullptr), m_vhd(nullptr), m_callback(callback) {
char headerBuffer[88];
char* buffer = headerBuffer;
m_audio_buffer = new uint8_t[m_audio_buffer_max_len];
// stamp out the template for the prelude and headers
memset(aws_prelude_and_headers, 0, AWS_PRELUDE_PLUS_HDRS_LEN);
// aws_prelude_and_headers[0..3] = total byte length (not known till message send time)
// aws_prelude_and_headers[4..7] = headers byte length
uint32_t headerLen = sizeof(headerBuffer);
headerLen = htonl(headerLen);
memcpy(&aws_prelude_and_headers[4], &headerLen, sizeof(uint32_t));
// aws_prelude_and_headers[8..11] = prelude crc (not known till message send time)
// aws_prelude_and_headers[12..99] = headers
TranscribeManager::writeHeader(&buffer, ":content-type", "application/octet-stream");
TranscribeManager::writeHeader(&buffer, ":event-type", "AudioEvent");
TranscribeManager::writeHeader(&buffer, ":message-type", "event");
memcpy(&aws_prelude_and_headers[12], headerBuffer, sizeof(headerBuffer));
// following this will be the audio data and a final message CRC (not known till message send time)
memcpy(m_audio_buffer + LWS_PRE, aws_prelude_and_headers, AWS_PRELUDE_PLUS_HDRS_LEN);
m_audio_buffer_write_offset = LWS_PRE + AWS_PRELUDE_PLUS_HDRS_LEN;
//writeToFile((const char *) m_audio_buffer + LWS_PRE, AWS_PRELUDE_PLUS_HDRS_LEN);
}
AudioPipe::~AudioPipe() {
if (m_audio_buffer) delete [] m_audio_buffer;
if (m_recv_buf) delete [] m_recv_buf;
}
void AudioPipe::connect(void) {
addPendingConnect(this);
}
bool AudioPipe::connect_client(struct lws_per_vhost_data *vhd) {
assert(m_audio_buffer != nullptr);
assert(m_vhd == nullptr);
struct lws_client_connect_info i;
memset(&i, 0, sizeof(i));
i.context = vhd->context;
i.port = m_port;
i.address = m_host.c_str();
i.path = m_path.c_str();
i.host = i.address;
i.origin = i.address;
i.ssl_connection = LCCSCF_USE_SSL;
//i.protocol = protocolName.c_str();
i.pwsi = &(m_wsi);
m_state = LWS_CLIENT_CONNECTING;
m_vhd = vhd;
m_wsi = lws_client_connect_via_info(&i);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG,"%s attempting connection, wsi is %p\n", m_uuid.c_str(), m_wsi);
return nullptr != m_wsi;
}
void AudioPipe::bufferForSending(const char* text) {
if (m_state != LWS_CLIENT_CONNECTED) return;
{
std::lock_guard<std::mutex> lk(m_text_mutex);
m_metadata.append(text);
}
addPendingWrite(this);
}
void AudioPipe::unlockAudioBuffer() {
if (m_audio_buffer_write_offset > LWS_PRE) addPendingWrite(this);
m_audio_mutex.unlock();
}
void AudioPipe::close() {
if (m_state != LWS_CLIENT_CONNECTED) return;
addPendingDisconnect(this);
}
void AudioPipe::finish() {
if (m_finished || m_state != LWS_CLIENT_CONNECTED) return;
m_finished = true;
addPendingWrite(this);
}
void AudioPipe::waitForClose() {
std::shared_future<void> sf(m_promise.get_future());
sf.wait();
return;
}

View File

@@ -0,0 +1,144 @@
#ifndef __AWS_AUDIO_PIPE_HPP__
#define __AWS_AUDIO_PIPE_HPP__
#include <string>
#include <list>
#include <mutex>
#include <future>
#include <queue>
#include <unordered_map>
#include <thread>
#include <libwebsockets.h>
namespace aws {
class AudioPipe {
public:
enum LwsState_t {
LWS_CLIENT_IDLE,
LWS_CLIENT_CONNECTING,
LWS_CLIENT_CONNECTED,
LWS_CLIENT_FAILED,
LWS_CLIENT_DISCONNECTING,
LWS_CLIENT_DISCONNECTED
};
enum NotifyEvent_t {
CONNECT_SUCCESS,
CONNECT_FAIL,
CONNECTION_DROPPED,
CONNECTION_CLOSED_GRACEFULLY,
MESSAGE
};
typedef void (*log_emit_function)(int level, const char *line);
typedef void (*notifyHandler_t)(const char *sessionId, const char* bugname, NotifyEvent_t event, const char* message, bool finished);
struct lws_per_vhost_data {
struct lws_context *context;
struct lws_vhost *vhost;
const struct lws_protocols *protocol;
};
static void initialize(unsigned int nThreads, int loglevel, log_emit_function logger);
static bool deinitialize();
static bool lws_service_thread(unsigned int nServiceThread);
// constructor
AudioPipe(const char* uuid, const char* bugname, const char* host, unsigned int port, const char* path,
size_t bufLen, size_t minFreespace, notifyHandler_t callback);
~AudioPipe();
LwsState_t getLwsState(void) { return m_state; }
std::string& getApiKey(void) {
return m_apiKey;
}
void connect(void);
void bufferForSending(const char* text);
size_t binarySpaceAvailable(void) {
return m_audio_buffer_max_len - m_audio_buffer_write_offset;
}
size_t binaryMinSpace(void) {
return m_audio_buffer_min_freespace;
}
char * binaryWritePtr(void) {
return (char *) m_audio_buffer + m_audio_buffer_write_offset;
}
void binaryWritePtrAdd(size_t len) {
m_audio_buffer_write_offset += len;
}
void binaryWritePtrResetToZero(void);
void lockAudioBuffer(void) {
m_audio_mutex.lock();
}
void unlockAudioBuffer(void) ;
void close() ;
void finish();
void waitForClose();
void setClosed() { m_promise.set_value(); }
bool isFinished() { return m_finished;}
// no default constructor or copying
AudioPipe() = delete;
AudioPipe(const AudioPipe&) = delete;
void operator=(const AudioPipe&) = delete;
private:
static int aws_lws_callback(struct lws *wsi, enum lws_callback_reasons reason, void *user, void *in, size_t len);
static unsigned int nchild;
static struct lws_context *contexts[];
static unsigned int numContexts;
static std::string protocolName;
static std::mutex mutex_connects;
static std::mutex mutex_disconnects;
static std::mutex mutex_writes;
static std::list<AudioPipe*> pendingConnects;
static std::list<AudioPipe*> pendingDisconnects;
static std::list<AudioPipe*> pendingWrites;
static log_emit_function logger;
static std::mutex mapMutex;
static std::unordered_map<std::thread::id, bool> stopFlags;
static std::queue<std::thread::id> threadIds;
static AudioPipe* findAndRemovePendingConnect(struct lws *wsi);
static AudioPipe* findPendingConnect(struct lws *wsi);
static void addPendingConnect(AudioPipe* ap);
static void addPendingDisconnect(AudioPipe* ap);
static void addPendingWrite(AudioPipe* ap);
static void processPendingConnects(lws_per_vhost_data *vhd);
static void processPendingDisconnects(lws_per_vhost_data *vhd);
static void processPendingWrites(void);
bool connect_client(struct lws_per_vhost_data *vhd);
LwsState_t m_state;
std::string m_uuid;
std::string m_host;
unsigned int m_port;
std::string m_path;
std::string m_metadata;
std::mutex m_text_mutex;
std::mutex m_audio_mutex;
int m_sslFlags;
struct lws *m_wsi;
uint8_t *m_audio_buffer;
size_t m_audio_buffer_max_len;
size_t m_audio_buffer_write_offset;
size_t m_audio_buffer_min_freespace;
uint8_t* m_recv_buf;
uint8_t* m_recv_buf_ptr;
size_t m_recv_buf_len;
struct lws_per_vhost_data* m_vhd;
notifyHandler_t m_callback;
log_emit_function m_logger;
std::string m_apiKey;
bool m_gracefulShutdown;
bool m_finished;
std::string m_bugname;
std::promise<void> m_promise;
};
} // namespace deepgram
#endif

View File

@@ -0,0 +1,415 @@
#include <switch.h>
#include <switch_json.h>
#include <string.h>
#include <string>
#include <mutex>
#include <thread>
#include <list>
#include <algorithm>
#include <functional>
#include <cassert>
#include <cstdlib>
#include <fstream>
#include <sstream>
#include <regex>
#include <iostream>
#include <unordered_map>
#include "mod_aws_transcribe_ws.h"
#include "simple_buffer.h"
//#include "parser.hpp"
#include "audio_pipe.hpp"
#include "transcribe_manager.hpp"
#define RTP_PACKETIZATION_PERIOD 20
#define FRAME_SIZE_8000 320 /*which means each 20ms frame as 320 bytes at 8 khz (1 channel only)*/
namespace {
static bool hasDefaultCredentials = false;
static const char* defaultApiKey = nullptr;
static const char *requestedBufferSecs = std::getenv("MOD_AUDIO_FORK_BUFFER_SECS");
static int nAudioBufferSecs = std::max(1, std::min(requestedBufferSecs ? ::atoi(requestedBufferSecs) : 2, 5));
static const char *requestedNumServiceThreads = std::getenv("MOD_AUDIO_FORK_SERVICE_THREADS");
static unsigned int nServiceThreads = std::max(1, std::min(requestedNumServiceThreads ? ::atoi(requestedNumServiceThreads) : 1, 5));
static unsigned int idxCallCount = 0;
static uint32_t playCount = 0;
static const char* emptyTranscript = "{\"Transcript\":{\"Results\":[]}}";
static const char* messageStart = "{\"Message\":";
static void reaper(private_t *tech_pvt) {
std::shared_ptr<aws::AudioPipe> pAp;
pAp.reset((aws::AudioPipe *)tech_pvt->pAudioPipe);
tech_pvt->pAudioPipe = nullptr;
std::thread t([pAp, tech_pvt]{
pAp->finish();
pAp->waitForClose();
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "%s (%u) got remote close\n", tech_pvt->sessionId, tech_pvt->id);
});
t.detach();
}
static void destroy_tech_pvt(private_t *tech_pvt) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_INFO, "%s (%u) destroy_tech_pvt\n", tech_pvt->sessionId, tech_pvt->id);
if (tech_pvt) {
if (tech_pvt->pAudioPipe) {
aws::AudioPipe* p = (aws::AudioPipe *) tech_pvt->pAudioPipe;
delete p;
tech_pvt->pAudioPipe = nullptr;
}
if (tech_pvt->resampler) {
speex_resampler_destroy(tech_pvt->resampler);
tech_pvt->resampler = NULL;
}
if (tech_pvt->vad) {
switch_vad_destroy(&tech_pvt->vad);
tech_pvt->vad = nullptr;
}
}
}
static void eventCallback(const char* sessionId, const char* bugname,
aws::AudioPipe::NotifyEvent_t event, const char* message, bool finished) {
switch_core_session_t* session = switch_core_session_locate(sessionId);
if (session) {
switch_channel_t *channel = switch_core_session_get_channel(session);
switch_media_bug_t *bug = (switch_media_bug_t*) switch_channel_get_private(channel, bugname);
if (bug) {
private_t* tech_pvt = (private_t*) switch_core_media_bug_get_user_data(bug);
if (tech_pvt) {
switch (event) {
case aws::AudioPipe::CONNECT_SUCCESS:
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_INFO, "connection successful\n");
tech_pvt->responseHandler(session, TRANSCRIBE_EVENT_CONNECT_SUCCESS, NULL, tech_pvt->bugname, finished);
break;
case aws::AudioPipe::CONNECT_FAIL:
{
// first thing: we can no longer access the AudioPipe
std::stringstream json;
json << "{\"reason\":\"" << message << "\"}";
tech_pvt->pAudioPipe = nullptr;
tech_pvt->responseHandler(session, TRANSCRIBE_EVENT_CONNECT_FAIL, (char *) json.str().c_str(), tech_pvt->bugname, finished);
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_NOTICE, "connection failed: %s\n", message);
}
break;
case aws::AudioPipe::CONNECTION_DROPPED:
// first thing: we can no longer access the AudioPipe
tech_pvt->pAudioPipe = nullptr;
tech_pvt->responseHandler(session, TRANSCRIBE_EVENT_DISCONNECT, NULL, tech_pvt->bugname, finished);
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "connection dropped from far end\n");
break;
case aws::AudioPipe::CONNECTION_CLOSED_GRACEFULLY:
// first thing: we can no longer access the AudioPipe
tech_pvt->pAudioPipe = nullptr;
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "connection closed gracefully\n");
break;
case aws::AudioPipe::MESSAGE:
if( strstr(message, emptyTranscript)) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "discarding empty aws transcript\n");
}
else if (0 == strncmp( message, messageStart, strlen(messageStart))) {
tech_pvt->responseHandler(session, TRANSCRIBE_EVENT_ERROR, message, tech_pvt->bugname, finished);
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "error message from aws: %s\n", message);
}
else {
tech_pvt->responseHandler(session, TRANSCRIBE_EVENT_RESULTS, message, tech_pvt->bugname, finished);
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "aws message: %s.\n", message);
}
break;
default:
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_NOTICE, "got unexpected msg from aws %d:%s\n", event, message);
break;
}
}
}
switch_core_session_rwunlock(session);
}
}
void lws_logger(int level, const char *line) {
switch_log_level_t llevel = SWITCH_LOG_DEBUG;
switch (level) {
case LLL_ERR: llevel = SWITCH_LOG_ERROR; break;
case LLL_WARN: llevel = SWITCH_LOG_WARNING; break;
case LLL_NOTICE: llevel = SWITCH_LOG_NOTICE; break;
case LLL_INFO: llevel = SWITCH_LOG_INFO; break;
break;
}
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE, "%s\n", line);
}
}
extern "C" {
switch_status_t aws_transcribe_init() {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE, "mod_aws_transcribe: audio buffer (in secs): %d secs\n", nAudioBufferSecs);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE, "mod_aws_transcribe: lws service threads: %d\n", nServiceThreads);
int logs = LLL_ERR | LLL_WARN | LLL_NOTICE || LLL_INFO | LLL_PARSER | LLL_HEADER | LLL_EXT | LLL_CLIENT | LLL_LATENCY | LLL_DEBUG ;
aws::AudioPipe::initialize(nServiceThreads, logs, lws_logger);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE, "AudioPipe::initialize completed\n");
return SWITCH_STATUS_SUCCESS;
}
switch_status_t aws_transcribe_cleanup() {
bool cleanup = false;
cleanup = aws::AudioPipe::deinitialize();
if (cleanup == true) {
return SWITCH_STATUS_SUCCESS;
}
return SWITCH_STATUS_FALSE;
}
// start transcribe on a channel
switch_status_t aws_transcribe_session_init(switch_core_session_t *session, responseHandler_t responseHandler,
uint32_t samples_per_second, uint32_t channels, char* lang, int interim, char* bugname, void **ppUserData
) {
switch_status_t status = SWITCH_STATUS_SUCCESS;
switch_channel_t *channel = switch_core_session_get_channel(session);
int err;
uint32_t desiredSampling = 8000;
switch_threadattr_t *thd_attr = NULL;
switch_memory_pool_t *pool = switch_core_session_get_pool(session);
auto read_codec = switch_core_session_get_read_codec(session);
uint32_t sampleRate = read_codec->implementation->actual_samples_per_second;
switch_codec_implementation_t read_impl;
switch_core_session_get_read_impl(session, &read_impl);
private_t* tech_pvt = (private_t *) switch_core_session_alloc(session, sizeof(private_t));
memset(tech_pvt, sizeof(tech_pvt), 0);
const char* awsAccessKeyId = switch_channel_get_variable(channel, "AWS_ACCESS_KEY_ID");
const char* awsSecretAccessKey = switch_channel_get_variable(channel, "AWS_SECRET_ACCESS_KEY");
const char* awsRegion = switch_channel_get_variable(channel, "AWS_REGION");
const char* awsSessionToken = switch_channel_get_variable(channel, "AWS_SECURITY_TOKEN");
tech_pvt->channels = channels;
strncpy(tech_pvt->sessionId, switch_core_session_get_uuid(session), MAX_SESSION_ID);
strncpy(tech_pvt->bugname, bugname, MAX_BUG_LEN);
if (awsAccessKeyId && awsSecretAccessKey && awsRegion && awsSessionToken) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "Using channel vars for aws authentication\n");
strncpy(tech_pvt->awsAccessKeyId, awsAccessKeyId, 128);
strncpy(tech_pvt->awsSecretAccessKey, awsSecretAccessKey, 128);
strncpy(tech_pvt->awsSessionToken, awsSessionToken, MAX_SESSION_TOKEN_LEN);
strncpy(tech_pvt->region, awsRegion, MAX_REGION);
}
else if (std::getenv("AWS_ACCESS_KEY_ID") &&
std::getenv("AWS_SECRET_ACCESS_KEY") &&
std::getenv("AWS_REGION")) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "Using env vars for aws authentication\n");
strncpy(tech_pvt->awsAccessKeyId, std::getenv("AWS_ACCESS_KEY_ID"), 128);
strncpy(tech_pvt->awsSecretAccessKey, std::getenv("AWS_SECRET_ACCESS_KEY"), 128);
strncpy(tech_pvt->region, std::getenv("AWS_REGION"), MAX_REGION);
}
else {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "No channel vars or env vars for aws authentication..will use default profile if found\n");
}
tech_pvt->responseHandler = responseHandler;
tech_pvt->interim = interim;
strncpy(tech_pvt->lang, lang, MAX_LANG);
tech_pvt->samples_per_second = sampleRate;
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "sample rate of rtp stream is %d\n", samples_per_second);
const char* vocabularyName = switch_channel_get_variable(channel, "AWS_VOCABULARY_NAME");
const char* vocabularyFilterName = switch_channel_get_variable(channel, "AWS_VOCABULARY_FILTER_NAME");
const char* vocabularyFilterMethod = switch_channel_get_variable(channel, "AWS_VOCABULARY_FILTER_METHOD");
const char* piiEntityTypes = switch_channel_get_variable(channel, "AWS_PII_ENTITY_TYPES");
int shouldIdentifyPII = switch_true(switch_channel_get_variable(channel, "AWS_PII_IDENTIFY_ENTITIES"));
const char* languageModelName = switch_channel_get_variable(channel, "AWS_LANGUAGE_MODEL_NAME");
std::string host, path;
TranscribeManager::getSignedWebsocketUrl(
host,
path,
tech_pvt->awsAccessKeyId,
tech_pvt->awsSecretAccessKey,
tech_pvt->awsSessionToken,
tech_pvt->region,
lang,
vocabularyName,
vocabularyFilterName,
vocabularyFilterMethod,
piiEntityTypes,
shouldIdentifyPII,
languageModelName
);
host = host.substr(0, host.find(':'));
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "connecting to host %s, path %s\n", host.c_str(), path.c_str());
strncpy(tech_pvt->sessionId, switch_core_session_get_uuid(session), MAX_SESSION_ID);
strncpy(tech_pvt->host, host.c_str(), MAX_WS_URL_LEN);
tech_pvt->port = 8443;
strncpy(tech_pvt->path, path.c_str(), MAX_PATH_LEN);
tech_pvt->responseHandler = responseHandler;
tech_pvt->channels = channels;
tech_pvt->id = ++idxCallCount;
tech_pvt->buffer_overrun_notified = 0;
size_t buflen = LWS_PRE + (FRAME_SIZE_8000 * desiredSampling / 8000 * channels * 1000 / RTP_PACKETIZATION_PERIOD * nAudioBufferSecs);
aws::AudioPipe* ap = new aws::AudioPipe(tech_pvt->sessionId, tech_pvt->bugname, tech_pvt->host, tech_pvt->port, tech_pvt->path,
buflen, read_impl.decoded_bytes_per_packet, eventCallback);
if (!ap) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "Error allocating AudioPipe\n");
return SWITCH_STATUS_FALSE;
}
tech_pvt->pAudioPipe = static_cast<void *>(ap);
if (switch_mutex_init(&tech_pvt->mutex, SWITCH_MUTEX_NESTED, pool) != SWITCH_STATUS_SUCCESS) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "Error initializing mutex\n");
status = SWITCH_STATUS_FALSE;
goto done;
}
if (sampleRate != 8000) {
tech_pvt->resampler = speex_resampler_init(1, sampleRate, 16000, SWITCH_RESAMPLE_QUALITY, &err);
if (0 != err) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "%s: Error initializing resampler: %s.\n",
switch_channel_get_name(channel), speex_resampler_strerror(err));
status = SWITCH_STATUS_FALSE;
goto done;
}
}
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "connecting now\n");
ap->connect();
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "connection in progress\n");
*ppUserData = tech_pvt;
done:
return status;
}
switch_status_t aws_transcribe_session_stop(switch_core_session_t *session,int channelIsClosing, char* bugname) {
switch_channel_t *channel = switch_core_session_get_channel(session);
switch_media_bug_t *bug = (switch_media_bug_t*) switch_channel_get_private(channel, MY_BUG_NAME);
if (!bug) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "aws_transcribe_session_stop: no bug - websocket conection already closed\n");
return SWITCH_STATUS_FALSE;
}
private_t* tech_pvt = (private_t*) switch_core_media_bug_get_user_data(bug);
uint32_t id = tech_pvt->id;
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "(%u) aws_transcribe_session_stop\n", id);
if (!tech_pvt) return SWITCH_STATUS_FALSE;
// close connection and get final responses
switch_mutex_lock(tech_pvt->mutex);
switch_channel_set_private(channel, bugname, NULL);
if (!channelIsClosing) switch_core_media_bug_remove(session, &bug);
aws::AudioPipe *pAudioPipe = static_cast<aws::AudioPipe *>(tech_pvt->pAudioPipe);
if (pAudioPipe) reaper(tech_pvt);
destroy_tech_pvt(tech_pvt);
switch_mutex_unlock(tech_pvt->mutex);
switch_mutex_destroy(tech_pvt->mutex);
tech_pvt->mutex = nullptr;
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_DEBUG, "(%u) aws_transcribe_session_stop\n", id);
return SWITCH_STATUS_SUCCESS;
}
switch_bool_t aws_transcribe_frame(switch_media_bug_t *bug, void* user_data) {
switch_core_session_t *session = switch_core_media_bug_get_session(bug);
private_t* tech_pvt = (private_t*) switch_core_media_bug_get_user_data(bug);
size_t inuse = 0;
bool dirty = false;
char *p = (char *) "{\"msg\": \"buffer overrun\"}";
if (!tech_pvt) return SWITCH_TRUE;
if (switch_mutex_trylock(tech_pvt->mutex) == SWITCH_STATUS_SUCCESS) {
if (!tech_pvt->pAudioPipe) {
switch_mutex_unlock(tech_pvt->mutex);
return SWITCH_TRUE;
}
aws::AudioPipe *pAudioPipe = static_cast<aws::AudioPipe *>(tech_pvt->pAudioPipe);
if (pAudioPipe->getLwsState() != aws::AudioPipe::LWS_CLIENT_CONNECTED) {
switch_mutex_unlock(tech_pvt->mutex);
return SWITCH_TRUE;
}
pAudioPipe->lockAudioBuffer();
size_t available = pAudioPipe->binarySpaceAvailable();
if (NULL == tech_pvt->resampler) {
switch_frame_t frame = { 0 };
frame.data = pAudioPipe->binaryWritePtr();
frame.buflen = available;
while (true) {
// check if buffer would be overwritten; dump packets if so
if (available < pAudioPipe->binaryMinSpace()) {
if (!tech_pvt->buffer_overrun_notified) {
tech_pvt->buffer_overrun_notified = 1;
tech_pvt->responseHandler(session, TRANSCRIBE_EVENT_BUFFER_OVERRUN, NULL, tech_pvt->bugname, 0);
}
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "(%u) dropping packets!\n",
tech_pvt->id);
pAudioPipe->binaryWritePtrResetToZero();
frame.data = pAudioPipe->binaryWritePtr();
frame.buflen = available = pAudioPipe->binarySpaceAvailable();
}
switch_status_t rv = switch_core_media_bug_read(bug, &frame, SWITCH_TRUE);
if (rv != SWITCH_STATUS_SUCCESS) break;
if (frame.datalen) {
pAudioPipe->binaryWritePtrAdd(frame.datalen);
frame.buflen = available = pAudioPipe->binarySpaceAvailable();
frame.data = pAudioPipe->binaryWritePtr();
dirty = true;
}
}
}
else {
uint8_t data[SWITCH_RECOMMENDED_BUFFER_SIZE];
switch_frame_t frame = { 0 };
frame.data = data;
frame.buflen = SWITCH_RECOMMENDED_BUFFER_SIZE;
while (switch_core_media_bug_read(bug, &frame, SWITCH_TRUE) == SWITCH_STATUS_SUCCESS) {
if (frame.datalen) {
spx_uint32_t out_len = available >> 1; // space for samples which are 2 bytes
spx_uint32_t in_len = frame.samples;
speex_resampler_process_interleaved_int(tech_pvt->resampler,
(const spx_int16_t *) frame.data,
(spx_uint32_t *) &in_len,
(spx_int16_t *) ((char *) pAudioPipe->binaryWritePtr()),
&out_len);
if (out_len > 0) {
// bytes written = num samples * 2 * num channels
size_t bytes_written = out_len << tech_pvt->channels;
pAudioPipe->binaryWritePtrAdd(bytes_written);
available = pAudioPipe->binarySpaceAvailable();
dirty = true;
}
if (available < pAudioPipe->binaryMinSpace()) {
if (!tech_pvt->buffer_overrun_notified) {
tech_pvt->buffer_overrun_notified = 1;
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "(%u) dropping packets!\n",
tech_pvt->id);
tech_pvt->responseHandler(session, TRANSCRIBE_EVENT_BUFFER_OVERRUN, NULL, tech_pvt->bugname, 0);
}
break;
}
}
}
}
pAudioPipe->unlockAudioBuffer();
switch_mutex_unlock(tech_pvt->mutex);
}
return SWITCH_TRUE;
}
}

View File

@@ -0,0 +1,11 @@
#ifndef __AWS_GLUE_H__
#define __AWS_GLUE_H__
switch_status_t aws_transcribe_init();
switch_status_t aws_transcribe_cleanup();
switch_status_t aws_transcribe_session_init(switch_core_session_t *session, responseHandler_t responseHandler,
uint32_t samples_per_second, uint32_t channels, char* lang, int interim, char* bugname, void **ppUserData);
switch_status_t aws_transcribe_session_stop(switch_core_session_t *session, int channelIsClosing, char* bugname);
switch_bool_t aws_transcribe_frame(switch_media_bug_t *bug, void* user_data);
#endif

2114
mod_aws_transcribe_ws/crc.h Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,212 @@
/*
*
* mod_aws_transcribe.c -- Freeswitch module for using aws streaming transcribe api
*
*/
#include "mod_aws_transcribe_ws.h"
#include "aws_transcribe_glue.h"
/* Prototypes */
SWITCH_MODULE_SHUTDOWN_FUNCTION(mod_aws_transcribe_ws_shutdown);
SWITCH_MODULE_LOAD_FUNCTION(mod_aws_transcribe_ws_load);
SWITCH_MODULE_DEFINITION(mod_aws_transcribe_ws, mod_aws_transcribe_ws_load, mod_aws_transcribe_ws_shutdown, NULL);
static switch_status_t do_stop(switch_core_session_t *session, char* bugname);
static void responseHandler(switch_core_session_t* session,
const char* eventName, const char * json, const char* bugname, int finished) {
switch_event_t *event;
switch_channel_t *channel = switch_core_session_get_channel(session);
switch_event_create_subclass(&event, SWITCH_EVENT_CUSTOM, eventName);
switch_channel_event_set_data(channel, event);
switch_event_add_header_string(event, SWITCH_STACK_BOTTOM, "transcription-vendor", "deepgram");
switch_event_add_header_string(event, SWITCH_STACK_BOTTOM, "transcription-session-finished", finished ? "true" : "false");
if (finished) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "responseHandler returning event %s, from finished recognition session\n", eventName);
}
if (json) switch_event_add_body(event, "%s", json);
if (bugname) switch_event_add_header_string(event, SWITCH_STACK_BOTTOM, "media-bugname", bugname);
switch_event_fire(&event);
}
static switch_bool_t capture_callback(switch_media_bug_t *bug, void *user_data, switch_abc_type_t type)
{
switch_core_session_t *session = switch_core_media_bug_get_session(bug);
switch (type) {
case SWITCH_ABC_TYPE_INIT:
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "Got SWITCH_ABC_TYPE_INIT.\n");
break;
case SWITCH_ABC_TYPE_CLOSE:
{
private_t *tech_pvt = (private_t*) switch_core_media_bug_get_user_data(bug);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "Got SWITCH_ABC_TYPE_CLOSE.\n");
aws_transcribe_session_stop(session, 1, tech_pvt->bugname);
//switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "Finished SWITCH_ABC_TYPE_CLOSE.\n");
}
break;
case SWITCH_ABC_TYPE_READ:
return aws_transcribe_frame(bug, user_data);
break;
case SWITCH_ABC_TYPE_WRITE:
default:
break;
}
return SWITCH_TRUE;
}
static switch_status_t start_capture(switch_core_session_t *session, switch_media_bug_flag_t flags,
char* lang, int interim, char* bugname)
{
switch_channel_t *channel = switch_core_session_get_channel(session);
switch_media_bug_t *bug;
switch_status_t status;
switch_codec_implementation_t read_impl = { 0 };
void *pUserData;
uint32_t samples_per_second;
if (switch_channel_get_private(channel, bugname)) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "removing bug from previous transcribe\n");
do_stop(session, bugname);
}
switch_core_session_get_read_impl(session, &read_impl);
if (switch_channel_pre_answer(channel) != SWITCH_STATUS_SUCCESS) {
return SWITCH_STATUS_FALSE;
}
samples_per_second = !strcasecmp(read_impl.iananame, "g722") ? read_impl.actual_samples_per_second : read_impl.samples_per_second;
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, " initializing aws speech session.\n");
if (SWITCH_STATUS_FALSE == aws_transcribe_session_init(session, responseHandler, samples_per_second,
flags & SMBF_STEREO ? 2 : 1, lang, interim, bugname, &pUserData)) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR, "Error initializing aws speech session.\n");
return SWITCH_STATUS_FALSE;
}
if ((status = switch_core_media_bug_add(session, bugname, NULL, capture_callback, pUserData, 0, flags, &bug)) != SWITCH_STATUS_SUCCESS) {
return status;
}
switch_channel_set_private(channel, bugname, bug);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_DEBUG, "added media bug for aws transcribe\n");
return SWITCH_STATUS_SUCCESS;
}
static switch_status_t do_stop(switch_core_session_t *session, char* bugname)
{
switch_status_t status = SWITCH_STATUS_SUCCESS;
switch_channel_t *channel = switch_core_session_get_channel(session);
switch_media_bug_t *bug = switch_channel_get_private(channel, bugname);
if (bug) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_INFO, "Received user command command to stop transcribe on %s.\n", bugname);
status = aws_transcribe_session_stop(session, 0, bugname);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_INFO, "stopped transcribe.\n");
}
return status;
}
#define TRANSCRIBE_API_SYNTAX "<uuid> [start|stop] lang-code [interim] [stereo|mono] [bugname]"
SWITCH_STANDARD_API(aws_transcribe_function)
{
char *mycmd = NULL, *argv[6] = { 0 };
int argc = 0;
switch_status_t status = SWITCH_STATUS_FALSE;
switch_media_bug_flag_t flags = SMBF_READ_STREAM /* | SMBF_WRITE_STREAM | SMBF_READ_PING */;
if (!zstr(cmd) && (mycmd = strdup(cmd))) {
argc = switch_separate_string(mycmd, ' ', argv, (sizeof(argv) / sizeof(argv[0])));
}
if (zstr(cmd) ||
(!strcasecmp(argv[1], "stop") && argc < 2) ||
(!strcasecmp(argv[1], "start") && argc < 3) ||
zstr(argv[0])) {
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_ERROR, "Error with command %s %s %s.\n", cmd, argv[0], argv[1]);
stream->write_function(stream, "-USAGE: %s\n", TRANSCRIBE_API_SYNTAX);
goto done;
} else {
switch_core_session_t *lsession = NULL;
if ((lsession = switch_core_session_locate(argv[0]))) {
if (!strcasecmp(argv[1], "stop")) {
char *bugname = argc > 2 ? argv[2] : MY_BUG_NAME;
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_INFO, "stop transcribing\n");
status = do_stop(lsession, bugname);
} else if (!strcasecmp(argv[1], "start")) {
char* lang = argv[2];
int interim = argc > 3 && !strcmp(argv[3], "interim");
char *bugname = argc > 5 ? argv[5] : MY_BUG_NAME;
if (argc > 4 && !strcmp(argv[4], "stereo")) {
flags |= SMBF_WRITE_STREAM ;
flags |= SMBF_STEREO;
}
switch_log_printf(SWITCH_CHANNEL_SESSION_LOG(session), SWITCH_LOG_INFO, "start transcribing %s %s %s\n", lang, interim ? "interim": "complete", bugname);
status = start_capture(lsession, flags, lang, interim, bugname);
}
switch_core_session_rwunlock(lsession);
}
}
if (status == SWITCH_STATUS_SUCCESS) {
stream->write_function(stream, "+OK Success\n");
} else {
stream->write_function(stream, "-ERR Operation Failed\n");
}
done:
switch_safe_free(mycmd);
return SWITCH_STATUS_SUCCESS;
}
SWITCH_MODULE_LOAD_FUNCTION(mod_aws_transcribe_ws_load)
{
switch_api_interface_t *api_interface;
/* create/register custom event message type */
if (switch_event_reserve_subclass(TRANSCRIBE_EVENT_RESULTS) != SWITCH_STATUS_SUCCESS) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR, "Couldn't register subclass %s!\n", TRANSCRIBE_EVENT_RESULTS);
return SWITCH_STATUS_TERM;
}
/* connect my internal structure to the blank pointer passed to me */
*module_interface = switch_loadable_module_create_module_interface(pool, modname);
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE, "AWS Speech Transcription API loading..\n");
if (SWITCH_STATUS_FALSE == aws_transcribe_init()) {
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_CRIT, "Failed initializing aws speech interface\n");
}
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_NOTICE, "AWS Speech Transcription API successfully loaded\n");
SWITCH_ADD_API(api_interface, "uuid_aws_transcribe", "AWS Speech Transcription API", aws_transcribe_function, TRANSCRIBE_API_SYNTAX);
switch_console_set_complete("add uuid_aws_transcribe start lang-code [interim|final] [stereo|mono]");
switch_console_set_complete("add uuid_aws_transcribe stop ");
/* indicate that the module should continue to be loaded */
return SWITCH_STATUS_SUCCESS;
}
/*
Called when the system shuts down
Macro expands to: switch_status_t mod_aws_transcribe_ws_shutdown() */
SWITCH_MODULE_SHUTDOWN_FUNCTION(mod_aws_transcribe_ws_shutdown)
{
aws_transcribe_cleanup();
switch_event_free_subclass(TRANSCRIBE_EVENT_RESULTS);
return SWITCH_STATUS_SUCCESS;
}

View File

@@ -0,0 +1,60 @@
#ifndef __MOD_AWS_TRANSCRIBE_WS_H__
#define __MOD_AWS_TRANSCRIBE_WS_H__
#include <switch.h>
#include <speex/speex_resampler.h>
#include <unistd.h>
#define MY_BUG_NAME "aws_transcribe_ws"
#define MAX_BUG_LEN (64)
#define MAX_SESSION_ID (256)
#define TRANSCRIBE_EVENT_RESULTS "aws_transcribe::transcription"
#define TRANSCRIBE_EVENT_END_OF_TRANSCRIPT "aws_transcribe::end_of_transcript"
#define TRANSCRIBE_EVENT_NO_AUDIO_DETECTED "aws_transcribe::no_audio_detected"
#define TRANSCRIBE_EVENT_MAX_DURATION_EXCEEDED "aws_transcribe::max_duration_exceeded"
#define TRANSCRIBE_EVENT_VAD_DETECTED "aws_transcribe::vad_detected"
#define TRANSCRIBE_EVENT_CONNECT_SUCCESS "aws::connect"
#define TRANSCRIBE_EVENT_CONNECT_FAIL "aws::connect_failed"
#define TRANSCRIBE_EVENT_DISCONNECT "aws::disconnect"
#define TRANSCRIBE_EVENT_BUFFER_OVERRUN "aws::buffer_overrun"
#define TRANSCRIBE_EVENT_ERROR "jambonz_transcribe::error"
#define MAX_LANG (12)
#define MAX_REGION (32)
#define MAX_WS_URL_LEN (512)
#define MAX_PATH_LEN (4096)
#define MAX_SESSION_TOKEN_LEN (4096)
/* per-channel data */
typedef void (*responseHandler_t)(switch_core_session_t* session, const char* eventName, const char* json, const char* bugname, int finished);
struct private_data {
switch_mutex_t *mutex;
char sessionId[MAX_SESSION_ID+1];
char awsAccessKeyId[128];
char awsSecretAccessKey[128];
char awsSessionToken[MAX_SESSION_TOKEN_LEN+1];
SpeexResamplerState *resampler;
responseHandler_t responseHandler;
int interim;
char lang[MAX_LANG+1];
char region[MAX_REGION+1];
switch_vad_t * vad;
uint32_t samples_per_second;
void *pAudioPipe;
int ws_state;
char host[MAX_WS_URL_LEN+1];
unsigned int port;
char path[MAX_PATH_LEN+1];
char bugname[MAX_BUG_LEN+1];
int sampling;
int channels;
unsigned int id;
int buffer_overrun_notified:1;
int is_finished:1;
};
typedef struct private_data private_t;
#endif

View File

@@ -0,0 +1,51 @@
/**
* (very) simple and limited circular buffer,
* supporting only the use case of doing all of the adds
* and then subsquently retrieves.
*
*/
class SimpleBuffer {
public:
SimpleBuffer(uint32_t chunkSize, uint32_t numChunks) : numItems(0),
m_numChunks(numChunks), m_chunkSize(chunkSize) {
m_pData = new char[chunkSize * numChunks];
m_pNextWrite = m_pData;
}
~SimpleBuffer() {
delete [] m_pData;
}
void add(void *data, uint32_t datalen) {
if (datalen % m_chunkSize != 0) return;
int numChunks = datalen / m_chunkSize;
for (int i = 0; i < numChunks; i++) {
memcpy(m_pNextWrite, data, m_chunkSize);
data = static_cast<char*>(data) + m_chunkSize;
if (numItems < m_numChunks) numItems++;
uint32_t offset = (m_pNextWrite - m_pData) / m_chunkSize;
if (offset >= m_numChunks - 1) m_pNextWrite = m_pData;
else m_pNextWrite += m_chunkSize;
}
}
char* getNextChunk() {
if (numItems--) {
char *p = m_pNextWrite;
uint32_t offset = (m_pNextWrite - m_pData) / m_chunkSize;
if (offset >= m_numChunks - 1) m_pNextWrite = m_pData;
else m_pNextWrite += m_chunkSize;
return p;
}
return nullptr;
}
uint32_t getNumItems() { return numItems;}
private:
char *m_pData;
uint32_t numItems;
uint32_t m_chunkSize;
uint32_t m_numChunks;
char* m_pNextWrite;
};

View File

@@ -0,0 +1,324 @@
#include "transcribe_manager.hpp"
#include "crc.h"
#include <switch.h>
#include <openssl/sha.h>
#include <openssl/hmac.h>
#include <iomanip>
#include <regex>
#include <iostream>
#include <cstring>
#include <netinet/in.h>
using namespace std;
namespace {
std::string uri_encode(const std::string &value) {
std::string encoded;
char hex[4];
for (char c : value) {
if (isalnum(c) || c == '-' || c == '_' || c == '.' || c == '~') {
encoded += c;
} else {
sprintf(hex, "%%%02X", c);
encoded.append(hex);
}
}
return encoded;
}
}
// see
// https://docs.aws.amazon.com/transcribe/latest/dg/websocket.html#websocket-url
// https://docs.aws.amazon.com/transcribe/latest/dg/event-stream.html
void TranscribeManager::getSignedWebsocketUrl(string& host, string& path, const string& accessKey,
const string& secretKey, const string& securityToken, const string& region, const std::string& lang,
const char* vocabularyName, const char* vocabularyFilterName, const char* vocabularyFilterMethod,
const char* piiEntities, int shouldIdentifyPiiEntities, const char* languageModelName) {
string method = "GET";
string service = "transcribe";
string endpoint = "wss://transcribestreaming." + region + ".amazonaws.com";
host = "transcribestreaming." + region + ".amazonaws.com";
time_t now = time(0);
tm *gmtm = gmtime(&now);
char amzDate[21];
snprintf (amzDate, 21, "%04d%02d%02dT%02d%02d%02dZ",
1900 + gmtm->tm_year, 1 + gmtm->tm_mon, gmtm->tm_mday,
gmtm->tm_hour, gmtm->tm_min, gmtm->tm_sec);
char datestamp[9];
snprintf (datestamp, 9, "%04d%02d%02d", 1900 + gmtm->tm_year, 1 + gmtm->tm_mon, gmtm->tm_mday);
string canonical_uri = "/stream-transcription-websocket";
string canonical_headers = "host:" + host + "\n";
string signed_headers = "host";
string algorithm = "AWS4-HMAC-SHA256";
string credential_scope = string(datestamp) + "%2F" + region + "%2F" + service + "%2F" + "aws4_request";
// N.B.: The order of all of these query args are important!
// Otherwise, the signature will be invalid.
string canonical_querystring = "X-Amz-Algorithm=" + algorithm;
canonical_querystring += "&X-Amz-Credential=" + accessKey + "%2F" + credential_scope;
canonical_querystring += "&X-Amz-Date=" + string(amzDate);
canonical_querystring += "&X-Amz-Expires=300";
canonical_querystring += "&X-Amz-Security-Token=" + uri_encode(securityToken);
canonical_querystring += "&X-Amz-SignedHeaders=" + signed_headers;
if (piiEntities && shouldIdentifyPiiEntities) {
canonical_querystring += "&content-redaction-type=PII";
}
canonical_querystring += "&language-code=" + lang;
if (languageModelName) {
std::string str(languageModelName);
canonical_querystring += "&language-model-name=" + uri_encode(str);
}
canonical_querystring += "&media-encoding=pcm";
if (piiEntities) {
std::string str(piiEntities);
canonical_querystring += "&pii-entitytypes=" + uri_encode(str);
}
canonical_querystring += "&sample-rate=8000";
// custom vocabulary and filter
if (vocabularyFilterMethod) {
std::string str(vocabularyFilterMethod);
canonical_querystring += "&vocabulary-filter-method=" + str;
}
if (vocabularyFilterName) {
std::string str(vocabularyFilterName);
canonical_querystring += "&vocabulary-filter-name=" + str;
}
if (vocabularyName) {
std::string str(vocabularyName);
canonical_querystring += "&vocabulary-name=" + str;
}
string payload_hash = getSha256("");
string canonical_request = method + '\n'
+ canonical_uri + '\n'
+ canonical_querystring + '\n'
+ canonical_headers + '\n'
+ signed_headers + '\n'
+ payload_hash;
switch_log_printf(SWITCH_CHANNEL_LOG, SWITCH_LOG_ERROR,"TranscribeManager::getSignedWebsocketUrl canonical_request: %s\n", canonical_request.c_str());
string string_to_sign = algorithm + "\n"
+ amzDate + "\n"
+ regex_replace(credential_scope, regex("%2F"), "/") + "\n"
+ getSha256(canonical_request);
unsigned char signing_key[SHA256_DIGEST_LENGTH];
getSignatureKey(signing_key, secretKey, datestamp, region, service);
unsigned char signatureBinary[SHA256_DIGEST_LENGTH];
getHMAC(signatureBinary, signing_key, SHA256_DIGEST_LENGTH, string_to_sign);
string signature = toHex(signatureBinary);
canonical_querystring += "&X-Amz-Signature=" + signature;
string request_url = endpoint + canonical_uri + "?" + canonical_querystring;
path = canonical_uri + "?" + canonical_querystring;
return;
}
string TranscribeManager::getSha256(string str) {
SHA256_CTX ctx;
SHA256_Init(&ctx);
SHA256_Update(&ctx, str.c_str(), str.length());
unsigned char hash[SHA256_DIGEST_LENGTH] = { 0 };
SHA256_Final(hash, &ctx);
ostringstream os;
os << hex << setfill('0');
for (int i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
os << setw(2) << static_cast<unsigned int>(hash[i]);
}
return os.str();
}
void TranscribeManager::getSignatureKey(unsigned char *signatureKey, const string& secretKey,
const string& datestamp, const string& region, const string& service) {
string key = string("AWS4") + secretKey;
unsigned char kDate[SHA256_DIGEST_LENGTH];
unsigned char kRegion[SHA256_DIGEST_LENGTH];
unsigned char kService[SHA256_DIGEST_LENGTH];
unsigned char kSigning[SHA256_DIGEST_LENGTH];
getHMAC(kDate, (unsigned char *)key.c_str(), key.length(), datestamp);
getHMAC(kRegion, kDate, SHA256_DIGEST_LENGTH, region);
getHMAC(kService, kRegion, SHA256_DIGEST_LENGTH, service);
getHMAC(kSigning, kService, SHA256_DIGEST_LENGTH, "aws4_request");
memcpy(signatureKey, kSigning, SHA256_DIGEST_LENGTH);
}
void TranscribeManager::getHMAC(unsigned char *hmac, unsigned char *key, int keyLen, const string& str) {
unsigned char *data = (unsigned char*)str.c_str();
unsigned char *result = HMAC(EVP_sha256(), key, keyLen, data, strlen((char *)data), NULL, NULL);
memcpy(hmac, result, SHA256_DIGEST_LENGTH);
}
string TranscribeManager::toHex(unsigned char *hmac) {
ostringstream os;
os << hex << setfill('0');
for (int i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
os << setw(2) << static_cast<unsigned int>(hmac[i]);
}
return os.str();
}
///////////////////////////////////////////////////////////////////////////////////////////
bool TranscribeManager::parseResponse(const string& response, string& payload, bool& isError, bool verbose) {
const char* buffer = response.c_str();
uint32_t totalLen;
memcpy(&totalLen, &buffer[0], sizeof(uint32_t));
totalLen = ntohl(totalLen);
uint32_t headerLen;
memcpy(&headerLen, &buffer[4], sizeof(uint32_t));
headerLen = ntohl(headerLen);
if (!verifyCRC(buffer, totalLen)) {
return false;
}
buffer += 12; // bytes 0 - 11 are prelude
const int numberOfHeaders = 3;
for (int i = 0; i < numberOfHeaders; i++) {
parseHeader(&buffer, isError, verbose);
}
payload = string(buffer, totalLen - headerLen - 4*4);
return true;
}
bool TranscribeManager::verifyCRC(const char* buffer, const uint32_t totalLength) {
uint32_t preludeCRC;
memcpy(&preludeCRC, &buffer[8], 4);
preludeCRC = ntohl(preludeCRC);
uint32_t calculatedPreludeCRC = CRC::Calculate(&buffer[0], 8, CRC::CRC_32());
if (calculatedPreludeCRC != preludeCRC) {
cout << "Prelude CRC didn't match!" << endl;
return false;
}
uint32_t messageCRC;
memcpy(&messageCRC, &buffer[totalLength - 4], 4);
messageCRC = ntohl(messageCRC);
uint32_t calculatedMessageCRC = CRC::Calculate(buffer, totalLength - 4, CRC::CRC_32());
if (calculatedMessageCRC != messageCRC) {
cout << "Message CRC didn't match!" << endl;
return false;
}
return true;
}
void TranscribeManager::parseHeader(const char** buffer, bool& isError, bool verbose) {
uint8_t headerNameLen;
memcpy(&headerNameLen, *buffer, sizeof(uint8_t));
(*buffer)++;
string headerName(*buffer, headerNameLen);
*buffer += headerNameLen;
uint8_t headerType;
memcpy(&headerType, *buffer, sizeof(uint8_t));
(*buffer)++;
uint16_t headerValLen;
memcpy(&headerValLen, *buffer, sizeof(uint16_t));
headerValLen = ntohs(headerValLen);
*buffer += 2;
string headerVal(*buffer, headerValLen);
*buffer += headerValLen;
if (headerVal == "exception") {
isError = true;
}
if (verbose) {
cout << headerName << "(" << (int)headerType << "): " << headerVal << endl;
}
}
///////////////////////////////////////////////////////////////////////////////////////////
bool TranscribeManager::makeRequest(string& request, const vector<uint8_t>& data) {
char preludeAndCrcBuffer[4*3];
char headerBuffer[88];
char messageCrcBuffer[4];
// prelude
uint32_t totalLen = sizeof(preludeAndCrcBuffer) + sizeof(headerBuffer) + data.size() + sizeof(messageCrcBuffer);
uint32_t headerLen = sizeof(headerBuffer);
totalLen = htonl(totalLen);
headerLen = htonl(headerLen);
memcpy(&preludeAndCrcBuffer[0], &totalLen, sizeof(uint32_t));
memcpy(&preludeAndCrcBuffer[4], &headerLen, sizeof(uint32_t));
uint32_t preludeCRC = CRC::Calculate(&preludeAndCrcBuffer[0], 8, CRC::CRC_32());
preludeCRC = htonl(preludeCRC);
memcpy(&preludeAndCrcBuffer[8], &preludeCRC, sizeof(uint32_t));
// header
char* buffer = headerBuffer;
writeHeader(&buffer, ":content-type", "application/octet-stream");
writeHeader(&buffer, ":event-type", "AudioEvent");
writeHeader(&buffer, ":message-type", "event");
// write everything to response string except for the message CRC
request.append(preludeAndCrcBuffer, sizeof(preludeAndCrcBuffer));
request.append(headerBuffer, sizeof(headerBuffer));
request.append(data.begin(), data.end());
// message CRC
uint32_t messageCRC = CRC::Calculate(request.c_str(), request.length(), CRC::CRC_32());
messageCRC = htonl(messageCRC);
memcpy(messageCrcBuffer, &messageCRC, sizeof(uint32_t));
// write message CRC to response string
request.append(messageCrcBuffer, sizeof(messageCrcBuffer));
return true;
}
void TranscribeManager::writeHeader(char** buffer, const char* key, const char* val) {
uint8_t keyLen = strlen(key);
uint16_t valueLen = strlen(val);
memcpy(*buffer, &keyLen, sizeof(uint8_t));
(*buffer)++;
memcpy(*buffer, key, keyLen);
(*buffer) += keyLen;
uint8_t valueType = 7;
memcpy(*buffer, &valueType, sizeof(uint8_t));
(*buffer)++;
uint16_t valLen = htons(valueLen);
memcpy(*buffer, &valLen, sizeof(uint16_t));
(*buffer) += 2;
memcpy(*buffer, val, valueLen);
(*buffer) += valueLen;
}

View File

@@ -0,0 +1,50 @@
#ifndef TRANSCRIBEMANAGER_HPP_
#define TRANSCRIBEMANAGER_HPP_
#include <string>
#include <vector>
/** Usage
#include "transcribe_manager.hpp"
// get signed URL
const string url = TranscribeManager::getSignedWebsocketUrl(accessKey_, secretKey_, region_);
// connect to the url using a socket library (e.g. https://github.com/machinezone/IXWebSocket)
// build request string
string request;
TranscribeManager::makeRequest(request, audioData); // audioData is a const vector<uint8_t>
// send request to socket
*
*/
using namespace std;
class TranscribeManager {
public:
static void getSignedWebsocketUrl(string& host, string& path,
const std::string& accessKey, const std::string& secretKey, const std::string& securityToken,
const std::string& region, const std::string& lang, const char* vocabularyName,
const char* vocabularyFilterName, const char* vocabularyFilterMethod,
const char* piiEntities, int shouldIdentifyPiiEntities, const char* languageModelName);
static bool parseResponse(const std::string& response, std::string& payload, bool& isError, bool verbose = false);
static bool makeRequest(std::string& request, const std::vector<uint8_t>& data);
static void writeHeader(char** buffer, const char* key, const char* val);
private:
static std::string getSha256(std::string str);
static void getSignatureKey(unsigned char *signatureKey, const std::string& secretKey,
const std::string& datestamp, const std::string& region, const std::string& service);
static void getHMAC(unsigned char *hmac, unsigned char *key, int keyLen, const std::string& str);
static std::string toHex(unsigned char *hmac);
static bool verifyCRC(const char* buffer, const uint32_t totalLength);
static void parseHeader(const char** buffer, bool& isError, bool verbose = false);
};
#endif /* TRANSCRIBEMANAGER_HPP_ */