diff --git a/mistralrs-audio/src/lib.rs b/mistralrs-audio/src/lib.rs index a97b14d7a5..9909330e6c 100644 --- a/mistralrs-audio/src/lib.rs +++ b/mistralrs-audio/src/lib.rs @@ -101,6 +101,46 @@ impl AudioInput { } mono } + + /// Normalize audio to prevent clipping + pub fn normalize(&mut self) -> &mut Self { + let max_amplitude = self.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max); + if max_amplitude > 0.0 && max_amplitude != 1.0 { + let scale = 1.0 / max_amplitude; + for sample in &mut self.samples { + *sample *= scale; + } + } + self + } + + /// Apply fade in/out to reduce audio artifacts + pub fn apply_fade(&mut self, fade_in_samples: usize, fade_out_samples: usize) -> &mut Self { + let len = self.samples.len(); + // Fade in + for i in 0..fade_in_samples.min(len) { + let factor = i as f32 / fade_in_samples as f32; + self.samples[i] *= factor; + } + // Fade out + for i in 0..fade_out_samples.min(len) { + let factor = (fade_out_samples - i) as f32 / fade_out_samples as f32; + self.samples[len - 1 - i] *= factor; + } + self + } + + /// Remove DC offset (audio centered around 0) + pub fn remove_dc_offset(&mut self) -> &mut Self { + if self.samples.is_empty() { + return self; + } + let mean = self.samples.iter().sum::() / self.samples.len() as f32; + for sample in &mut self.samples { + *sample -= mean; + } + self + } } #[cfg(test)] @@ -148,4 +188,44 @@ mod tests { assert_eq!(input.samples.len(), 80); assert_eq!(input.sample_rate, 8000); } + + #[test] + fn test_normalize() { + let mut input = AudioInput { + samples: vec![0.2, -0.5, 0.8, -1.0], + sample_rate: 16000, + channels: 1, + }; + input.normalize(); + let max = input.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max); + assert!((max - 1.0).abs() < 1e-6); + } + + #[test] + fn test_apply_fade() { + let mut input = AudioInput { + samples: vec![1.0; 10], + sample_rate: 16000, + channels: 1, + }; + input.apply_fade(3, 3); + assert!((input.samples[0] - 0.0).abs() < 1e-6); + assert!(input.samples[1] > 0.0 && input.samples[1] < 1.0); + assert!(input.samples[2] > 0.0 && input.samples[2] < 1.0); + assert!(input.samples[3] == 1.0); + assert!((input.samples[9] - 0.0).abs() < 1e-6); + } + + #[test] + fn test_remove_dc_offset() { + let mut input = AudioInput { + samples: vec![1.0, 1.0, 1.0, 1.0], + sample_rate: 16000, + channels: 1, + }; + input.remove_dc_offset(); + for s in input.samples { + assert!((s - 0.0).abs() < 1e-6); + } + } }