Start replacing order-sorting with two-pass setup

This commit is contained in:
Geraint 2022-11-28 15:40:39 +00:00
parent 598d037212
commit 96eeee7a6f

View File

@ -41,9 +41,8 @@ struct SignalsmithStretch {
} }
// manual parameters // manual parameters
Sample freqWeight = 1, timeWeight = 2, channelWeight = 0.5; Sample freqWeight = 1, timeWeight = 2, channelWeight = 0.5, maxWeight = 2;
bool sortOrder = true; // Assemble output spectrum highest-magnitude first bool sortOrder = true; // Assemble output spectrum highest-magnitude first
Sample maxProportion = 0.75; // How much the strongest prediction overrides everything else
/// Manual setup /// Manual setup
void configure(int nChannels, int blockSamples, int intervalSamples) { void configure(int nChannels, int blockSamples, int intervalSamples) {
@ -63,7 +62,7 @@ struct SignalsmithStretch {
energy.resize(stft.bands()); energy.resize(stft.bands());
smoothedEnergy.resize(stft.bands()); smoothedEnergy.resize(stft.bands());
outputMap.resize(stft.bands()); outputMap.resize(stft.bands());
observationOrder.resize(channels*stft.bands()); channelPredictions.resize(channels*stft.bands());
maxEnergyChannel.resize(stft.bands()); maxEnergyChannel.resize(stft.bands());
} }
@ -173,9 +172,6 @@ private:
struct Band { struct Band {
Complex input, prevInput{0}; Complex input, prevInput{0};
Complex output, prevOutput{0}; Complex output, prevOutput{0};
Complex timeChange{0};
Sample energy, prevEnergy;
bool ready = false;
}; };
std::vector<Band> channelBands; std::vector<Band> channelBands;
Band * bandsForChannel(int channel) { Band * bandsForChannel(int channel) {
@ -234,18 +230,17 @@ private:
}; };
std::vector<PitchMapPoint> outputMap; std::vector<PitchMapPoint> outputMap;
struct OrderPoint { struct Prediction {
int channel, outputBand;
Sample inputIndex;
Sample energy; Sample energy;
Complex input; Complex input;
Complex freqPrediction;
// For sorting in descending order Complex shortVerticalTwist, longVerticalTwist;
bool operator<(const OrderPoint &other) const { Complex channelTwist;
return other.energy < energy;
}
}; };
std::vector<OrderPoint> observationOrder; std::vector<Prediction> channelPredictions;
Prediction * predictionsForChannel(int c) {
return channelPredictions.data() + c*stft.bands();
}
std::vector<int> maxEnergyChannel; std::vector<int> maxEnergyChannel;
void processSpectrum(int inputInterval) { void processSpectrum(int inputInterval) {
@ -269,24 +264,74 @@ private:
findPeaks(smoothingBins); findPeaks(smoothingBins);
updateOutputMap(smoothingBins); updateOutputMap(smoothingBins);
int longVerticalStep = std::round(smoothingBins);
auto *predictions0 = predictionsForChannel(0);
for (auto &c : maxEnergyChannel) c = -1;
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
Band *bins = bandsForChannel(c); Band *bins = bandsForChannel(c);
auto *order = observationOrder.data() + c*stft.bands(); auto *predictions = predictionsForChannel(c);
for (int b = 0; b < stft.bands(); ++b) { for (int b = 0; b < stft.bands(); ++b) {
auto &outputBin = bins[b];
auto mapPoint = outputMap[b]; auto mapPoint = outputMap[b];
int lowIndex = std::floor(mapPoint.inputBin); int lowIndex = std::floor(mapPoint.inputBin);
Sample fracIndex = mapPoint.inputBin - std::floor(mapPoint.inputBin); Sample fracIndex = mapPoint.inputBin - std::floor(mapPoint.inputBin);
Sample outputEnergy = getFractional<&Band::energy>(c, lowIndex, fracIndex); Prediction prediction;
Complex input = getFractional<&Band::input>(c, lowIndex, fracIndex);
outputEnergy *= std::max<Sample>(0, mapPoint.freqGrad); // scale the energy according to local stretch factor prediction.energy = getFractional<&Band::energy>(c, lowIndex, fracIndex);
order[b] = {c, b, mapPoint.inputBin, outputEnergy, input}; prediction.energy *= std::max<Sample>(0, mapPoint.freqGrad); // scale the energy according to local stretch factor
prediction.input = getFractional<&Band::input>(c, lowIndex, fracIndex);
bins[b].ready = false; Complex prevInput = getFractional<&Band::prevInput>(c, lowIndex, fracIndex);
Complex freqTwist = prediction.input*std::conj(prevInput);
prediction.freqPrediction = outputBin.prevOutput*freqTwist;
if (c > 0) {
prediction.channelTwist = prediction.input*std::conj(predictions0[b]);
} else {
prediction.channelTwist = 0;
}
if (b > 0) {
Complex downInput = getFractional<&Band::input>(c, mapPoint.inputBin - rate);
prediction.shortVerticalTwist = prediction.input*std::conj(downInput);
if (b > longVerticalStep) {
Complex longDownInput = getFractional<&Band::input>(c, mapPoint.inputBin - longVerticalStep*rate);
prediction.longVerticalTwist = prediction.input*std::conj(longDownInput);
} else {
prediction.longVerticalTwist = 0;
}
} else {
prediction.shortVerticalTwist = prediction.longVerticalTwist = 0;
}
predictions[b] = prediction;
// Rough output prediction based on phase-vocoder, sensitive to previous input/output magnitude
outputBin.output = prediction.freqPrediction/(prediction.energy + 1e-10);
}
}
for (int b = 0; b < stft.bands(); ++b) {
int maxChannel = 0;
maxEnergyChannel[b] = 0;
Sample maxEnergy = predictions0[b].energy;
for (int c = 1; c < channels; ++c) {
Sample e = predictionsForChannel(c)[b].energy;
if (e > maxEnergy) {
maxChannel = c;
maxEnergy = e;
}
}
maxEnergyChannel[b] = maxChannel;
Sample channelInput = predictionsForChannel(maxChannel)
for (int c = 0; c < channels; ++c) {
Prediction &prediction = predictionsForChannel(c)[b];
if (c == maxChannel) {
prediction.channelTwist = 0;
} else {
prediction.channelTwist = prediction.input*std::conj(channelInput);
}
} }
} }
if (sortOrder) std::sort(observationOrder.begin(), observationOrder.end());
for (auto &c : maxEnergyChannel) c = -1; for (auto &c : maxEnergyChannel) c = -1;
for (auto &ordered : observationOrder) { for (auto &ordered : observationOrder) {
@ -297,12 +342,12 @@ private:
Sample fracIndex = ordered.inputIndex - std::floor(ordered.inputIndex); Sample fracIndex = ordered.inputIndex - std::floor(ordered.inputIndex);
// We always have the phase-vocoder prediction // We always have the phase-vocoder prediction
Complex prevInput = getFractional<&Band::prevInput>(ordered.channel, lowIndex, fracIndex);
Complex timeChange = ordered.input*std::conj(prevInput); Complex timeChange = ordered.input*std::conj(prevInput);
Complex prediction = outputBin.prevOutput*timeChange*freqWeight; Complex freqPrediction = outputBin.prevOutput*timeChange;
Complex prediction = freqPrediction*freqWeight;
// Track the strongest prediction // Track the strongest prediction
Complex maxPrediction = prediction; Complex maxPrediction = freqPrediction;
Sample maxPredictionNorm = std::norm(maxPrediction); Sample maxPredictionNorm = std::norm(maxPrediction);
// vertical upwards, if it exists // vertical upwards, if it exists
@ -311,8 +356,8 @@ private:
if (outputDownBin.ready) { if (outputDownBin.ready) {
Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex - rate); Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex - rate);
Complex freqChange = ordered.input*std::conj(downInput); Complex freqChange = ordered.input*std::conj(downInput);
Complex newPrediction = outputDownBin.output*freqChange*timeWeight; Complex newPrediction = outputDownBin.output*freqChange;
prediction += newPrediction; prediction += newPrediction*timeWeight;
if (std::norm(newPrediction) > maxPredictionNorm) { if (std::norm(newPrediction) > maxPredictionNorm) {
maxPredictionNorm = std::norm(newPrediction); maxPredictionNorm = std::norm(newPrediction);
maxPrediction = newPrediction; maxPrediction = newPrediction;
@ -325,8 +370,8 @@ private:
if (outputDownBin.ready) { if (outputDownBin.ready) {
Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex + rate); Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex + rate);
Complex freqChange = ordered.input*std::conj(downInput); Complex freqChange = ordered.input*std::conj(downInput);
Complex newPrediction = outputDownBin.output*freqChange*timeWeight; Complex newPrediction = outputDownBin.output*freqChange;
prediction += newPrediction; prediction += newPrediction*timeWeight;
if (std::norm(newPrediction) > maxPredictionNorm) { if (std::norm(newPrediction) > maxPredictionNorm) {
maxPredictionNorm = std::norm(newPrediction); maxPredictionNorm = std::norm(newPrediction);
maxPrediction = newPrediction; maxPrediction = newPrediction;
@ -334,14 +379,13 @@ private:
} }
} }
// longer verticals // longer verticals
int longStep = std::round(smoothingBins);
if (ordered.outputBand > longStep) { if (ordered.outputBand > longStep) {
auto &outputDownBin = bins[ordered.outputBand - longStep]; auto &outputDownBin = bins[ordered.outputBand - longStep];
if (outputDownBin.ready) { if (outputDownBin.ready) {
Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex - longStep*rate); Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex - longStep*rate);
Complex freqChange = ordered.input*std::conj(downInput); Complex freqChange = ordered.input*std::conj(downInput);
Complex newPrediction = outputDownBin.output*freqChange*timeWeight; Complex newPrediction = outputDownBin.output*freqChange;
prediction += newPrediction; prediction += newPrediction*timeWeight;
if (std::norm(newPrediction) > maxPredictionNorm) { if (std::norm(newPrediction) > maxPredictionNorm) {
maxPredictionNorm = std::norm(newPrediction); maxPredictionNorm = std::norm(newPrediction);
maxPrediction = newPrediction; maxPrediction = newPrediction;
@ -353,8 +397,8 @@ private:
if (outputDownBin.ready) { if (outputDownBin.ready) {
Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex + longStep*rate); Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex + longStep*rate);
Complex freqChange = ordered.input*std::conj(downInput); Complex freqChange = ordered.input*std::conj(downInput);
Complex newPrediction = outputDownBin.output*freqChange*timeWeight; Complex newPrediction = outputDownBin.output*freqChange;
prediction += newPrediction; prediction += newPrediction*timeWeight;
if (std::norm(newPrediction) > maxPredictionNorm) { if (std::norm(newPrediction) > maxPredictionNorm) {
maxPredictionNorm = std::norm(newPrediction); maxPredictionNorm = std::norm(newPrediction);
maxPrediction = newPrediction; maxPrediction = newPrediction;
@ -370,8 +414,8 @@ private:
auto *otherBins = bandsForChannel(maxChannel); auto *otherBins = bandsForChannel(maxChannel);
Complex otherOutputOutput = otherBins[ordered.outputBand].output; Complex otherOutputOutput = otherBins[ordered.outputBand].output;
Complex newPrediction = otherOutputOutput*channelRot*channelWeight; Complex newPrediction = otherOutputOutput*channelRot;
prediction += newPrediction; prediction += newPrediction*channelWeight;
if (std::norm(newPrediction) > maxPredictionNorm) { if (std::norm(newPrediction) > maxPredictionNorm) {
maxPredictionNorm = std::norm(newPrediction); maxPredictionNorm = std::norm(newPrediction);
maxPrediction = newPrediction; maxPrediction = newPrediction;
@ -380,7 +424,7 @@ private:
maxChannel = ordered.channel; maxChannel = ordered.channel;
} }
prediction += (maxPrediction - prediction)*maxProportion; prediction += maxPrediction*maxWeight;
Sample predictionNorm = std::norm(prediction); Sample predictionNorm = std::norm(prediction);
if (predictionNorm > 1e-15) { if (predictionNorm > 1e-15) {
@ -388,14 +432,11 @@ private:
} else { } else {
outputBin.output = ordered.input; outputBin.output = ordered.input;
} }
outputBin.ready = true;
} }
for (auto &bin : channelBands) { for (auto &bin : channelBands) {
bin.prevOutput = bin.output; bin.prevOutput = bin.output;
bin.prevInput = bin.input; bin.prevInput = bin.input;
bin.prevEnergy = bin.energy;
} }
} }
@ -407,7 +448,6 @@ private:
Band *bins = bandsForChannel(c); Band *bins = bandsForChannel(c);
for (int b = 0; b < stft.bands(); ++b) { for (int b = 0; b < stft.bands(); ++b) {
Sample e = std::norm(bins[b].input); Sample e = std::norm(bins[b].input);
bins[b].energy = e;
energy[b] += e; energy[b] += e;
} }
} }