Various speedups

This commit is contained in:
Geraint 2025-02-27 14:09:49 +00:00
parent d70e5c3d5d
commit 1ad2839a04
2 changed files with 32 additions and 12 deletions

View File

@ -2,15 +2,17 @@
#include <iostream> #include <iostream>
#define LOG_EXPR(expr) std::cout << #expr << " = " << (expr) << "\n"; #define LOG_EXPR(expr) std::cout << #expr << " = " << (expr) << "\n";
#ifdef PROFILE_PLOT_CHUNKS
size_t activeStepIndex = 0; size_t activeStepIndex = 0;
void profileProcessStart(int, int); void profileProcessStart(int, int);
void profileProcessEndStep(); void profileProcessEndStep();
void profileProcessStep(size_t, size_t); void profileProcessStep(size_t, size_t);
void profileProcessEnd(); void profileProcessEnd();
#define SIGNALSMITH_STRETCH_PROFILE_PROCESS_START profileProcessStart # define SIGNALSMITH_STRETCH_PROFILE_PROCESS_START profileProcessStart
#define SIGNALSMITH_STRETCH_PROFILE_PROCESS_STEP profileProcessStep # define SIGNALSMITH_STRETCH_PROFILE_PROCESS_STEP profileProcessStep
#define SIGNALSMITH_STRETCH_PROFILE_PROCESS_ENDSTEP profileProcessEndStep # define SIGNALSMITH_STRETCH_PROFILE_PROCESS_ENDSTEP profileProcessEndStep
#define SIGNALSMITH_STRETCH_PROFILE_PROCESS_END profileProcessEnd # define SIGNALSMITH_STRETCH_PROFILE_PROCESS_END profileProcessEnd
#endif
#include "signalsmith-stretch/signalsmith-stretch.h" #include "signalsmith-stretch/signalsmith-stretch.h"
@ -19,6 +21,7 @@ void profileProcessEnd();
#include "./util/simple-args.h" #include "./util/simple-args.h"
#include "./util/wav.h" #include "./util/wav.h"
#ifdef PROFILE_PLOT_CHUNKS
#include "plot/plot.h" #include "plot/plot.h"
std::vector<signalsmith::Stopwatch> processStopwatches; std::vector<signalsmith::Stopwatch> processStopwatches;
signalsmith::Stopwatch processStopwatchStart, processStopwatchEnd; signalsmith::Stopwatch processStopwatchStart, processStopwatchEnd;
@ -51,11 +54,14 @@ void profileProcessStep(size_t step, size_t count) {
void profileProcessEnd() { void profileProcessEnd() {
processStopwatchEnd.lap(); processStopwatchEnd.lap();
} }
#endif
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
signalsmith::stretch::SignalsmithStretch<float/*, std::ranlux48_base*/> stretch; // optional cheaper RNG for performance comparison signalsmith::stretch::SignalsmithStretch<float/*, std::ranlux48_base*/> stretch; // optional cheaper RNG for performance comparison
#ifdef PROFILE_PLOT_CHUNKS
processStopwatches.reserve(1000); processStopwatches.reserve(1000);
#endif
SimpleArgs args(argc, argv); SimpleArgs args(argc, argv);
@ -154,6 +160,7 @@ int main(int argc, char* argv[]) {
// the `.flush()` call already handled foldback stuff at the end (since we asked for a shorter `tailSamples`) // the `.flush()` call already handled foldback stuff at the end (since we asked for a shorter `tailSamples`)
} }
#ifdef PROFILE_PLOT_CHUNKS
signalsmith::plot::Figure figure; signalsmith::plot::Figure figure;
auto &plot = figure(0, 0).plot(400, 150); auto &plot = figure(0, 0).plot(400, 150);
plot.x.blank().label("step"); plot.x.blank().label("step");
@ -196,6 +203,7 @@ int main(int argc, char* argv[]) {
flatLine.add(0, 0); flatLine.add(0, 0);
flatLine.add(processStopwatches.size(), cumulativeTime); flatLine.add(processStopwatches.size(), cumulativeTime);
figure.write("profile.svg"); figure.write("profile.svg");
#endif
if (!outWav.write(outputWav).warn()) args.errorExit("failed to write WAV"); if (!outWav.write(outputWav).warn()) args.errorExit("failed to write WAV");

View File

@ -11,11 +11,12 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <random> #include <random>
#include <type_traits>
namespace signalsmith { namespace stretch { namespace signalsmith { namespace stretch {
namespace _impl { namespace _impl {
template <bool conjugateSecond=false, typename V> template<bool conjugateSecond=false, typename V>
static std::complex<V> mul(const std::complex<V> &a, const std::complex<V> &b) { static std::complex<V> mul(const std::complex<V> &a, const std::complex<V> &b) {
return conjugateSecond ? std::complex<V>{ return conjugateSecond ? std::complex<V>{
b.real()*a.real() + b.imag()*a.imag(), b.real()*a.real() + b.imag()*a.imag(),
@ -25,9 +26,14 @@ namespace _impl {
a.real()*b.imag() + a.imag()*b.real() a.real()*b.imag() + a.imag()*b.real()
}; };
} }
template<typename V>
static V norm(const std::complex<V> &a) {
V r = a.real(), i = a.imag();
return r*r + i*i;
}
} }
template<typename Sample=float, class RandomEngine=std::default_random_engine> template<typename Sample=float, class RandomEngine=void>
struct SignalsmithStretch { struct SignalsmithStretch {
static constexpr size_t version[3] = {1, 1, 1}; static constexpr size_t version[3] = {1, 1, 1};
@ -475,10 +481,10 @@ private:
Complex input; Complex input;
Complex makeOutput(Complex phase) { Complex makeOutput(Complex phase) {
Sample phaseNorm = std::norm(phase); Sample phaseNorm = _impl::norm(phase);
if (phaseNorm <= noiseFloor) { if (phaseNorm <= noiseFloor) {
phase = input; // prediction is too weak, fall back to the input phase = input; // prediction is too weak, fall back to the input
phaseNorm = std::norm(input) + noiseFloor; phaseNorm = _impl::norm(input) + noiseFloor;
} }
return phase*std::sqrt(energy/phaseNorm); return phase*std::sqrt(energy/phaseNorm);
} }
@ -488,7 +494,13 @@ private:
return channelPredictions.data() + c*bands; return channelPredictions.data() + c*bands;
} }
RandomEngine randomEngine; // If RandomEngine=void, use std::default_random_engine;
using RandomEngineImpl = std::conditional<
std::is_void<RandomEngine>::value,
std::default_random_engine,
RandomEngine
>::type;
RandomEngineImpl randomEngine;
size_t processSpectrumSteps = 0; size_t processSpectrumSteps = 0;
static constexpr size_t splitMainPrediction = 8; // it's just heavy, since we're blending up to 4 different phase predictions static constexpr size_t splitMainPrediction = 8; // it's just heavy, since we're blending up to 4 different phase predictions
@ -550,7 +562,7 @@ private:
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
Band *bins = bandsForChannel(c); Band *bins = bandsForChannel(c);
for (int b = 0; b < bands; ++b) { for (int b = 0; b < bands; ++b) {
bins[b].inputEnergy = std::norm(bins[b].input); bins[b].inputEnergy = _impl::norm(bins[b].input);
} }
} }
for (int b = 0; b < bands; ++b) { for (int b = 0; b < bands; ++b) {
@ -687,7 +699,7 @@ private:
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
Band *bins = bandsForChannel(c); Band *bins = bandsForChannel(c);
for (int b = 0; b < bands; ++b) { for (int b = 0; b < bands; ++b) {
Sample e = std::norm(bins[b].input); Sample e = _impl::norm(bins[b].input);
bins[b].inputEnergy = e; // Used for interpolating prediction energy bins[b].inputEnergy = e; // Used for interpolating prediction energy
energy[b] += e; energy[b] += e;
} }