Refactor chunked computation to be neater, fix some int warnings

This commit is contained in:
Geraint 2025-02-27 14:38:11 +00:00
parent 1ad2839a04
commit 8c3852cae3
3 changed files with 82 additions and 81 deletions

View File

@ -105,7 +105,7 @@ int main(int argc, char* argv[]) {
signalsmith::Stopwatch stopwatch; signalsmith::Stopwatch stopwatch;
stopwatch.start(); stopwatch.start();
stretch.presetDefault(inWav.channels, inWav.sampleRate); stretch.presetDefault(int(inWav.channels), inWav.sampleRate);
stretch.setTransposeSemitones(semitones, tonality/inWav.sampleRate); stretch.setTransposeSemitones(semitones, tonality/inWav.sampleRate);
double initSeconds = stopwatch.lap(); double initSeconds = stopwatch.lap();
@ -130,7 +130,7 @@ int main(int argc, char* argv[]) {
stretch.seek(inWav, stretch.inputLatency(), 1/time); stretch.seek(inWav, stretch.inputLatency(), 1/time);
inWav.offset += stretch.inputLatency(); inWav.offset += stretch.inputLatency();
// Process it all in one call, although it works just the same if we split into smaller blocks // Process it all in one call, although it works just the same if we split into smaller blocks
stretch.process(inWav, inputLength, outWav, outputLength); stretch.process(inWav, int(inputLength), outWav, int(outputLength));
// Read the last bit of output without giving it any more input // Read the last bit of output without giving it any more input
outWav.offset += outputLength; outWav.offset += outputLength;
stretch.flush(outWav, tailSamples); stretch.flush(outWav, tailSamples);

View File

@ -66,39 +66,40 @@ public:
} }
}; };
unsigned int sampleRate = 48000; size_t sampleRate = 48000;
unsigned int channels = 1, offset = 0; size_t channels = 1, offset = 0;
std::vector<double> samples; std::vector<double> samples;
int length() const { size_t length() const {
return samples.size()/channels - offset; size_t perChannel = samples.size()/channels;
return (perChannel >= offset) ? perChannel - offset : 0;
} }
void resize(int length) { void resize(size_t length) {
samples.resize((offset + length)*channels, 0); samples.resize((offset + length)*channels, 0);
} }
template<bool isConst> template<bool isConst>
class ChannelReader { class ChannelReader {
using CSample = typename std::conditional<isConst, const double, double>::type; using CSample = typename std::conditional<isConst, const double, double>::type;
CSample *data; CSample *data;
int stride; size_t stride;
public: public:
ChannelReader(CSample *samples, int channels) : data(samples), stride(channels) {} ChannelReader(CSample *samples, size_t channels) : data(samples), stride(channels) {}
CSample & operator [](int i) { CSample & operator [](size_t i) {
return data[i*stride]; return data[i*stride];
} }
}; };
ChannelReader<false> operator [](int c) { ChannelReader<false> operator [](size_t c) {
return ChannelReader<false>(samples.data() + offset*channels + c, channels); return ChannelReader<false>(samples.data() + offset*channels + c, channels);
} }
ChannelReader<true> operator [](int c) const { ChannelReader<true> operator [](size_t c) const {
return ChannelReader<true>(samples.data() + offset*channels + c, channels); return ChannelReader<true>(samples.data() + offset*channels + c, channels);
} }
Result result = Result(Result::Code::OK); Result result = Result(Result::Code::OK);
Wav() {} Wav() {}
Wav(double sampleRate, int channels) : sampleRate(sampleRate), channels(channels) {} Wav(double sampleRate, size_t channels) : sampleRate(sampleRate), channels(channels) {}
Wav(double sampleRate, int channels, const std::vector<double> &samples) : sampleRate(sampleRate), channels(channels), samples(samples) {} Wav(double sampleRate, size_t channels, const std::vector<double> &samples) : sampleRate(sampleRate), channels(channels), samples(samples) {}
Wav(std::string filename) { Wav(std::string filename) {
result = read(filename).warn(); result = read(filename).warn();
} }
@ -141,9 +142,9 @@ public:
sampleRate = read32(file); sampleRate = read32(file);
if (sampleRate < 1) return result = Result(Result::Code::FORMAT_ERROR, "Cannot have zero sampleRate"); if (sampleRate < 1) return result = Result(Result::Code::FORMAT_ERROR, "Cannot have zero sampleRate");
unsigned int expectedBytesPerSecond = read32(file); size_t expectedBytesPerSecond = read32(file);
unsigned int bytesPerFrame = read16(file); size_t bytesPerFrame = read16(file);
unsigned int bitsPerSample = read16(file); size_t bitsPerSample = read16(file);
if (!formatIsValid(formatInt, bitsPerSample)) return result = Result(Result::Code::UNSUPPORTED, "Unsupported format:bits: " + std::to_string(formatInt) + ":" + std::to_string(bitsPerSample)); if (!formatIsValid(formatInt, bitsPerSample)) return result = Result(Result::Code::UNSUPPORTED, "Unsupported format:bits: " + std::to_string(formatInt) + ":" + std::to_string(bitsPerSample));
// Since it's plain WAVE, we can do some extra checks for consistency // Since it's plain WAVE, we can do some extra checks for consistency
if (bitsPerSample*channels != bytesPerFrame*8) return result = Result(Result::Code::FORMAT_ERROR, "Format sizes don't add up"); if (bitsPerSample*channels != bytesPerFrame*8) return result = Result(Result::Code::FORMAT_ERROR, "Format sizes don't add up");
@ -191,7 +192,7 @@ public:
file.open(filename, std::ios::binary); file.open(filename, std::ios::binary);
if (!file.is_open()) return result = Result(Result::Code::IO_ERROR, "Failed to open file: " + filename); if (!file.is_open()) return result = Result(Result::Code::IO_ERROR, "Failed to open file: " + filename);
int bytesPerSample; size_t bytesPerSample;
switch (format) { switch (format) {
case Format::PCM: case Format::PCM:
bytesPerSample = 2; bytesPerSample = 2;
@ -199,30 +200,30 @@ public:
} }
// File size - 44 bytes is RIFF header, "fmt" block, and "data" block header // File size - 44 bytes is RIFF header, "fmt" block, and "data" block header
unsigned int dataLength = (samples.size() - offset*channels)*bytesPerSample; size_t dataLength = (samples.size() - offset*channels)*bytesPerSample;
unsigned int fileLength = 44 + dataLength; size_t fileLength = 44 + dataLength;
// RIFF chunk // RIFF chunk
write32(file, value_RIFF); write32(file, value_RIFF);
write32(file, fileLength - 8); // File length, excluding the RIFF header write32(file, uint32_t(fileLength - 8)); // File length, excluding the RIFF header
write32(file, value_WAVE); write32(file, value_WAVE);
// "fmt " block // "fmt " block
write32(file, value_fmt); write32(file, value_fmt);
write32(file, 16); // block length write32(file, 16); // block length
write16(file, (uint16_t)format); write16(file, uint16_t(format));
write16(file, channels); write16(file, uint16_t(channels));
write32(file, sampleRate); write32(file, uint32_t(sampleRate));
unsigned int expectedBytesPerSecond = sampleRate*channels*bytesPerSample; size_t expectedBytesPerSecond = sampleRate*channels*bytesPerSample;
write32(file, expectedBytesPerSecond); write32(file, uint32_t(expectedBytesPerSecond));
write16(file, channels*bytesPerSample); // Bytes per frame write16(file, uint16_t(channels*bytesPerSample)); // Bytes per frame
write16(file, bytesPerSample*8); // bist per sample write16(file, uint16_t(bytesPerSample*8)); // bist per sample
// "data" block // "data" block
write32(file, value_data); write32(file, value_data);
write32(file, dataLength); write32(file, uint32_t(dataLength));
switch (format) { switch (format) {
case Format::PCM: case Format::PCM:
for (unsigned int i = offset*channels; i < samples.size(); i++) { for (size_t i = offset*channels; i < samples.size(); i++) {
double value = samples[i]*32768; double value = samples[i]*32768;
if (value > 32767) value = 32767; if (value > 32767) value = 32767;
if (value <= -32768) value = -32768; if (value <= -32768) value = -32768;

View File

@ -41,16 +41,16 @@ struct SignalsmithStretch {
SignalsmithStretch(long seed) : randomEngine(seed) {} SignalsmithStretch(long seed) : randomEngine(seed) {}
int blockSamples() const { int blockSamples() const {
return stft.blockSamples(); return int(stft.blockSamples());
} }
int intervalSamples() const { int intervalSamples() const {
return stft.defaultInterval(); return int(stft.defaultInterval());
} }
int inputLatency() const { int inputLatency() const {
return stft.blockSamples() - stft.analysisOffset(); return int(stft.analysisLatency());
} }
int outputLatency() const { int outputLatency() const {
return stft.synthesisOffset() + stft.defaultInterval(); return int(stft.synthesisLatency() + stft.defaultInterval());
} }
void reset() { void reset() {
@ -84,7 +84,7 @@ struct SignalsmithStretch {
stashedOutput = stft.output; stashedOutput = stft.output;
tmpBuffer.resize(blockSamples + intervalSamples); tmpBuffer.resize(blockSamples + intervalSamples);
bands = stft.bands(); bands = int(stft.bands());
channelBands.assign(bands*channels, Band()); channelBands.assign(bands*channels, Band());
peaks.reserve(bands/2); peaks.reserve(bands/2);
@ -122,7 +122,7 @@ struct SignalsmithStretch {
tmpBuffer.resize(stft.blockSamples() + stft.defaultInterval()); tmpBuffer.resize(stft.blockSamples() + stft.defaultInterval());
int startIndex = std::max<int>(0, inputSamples - int(tmpBuffer.size())); // start position in input int startIndex = std::max<int>(0, inputSamples - int(tmpBuffer.size())); // start position in input
int padStart = tmpBuffer.size() - (inputSamples - startIndex); // start position in tmpBuffer int padStart = int(tmpBuffer.size() + startIndex) - inputSamples; // start position in tmpBuffer
Sample totalEnergy = 0; Sample totalEnergy = 0;
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
@ -152,7 +152,7 @@ struct SignalsmithStretch {
int prevCopiedInput = 0; int prevCopiedInput = 0;
auto copyInput = [&](int toIndex){ auto copyInput = [&](int toIndex){
int length = std::min<int>(stft.blockSamples() + stft.defaultInterval(), toIndex - prevCopiedInput); int length = std::min<int>(int(stft.blockSamples() + stft.defaultInterval()), toIndex - prevCopiedInput);
tmpBuffer.resize(length); tmpBuffer.resize(length);
int offset = toIndex - length; int offset = toIndex - length;
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
@ -217,13 +217,48 @@ struct SignalsmithStretch {
for (int outputIndex = 0; outputIndex < outputSamples; ++outputIndex) { for (int outputIndex = 0; outputIndex < outputSamples; ++outputIndex) {
Sample processRatio = Sample(blockProcess.samplesSinceLast)/stft.defaultInterval(); Sample processRatio = Sample(blockProcess.samplesSinceLast)/stft.defaultInterval();
size_t processToStep = std::min<size_t>(blockProcess.steps, blockProcess.steps*processRatio); if (processRatio >= 1) { // we're ready to start a new block
processRatio = 0;
blockProcess.step = 0;
blockProcess.steps = 0; // how many steps
blockProcess.samplesSinceLast = 0;
// Time to process a spectrum! Where should it come from in the input?
int inputOffset = std::round(outputIndex*Sample(inputSamples)/outputSamples);
int inputInterval = inputOffset - prevInputOffset;
prevInputOffset = inputOffset;
copyInput(inputOffset);
stashedInput = stft.input; // save the input state, since that's what we'll analyse later
stashedOutput = stft.output; // save the current output, and read from it
stft.moveOutput(stft.defaultInterval()); // the actual input jumps forward in time by one interval, ready for the synthesis
blockProcess.newSpectrum = didSeek || (inputInterval > 0);
blockProcess.mappedFrequencies = customFreqMap || freqMultiplier != 1;
if (blockProcess.newSpectrum) {
// make sure the previous input is the correct distance in the past (give or take 1 sample)
blockProcess.reanalysePrev = didSeek || std::abs(inputInterval - int(stft.defaultInterval())) > 1;
if (blockProcess.reanalysePrev) blockProcess.steps += stft.analyseSteps() + 1;
// analyse a new input
blockProcess.steps += stft.analyseSteps() + 1;
}
blockProcess.timeFactor = didSeek ? seekTimeFactor : stft.defaultInterval()/std::max<Sample>(1, inputInterval);
didSeek = false;
updateProcessSpectrumSteps();
blockProcess.steps += processSpectrumSteps;
blockProcess.steps += stft.synthesiseSteps() + 1;
}
size_t processToStep = std::min<size_t>(blockProcess.steps, (blockProcess.steps + 1)*processRatio);
while (blockProcess.step < processToStep) { while (blockProcess.step < processToStep) {
size_t step = blockProcess.step++; size_t step = blockProcess.step++;
#ifdef SIGNALSMITH_STRETCH_PROFILE_PROCESS_STEP #ifdef SIGNALSMITH_STRETCH_PROFILE_PROCESS_STEP
SIGNALSMITH_STRETCH_PROFILE_PROCESS_STEP(step, blockProcess.steps); SIGNALSMITH_STRETCH_PROFILE_PROCESS_STEP(step, blockProcess.steps);
#endif #endif
if (blockProcess.newSpectrum) { if (blockProcess.newSpectrum) {
if (blockProcess.reanalysePrev) { if (blockProcess.reanalysePrev) {
// analyse past input // analyse past input
@ -294,41 +329,6 @@ struct SignalsmithStretch {
continue; continue;
} }
} }
if (processRatio >= 1) { // we *should* have just written a block, and are now ready to start a new one
blockProcess.step = 0;
blockProcess.steps = 0; // how many steps
blockProcess.samplesSinceLast = 0;
// Time to process a spectrum! Where should it come from in the input?
int inputOffset = std::round(outputIndex*Sample(inputSamples)/outputSamples);
int inputInterval = inputOffset - prevInputOffset;
prevInputOffset = inputOffset;
copyInput(inputOffset);
stashedInput = stft.input; // save the input state, since that's what we'll analyse later
stashedOutput = stft.output; // save the current output, and read from it
stft.moveOutput(stft.defaultInterval()); // the actual input jumps forward in time by one interval, ready for the synthesis
blockProcess.newSpectrum = didSeek || (inputInterval > 0);
blockProcess.mappedFrequencies = customFreqMap || freqMultiplier != 1;
if (blockProcess.newSpectrum) {
// make sure the previous input is the correct distance in the past (give or take 1 sample)
blockProcess.reanalysePrev = didSeek || std::abs(inputInterval - int(stft.defaultInterval())) > 1;
if (blockProcess.reanalysePrev) blockProcess.steps += stft.analyseSteps() + 1;
// analyse a new input
blockProcess.steps += stft.analyseSteps() + 1;
}
blockProcess.timeFactor = didSeek ? seekTimeFactor : stft.defaultInterval()/std::max<Sample>(1, inputInterval);
didSeek = false;
updateProcessSpectrumSteps();
blockProcess.steps += processSpectrumSteps;
blockProcess.steps += stft.synthesiseSteps() + 1;
blockProcess.steps += 1; // planning the next block
}
#ifdef SIGNALSMITH_STRETCH_PROFILE_PROCESS_ENDSTEP #ifdef SIGNALSMITH_STRETCH_PROFILE_PROCESS_ENDSTEP
SIGNALSMITH_STRETCH_PROFILE_PROCESS_ENDSTEP(); SIGNALSMITH_STRETCH_PROFILE_PROCESS_ENDSTEP();
#endif #endif
@ -355,7 +355,7 @@ struct SignalsmithStretch {
// Read the remaining output, providing no further input. `outputSamples` should ideally be at least `.outputLatency()` // Read the remaining output, providing no further input. `outputSamples` should ideally be at least `.outputLatency()`
template<class Outputs> template<class Outputs>
void flush(Outputs &&outputs, int outputSamples) { void flush(Outputs &&outputs, int outputSamples) {
int plainOutput = std::min<int>(outputSamples, stft.blockSamples()); int plainOutput = std::min<int>(outputSamples, int(stft.blockSamples()));
int foldedBackOutput = std::min<int>(outputSamples, int(stft.blockSamples()) - plainOutput); int foldedBackOutput = std::min<int>(outputSamples, int(stft.blockSamples()) - plainOutput);
stft.finishOutput(1); stft.finishOutput(1);
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
@ -495,7 +495,7 @@ private:
} }
// If RandomEngine=void, use std::default_random_engine; // If RandomEngine=void, use std::default_random_engine;
using RandomEngineImpl = std::conditional< using RandomEngineImpl = typename std::conditional<
std::is_void<RandomEngine>::value, std::is_void<RandomEngine>::value,
std::default_random_engine, std::default_random_engine,
RandomEngine RandomEngine
@ -527,7 +527,7 @@ private:
if (blockProcess.newSpectrum) { if (blockProcess.newSpectrum) {
if (step < size_t(channels)) { if (step < size_t(channels)) {
int channel = step; int channel = int(step);
auto bins = bandsForChannel(channel); auto bins = bandsForChannel(channel);
Complex rot = std::polar(Sample(1), bandToFreq(0)*stft.defaultInterval()*Sample(2*M_PI)); Complex rot = std::polar(Sample(1), bandToFreq(0)*stft.defaultInterval()*Sample(2*M_PI));
@ -572,7 +572,7 @@ private:
return; return;
} }
if (step < size_t(channels)) { if (step < size_t(channels)) {
size_t c = step; int c = int(step);
Band *bins = bandsForChannel(c); Band *bins = bandsForChannel(c);
auto *predictions = predictionsForChannel(c); auto *predictions = predictionsForChannel(c);
for (int b = 0; b < bands; ++b) { for (int b = 0; b < bands; ++b) {
@ -598,9 +598,9 @@ private:
if (step < splitMainPrediction) { if (step < splitMainPrediction) {
// Re-predict using phase differences between frequencies // Re-predict using phase differences between frequencies
int chunk = step; size_t chunk = step;
int startB = bands*chunk/splitMainPrediction; int startB = int(bands*chunk/splitMainPrediction);
int endB = bands*(chunk + 1)/splitMainPrediction; int endB = int(bands*(chunk + 1)/splitMainPrediction);
for (int b = startB; b < endB; ++b) { for (int b = startB; b < endB; ++b) {
// Find maximum-energy channel and calculate that // Find maximum-energy channel and calculate that
int maxChannel = 0; int maxChannel = 0;