From 8787460488ab5d881699632a1ae391624db52169 Mon Sep 17 00:00:00 2001 From: Geraint Date: Tue, 29 Nov 2022 10:39:03 +0000 Subject: [PATCH] Two-pass setup --- signalsmith-stretch.h | 185 +++++++++++++++++------------------------- 1 file changed, 74 insertions(+), 111 deletions(-) diff --git a/signalsmith-stretch.h b/signalsmith-stretch.h index 7768e83..8732ee3 100644 --- a/signalsmith-stretch.h +++ b/signalsmith-stretch.h @@ -37,12 +37,10 @@ struct SignalsmithStretch { configure(nChannels, sampleRate*0.12, sampleRate*0.03); freqWeight = 1; timeWeight = 2; - channelWeight = 0.5; } // manual parameters - Sample freqWeight = 1, timeWeight = 2, channelWeight = 0.5, maxWeight = 2; - bool sortOrder = true; // Assemble output spectrum highest-magnitude first + Sample freqWeight = 1, timeWeight = 2, maxWeight = 2; /// Manual setup void configure(int nChannels, int blockSamples, int intervalSamples) { @@ -172,6 +170,7 @@ private: struct Band { Complex input, prevInput{0}; Complex output, prevOutput{0}; + Sample inputEnergy; }; std::vector channelBands; Band * bandsForChannel(int channel) { @@ -235,7 +234,6 @@ private: Complex input; Complex freqPrediction; Complex shortVerticalTwist, longVerticalTwist; - Complex channelTwist; }; std::vector channelPredictions; Prediction * predictionsForChannel(int c) { @@ -278,7 +276,7 @@ private: Prediction prediction; - prediction.energy = getFractional<&Band::energy>(c, lowIndex, fracIndex); + prediction.energy = getFractional<&Band::inputEnergy>(c, lowIndex, fracIndex); prediction.energy *= std::max(0, mapPoint.freqGrad); // scale the energy according to local stretch factor prediction.input = getFractional<&Band::input>(c, lowIndex, fracIndex); @@ -286,11 +284,6 @@ private: Complex freqTwist = prediction.input*std::conj(prevInput); prediction.freqPrediction = outputBin.prevOutput*freqTwist; - if (c > 0) { - prediction.channelTwist = prediction.input*std::conj(predictions0[b]); - } else { - prediction.channelTwist = 0; - } if (b > 0) { Complex downInput = getFractional<&Band::input>(c, mapPoint.inputBin - rate); prediction.shortVerticalTwist = prediction.input*std::conj(downInput); @@ -307,12 +300,12 @@ private: predictions[b] = prediction; // Rough output prediction based on phase-vocoder, sensitive to previous input/output magnitude - outputBin.output = prediction.freqPrediction/(prediction.energy + 1e-10); + outputBin.output = prediction.freqPrediction/(prediction.energy + Sample(1e-10)); } } for (int b = 0; b < stft.bands(); ++b) { + // Find maximum-energy channel and calculate that int maxChannel = 0; - maxEnergyChannel[b] = 0; Sample maxEnergy = predictions0[b].energy; for (int c = 1; c < channels; ++c) { Sample e = predictionsForChannel(c)[b].energy; @@ -321,116 +314,85 @@ private: maxEnergy = e; } } - maxEnergyChannel[b] = maxChannel; - Sample channelInput = predictionsForChannel(maxChannel) - for (int c = 0; c < channels; ++c) { - Prediction &prediction = predictionsForChannel(c)[b]; - if (c == maxChannel) { - prediction.channelTwist = 0; - } else { - prediction.channelTwist = prediction.input*std::conj(channelInput); - } - } - } - - for (auto &c : maxEnergyChannel) c = -1; - for (auto &ordered : observationOrder) { - auto *bins = bandsForChannel(ordered.channel); - auto &outputBin = bins[ordered.outputBand]; - - int lowIndex = std::floor(ordered.inputIndex); - Sample fracIndex = ordered.inputIndex - std::floor(ordered.inputIndex); - - // We always have the phase-vocoder prediction - Complex timeChange = ordered.input*std::conj(prevInput); - Complex freqPrediction = outputBin.prevOutput*timeChange; - Complex prediction = freqPrediction*freqWeight; + auto *predictions = predictionsForChannel(maxChannel); + auto &prediction = predictions[b]; + auto *bins = bandsForChannel(maxChannel); + auto &outputBin = bins[b]; + auto mapPoint = outputMap[b]; + + Complex phase = prediction.freqPrediction*freqWeight; + // Track the strongest prediction - Complex maxPrediction = freqPrediction; + Complex maxPrediction = prediction.freqPrediction; Sample maxPredictionNorm = std::norm(maxPrediction); - // vertical upwards, if it exists - if (ordered.outputBand > 0) { - auto &outputDownBin = bins[ordered.outputBand - 1]; - if (outputDownBin.ready) { - Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex - rate); - Complex freqChange = ordered.input*std::conj(downInput); - Complex newPrediction = outputDownBin.output*freqChange; - prediction += newPrediction*timeWeight; - if (std::norm(newPrediction) > maxPredictionNorm) { - maxPredictionNorm = std::norm(newPrediction); - maxPrediction = newPrediction; - } - } - } - // vertical downwards, if it exists - if (ordered.outputBand < stft.bands() - 1) { - auto &outputDownBin = bins[ordered.outputBand + 1]; - if (outputDownBin.ready) { - Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex + rate); - Complex freqChange = ordered.input*std::conj(downInput); - Complex newPrediction = outputDownBin.output*freqChange; - prediction += newPrediction*timeWeight; - if (std::norm(newPrediction) > maxPredictionNorm) { - maxPredictionNorm = std::norm(newPrediction); - maxPrediction = newPrediction; - } - } - } - // longer verticals - if (ordered.outputBand > longStep) { - auto &outputDownBin = bins[ordered.outputBand - longStep]; - if (outputDownBin.ready) { - Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex - longStep*rate); - Complex freqChange = ordered.input*std::conj(downInput); - Complex newPrediction = outputDownBin.output*freqChange; - prediction += newPrediction*timeWeight; - if (std::norm(newPrediction) > maxPredictionNorm) { - maxPredictionNorm = std::norm(newPrediction); - maxPrediction = newPrediction; - } - } - } - if (ordered.outputBand < stft.bands() - longStep) { - auto &outputDownBin = bins[ordered.outputBand + longStep]; - if (outputDownBin.ready) { - Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex + longStep*rate); - Complex freqChange = ordered.input*std::conj(downInput); - Complex newPrediction = outputDownBin.output*freqChange; - prediction += newPrediction*timeWeight; - if (std::norm(newPrediction) > maxPredictionNorm) { - maxPredictionNorm = std::norm(newPrediction); - maxPrediction = newPrediction; - } - } - } - - // Inter-channel prediction, if it exists - int &maxChannel = maxEnergyChannel[ordered.outputBand]; - if (maxChannel >= 0) { - Complex otherInput = getFractional<&Band::input>(maxChannel, lowIndex, fracIndex); - Complex channelRot = ordered.input*std::conj(otherInput); - - auto *otherBins = bandsForChannel(maxChannel); - Complex otherOutputOutput = otherBins[ordered.outputBand].output; - Complex newPrediction = otherOutputOutput*channelRot; - prediction += newPrediction*channelWeight; + // Short steps + if (b > 0) { + auto &otherBin = bins[b - 1]; + Complex newPrediction = otherBin.output*prediction.shortVerticalTwist; + phase += newPrediction*timeWeight; if (std::norm(newPrediction) > maxPredictionNorm) { maxPredictionNorm = std::norm(newPrediction); maxPrediction = newPrediction; } + } + if (b < stft.bands() - 1) { + auto &otherBin = bins[b + 1]; + auto &otherPrediction = predictions[b + 1]; + Complex newPrediction = otherBin.output*std::conj(otherPrediction.shortVerticalTwist); + phase += newPrediction*timeWeight; + if (std::norm(newPrediction) > maxPredictionNorm) { + maxPredictionNorm = std::norm(newPrediction); + maxPrediction = newPrediction; + } + } + // longer verticals + if (b > longVerticalStep) { + auto &otherBin = bins[b - longVerticalStep]; + Complex newPrediction = otherBin.output*prediction.longVerticalTwist; + phase += newPrediction*timeWeight; + if (std::norm(newPrediction) > maxPredictionNorm) { + maxPredictionNorm = std::norm(newPrediction); + maxPrediction = newPrediction; + } + } + if (b < stft.bands() - longVerticalStep) { + auto &otherBin = bins[b + longVerticalStep]; + auto &otherPrediction = predictions[b + longVerticalStep]; + Complex newPrediction = otherBin.output*std::conj(otherPrediction.longVerticalTwist); + phase += newPrediction*timeWeight; + if (std::norm(newPrediction) > maxPredictionNorm) { + maxPredictionNorm = std::norm(newPrediction); + maxPrediction = newPrediction; + } + } + + phase += maxPrediction*maxWeight; + + Sample phaseNorm = std::norm(phase); + if (phaseNorm > 1e-15) { + outputBin.output = phase*std::sqrt(prediction.energy/phaseNorm); } else { - maxChannel = ordered.channel; + outputBin.output = prediction.input; } - prediction += maxPrediction*maxWeight; - - Sample predictionNorm = std::norm(prediction); - if (predictionNorm > 1e-15) { - outputBin.output = prediction*std::sqrt(ordered.energy/predictionNorm); - } else { - outputBin.output = ordered.input; + // All other bins are locked in phase + for (int c = 0; c < channels; ++c) { + if (c != maxChannel) { + auto &channelBin = bandsForChannel(c)[b]; + auto &channelPrediction = predictionsForChannel(c)[b]; + + Complex channelTwist = prediction.input*std::conj(channelPrediction.input); + Complex channelPhase = outputBin.output*channelTwist; + + Sample channelPhaseNorm = std::norm(channelPhase); + if (channelPhaseNorm > 1e-15) { + channelBin.output = channelPhase*std::sqrt(prediction.energy/channelPhaseNorm); + } else { + channelBin.output = channelPrediction.input; + } + } } } @@ -448,6 +410,7 @@ private: Band *bins = bandsForChannel(c); for (int b = 0; b < stft.bands(); ++b) { Sample e = std::norm(bins[b].input); + bins[b].inputEnergy = e; // Used for interpolating prediction energy energy[b] += e; } }