Restore up-front vertical phase-twist calculations and prediction weighting

This commit is contained in:
Geraint 2022-12-18 02:06:22 +00:00
parent f72fa99cfa
commit 7ca6c8d13c

View File

@ -311,9 +311,8 @@ private:
std::vector<PitchMapPoint> outputMap;
struct Prediction {
Sample energy;
Sample energy = 0;
Complex input;
bool hasShortVertical = false, hasLongVertical = false;
Complex shortVerticalTwist, longVerticalTwist;
Complex makeOutput(Complex phase) {
@ -347,6 +346,7 @@ private:
}
Sample smoothingBins = Sample(stft.fftSize())/stft.interval();
int longVerticalStep = std::round(smoothingBins);
findPeaks(smoothingBins);
updateOutputMap(smoothingBins);
@ -368,12 +368,24 @@ private:
Complex prevInput = getFractional<&Band::prevInput>(c, lowIndex, fracIndex);
Complex freqTwist = signalsmith::perf::mul<true>(prediction.input, prevInput);
Complex phase = signalsmith::perf::mul(outputBin.prevOutput, freqTwist);
outputBin.output = prediction.makeOutput(phase);
outputBin.output = phase/(prediction.energy + noiseFloor);
if (b > 0) {
Complex downInput = getFractional<&Band::input>(c, mapPoint.inputBin - timeFactor);
prediction.shortVerticalTwist = signalsmith::perf::mul<true>(prediction.input, downInput);
if (b >= longVerticalStep) {
Complex longDownInput = getFractional<&Band::input>(c, mapPoint.inputBin - longVerticalStep*timeFactor);
prediction.longVerticalTwist = signalsmith::perf::mul<true>(prediction.input, longDownInput);
} else {
prediction.longVerticalTwist = 0;
}
} else {
prediction.shortVerticalTwist = prediction.longVerticalTwist = 0;
}
}
}
// Re-predict using phase differences between frequencies
int longVerticalStep = std::round(smoothingBins);
for (int b = 0; b < stft.bands(); ++b) {
// Find maximum-energy channel and calculate that
int maxChannel = 0;
@ -396,18 +408,10 @@ private:
// Upwards vertical steps
if (b > 0) {
if (!prediction.hasShortVertical) {
Complex downInput = getFractional<&Band::input>(maxChannel, mapPoint.inputBin - timeFactor);
prediction.shortVerticalTwist = signalsmith::perf::mul<true>(prediction.input, downInput);
}
auto &downBin = bins[b - 1];
phase += signalsmith::perf::mul(downBin.output, prediction.shortVerticalTwist);
if (b >= longVerticalStep) {
if (!prediction.hasLongVertical) {
Complex longDownInput = getFractional<&Band::input>(maxChannel, mapPoint.inputBin - longVerticalStep*timeFactor);
prediction.longVerticalTwist = signalsmith::perf::mul<true>(prediction.input, longDownInput);
}
auto &longDownBin = bins[b - longVerticalStep];
phase += signalsmith::perf::mul(longDownBin.output, prediction.longVerticalTwist);
}
@ -415,23 +419,11 @@ private:
// Downwards vertical steps
if (b < stft.bands() - 1) {
auto &upPrediction = predictions[b + 1];
{ // upwards short vertical twist
auto upMapPoint = outputMap[b + 1];
Complex upInput = getFractional<&Band::input>(maxChannel, upMapPoint.inputBin - timeFactor);
upPrediction.shortVerticalTwist = signalsmith::perf::mul<true>(upPrediction.input, upInput);
upPrediction.hasShortVertical = true;
}
auto &upBin = bins[b + 1];
phase += signalsmith::perf::mul<true>(upBin.output, upPrediction.shortVerticalTwist);
if (b < stft.bands() - longVerticalStep) {
auto &longUpPrediction = predictions[b + longVerticalStep];
{ // upwards long vertical twist
auto longUpMapPoint = outputMap[b + longVerticalStep];
Complex longUpInput = getFractional<&Band::input>(maxChannel, longUpMapPoint.inputBin - longVerticalStep*timeFactor);
longUpPrediction.longVerticalTwist = signalsmith::perf::mul<true>(longUpPrediction.input, longUpInput);
longUpPrediction.hasLongVertical = true;
}
auto &longUpBin = bins[b + longVerticalStep];
phase += signalsmith::perf::mul<true>(longUpBin.output, longUpPrediction.longVerticalTwist);
}