Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions include/ortx_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor**
* The caller is responsible for freeing the memory allocated for the audio data.
* @param audio_paths An array of strings representing the paths to the audio files to be loaded.
* @param num_audios The number of audio files to be loaded.
*
*
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios);
Expand All @@ -51,7 +51,8 @@ extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* cons
*
* @return extError_t Error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t sizes[], size_t num_audios);
extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t sizes[],
size_t num_audios);

/**
* @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor.
Expand All @@ -65,12 +66,50 @@ extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void*
* @param log_mel A pointer to an OrtxTensorResult pointer where the result will be stored.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel);
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio,
OrtxTensorResult** log_mel);

/**
* @brief Splits an input audio signal and outputs the areas of high vs low energy based on the STFT analysis.
*
* This function takes an input waveform tensor and associated parameters such as sample rate,
* frame length, hop length, and energy threshold (in dB), and identifies contiguous segments
* of speech or sound activity. It writes the resulting segment start and end indices into
* the provided output tensor.
*
* @param input The input waveform tensor (1D or 2D) containing audio samples.
* @param sr_tensor A tensor containing the sample rate of the input audio (in Hz).
* @param frame_ms_tensor A tensor containing the frame size in milliseconds.
* @param hop_ms_tensor A tensor containing the hop length in milliseconds.
* @param energy_threshold_db_tensor A tensor specifying the energy threshold in decibels (dB)
* used to decide which frames are considered active.
* @param output0 A pointer to an output tensor where the resulting segments will be written.
* Each row contains two integers: [start_sample, end_sample] for a detected segment.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxSplitSignalSegments(const OrtxTensor* input, const OrtxTensor* sr_tensor,
const OrtxTensor* frame_ms_tensor, const OrtxTensor* hop_ms_tensor,
const OrtxTensor* energy_threshold_db_tensor, OrtxTensor* output0);

/**
* @brief Merges adjacent signal segments that are separated by short gaps.
*
* This function takes a tensor of detected segments (each row containing [start, end] indices)
* and merges any consecutive segments whose gap is smaller than the specified threshold (in milliseconds).
*
* @param segments_tensor The input tensor of detected segments, of shape [N, 2].
* @param merge_gap_ms_tensor A tensor containing a single integer value representing
* the maximum allowed gap (in milliseconds) between consecutive segments to be merged.
* @param output0 A pointer to an output tensor where the merged segments will be stored.
* Each row contains two integers: [merged_start_sample, merged_end_sample].
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxMergeSignalSegments(const OrtxTensor* segments_tensor,
const OrtxTensor* merge_gap_ms_tensor, OrtxTensor* output0);

/**
* @brief Extracts log-mel features from raw audio data using a feature extractor.
*
*
* This function processes the input audio buffers through the provided feature extractor,
* producing log-mel spectrogram outputs suitable for inference or further signal analysis.
*
Expand All @@ -84,7 +123,8 @@ extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxR
* @return An extError_t value indicating success or error status. Returns
* EXT_SUCCESS on success, or an appropriate error code if extraction fails.
*/
extError_t ORTX_API_CALL OrtxFeatureExtraction(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** result);
extError_t ORTX_API_CALL OrtxFeatureExtraction(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio,
OrtxTensorResult** result);

#ifdef __cplusplus
}
Expand Down
51 changes: 48 additions & 3 deletions shared/api/c_api_feature_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "speech_extractor.h"

#include "c_api_utils.hpp"
#include <math/energy_stft_segmentation.hpp>

using namespace ort_extensions;

Expand All @@ -17,7 +18,7 @@ class RawAudiosObject : public OrtxObjectImpl {
};

extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t sizes[],
size_t num_audios) {
size_t num_audios) {
if (audios == nullptr || data == nullptr || sizes == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
Expand Down Expand Up @@ -99,6 +100,50 @@ extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxR
return status.Code();
}

extError_t ORTX_API_CALL OrtxSplitSignalSegments(const OrtxTensor* input, const OrtxTensor* sr_tensor,
const OrtxTensor* frame_ms_tensor, const OrtxTensor* hop_ms_tensor,
const OrtxTensor* energy_threshold_db_tensor, OrtxTensor* output0) {
if (input == nullptr || sr_tensor == nullptr || frame_ms_tensor == nullptr || hop_ms_tensor == nullptr ||
energy_threshold_db_tensor == nullptr || output0 == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const ortc::Tensor<float>& input_tensor = *reinterpret_cast<const ortc::Tensor<float>*>(input);
const ortc::Tensor<int64_t>& sr_t = *reinterpret_cast<const ortc::Tensor<int64_t>*>(sr_tensor);
const ortc::Tensor<int64_t>& frame_t = *reinterpret_cast<const ortc::Tensor<int64_t>*>(frame_ms_tensor);
const ortc::Tensor<int64_t>& hop_t = *reinterpret_cast<const ortc::Tensor<int64_t>*>(hop_ms_tensor);
const ortc::Tensor<float>& threshold_t = *reinterpret_cast<const ortc::Tensor<float>*>(energy_threshold_db_tensor);
ortc::Tensor<int64_t>& output_t = *reinterpret_cast<ortc::Tensor<int64_t>*>(output0);

OrtStatusPtr status = split_signal_segments(input_tensor, sr_t, frame_t, hop_t, threshold_t, output_t);
if (status) {
ReturnableStatus::last_error_message_ = "split_signal_segments failed";
return kOrtxErrorInvalidArgument;
}

return extError_t();
}

extError_t ORTX_API_CALL OrtxMergeSignalSegments(const OrtxTensor* segments_tensor,
const OrtxTensor* merge_gap_ms_tensor, OrtxTensor* output0) {
if (segments_tensor == nullptr || merge_gap_ms_tensor == nullptr || output0 == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}

const ortc::Tensor<int64_t>& seg_t = *reinterpret_cast<const ortc::Tensor<int64_t>*>(segments_tensor);
const ortc::Tensor<int64_t>& gap_t = *reinterpret_cast<const ortc::Tensor<int64_t>*>(merge_gap_ms_tensor);
ortc::Tensor<int64_t>& output_t = *reinterpret_cast<ortc::Tensor<int64_t>*>(output0);

OrtStatusPtr status = merge_signal_segments(seg_t, gap_t, output_t);
if (status) {
ReturnableStatus::last_error_message_ = "merge_signal_segments failed";
return kOrtxErrorInvalidArgument;
}

return kOrtxOK;
}

extError_t ORTX_API_CALL OrtxFeatureExtraction(OrtxFeatureExtractor* extractor, OrtxRawAudios* raw_audios,
OrtxTensorResult** result) {
if (extractor == nullptr || raw_audios == nullptr || result == nullptr) {
Expand All @@ -110,8 +155,8 @@ extError_t ORTX_API_CALL OrtxFeatureExtraction(OrtxFeatureExtractor* extractor,
auto audios_obj = static_cast<RawAudiosObject*>(raw_audios);

auto result_ptr = std::make_unique<TensorResult>();
ReturnableStatus status = extractor_ptr->Preprocess(
ort_extensions::span(audios_obj->audios_.get(), audios_obj->num_audios_), *result_ptr);
ReturnableStatus status =
extractor_ptr->Preprocess(ort_extensions::span(audios_obj->audios_.get(), audios_obj->num_audios_), *result_ptr);
if (status.IsOk()) {
*result = static_cast<OrtxTensorResult*>(result_ptr.release());
} else {
Expand Down
39 changes: 23 additions & 16 deletions test/pp_api_test/test_feature_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ TEST(ExtractorTest, TestWhisperAudioOutput) {
ASSERT_LT(mismatch_percentage, 0.04) << "Mismatch percentage exceeds 4% threshold!";
}

TEST(ExtractorTest, TestStftEnergySegmentationAndMerge) {
TEST(ExtractorTest, TestSplitSignalSegments) {
const int64_t sample_rate = 16000;
const int64_t num_samples = sample_rate * 2;

Expand All @@ -295,42 +295,49 @@ TEST(ExtractorTest, TestStftEnergySegmentationAndMerge) {

auto* alloc = &CppAllocator::Instance();

ortc::Tensor<float> audio(alloc);
float* audio_data = audio.Allocate({1, num_samples});
std::memcpy(audio_data, pcm.data(), num_samples * sizeof(float));
ortc::Tensor<float> input(alloc);
float* in_data = input.Allocate({1, num_samples});
std::memcpy(in_data, pcm.data(), num_samples * sizeof(float));

ortc::Tensor<int64_t> sr(alloc);
sr.Allocate({1})[0] = sample_rate;

ortc::Tensor<int64_t> frame(alloc);
frame.Allocate({1})[0] = 25;
ortc::Tensor<int64_t> frame_ms(alloc);
frame_ms.Allocate({1})[0] = 25;

ortc::Tensor<int64_t> hop(alloc);
hop.Allocate({1})[0] = 10;
ortc::Tensor<int64_t> hop_ms(alloc);
hop_ms.Allocate({1})[0] = 10;

ortc::Tensor<float> thr(alloc);
thr.Allocate({1})[0] = -40.0f;
ortc::Tensor<float> energy_threshold_db(alloc);
// Difference of 40 decibels can be a reasonable diff between voice and silence (or background noise)
energy_threshold_db.Allocate({1})[0] = -40.0f;

ortc::Tensor<int64_t> output(alloc);

split_signal_segments(audio, sr, frame, hop, thr, output);
extError_t err = OrtxSplitSignalSegments(
reinterpret_cast<OrtxTensor*>(&input), reinterpret_cast<OrtxTensor*>(&sr),
reinterpret_cast<OrtxTensor*>(&frame_ms), reinterpret_cast<OrtxTensor*>(&hop_ms),
reinterpret_cast<OrtxTensor*>(&energy_threshold_db), reinterpret_cast<OrtxTensor*>(&output));

ASSERT_EQ(err, kOrtxOK);

const auto& out_shape = output.Shape();
ASSERT_EQ(out_shape.size(), 2u);
int64_t num_segments = out_shape[0];
ASSERT_EQ(out_shape[1], 2);
ASSERT_EQ(num_segments, 53);
ASSERT_EQ(out_shape[0], 53);

// Start merging
ortc::Tensor<int64_t> merge_gap(alloc);
merge_gap.Allocate({1})[0] = 50;

ortc::Tensor<int64_t> merged_segments(alloc);

merge_signal_segments(output, merge_gap, merged_segments);
err = OrtxMergeSignalSegments(reinterpret_cast<OrtxTensor*>(&output), reinterpret_cast<OrtxTensor*>(&merge_gap),
reinterpret_cast<OrtxTensor*>(&merged_segments));

ASSERT_EQ(err, kOrtxOK);

const auto& merged_shape = merged_segments.Shape();
ASSERT_EQ(merged_shape.size(), 2u);
ASSERT_EQ(merged_shape[1], 2);
ASSERT_EQ(merged_shape[0], 4);
}
}