diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fb88281 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +all: out/stretch + +out/stretch: signalsmith-stretch.h cmd/main.cpp cmd/util/*.h + mkdir -p out + g++ cmd/main.cpp -o out/stretch -std=c++11 -Ofast -g + +clean: + rm -rf out \ No newline at end of file diff --git a/cmd/main.cpp b/cmd/main.cpp new file mode 100644 index 0000000..e22022f --- /dev/null +++ b/cmd/main.cpp @@ -0,0 +1,72 @@ +#include +#define LOG_EXPR(expr) std::cout << #expr << " = " << (expr) << "\n"; + +#include + +#include "../signalsmith-stretch.h" +#include "util/simple-args.h" +#include "util/wav.h" + +int main(int argc, char* argv[]) { + SimpleArgs args(argc, argv); + + std::string inputWav = args.arg("input.wav", "16-bit WAV file"); + std::string outputWav = args.arg("output.wav", "output WAV file"); + + double semitones = args.flag("semitones", "pitch-shift amount", 0); + double tonality = args.flag("tonality", "tonality limit (Hz)", 8000); + double time = args.flag("time", "time-stretch factor", 1); + bool exactLength = args.hasFlag("exact", "trims the start/end so the output has the correct length"); + args.errorExit(); + + Wav inWav; + if (!inWav.read(inputWav).warn()) args.errorExit("failed to read WAV"); + size_t inputLength = inWav.samples.size()/inWav.channels; + + Wav outWav; + outWav.channels = inWav.channels; + outWav.sampleRate = inWav.sampleRate; + int outputLength = std::round(inputLength*time); + + signalsmith::stretch::SignalsmithStretch stretch; + stretch.presetDefault(inWav.channels, inWav.sampleRate); + stretch.setTransposeSemitones(semitones, tonality/inWav.sampleRate); + + // pad the input at the end, since we'll be reading slightly ahead + size_t paddedInputLength = inputLength + stretch.inputLatency(); + inWav.samples.resize(paddedInputLength*inWav.channels); + // pad the output at the end, since we have output latency as well + int tailSamples = exactLength ? stretch.outputLatency() : (stretch.outputLatency() + stretch.inputLatency()); // if we don't need exact length, add a bit more output to catch any wobbles past the end + int paddedOutputLength = outputLength + tailSamples; + outWav.samples.resize(paddedOutputLength*outWav.channels); + + // The simplest way to deal with input latency is to always be slightly ahead in the input + stretch.seek(inWav, stretch.inputLatency(), 1/time); + + // Process it all in one call, although it works just the same if we split into smaller blocks + inWav.offset += stretch.inputLatency(); + + // These lengths in the right ratio to get the time-stretch + stretch.process(inWav, inputLength, outWav, outputLength); + + // Read the last bit of output without giving it any more input + outWav.offset += outputLength; + stretch.flush(outWav, tailSamples); + outWav.offset -= outputLength; + + if (exactLength) { + // The start has some extra output - we could just trim it, but we might as well fold it back into the output + for (int c = 0; c < outWav.channels; ++c) { + for (int i = 0; i < stretch.outputLatency(); ++i) { + double trimmed = outWav[stretch.outputLatency() - 1 - i][c]; + outWav[stretch.outputLatency() + i][c] -= trimmed; // reversed in time and negated + } + } + // Skips the output + outWav.offset += stretch.outputLatency(); + + // the `.flush()` call already handled foldback stuff at the end (since we asked for a shorter `tailSamples`) + } + + if (!outWav.write(outputWav).warn()) args.errorExit("failed to write WAV"); +} diff --git a/cmd/util/console-colours.h b/cmd/util/console-colours.h new file mode 100644 index 0000000..00b4319 --- /dev/null +++ b/cmd/util/console-colours.h @@ -0,0 +1,41 @@ +#pragma once +#ifndef _CONSOLE_COLOURS_H +#define _CONSOLE_COLOURS_H + +#include + +namespace Console { + std::string Reset = "\x1b[0m"; + std::string Bright = "\x1b[1m"; + std::string Dim = "\x1b[2m"; + std::string Underscore = "\x1b[4m"; + std::string Blink = "\x1b[5m"; + std::string Reverse = "\x1b[7m"; + std::string Hidden = "\x1b[8m"; + + namespace Foreground { + std::string Black = "\x1b[30m"; + std::string Red = "\x1b[31m"; + std::string Green = "\x1b[32m"; + std::string Yellow = "\x1b[33m"; + std::string Blue = "\x1b[34m"; + std::string Magenta = "\x1b[35m"; + std::string Cyan = "\x1b[36m"; + std::string White = "\x1b[37m"; + } + + namespace Background { + std::string Black = "\x1b[40m"; + std::string Red = "\x1b[41m"; + std::string Green = "\x1b[42m"; + std::string Yellow = "\x1b[43m"; + std::string Blue = "\x1b[44m"; + std::string Magenta = "\x1b[45m"; + std::string Cyan = "\x1b[46m"; + std::string White = "\x1b[47m"; + } + + using namespace Foreground; +} + +#endif \ No newline at end of file diff --git a/cmd/util/simple-args.h b/cmd/util/simple-args.h new file mode 100644 index 0000000..eb2d69a --- /dev/null +++ b/cmd/util/simple-args.h @@ -0,0 +1,322 @@ +#include +#include +#include +#include +#include +#include +#include // exit() and codes + +#include "console-colours.h" + +/** Expected use: + + SimpleArgs args(argc, argv); + + // positional argument + std::string foo = args.arg("foo"); + // optional argument + std::string bar = args.arg("bar", "a string for Bar", "default"); + // --flag=value + double = args.flag("baz", "an optional flag", 5); + + // Exits if "foo" not supplied + args.errorExit(); + + If you have multiple commands, each with their own options: + + // Switches based on a command + if (args.command("bink", "Bink description")) { + // collect arguments for the command + } + // Exits with a help message (and list of commands) if no command matched + args.errorCommand(); + + By default, a flag of "-h" (or a command of "help", if any commands are used) prints a help message. To override: + SimpleArgs args(argc, argv); + args.helpFlag("h"); + args.helpCommand("help"); + +**/ +class SimpleArgs { + int argc; + const char* const* argv; + + template + T valueFromString(const char *arg); + + std::string parsedCommand; + struct Keywords { + std::string keyword; + std::string description; + bool isHelp; + }; + std::vector keywordOptions; + std::vector argDetails; + std::vector flagOptions; + std::set flagSet; + void clearKeywords() { + keywordOptions.resize(0); + flagSet.clear(); + flagOptions.clear(); + } + + bool helpMode = false; + bool checkedHelpCommand = false; + bool hasError = false; + std::string errorMessage; + void setError(std::string message) { + if (!hasError) { + hasError = true; + errorMessage = message; + } + } + + std::map flagMap; + void consumeFlags() { + while (index < argc && std::strlen(argv[index]) > 1 && argv[index][0] == '-') { + const char* arg = argv[index++]; + size_t length = strlen(arg); + + size_t keyStart = 1, keyEnd = keyStart + 1; + size_t valueStart = keyEnd; + // If it's "--long-arg" format + if (length > 1 && arg[1] == '-') { + keyStart++; + while (keyEnd < length && arg[keyEnd] != '=') { + keyEnd++; + } + valueStart = keyEnd; + if (keyEnd < length) valueStart++; + } + + std::string key = std::string(arg + keyStart, keyEnd - keyStart); + std::string value = std::string(arg + valueStart); + + if (key == "help") { + helpMode = true; + } + + flagMap[key] = value; + } + } + + int index = 1; +public: + SimpleArgs(int argc, const char* const argv[]) : argc(argc), argv(argv) { + std::string cmd = argv[0]; + size_t slashPos = cmd.find_last_of("\\/"); + if (slashPos != std::string::npos) cmd = cmd.substr(slashPos + 1); + parsedCommand = cmd; + } + + void help(std::ostream& out=std::cerr) const { + std::string parsedCommand = this->parsedCommand; + if (keywordOptions.size() > 0) { + parsedCommand += std::string(" "); + } + out << "Usage:\n\t" << parsedCommand << "\n\n"; + if (keywordOptions.size() > 0) { + out << "Commands:\n"; + for (unsigned int i = 0; i < keywordOptions.size(); i++) { + out << "\t" << keywordOptions[i].keyword; + if (keywordOptions[i].isHelp) out << " [command...]"; + if (keywordOptions[i].description.size()) out << " - " << keywordOptions[i].description; + out << "\n"; + } + out << "\n"; + } + if (argDetails.size() > 0) { + out << "Arguments:\n"; + for (Keywords const &arg : argDetails) { + out << "\t" << arg.keyword; + if (arg.description.size()) out << " - " << arg.description; + out << "\n"; + } + out << "\n"; + } + if (flagOptions.size() > 0) { + out << "Options: " << Console::Dim << "(--arg=value)" << Console::Reset << "\n"; + for (Keywords const &pair : flagOptions) { + out << "\t" << (pair.keyword.length() > 1 ? "--" : "-") << pair.keyword; + if (pair.description.size()) out << " - " << pair.description; + out << "\n"; + } + out << "\n"; + } + } + + bool isHelp() const { + return helpMode; + } + bool finished() const { + return index >= argc; + } + std::string peek() const { + return (index >= argc) ? "" : argv[index]; + } + + int errorExit(std::ostream& out=std::cerr) const { + if (hasError || helpMode) { + help(out); + if (!helpMode) { + out << Console::Red << errorMessage << Console::Reset << "\n"; + } + std::exit((!helpMode && hasError) ? EXIT_FAILURE : EXIT_SUCCESS); + } + return 0; + } + int errorExit(std::string forcedError, std::ostream& out=std::cerr) const { + if (hasError) return errorExit(out); // Argument errors take priority + out << Console::Red << forcedError << Console::Reset << "\n"; + std::exit(EXIT_FAILURE); + return 0; + } + int errorCommand(std::string message="", std::ostream& out=std::cerr) const { + if (keywordOptions.size()) { + // We expected a command, but didn't match on any + if (helpMode) return errorExit(out); + if (index >= argc) help(out); + if (message.length() == 0) { + message = (index < argc) ? std::string("Unknown command: ") + argv[index] : "Missing command"; + } + errorExit(message, out); + } + return 0; + } + + template + T arg(std::string name, std::string longName, T defaultValue) { + consumeFlags(); + if (index < argc) clearKeywords(); + parsedCommand += std::string(" [") + name + "]"; + argDetails.push_back(Keywords{name, longName, false}); + + if (index >= argc) return defaultValue; + return valueFromString(argv[index++]); + } + + template + T arg(std::string name, std::string longName="") { + consumeFlags(); + if (index < argc) clearKeywords(); + parsedCommand += std::string(" <") + name + ">"; + argDetails.push_back(Keywords{name, longName, false}); + + if (index >= argc) { + if (longName.length() > 0) { + setError("Missing " + longName + " <" + name + ">"); + } else { + setError("Missing argument <" + name + ">"); + } + return T(); + } + + return valueFromString(argv[index++]); + } + + bool command(std::string keyword, std::string description="", bool isHelp=false) { + consumeFlags(); + if (index == 1) { + helpCommand(); + } + if (index < argc && !keyword.compare(argv[index])) { + clearKeywords(); + index++; + if (!isHelp) parsedCommand += std::string(" ") + keyword; + return true; + } + keywordOptions.push_back(Keywords{keyword, description, isHelp}); + return false; + } + bool helpCommand(std::string keyword="help") { + if (!checkedHelpCommand && index == 1) { + keywordOptions.push_back(Keywords{keyword, "", true}); + if (index < argc && !keyword.compare(argv[index])) { + index++; + helpMode = true; + } + } + checkedHelpCommand = true; + return helpMode; + } + + template + T flag(std::string key, std::string description, T defaultValue) { + consumeFlags(); + if (!hasFlag(key, description)) return defaultValue; + + auto iterator = flagMap.find(key); + return valueFromString(iterator->second.c_str()); + } + template + T flag(std::string key, T defaultValue) { + consumeFlags(); + if (!hasFlag(key, "")) return defaultValue; + + auto iterator = flagMap.find(key); + return valueFromString(iterator->second.c_str()); + } + template + T flag(std::string key) { + return flag(key, T()); + } + bool hasFlag(std::string key, std::string description="") { + consumeFlags(); + auto iterator = flagSet.find(key); + if (iterator == flagSet.end()) { + flagSet.insert(key); + flagOptions.push_back(Keywords{key, description, false}); + } else if (description.length() > 0) { + bool found = false; + for (auto &option : flagOptions) { + if (option.keyword == key) { + option.description = description; + found = true; + break; + } + } + if (!found) { + flagOptions.push_back(Keywords{key, description, false}); + } + } + + auto mapIterator = flagMap.find(key); + return mapIterator != flagMap.end(); + } + bool helpFlag(std::string key, std::string description="shows this help") { + consumeFlags(); + hasFlag(key, description); + auto iterator = flagMap.find(key); + helpMode = (iterator != flagMap.end()); + return helpMode; + } +}; + +template<> +std::string SimpleArgs::valueFromString(const char *arg) { + return arg; +} +template<> +const char * SimpleArgs::valueFromString(const char *arg) { + return arg; +} +template<> +int SimpleArgs::valueFromString(const char *arg) { + return std::stoi(arg); +} +template<> +long SimpleArgs::valueFromString(const char *arg) { + return std::stol(arg); +} +template<> +unsigned long SimpleArgs::valueFromString(const char *arg) { + return std::stoul(arg); +} +template<> +float SimpleArgs::valueFromString(const char *arg) { + return std::stof(arg); +} +template<> +double SimpleArgs::valueFromString(const char *arg) { + return std::stod(arg); +} diff --git a/cmd/util/wav.h b/cmd/util/wav.h new file mode 100644 index 0000000..496e197 --- /dev/null +++ b/cmd/util/wav.h @@ -0,0 +1,261 @@ +#ifndef RIFF_WAVE_H_ +#define RIFF_WAVE_H_ + +#include +#include +#include + +// TODO: something better here that doesn't assume little-endian architecture +template +struct BigEndian { + static uint32_t read16(std::istream& in) { + unsigned char a[2]; + in.read((char*)a, sizeof(a)); + return ((uint32_t)a[0]) + ((uint32_t)a[1])*256; + } + static uint32_t read32(std::istream& in) { + unsigned char a[4]; + in.read((char*)a, sizeof(a)); + return ((uint32_t)a[0]&0xff) + ((uint32_t)a[1])*256 + ((uint32_t)a[2])*65536 + ((uint32_t)a[3])*256*65536; + } + + static void write16(std::ostream& out, uint16_t value) { + char a[2] = {(char)(value>>0), (char)(value>>8)}; + out.write(a, sizeof(a)); + } + static void write32(std::ostream& out, uint32_t value) { + char a[4] = {(char)(value>>0), (char)(value>>8), (char)(value>>16), (char)(value>>24)}; + out.write(a, sizeof(a)); + } +}; + +class Wav : BigEndian { + // Little-endian versions of text values + uint32_t value_RIFF = 0x46464952; + uint32_t value_WAVE = 0x45564157; + uint32_t value_fmt = 0x20746d66; + uint32_t value_data = 0x61746164; + + using BigEndian::read16; + using BigEndian::read32; + using BigEndian::write16; + using BigEndian::write32; + +public: + struct Result { + enum class Code { + OK = 0, + IO_ERROR, + FORMAT_ERROR, + UNSUPPORTED, + WEIRD_CONFIG + }; + Code code = Code::OK; + std::string reason; + + Result(Code code, std::string reason="") : code(code), reason(reason) {}; + Result(const Result &other) = default; + Result & operator=(const Result &other) { + if (code == Code::OK) { + code = other.code; + reason = other.reason; + } + return *this; + } + // Used to neatly test for success + explicit operator bool () const { + return code == Code::OK; + }; + const Result & warn(std::ostream& output=std::cerr) const { + if (!(bool)*this) { + output << "WAV error: " << reason << std::endl; + } + return *this; + } + }; + + unsigned int sampleRate = 48000; + unsigned int channels = 1, offset = 0; + std::vector samples; + int length() const { + return samples.size()/channels - offset; + } + void resize(int length) { + samples.resize((offset + length)*channels, 0); + } + template + class ChannelReader { + using CSample = typename std::conditional::type; + CSample *data; + int stride; + public: + ChannelReader(CSample *samples, int channels) : data(samples), stride(channels) {} + + CSample & operator [](int i) { + return data[i*stride]; + } + }; + ChannelReader operator [](int c) { + return ChannelReader(samples.data() + offset*channels + c, channels); + } + ChannelReader operator [](int c) const { + return ChannelReader(samples.data() + offset*channels + c, channels); + } + + Result result = Result(Result::Code::OK); + + Wav() {} + Wav(double sampleRate, int channels) : sampleRate(sampleRate), channels(channels) {} + Wav(double sampleRate, int channels, const std::vector &samples) : sampleRate(sampleRate), channels(channels), samples(samples) {} + Wav(std::string filename) { + result = read(filename).warn(); + } + + enum class Format { + PCM=1 + }; + bool formatIsValid(uint16_t format, uint16_t bits) const { + if (format == (uint16_t)Format::PCM) { + if (bits == 16) { + return true; + } + } + return false; + } + + Result read(std::string filename) { + std::ifstream file; + file.open(filename, std::ios::binary); + if (!file.is_open()) return result = Result(Result::Code::IO_ERROR, "Failed to open file: " + filename); + + // RIFF chunk + if (read32(file) != value_RIFF) return result = Result(Result::Code::FORMAT_ERROR, "Input is not a RIFF file"); + read32(file); // File length - we don't check this + if (read32(file) != value_WAVE) return result = Result(Result::Code::FORMAT_ERROR, "Input is not a plain WAVE file"); + + auto blockStart = file.tellg(); // start of the blocks - we will seek back to here periodically + bool hasFormat = false, hasData = false; + + Format format = Format::PCM; // Shouldn't matter, we should always read the `fmt ` chunk before `data` + while (!file.eof()) { + auto blockType = read32(file), blockLength = read32(file); + if (!hasFormat && blockType == value_fmt) { + auto formatInt = read16(file); + format = (Format)formatInt; + channels = read16(file); + if (channels < 1) return result = Result(Result::Code::FORMAT_ERROR, "Cannot have zero channels"); + + sampleRate = read32(file); + if (sampleRate < 1) return result = Result(Result::Code::FORMAT_ERROR, "Cannot have zero sampleRate"); + + unsigned int expectedBytesPerSecond = read32(file); + unsigned int bytesPerFrame = read16(file); + unsigned int bitsPerSample = read16(file); + if (!formatIsValid(formatInt, bitsPerSample)) return result = Result(Result::Code::UNSUPPORTED, "Unsupported format:bits: " + std::to_string(formatInt) + ":" + std::to_string(bitsPerSample)); + // Since it's plain WAVE, we can do some extra checks for consistency + if (bitsPerSample*channels != bytesPerFrame*8) return result = Result(Result::Code::FORMAT_ERROR, "Format sizes don't add up"); + if (expectedBytesPerSecond != sampleRate*bytesPerFrame) return result = Result(Result::Code::FORMAT_ERROR, "Format sizes don't add up"); + + hasFormat = true; + file.clear(); + file.seekg(blockStart); + } else if (hasFormat && blockType == value_data) { + std::vector samples(0); + switch (format) { + case Format::PCM: + samples.reserve(blockLength/2); + for (size_t i = 0; i < blockLength/2; ++i) { + uint16_t value = read16(file); + if (file.eof()) break; + if (value >= 32768) { + samples.push_back(((double)value - 65536)/32768); + } else { + samples.push_back((double)value/32768); + } + } + } + while (samples.size()%channels != 0) { + samples.push_back(0); + } + this->samples = samples; + offset = 0; + hasData = true; + } else { + // We either don't recognise + file.ignore(blockLength); + } + } + if (!hasFormat) return result = Result(Result::Code::FORMAT_ERROR, "missing `fmt ` block"); + if (!hasData) return result = Result(Result::Code::FORMAT_ERROR, "missing `data` block"); + return result = Result(Result::Code::OK); + } + + Result write(std::string filename, Format format=Format::PCM) { + if (channels == 0 || channels > 65535) return result = Result(Result::Code::WEIRD_CONFIG, "Invalid channel count"); + if (sampleRate <= 0 || sampleRate > 0xFFFFFFFFu) return result = Result(Result::Code::WEIRD_CONFIG, "Invalid sample rate"); + + std::ofstream file; + file.open(filename, std::ios::binary); + if (!file.is_open()) return result = Result(Result::Code::IO_ERROR, "Failed to open file: " + filename); + + int bytesPerSample; + switch (format) { + case Format::PCM: + bytesPerSample = 2; + break; + } + + // File size - 44 bytes is RIFF header, "fmt" block, and "data" block header + unsigned int dataLength = (samples.size() - offset*channels)*bytesPerSample; + unsigned int fileLength = 44 + dataLength; + + // RIFF chunk + write32(file, value_RIFF); + write32(file, fileLength - 8); // File length, excluding the RIFF header + write32(file, value_WAVE); + // "fmt " block + write32(file, value_fmt); + write32(file, 16); // block length + write16(file, (uint16_t)format); + write16(file, channels); + write32(file, sampleRate); + unsigned int expectedBytesPerSecond = sampleRate*channels*bytesPerSample; + write32(file, expectedBytesPerSecond); + write16(file, channels*bytesPerSample); // Bytes per frame + write16(file, bytesPerSample*8); // bist per sample + + // "data" block + write32(file, value_data); + write32(file, dataLength); + switch (format) { + case Format::PCM: + for (unsigned int i = offset*channels; i < samples.size(); i++) { + double value = samples[i]*32768; + if (value > 32767) value = 32767; + if (value <= -32768) value = -32768; + if (value < 0) value += 65536; + write16(file, (uint16_t)value); + } + break; + } + return result = Result(Result::Code::OK); + } + + void makeMono() { + std::vector newSamples(samples.size()/channels, 0); + + for (size_t channel = 0; channel < channels; ++channel) { + for (size_t i = 0; i < newSamples.size(); ++i) { + newSamples[i] += samples[i*channels + channel]; + } + } + for (size_t i = 0; i < newSamples.size(); ++i) { + newSamples[i] /= channels; + } + + channels = 1; + samples = newSamples; + } +}; + +#endif // RIFF_WAVE_H_