Two-pass setup

This commit is contained in:
Geraint 2022-11-29 10:39:03 +00:00
parent 96eeee7a6f
commit 8787460488

View File

@ -37,12 +37,10 @@ struct SignalsmithStretch {
configure(nChannels, sampleRate*0.12, sampleRate*0.03); configure(nChannels, sampleRate*0.12, sampleRate*0.03);
freqWeight = 1; freqWeight = 1;
timeWeight = 2; timeWeight = 2;
channelWeight = 0.5;
} }
// manual parameters // manual parameters
Sample freqWeight = 1, timeWeight = 2, channelWeight = 0.5, maxWeight = 2; Sample freqWeight = 1, timeWeight = 2, maxWeight = 2;
bool sortOrder = true; // Assemble output spectrum highest-magnitude first
/// Manual setup /// Manual setup
void configure(int nChannels, int blockSamples, int intervalSamples) { void configure(int nChannels, int blockSamples, int intervalSamples) {
@ -172,6 +170,7 @@ private:
struct Band { struct Band {
Complex input, prevInput{0}; Complex input, prevInput{0};
Complex output, prevOutput{0}; Complex output, prevOutput{0};
Sample inputEnergy;
}; };
std::vector<Band> channelBands; std::vector<Band> channelBands;
Band * bandsForChannel(int channel) { Band * bandsForChannel(int channel) {
@ -235,7 +234,6 @@ private:
Complex input; Complex input;
Complex freqPrediction; Complex freqPrediction;
Complex shortVerticalTwist, longVerticalTwist; Complex shortVerticalTwist, longVerticalTwist;
Complex channelTwist;
}; };
std::vector<Prediction> channelPredictions; std::vector<Prediction> channelPredictions;
Prediction * predictionsForChannel(int c) { Prediction * predictionsForChannel(int c) {
@ -278,7 +276,7 @@ private:
Prediction prediction; Prediction prediction;
prediction.energy = getFractional<&Band::energy>(c, lowIndex, fracIndex); prediction.energy = getFractional<&Band::inputEnergy>(c, lowIndex, fracIndex);
prediction.energy *= std::max<Sample>(0, mapPoint.freqGrad); // scale the energy according to local stretch factor prediction.energy *= std::max<Sample>(0, mapPoint.freqGrad); // scale the energy according to local stretch factor
prediction.input = getFractional<&Band::input>(c, lowIndex, fracIndex); prediction.input = getFractional<&Band::input>(c, lowIndex, fracIndex);
@ -286,11 +284,6 @@ private:
Complex freqTwist = prediction.input*std::conj(prevInput); Complex freqTwist = prediction.input*std::conj(prevInput);
prediction.freqPrediction = outputBin.prevOutput*freqTwist; prediction.freqPrediction = outputBin.prevOutput*freqTwist;
if (c > 0) {
prediction.channelTwist = prediction.input*std::conj(predictions0[b]);
} else {
prediction.channelTwist = 0;
}
if (b > 0) { if (b > 0) {
Complex downInput = getFractional<&Band::input>(c, mapPoint.inputBin - rate); Complex downInput = getFractional<&Band::input>(c, mapPoint.inputBin - rate);
prediction.shortVerticalTwist = prediction.input*std::conj(downInput); prediction.shortVerticalTwist = prediction.input*std::conj(downInput);
@ -307,12 +300,12 @@ private:
predictions[b] = prediction; predictions[b] = prediction;
// Rough output prediction based on phase-vocoder, sensitive to previous input/output magnitude // Rough output prediction based on phase-vocoder, sensitive to previous input/output magnitude
outputBin.output = prediction.freqPrediction/(prediction.energy + 1e-10); outputBin.output = prediction.freqPrediction/(prediction.energy + Sample(1e-10));
} }
} }
for (int b = 0; b < stft.bands(); ++b) { for (int b = 0; b < stft.bands(); ++b) {
// Find maximum-energy channel and calculate that
int maxChannel = 0; int maxChannel = 0;
maxEnergyChannel[b] = 0;
Sample maxEnergy = predictions0[b].energy; Sample maxEnergy = predictions0[b].energy;
for (int c = 1; c < channels; ++c) { for (int c = 1; c < channels; ++c) {
Sample e = predictionsForChannel(c)[b].energy; Sample e = predictionsForChannel(c)[b].energy;
@ -321,116 +314,85 @@ private:
maxEnergy = e; maxEnergy = e;
} }
} }
maxEnergyChannel[b] = maxChannel;
Sample channelInput = predictionsForChannel(maxChannel)
for (int c = 0; c < channels; ++c) {
Prediction &prediction = predictionsForChannel(c)[b];
if (c == maxChannel) {
prediction.channelTwist = 0;
} else {
prediction.channelTwist = prediction.input*std::conj(channelInput);
}
}
}
for (auto &c : maxEnergyChannel) c = -1; auto *predictions = predictionsForChannel(maxChannel);
for (auto &ordered : observationOrder) { auto &prediction = predictions[b];
auto *bins = bandsForChannel(ordered.channel); auto *bins = bandsForChannel(maxChannel);
auto &outputBin = bins[ordered.outputBand]; auto &outputBin = bins[b];
auto mapPoint = outputMap[b];
int lowIndex = std::floor(ordered.inputIndex); Complex phase = prediction.freqPrediction*freqWeight;
Sample fracIndex = ordered.inputIndex - std::floor(ordered.inputIndex);
// We always have the phase-vocoder prediction
Complex timeChange = ordered.input*std::conj(prevInput);
Complex freqPrediction = outputBin.prevOutput*timeChange;
Complex prediction = freqPrediction*freqWeight;
// Track the strongest prediction // Track the strongest prediction
Complex maxPrediction = freqPrediction; Complex maxPrediction = prediction.freqPrediction;
Sample maxPredictionNorm = std::norm(maxPrediction); Sample maxPredictionNorm = std::norm(maxPrediction);
// vertical upwards, if it exists // Short steps
if (ordered.outputBand > 0) { if (b > 0) {
auto &outputDownBin = bins[ordered.outputBand - 1]; auto &otherBin = bins[b - 1];
if (outputDownBin.ready) { Complex newPrediction = otherBin.output*prediction.shortVerticalTwist;
Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex - rate); phase += newPrediction*timeWeight;
Complex freqChange = ordered.input*std::conj(downInput); if (std::norm(newPrediction) > maxPredictionNorm) {
Complex newPrediction = outputDownBin.output*freqChange; maxPredictionNorm = std::norm(newPrediction);
prediction += newPrediction*timeWeight; maxPrediction = newPrediction;
if (std::norm(newPrediction) > maxPredictionNorm) { }
maxPredictionNorm = std::norm(newPrediction); }
maxPrediction = newPrediction; if (b < stft.bands() - 1) {
} auto &otherBin = bins[b + 1];
} auto &otherPrediction = predictions[b + 1];
} Complex newPrediction = otherBin.output*std::conj(otherPrediction.shortVerticalTwist);
// vertical downwards, if it exists phase += newPrediction*timeWeight;
if (ordered.outputBand < stft.bands() - 1) { if (std::norm(newPrediction) > maxPredictionNorm) {
auto &outputDownBin = bins[ordered.outputBand + 1]; maxPredictionNorm = std::norm(newPrediction);
if (outputDownBin.ready) { maxPrediction = newPrediction;
Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex + rate); }
Complex freqChange = ordered.input*std::conj(downInput); }
Complex newPrediction = outputDownBin.output*freqChange; // longer verticals
prediction += newPrediction*timeWeight; if (b > longVerticalStep) {
if (std::norm(newPrediction) > maxPredictionNorm) { auto &otherBin = bins[b - longVerticalStep];
maxPredictionNorm = std::norm(newPrediction); Complex newPrediction = otherBin.output*prediction.longVerticalTwist;
maxPrediction = newPrediction; phase += newPrediction*timeWeight;
} if (std::norm(newPrediction) > maxPredictionNorm) {
} maxPredictionNorm = std::norm(newPrediction);
} maxPrediction = newPrediction;
// longer verticals }
if (ordered.outputBand > longStep) { }
auto &outputDownBin = bins[ordered.outputBand - longStep]; if (b < stft.bands() - longVerticalStep) {
if (outputDownBin.ready) { auto &otherBin = bins[b + longVerticalStep];
Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex - longStep*rate); auto &otherPrediction = predictions[b + longVerticalStep];
Complex freqChange = ordered.input*std::conj(downInput); Complex newPrediction = otherBin.output*std::conj(otherPrediction.longVerticalTwist);
Complex newPrediction = outputDownBin.output*freqChange; phase += newPrediction*timeWeight;
prediction += newPrediction*timeWeight;
if (std::norm(newPrediction) > maxPredictionNorm) {
maxPredictionNorm = std::norm(newPrediction);
maxPrediction = newPrediction;
}
}
}
if (ordered.outputBand < stft.bands() - longStep) {
auto &outputDownBin = bins[ordered.outputBand + longStep];
if (outputDownBin.ready) {
Complex downInput = getFractional<&Band::input>(ordered.channel, ordered.inputIndex + longStep*rate);
Complex freqChange = ordered.input*std::conj(downInput);
Complex newPrediction = outputDownBin.output*freqChange;
prediction += newPrediction*timeWeight;
if (std::norm(newPrediction) > maxPredictionNorm) {
maxPredictionNorm = std::norm(newPrediction);
maxPrediction = newPrediction;
}
}
}
// Inter-channel prediction, if it exists
int &maxChannel = maxEnergyChannel[ordered.outputBand];
if (maxChannel >= 0) {
Complex otherInput = getFractional<&Band::input>(maxChannel, lowIndex, fracIndex);
Complex channelRot = ordered.input*std::conj(otherInput);
auto *otherBins = bandsForChannel(maxChannel);
Complex otherOutputOutput = otherBins[ordered.outputBand].output;
Complex newPrediction = otherOutputOutput*channelRot;
prediction += newPrediction*channelWeight;
if (std::norm(newPrediction) > maxPredictionNorm) { if (std::norm(newPrediction) > maxPredictionNorm) {
maxPredictionNorm = std::norm(newPrediction); maxPredictionNorm = std::norm(newPrediction);
maxPrediction = newPrediction; maxPrediction = newPrediction;
} }
} else {
maxChannel = ordered.channel;
} }
prediction += maxPrediction*maxWeight; phase += maxPrediction*maxWeight;
Sample predictionNorm = std::norm(prediction); Sample phaseNorm = std::norm(phase);
if (predictionNorm > 1e-15) { if (phaseNorm > 1e-15) {
outputBin.output = prediction*std::sqrt(ordered.energy/predictionNorm); outputBin.output = phase*std::sqrt(prediction.energy/phaseNorm);
} else { } else {
outputBin.output = ordered.input; outputBin.output = prediction.input;
}
// All other bins are locked in phase
for (int c = 0; c < channels; ++c) {
if (c != maxChannel) {
auto &channelBin = bandsForChannel(c)[b];
auto &channelPrediction = predictionsForChannel(c)[b];
Complex channelTwist = prediction.input*std::conj(channelPrediction.input);
Complex channelPhase = outputBin.output*channelTwist;
Sample channelPhaseNorm = std::norm(channelPhase);
if (channelPhaseNorm > 1e-15) {
channelBin.output = channelPhase*std::sqrt(prediction.energy/channelPhaseNorm);
} else {
channelBin.output = channelPrediction.input;
}
}
} }
} }
@ -448,6 +410,7 @@ private:
Band *bins = bandsForChannel(c); Band *bins = bandsForChannel(c);
for (int b = 0; b < stft.bands(); ++b) { for (int b = 0; b < stft.bands(); ++b) {
Sample e = std::norm(bins[b].input); Sample e = std::norm(bins[b].input);
bins[b].inputEnergy = e; // Used for interpolating prediction energy
energy[b] += e; energy[b] += e;
} }
} }