Correct energy for preliminary (phase-vocoder) output.

This commit is contained in:
Geraint 2022-12-17 23:08:45 +00:00
parent c3addb7298
commit 3ffe6704ce

View File

@ -229,11 +229,11 @@ struct SignalsmithStretch {
} }
private: private:
using Complex = std::complex<Sample>;
static constexpr Sample noiseFloor{1e-15}; static constexpr Sample noiseFloor{1e-15};
int silenceCounter = 0; int silenceCounter = 0;
bool silenceFirst = true; bool silenceFirst = true;
using Complex = std::complex<Sample>;
Sample freqMultiplier = 1, freqTonalityLimit = 0.5; Sample freqMultiplier = 1, freqTonalityLimit = 0.5;
std::function<Sample(Sample)> customFreqMap = nullptr; std::function<Sample(Sample)> customFreqMap = nullptr;
@ -314,8 +314,14 @@ private:
struct Prediction { struct Prediction {
Sample energy; Sample energy;
Complex input; Complex input;
Complex freqPrediction; bool hasShortVertical = false, hasLongVertical = false;
Complex shortVerticalTwist, longVerticalTwist; Complex shortVerticalTwist, longVerticalTwist;
Complex makeOutput(Complex phase) {
Sample phaseNorm = std::norm(phase);
if (phaseNorm <= noiseFloor) return input;
return phase*std::sqrt(energy/phaseNorm);
}
}; };
std::vector<Prediction> channelPredictions; std::vector<Prediction> channelPredictions;
Prediction * predictionsForChannel(int c) { Prediction * predictionsForChannel(int c) {
@ -344,7 +350,6 @@ private:
updateOutputMap(smoothingBins); updateOutputMap(smoothingBins);
int longVerticalStep = std::round(smoothingBins); int longVerticalStep = std::round(smoothingBins);
auto *predictions0 = predictionsForChannel(0);
for (auto &c : maxEnergyChannel) c = -1; for (auto &c : maxEnergyChannel) c = -1;
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
Band *bins = bandsForChannel(c); Band *bins = bandsForChannel(c);
@ -355,35 +360,20 @@ private:
int lowIndex = std::floor(mapPoint.inputBin); int lowIndex = std::floor(mapPoint.inputBin);
Sample fracIndex = mapPoint.inputBin - std::floor(mapPoint.inputBin); Sample fracIndex = mapPoint.inputBin - std::floor(mapPoint.inputBin);
Prediction prediction; Prediction &prediction = predictions[b];
prediction.energy = getFractional<&Band::inputEnergy>(c, lowIndex, fracIndex); prediction.energy = getFractional<&Band::inputEnergy>(c, lowIndex, fracIndex);
prediction.energy *= std::max<Sample>(0, mapPoint.freqGrad); // scale the energy according to local stretch factor prediction.energy *= std::max<Sample>(0, mapPoint.freqGrad); // scale the energy according to local stretch factor
prediction.input = getFractional<&Band::input>(c, lowIndex, fracIndex); prediction.input = getFractional<&Band::input>(c, lowIndex, fracIndex);
// Preliminary output prediction from phase-vocoder
Complex prevInput = getFractional<&Band::prevInput>(c, lowIndex, fracIndex); Complex prevInput = getFractional<&Band::prevInput>(c, lowIndex, fracIndex);
Complex freqTwist = signalsmith::perf::mul<true>(prediction.input, prevInput); Complex freqTwist = signalsmith::perf::mul<true>(prediction.input, prevInput);
prediction.freqPrediction = signalsmith::perf::mul(outputBin.prevOutput, freqTwist); Complex phase = signalsmith::perf::mul(outputBin.prevOutput, freqTwist);
outputBin.output = prediction.makeOutput(phase);
if (b > 0) {
Complex downInput = getFractional<&Band::input>(c, mapPoint.inputBin - timeFactor);
prediction.shortVerticalTwist = signalsmith::perf::mul<true>(prediction.input, downInput);
if (b > longVerticalStep) {
Complex longDownInput = getFractional<&Band::input>(c, mapPoint.inputBin - longVerticalStep*timeFactor);
prediction.longVerticalTwist = signalsmith::perf::mul<true>(prediction.input, longDownInput);
} else {
prediction.longVerticalTwist = 0;
} }
} else {
prediction.shortVerticalTwist = prediction.longVerticalTwist = 0;
} }
predictions[b] = prediction; auto *predictions0 = predictionsForChannel(0);
// Rough output prediction based on phase-vocoder, sensitive to previous input/output magnitude
outputBin.output = prediction.freqPrediction/(prediction.energy + noiseFloor);
}
}
for (int b = 0; b < stft.bands(); ++b) { for (int b = 0; b < stft.bands(); ++b) {
// Find maximum-energy channel and calculate that // Find maximum-energy channel and calculate that
int maxChannel = 0; int maxChannel = 0;
@ -404,33 +394,48 @@ private:
Complex phase = 0; Complex phase = 0;
// Short steps // Short vertical step
if (b > 0) { if (b > 0) {
if (!prediction.hasShortVertical) {
Complex downInput = getFractional<&Band::input>(maxChannel, mapPoint.inputBin - timeFactor);
prediction.shortVerticalTwist = signalsmith::perf::mul<true>(prediction.input, downInput);
}
auto &otherBin = bins[b - 1]; auto &otherBin = bins[b - 1];
phase += signalsmith::perf::mul(otherBin.output, prediction.shortVerticalTwist); phase += signalsmith::perf::mul(otherBin.output, prediction.shortVerticalTwist);
} }
if (b < stft.bands() - 1) { if (b < stft.bands() - 1) {
auto &otherBin = bins[b + 1];
auto &otherPrediction = predictions[b + 1]; auto &otherPrediction = predictions[b + 1];
{ // upwards short vertical twist
auto otherMapPoint = outputMap[b + 1];
Complex downInput = getFractional<&Band::input>(maxChannel, otherMapPoint.inputBin - timeFactor);
otherPrediction.shortVerticalTwist = signalsmith::perf::mul<true>(otherPrediction.input, downInput);
otherPrediction.hasShortVertical = true;
}
auto &otherBin = bins[b + 1];
phase += signalsmith::perf::mul<true>(otherBin.output, otherPrediction.shortVerticalTwist); phase += signalsmith::perf::mul<true>(otherBin.output, otherPrediction.shortVerticalTwist);
} }
// longer verticals // Long vertical steps
if (b > longVerticalStep) { if (b > longVerticalStep) {
if (!prediction.hasLongVertical) {
Complex downInput = getFractional<&Band::input>(maxChannel, mapPoint.inputBin - longVerticalStep*timeFactor);
prediction.longVerticalTwist = signalsmith::perf::mul<true>(prediction.input, downInput);
}
auto &otherBin = bins[b - longVerticalStep]; auto &otherBin = bins[b - longVerticalStep];
phase += signalsmith::perf::mul(otherBin.output, prediction.longVerticalTwist); phase += signalsmith::perf::mul(otherBin.output, prediction.longVerticalTwist);
} }
if (b < stft.bands() - longVerticalStep) { if (b < stft.bands() - longVerticalStep) {
auto &otherBin = bins[b + longVerticalStep];
auto &otherPrediction = predictions[b + longVerticalStep]; auto &otherPrediction = predictions[b + longVerticalStep];
{ // upwards long vertical twist
auto otherMapPoint = outputMap[b + longVerticalStep];
Complex downInput = getFractional<&Band::input>(maxChannel, otherMapPoint.inputBin - longVerticalStep*timeFactor);
otherPrediction.longVerticalTwist = signalsmith::perf::mul<true>(otherPrediction.input, downInput);
otherPrediction.hasLongVertical = true;
}
auto &otherBin = bins[b + longVerticalStep];
phase += signalsmith::perf::mul<true>(otherBin.output, otherPrediction.longVerticalTwist); phase += signalsmith::perf::mul<true>(otherBin.output, otherPrediction.longVerticalTwist);
} }
Sample phaseNorm = std::norm(phase); outputBin.output = prediction.makeOutput(phase);
if (phaseNorm > noiseFloor) {
outputBin.output = phase*std::sqrt(prediction.energy/phaseNorm);
} else {
outputBin.output = prediction.input;
}
// All other bins are locked in phase // All other bins are locked in phase
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
@ -440,13 +445,7 @@ private:
Complex channelTwist = signalsmith::perf::mul<true>(channelPrediction.input, prediction.input); Complex channelTwist = signalsmith::perf::mul<true>(channelPrediction.input, prediction.input);
Complex channelPhase = signalsmith::perf::mul(outputBin.output, channelTwist); Complex channelPhase = signalsmith::perf::mul(outputBin.output, channelTwist);
channelBin.output = channelPrediction.makeOutput(channelPhase);
Sample channelPhaseNorm = std::norm(channelPhase);
if (channelPhaseNorm > noiseFloor) {
channelBin.output = channelPhase*std::sqrt(channelPrediction.energy/channelPhaseNorm);
} else {
channelBin.output = channelPrediction.input;
}
} }
} }
} }
@ -516,7 +515,6 @@ private:
} }
Sample avgBand = bandSum/energySum; Sample avgBand = bandSum/energySum;
Sample avgFreq = bandToFreq(avgBand); Sample avgFreq = bandToFreq(avgBand);
Sample avgEnergy = energySum/(end - start);
peaks.emplace_back(Peak{avgBand, freqToBand(mapFreq(avgFreq))}); peaks.emplace_back(Peak{avgBand, freqToBand(mapFreq(avgFreq))});
start = end; start = end;