Skip to content

From Demucs to demucs.cpp

How I adapted the Demucs neural network PyTorch inference in C++ for WebAssembly, with a more complicated architecture than Open-Unmix

Published at: 2024-02-26

Demucs

Demucs is a leading neural network for music source separation with significantly high performance. Every time a new Demucs architecture came out, it pushed forward the state-of-the-art. Right now, the Demucs v4 hybrid transformer architecture is pretty dominant, used heavily in the ensemble model that scored high in the most recent Sound Demixing Challenge 2023, after the Demucs v3 hybrid time-frequency model won the Music Demixing Challenge 2021.

My site https://freemusicdemixer.com and paid product https://pro.freemusicdemixer.com feature Demucs inference and ensembles that produce some remarkable stem separation. The beating heart of my site is demucs.cpp, a C++ translation of the Demucs v4 hybrid transformer model inference using only the header-only linear algebra library Eigen. Demucs.cpp easily compiles to WebAssembly with the Emscripten toolchain to run the world's best demixing AI model in the browser.

In my last post, I described a few details about my C++ implementation of the inference of the Open-Unmix neural network architecture, a venerable but simple model. In this post, I'll describe how I needed a bigger bag of tricks in my Eigen/PyTorch translation toolbox to correctly implement the more complicated operations of Demucs.

Most crucial tool: Tensor print debugging

What I considered the most important tool I had for implementing demucs.cpp was printing some information about the tensors in similar format in both the PyTorch code of Demucs, and my own C++ code for demucs.cpp.

Printing PyTorch tensors from Demucs Python code

For modifying the PyTorch print statements of an external model or library like Demucs, my procedure is the following:

  • Clone the Demucs source code repo locally
  • Create a new Python environment (I use conda/mamba but a virtualenv should be the same)
    • This is just standard best practice; if you know what you're doing, use whichever Python you want
  • Pip install your local copy of Demucs into the Python env in editable mode: cd my-demucs-clone/ && pip install -e .
  • Now, your local changes to the source code will be reflected when you run demucs, the command-line separation tool

When I run demucs (which is provided as part of the Demucs pip package), you can see all my different print statements:

$ demucs -n 'htdemucs' ~/Music/Cypress\ Hill\ -\ Cock\ the\ Hammer\ \(Official\ Audio\)\ \[Sk9T6dfEl4M\].opus  -o demucs_out
Selected model is a bag of 1 models. You will see that many progress bars per track.
Separated tracks will be stored in /home/sevagh/repos/demucs.cpp/demucs_out/htdemucs
Separating track /home/sevagh/Music/Cypress Hill - Cock the Hammer (Official Audio) [Sk9T6dfEl4M].opus
Debugging tensor!: mix in shift!
        shape: (1, 2, 11613008)
        min: -4.325106620788574
        max: 4.32819938659668
        mean: 1.3361121098398598e-08
        stddev: 1.0053081512451172
        sum: 0.31032562255859375
        min idx: (0, 1, 3485316)
        max idx: (0, 1, 4180205)
FINISHED DEBUG FOR TENSOR mix in shift!
Debugging tensor!: padded_mix
        shape: (1, 2, 11657108)
        min: -4.325106620788574
        max: 4.32819938659668
        mean: 1.3281122868136208e-08
        stddev: 1.0034047365188599
        sum: 0.30963897705078125
        min idx: (0, 1, 3507366)
        max idx: (0, 1, 4202255)
FINISHED DEBUG FOR TENSOR padded_mix
1., apply_model w/ shift, offset, shifted shape: 1337, [1, 2, 11633721]
Debugging tensor!: shifted_audio
        shape: (1, 2, 11633721)
        min: -4.325106620788574
        max: 4.32819938659668
        mean: 1.3340612170509303e-08
        stddev: 1.0044127702713013
        sum: 0.31040191650390625
        min idx: (0, 1, 3506029)
        max idx: (0, 1, 4200918)

I have a ton of print statements inserted into the codebase as I followed how the inference works from beginning (loading file) to end (returning files):

/home/sevagh/scratch/demucs.cpp-private/demucs-working-copy/demucs/utils.py
26:    print(f"Debugging tensor!: {name}")
27:    print(f"\tshape: {tuple(x.shape)}")
33:    print(f"\tmin: {x_min.item()}")
34:    print(f"\tmax: {x_max.item()}")
35:    print(f"\tmean: {x_mean.item()}")
36:    print(f"\tstddev: {x_stddev.item()}")
37:    print(f"\tsum: {x_sum.item()}")
38:    print(f"\tmin idx: {tuple(np.unravel_index(x_min_idx.item(), x.shape))}")
39:    print(f"\tmax idx: {tuple(np.unravel_index(x_max_idx.item(), x.shape))}")
40:    print(f"FINISHED DEBUG FOR TENSOR {name}")
76:        print(f"shape before trim: {tensor.shape}")
77:        print(f"delta: {delta}")
78:        print(f"limits: {delta // 2} {-(delta - delta // 2)}")
80:        print(f"shape after trim: {tensor.shape}")

/home/sevagh/scratch/demucs.cpp-private/demucs-working-copy/demucs/transformer.py
650:        print(f"Crosstransformer inference!")

/home/sevagh/scratch/demucs.cpp-private/demucs-working-copy/demucs/htdemucs.py
445:        print(f"in ispec: hl: {hl} z: {z.shape}")
447:        print(f"in ispec: z pad 1: {z.shape}")
449:        print(f"in ispec: z pad 2: {z.shape}")
453:        print(f"in ispec: x: {x.shape}, {pad=} {length=}")
455:        print(f"in ispec: x: {x.shape}")
534:        print(f"mix: {mix.shape}")
558:        print(f"mean: {mean}")
560:        print(f"std: {std}")
634:            print_idx = self.depth - idx - 1
637:            debug_tensor_demucscpp(x, f"x_{print_idx} pre-decoder-{print_idx}")
638:            debug_tensor_demucscpp(skip, f"skip pre-decoder-{print_idx}")
640:            debug_tensor_demucscpp(x, f"x_{print_idx} post-decoder-{print_idx}")
654:                    debug_tensor_demucscpp(xt, f"xt_{print_idx} pre-tdecoder-{print_idx}")
655:                    debug_tensor_demucscpp(skip, f"skip pre-tdecoder-{print_idx}")
657:                    debug_tensor_demucscpp(xt, f"xt_{print_idx} post-tdecoder-{print_idx}")

The tensor metric print function is debug_tensor_demucscpp:

def debug_tensor_demucscpp(x, name):
    #check if x is of type TensorChunk
    if hasattr(x, 'tensor'):
        # split into subchunk from self.offset:self.offset+self.length
        x = x.tensor[..., x.offset:x.offset+x.length]

    print(f"Debugging tensor!: {name}")
    print(f"\tshape: {tuple(x.shape)}")
    x_min, x_min_idx = torch.min(x.reshape(-1), dim=0)
    x_max, x_max_idx = torch.max(x.reshape(-1), dim=0)
    x_mean = torch.mean(x)
    x_stddev = torch.std(x)
    x_sum = torch.sum(x)
    print(f"\tmin: {x_min.item()}")
    print(f"\tmax: {x_max.item()}")
    print(f"\tmean: {x_mean.item()}")
    print(f"\tstddev: {x_stddev.item()}")
    print(f"\tsum: {x_sum.item()}")
    print(f"\tmin idx: {tuple(np.unravel_index(x_min_idx.item(), x.shape))}")
    print(f"\tmax idx: {tuple(np.unravel_index(x_max_idx.item(), x.shape))}")
    print(f"FINISHED DEBUG FOR TENSOR {name}")

Why I picked these metrics

In umx.cpp, I was originally only printing the min and max values, which I thought was enough to debug. In demucs.cpp, after many errors resulting from incorrect dimensions, I had to introduce the min and max index alongside the min and max values. Similarly, mean, stddev, and sum of all elements added more comparison metrics in many situations where the min and max values of my layers were accidentally correct but the rest of the tensor was mangled or incorrect.

Printing Eigen tensors from demucs.cpp C++ code

In my own C++ code, I added the same print statements so I could compare both codebases side by side and eyeball each tensor for correctness:

$ ./demucs.cpp.main ../ggml-demucs/ggml-model-htdemucs-4s-f16.bin ~/Music/MDX-datasets/MUSDB18-HQ/test/Zeno\ -\ Signs/mixture.wav ./demucs-out
Loaded model (533 tensors,  80.08 MB) in 0.180710 s
demucs_model_load returned true
Starting demucs inference

Debugging matrix!: full_audio
        shape: (2, 10336331)
        min: -1.00003051757812500000
        max: 0.96710103750228881836
        mean: -0.00014591435319744051
        stddev: 0.12570409476757049561
        sum: -3016.43798828125000000000
        min idx: (0, 9704872)
        max idx: (1, 9028866)
FINISHED DEBUG for tensor: full_audio
Debugging matrix!: normalized_audio
        shape: (2, 10336331)
        min: -8.08255100250244140625
        max: 7.81872510910034179688
        mean: 0.00000042776349573614
        stddev: 1.01614403724670410156
        sum: 8.84300994873046875000
        min idx: (0, 9704872)
        max idx: (1, 9028866)
FINISHED DEBUG for tensor: normalized_audio
Debugging matrix!: mix in shift!
        shape: (2, 10336331)
        min: -8.08255100250244140625
        max: 7.81872510910034179688
        mean: 0.00000042776349573614
        stddev: 1.01614403724670410156
        sum: 8.84300994873046875000
        min idx: (0, 9704872)
        max idx: (1, 9028866)
FINISHED DEBUG for tensor: mix in shift!
Debugging matrix!: padded_mix
        shape: (2, 10380431)
        min: -8.08255100250244140625
        max: 7.81872510910034179688
        mean: 0.00000042594618321345
        stddev: 1.01398324966430664062
        sum: 8.84300994873046875000
        min idx: (0, 9726922)
        max idx: (1, 9050916)
FINISHED DEBUG for tensor: padded_mix
1., apply model w/ shift, offset: 1337

The print functions spread out through my layers:

(system) sevagh@pop-os:~/scratch/demucs.cpp-private/build$ rg -iuu 'demucscppdebug::debug' ../src/
../src/crosstransformer.cpp
224:    demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer pre-layers");
225:    demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer pre-tlayers");
239:    demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-0");
...

../src/model_inference.cpp
81:    // demucscppdebug::debug_tensor_3dxcf(buffers.z, "z!");
105:    demucscppdebug::debug_tensor_3dxf(buffers.x, "x pre-std/mean");
124:    demucscppdebug::debug_tensor_3dxf(buffers.x, "x post-std/mean");
135:    demucscppdebug::debug_tensor_3dxf(buffers.xt, "xt pre-std/mean");
146:    demucscppdebug::debug_tensor_3dxf(buffers.xt, "xt post-std/mean");
155:    demucscppdebug::debug_tensor_3dxf(buffers.xt, "xt pre-encoder");
...
482:    demucscppdebug::debug_tensor_3dxf(buffers.targets_out,

../src/model_apply.cpp
64:    demucscppdebug::debug_matrix_xf(full_audio, "full_audio");
81:    demucscppdebug::debug_matrix_xf(normalized_audio, "normalized_audio");
110:    demucscppdebug::debug_matrix_xf(full_audio, "mix in shift!");

The implementation of these print statements is a bit more verbose than PyTorch, given the stronger tensor types in C++ that are fixed to the tensor dimensions (I did this by choice for more robust and understandable code, that's hopefully more performant than trying to use fully dynamic Eigen tensor API).

Example of the 3d Tensor print function:

// For Tensor3dXf
inline void debug_tensor_3dxf(const Eigen::Tensor3dXf &x,
                              const std::string &name)
{
    std::cout << "Debugging tensor!: " << name << std::endl;
    std::cout << "\tshape: (" << x.dimension(0) << ", " << x.dimension(1)
              << ", " << x.dimension(2) << ")" << std::endl;

    auto x_min = x.minimum();
    auto x_max = x.maximum();
    Eigen::Tensor<float, 0> x_sum_tensor = x.sum();
    float x_sum = x_sum_tensor(0);
    Eigen::Tensor<float, 0> x_mean_tensor = x.mean();
    float x_mean = x_mean_tensor(0);
    Eigen::Tensor<float, 0> x_stddev_tensor =
        ((x - x_mean).square()).mean().sqrt();
    float x_stddev = x_stddev_tensor(0);

    // You might need to keep the existing loop for this purpose, or use other
    // methods Re-inserting the loop for finding indices of min and max
    int x_min_idx_0 = -1, x_min_idx_1 = -1, x_min_idx_2 = -1;
    int x_max_idx_0 = -1, x_max_idx_1 = -1, x_max_idx_2 = -1;
    float min_val = std::numeric_limits<float>::max();
    float max_val = std::numeric_limits<float>::lowest();

    for (int i = 0; i < x.dimension(0); ++i)
    {
        for (int j = 0; j < x.dimension(1); ++j)
        {
            for (int k = 0; k < x.dimension(2); ++k)
            {
                float val = x(i, j, k);
                if (val < min_val)
                {
                    min_val = val;
                    x_min_idx_0 = i;
                    x_min_idx_1 = j;
                    x_min_idx_2 = k;
                }
                if (val > max_val)
                {
                    max_val = val;
                    x_max_idx_0 = i;
                    x_max_idx_1 = j;
                    x_max_idx_2 = k;
                }
            }
        }
    }

    std::cout << "\tmin: " << x_min << std::endl;
    std::cout << "\tmax: " << x_max << std::endl;
    std::cout << "\tmean: " << x_mean << std::endl;
    std::cout << "\tstddev: " << x_stddev << std::endl;
    std::cout << "\tsum: " << x_sum << std::endl;
    std::cout << "\tmin idx: (" << x_min_idx_0 << ", " << x_min_idx_1 << ", "
              << x_min_idx_2 << ")" << std::endl;
    std::cout << "\tmax idx: (" << x_max_idx_0 << ", " << x_max_idx_1 << ", "
              << x_max_idx_2 << ")" << std::endl;

    std::cout << "FINISHED DEBUG for tensor: " << name << std::endl;
}

Row-major vs. col-major, PyTorch vs. Eigen

The error I made the most frequently throughout the development of demucs.cpp was around incorrectly loading a RowMajor tensor from PyTorch (indirectly through the ggml model file format) into the Eigen ColMajor tensor without accounting for the difference in row- and column-major formats.

Wikipedia describes how C stores multidimensional arrays in row-major order, and Fortran stores them in col-major. In Numpy, you can choose between the two. Eigen uses Column-major storage by default. This PyTorch blog post is a good resource as well. It's funny that I can't find written documentation for PyTorch that describes whether its tensors are row- or col-major, but I'm 100% certain they are row-major.

Printing the min and max index alongside the min and max value is the best way to easily spot this error:

# same tensor data, loaded as RowMajor
FINISHED DEBUG for tensor: tensor_test_1
Debugging tensor!: tensor_test_1
        shape: (2, 6, 336)
        min: -8.08255100250244140625
        max: 7.81872510910034179688
        mean: 0.00000042594618321345
        stddev: 1.01398324966430664062
        sum: 8.84300994873046875000
        min idx: (0, 4, 272)
        max idx: (1, 2, 183)

# same tensor data, loaded as ColMajor
FINISHED DEBUG for tensor: tensor_test_2
Debugging tensor!: tensor_test_2
        shape: (2, 6, 336)
        min: -8.08255100250244140625
        max: 7.81872510910034179688
        mean: 0.00000042594618321345
        stddev: 1.01398324966430664062
        sum: 8.84300994873046875000
        min idx: (0, 4, 134)
        max idx: (1, 1, 239)

The subtlety is that unless you were inspecting the min and max index, you may falsely believe your tensors are correct! I had to find this out the hard way.

ChatGPT explains the above

Index Mapping

To map the row-major indices to column-major, we have to understand how the linear index would change. The linear index of a position (i, j, k) in a 3D array with dimensions (D1, D2, D3) in row-major is:

Linear Index row-major = i×(D2×D3)+j×D3+k

And in column-major, it is:

Linear Index column-major = k×(D1×D2)+j×D1+i

To find the column-major indices corresponding to the row-major indices (0, 4, 272) and (1, 2, 183), we convert these to their linear indices and then back to the 3D indices under column-major rules.

Let's calculate this mapping for both the min and max indices.

To reflect how the row-major min index (0, 4, 272) and max index (1, 2, 183) would map to a tensor with dimensions 2, 6, 336 being interpreted first as row-major, then as column-major, the mappings are as follows:

  • The row-major min index of (0, 4, 272) maps to the column-major index (0, 4, 134).
  • The row-major max index of (1, 2, 183) maps to the column-major index (1, 1, 239).

How do I swap between the two? In demucs.cpp, I'm very conservative about this:

  • I use an intermediate Eigen types of RowMajor Half-type (float16), to match the float16 weights of Demucs v4 hybrid transformer model (a decision made by the creator of Demucs to easily save 50% disk space compared to the float32 weights without fancy quantization):
    namespace Eigen
    {
    // half/float16 typedefs for weights
    typedef Tensor<Eigen::half, 3, Eigen::RowMajor> Tensor3dXh;
    typedef Tensor<std::complex<Eigen::half>, 3, Eigen::RowMajor> Tensor3dXch;
    typedef Tensor<Eigen::half, 1, Eigen::RowMajor> Tensor1dXh;
    typedef Tensor<Eigen::half, 4, Eigen::RowMajor> Tensor4dXh;
    typedef Vector<Eigen::half, Dynamic> VectorXh;
    }
    
  • For the loaded model weights, I use Eigen types of ColMajor float-type (float32), since it is standard to do the forward pass in float32 despite quantizing weights smaller
    namespace Eigen
    {
    // define MatrixXh for some layers in demucs
    typedef Matrix<Eigen::half, Dynamic, Dynamic, Eigen::RowMajor> MatrixXh;
    
    // define Tensor3dXf, Tensor3dXcf for spectrograms etc.
    typedef Tensor<float, 4> Tensor4dXf;
    typedef Tensor<float, 3> Tensor3dXf;
    typedef Tensor<float, 2> Tensor2dXf;
    typedef Tensor<float, 1> Tensor1dXf;
    typedef Tensor<std::complex<float>, 3> Tensor3dXcf;
    } // namespace Eigen
    
  • As an intermediary between the two, I just iterate over their dimensions and let Eigen handle it for me:
    static size_t load_single_tensor3d(FILE *f, std::string &name,
                                       Eigen::Tensor3dXf &tensor, int *ne,
                                       int32_t nelements)
    {
        ...
        // Create a temporary Tensor3dXh to load the data
        // respecting the original pytorch row-major representation
        Eigen::Tensor3dXh tmp_tensor(tensor.dimensions());
    
        // loading weights
        const size_t bpe = sizeof(Eigen::half);
        auto nbytes_tensor = tmp_tensor.size() * bpe;
    
        fread(tmp_tensor.data(), bpe, nelements, f);
    
        printf("%16s: [%5d, %5d, %5d], type = float, %6.2f MB\n", name.data(),
               ne[0], ne[1], ne[2], nbytes_tensor / 1024.0 / 1024.0);
    
        // Create a column-major Tensor3dXf with the same dimensions as tmp_tensor
        Eigen::Tensor3dXf tensor_cm(ne[0], ne[1], ne[2]);
    
        // Manually copy the data from tmp_tensor to tensor_cm
        for (int i = 0; i < ne[0]; ++i)
        {
            for (int j = 0; j < ne[1]; ++j)
            {
                for (int k = 0; k < ne[2]; ++k)
                {
                    tensor_cm(i, j, k) = static_cast<float>(tmp_tensor(i, j, k));
                }
            }
        }
    
        // Assign tensor_cm to tensor
        tensor = tensor_cm;
        return nbytes_tensor;
    }
    

Next tool: printing and pausing after every layer

In the tensor print functions, to enable layer-by-layer debugging, you can add the following lines:

// In C++:
std::cin.ignore(); // pause execution until user presses enter

// In Python
input() # pause execution until user presses enter

Finally, use the print statements before and after each layer of execution, after which there will be a pause while you eyeball the values and make sure everything is correct. This was an indispensible tool for me to end up implementing the entire inference from end-to-end. I could proceed past each layer with confidence that it worked correctly:

// x0, x1, x2 are pre-allocated in the expected shapes
// of each encoder layer of demucs
demucscppdebug::debug_tensor_3dxf(x0, "x0 pre-encoder");

demucscpp::apply_time_encoder(model, x0, x1, 0);
demucscppdebug::debug_tensor_3dxf(x1, "x1 encoder-1");

demucscpp::apply_time_encoder(model, x1, x2, 1);
demucscppdebug::debug_tensor_3dxf(x2, "x2 encoder-1");

Layer tests and unit tests with simple inputs

All of the above uses a real execution of Demucs inference (in both the PyTorch and C++ command-line tools), pauses in between each layer, and allows an inspection of values. This is not a unit test or lightweight test, and very involved in manually setting up the testbench, so to speak.

In some cases I had to double down on a difficult layer or operation, and using a full-sized input file was too much to start with.

For those, I implemented a bunch of unit-test-like tests in both Python and C++ for individual layers or certain collections or groups of layers that belonged together (for example, the frequency encoder).

Python/pytorch script for the first 4 frequency encoder layers in isolation, using a simple input tensor of [-1, 1, -1, ...]:

from demucs.apply import apply_model
from demucs.pretrained import get_model
from demucs.pretrained import SOURCES
from demucs.utils import debug_tensor_demucscpp


if __name__ == '__main__':
    # demucs v4 hybrid transformer
    model = get_model('htdemucs')
    print(model)

    try:
        test_name = sys.argv[1]
    except IndexError:
        test_name = "all"

    if test_name == "all" or test_name == "freq-enc":
        # get the henclayer
        henclayer_0 = model.models[0].encoder[0]

        # create a fake tensor of shape (1, 4, 2048, 336)
        x = torch.ones((1, 4, 2048, 336))

        # set alternating odd index values to -1
        x[..., ::2] = -1

        x_enc_0 = henclayer_0(x)

        debug_tensor_demucscpp(x, "x")
        debug_tensor_demucscpp(x_enc_0, "x_enc_0")

        # continue for the rest of the encoder layers
        # generate tensors for each layer
        # shapes are:
        #    (96, 128, 336) -> (192, 32, 336) -> (384, 8, 336)
        # continue with x_enc_1,2,3

        henclayer_1 = model.models[0].encoder[1]
        x_enc_1 = henclayer_1(x_enc_0)

        debug_tensor_demucscpp(x_enc_1, "x_enc_1")

        henclayer_2 = model.models[0].encoder[2]
        x_enc_2 = henclayer_2(x_enc_1)

        debug_tensor_demucscpp(x_enc_2, "x_enc_2")

        henclayer_3 = model.models[0].encoder[3]
        x_enc_3 = henclayer_3(x_enc_2)

        debug_tensor_demucscpp(x_enc_3, "x_enc_3")

Note that this is just a Python script, not a real unit test framework, that I invoke with python scripts/demucs_pytorch_layer_test.py 'freq-enc'. Note also that the full Demucs model is loaded, and that I'm manually using its deeper inner layers to test single layers in isolation:

model = get_model('htdemucs')

# you can access any of the models layers by inspecting the model
# architecture and how pytorch saves it as a class
henclayer_0 = model.models[0].encoder[0]

C++ unit test code corresponding to the above using googletest, with the same input tensor [-1, 1, ...]:

#include "encdec.hpp"
#include "layers.hpp"
#include "model.hpp"
#include "tensor.hpp"
#include <gtest/gtest.h>

// google test global setup for model before all tests
static void setUpTestSuite()
{
    // load model from "../ggml-demucs/ggml-model-htdemucs-f16.bin"
    std::string model_file = "../ggml-demucs/ggml-model-htdemucs-f16.bin";
    auto ret = load_demucs_model_4s(model_file, &model);
}

// write a basic test case for a stereo file
TEST(DemucsCPPLayers, FreqEncoders)
{
    setUpTestSuite();

    std::cout << std::fixed << std::setprecision(20) << std::endl;

    Eigen::Tensor3dXf x_fake(4, 2048, 336);

    // fill with -1, 1 alternating
#pragma omp parallel for collapse(3)
    for (size_t i = 0; i < 4; ++i)
    {
        for (size_t j = 0; j < 2048; ++j)
        {
            for (size_t k = 0; k < 336; ++k)
            {
                if (k % 2 == 0)
                {
                    x_fake(i, j, k) = -1.0;
                }
                else
                {
                    x_fake(i, j, k) = 1.0;
                }
            }
        }
    }

    Eigen::Tensor3dXf x_fake_enc_0(48, 512, 336);
    demucscpp::apply_freq_encoder(model, 0, x_fake, x_fake_enc_0);

    demucscppdebug::debug_tensor_3dxf(x_fake, "x_fake");
    demucscppdebug::debug_tensor_3dxf(x_fake_enc_0, "x_fake_enc_0");

    Eigen::Tensor3dXf x_fake_enc_1(96, 128, 336);
    demucscpp::apply_freq_encoder(model, 1, x_fake_enc_0, x_fake_enc_1);
    demucscppdebug::debug_tensor_3dxf(x_fake_enc_1, "x_fake_enc_1");

    Eigen::Tensor3dXf x_fake_enc_2(192, 32, 336);
    demucscpp::apply_freq_encoder(model, 2, x_fake_enc_1, x_fake_enc_2);
    demucscppdebug::debug_tensor_3dxf(x_fake_enc_2, "x_fake_enc_2");

    Eigen::Tensor3dXf x_fake_enc_3(384, 8, 336);
    demucscpp::apply_freq_encoder(model, 3, x_fake_enc_2, x_fake_enc_3);
    demucscppdebug::debug_tensor_3dxf(x_fake_enc_3, "x_fake_enc_3");
}

Zooming in on a specific tensor by saving to a file

In a very desperate case, in the middle of the pausing-inspection-inference method, I would save the tensor to a file to then load in C++ (this was a primitive step before figuring out the cleaner unit testing above):

# generate fake tensors with random values and write to .txt file
# fake_x and fake_xt of the same shape as x and xt
print("FAKE DEBUGGING!")

# Define the range of values for the fake tensors
x_min = -19.62506866455078
x_max = 13.808989524841309
xt_min = -17.259153366088867
xt_max = 29.980199813842773

# Generate fake tensors with random values in the same range as x and xt
fake_x = (x_max - x_min) * torch.rand_like(x) + x_min
fake_xt = (xt_max - xt_min) * torch.rand_like(xt) + xt_min

# Store col-major for Eigen
fake_x_permuted = fake_x.permute(2, 1, 0)
fake_xt_permuted = fake_xt.permute(2, 1, 0)

# Save as "./fake_x.txt" and "./fake_xt.txt" respectively
fake_x_np = fake_x_permuted.cpu().numpy()
fake_xt_np = fake_xt_permuted.cpu().numpy()

# Save the NumPy arrays to text files
np.savetxt("./fake_x.txt", fake_x_np.reshape(-1, fake_x_np.shape[-1]))
np.savetxt("./fake_xt.txt", fake_xt_np.reshape(-1, fake_xt_np.shape[-1]))

debug_tensor_sevag(fake_x, "fake_x")
debug_tensor_sevag(fake_xt, "fake_xt")

for idx in range(self.num_layers):
    if idx % 2 == self.classic_parity:
        print("DEBUG MYTRANSFORMERLAYER")
        fake_x_tmp = self.layers[idx](fake_x)
        fake_xt_tmp = self.layers_t[idx](fake_xt)

        debug_tensor_sevag(fake_x_tmp, f"fake_x after layer {idx}")
        debug_tensor_sevag(fake_xt_tmp, f"fake_xt after layer {idx}")
    else:
        print("DEBUG CROSSTRANSFORMERLAYER")

        old_fake_x = fake_x
        fake_x_tmp = self.layers[idx](fake_x, fake_xt)
        fake_xt_tmp = self.layers_t[idx](fake_xt, old_fake_x)

        debug_tensor_sevag(fake_x_tmp, f"fake_x after layer {idx}")
        debug_tensor_sevag(fake_xt_tmp, f"fake_xt after layer {idx}")

    input(f"finish layer {idx} of crosstransformer")

input("END FAKE DEBUGGING!")

Loading the fake tensor in C++:

// let's try using fake inputs
// load from a file like last time
// and then run the layers on it
// and then compare the outputs

// open files "../fake_x.txt" and "../fake_xt.txt"
// and load the data into new tensor3dxf fake_x and fake_xt

Eigen::Tensor3dXf fake_x(x.dimension(0), x.dimension(1), x.dimension(2));
Eigen::Tensor3dXf fake_xt(xt.dimension(0), xt.dimension(1), xt.dimension(2));

// Load fake_x from file
std::ifstream file("../fake_x.txt");
if (file.is_open()) {
    float value;
    for (int i = 0; i < fake_x.size(); ++i) {
        file >> value;
        fake_x.data()[i] = value;
    }
    file.close();
} else {
    std::cerr << "Failed to open file" << std::endl;
}

// Load fake_xt from file
std::ifstream file2("../fake_xt.txt");
if (file2.is_open()) {
    float value;
    for (int i = 0; i < fake_xt.size(); ++i) {
        file2 >> value;
        fake_xt.data()[i] = value;
    }
    file2.close();
} else {
    std::cerr << "Failed to open file" << std::endl;
}

demucscppdebug::debug_tensor_3dxf(fake_x, "fake_x");
demucscppdebug::debug_tensor_3dxf(fake_xt, "fake_xt");

demucscpp::my_transformer_encoder_layer(
    model, fake_x, 0, 0);
demucscpp::my_transformer_encoder_layer(
    model, fake_xt, 1, 0);

demucscppdebug::debug_tensor_3dxf(fake_x, "fake_x after layer 0");
demucscppdebug::debug_tensor_3dxf(fake_xt, "fake_xt after layer 0");

std::cout << "finish layer 0 of crosstransformer" << std::endl;
std::cin.ignore();

Last tool: reimplementing PyTorch layers in Python

In some cases, the PyTorch layer is too opaque to give any hints or help on the implementation. This was the most relevant when I was implementing the CrossTransformer class of Demucs.

Here is the custom attention function of the CrossTransformerEncoderLayer:

def _ca_block(self, q, k, attn_mask=None):
    x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
    return self.dropout1(x)

This uses the forward function of the PyTorch module MultiheadAttention:

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)

This is difficult to follow into the Torch C++ implementation, so I went a different route: using ChatGPT and various online resources, it was easier to find and create a "from-scratch PyTorch reimplementations of bigger PyTorch modules" to then translate to the final C++/Eigen form. As an example, to create the cross-attention C++ function for demucs.cpp, I had to recreate it first in Python, using PyTorch tensors, but not the built-in MultiheadAttention module:

def custom_ca_block(q_norm, k_norm, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, embed_dim=512, num_heads=8):
    B, T, C = q_norm.size()
    print(f"q_norm: {q_norm.shape}")
    B_k, S, C_k = k_norm.size()

    assert B == B_k and C == C_k  # Ensure batch size and channel dimensions match

    # Compute Q, K, V matrices
    qkv_weights = in_proj_weight
    Q_weight, K_weight, V_weight = qkv_weights.chunk(3, 0)

    Q = torch.matmul(q_norm.view(T * B, C), Q_weight.t())
    K = torch.matmul(k_norm.view(S * B, C), K_weight.t())
    V = torch.matmul(k_norm.view(S * B, C), V_weight.t())  # Assuming V is computed from k_norm

    # Split heads
    Q = Q.view(T, B, num_heads, embed_dim // num_heads).transpose(1, 2)
    K = K.view(S, B, num_heads, embed_dim // num_heads).transpose(1, 2)
    V = V.view(S, B, num_heads, embed_dim // num_heads).transpose(1, 2)

    qkv_bias = in_proj_bias
    q_bias, k_bias, v_bias = qkv_bias.chunk(3)

    head_dim = embed_dim // num_heads
    q_bias = q_bias.view(1, num_heads, 1, head_dim).expand(T, num_heads, B, head_dim)
    k_bias = k_bias.view(1, num_heads, 1, head_dim).expand(S, num_heads, B, head_dim)
    v_bias = v_bias.view(1, num_heads, 1, head_dim).expand(S, num_heads, B, head_dim)

    Q = Q + q_bias
    K = K + k_bias
    V = V + v_bias

    # Compute cross-attention scores
    # Specifically, given Q of shape (L, N, E) and K of shape (S, N, E), the attention scores are computed as
    # scores = matmul(Q, K.transpose(-2, -1)) / sqrt(d_k).
    # Here, L is the target sequence length, S is the source sequence length, N is the batch size, and d_k is the dimension of keys.
    # The resulting scores tensor has shape (N, L, S),

    print(f"Q shape: {Q.shape}")
    print(f"K shape: {K.shape}")
    input("PAUSE!")

    # shape 1, 8, 2688, 64
    Q = Q.permute(2, 1, 0, 3)
    # shape of 1, 8, 64, 1344
    K = K.permute(2, 1, 3, 0)

    scores = torch.matmul(Q, K) / (head_dim ** 0.5)
    debug_tensor_sevag(scores, "scores")

    print(f"scores shape: {scores.shape}")
    debug_tensor_sevag(scores, "scores")
    input("PAUSE!")

    # Apply softmax to scores
    scores = F.softmax(scores, dim=-1)

    debug_tensor_sevag(V, "V")
    V = V.permute(2, 1, 0, 3)
    debug_tensor_sevag(V, "V")

    # Compute cross-attention output
    cross_attn_out = torch.matmul(scores, V)

    # Merge heads
    cross_attn_out = cross_attn_out.transpose(1, 2).contiguous().view(T, B, embed_dim)

    # Apply output projection
    out_proj = F.linear(cross_attn_out, out_proj_weight, out_proj_bias)

    return out_proj

In the middle of the Demucs crosstransformer inference, I added print statements to compare the custom cross-attention block with the built-in PyTorch one:

if self.norm_first:
    debug_tensor_sevag(q, "q")
    debug_tensor_sevag(k, "k")

    q_norm = self.norm1(q)
    k_norm = self.norm2(k)

    debug_tensor_sevag(q_norm, "q_norm")
    debug_tensor_sevag(k_norm, "k_norm")

    ca = self._ca_block(q_norm, k_norm, mask)
    ca_custom = custom_ca_block(q_norm, k_norm, self.cross_attn.in_proj_weight, self.cross_attn.in_proj_bias, self.cross_attn.out_proj.weight, self.cross_attn.out_proj.bias)
    debug_tensor_sevag(ca, "cross-attn")
    debug_tensor_sevag(ca_custom, "cross-attn custom")

Finally, once that proved correct, it wasn't so bad to implement the above custom cross-attention code in C++ (from demucs.cpp):

// Normalize x using the norm1 weights and biases
Eigen::Tensor3dXf q_norm =
    demucscpp::layer_norm(q, norm1_weight, norm1_bias, eps);

Eigen::Tensor3dXf k_norm;
if (self_attention)
{
    k_norm = q_norm;
}
else
{
    k_norm = demucscpp::layer_norm(k, norm2_weight, norm2_bias, eps);
}

// Cross-attention block
// Compute Q, K, V matrices

int B = q.dimension(0);
int T = q.dimension(1);
int C = q.dimension(2);

int B_k = k.dimension(0);
int S = k.dimension(1);
int C_k = k.dimension(2);

// Reshape q, k to 2D matrix of dimensions (T*B, C)

// Use Eigen::Map to avoid manual loops for reshaping
Eigen::MatrixXf q_norm_2d =
    Eigen::Map<const Eigen::MatrixXf>(q_norm.data(), T, C);
Eigen::MatrixXf k_norm_2d =
    Eigen::Map<const Eigen::MatrixXf>(k_norm.data(), S, C);

// Compute Q, K, V matrices
Eigen::MatrixXf Q =
    q_norm_2d * in_proj_weight.block(0, 0, C, C).transpose();
Eigen::MatrixXf K =
    k_norm_2d * in_proj_weight.block(C, 0, C, C).transpose();
Eigen::MatrixXf V =
    k_norm_2d * in_proj_weight.block(2 * C, 0, C, C).transpose();

Eigen::VectorXf q_bias = in_proj_bias.segment(0, C);
Eigen::VectorXf k_bias = in_proj_bias.segment(C, C);
Eigen::VectorXf v_bias = in_proj_bias.segment(2 * C, C);

// copied from linear layer: ff1.rowwise() += linear1_bias.transpose();
Q.rowwise() += q_bias.transpose();
K.rowwise() += k_bias.transpose();
V.rowwise() += v_bias.transpose();

int head_split = C / num_heads;

// map matrices to tensors
Eigen::Tensor3dXf Q_heads =
    Eigen::TensorMap<Eigen::Tensor3dXf>(Q.data(), T, head_split, num_heads);
Eigen::Tensor3dXf K_heads =
    Eigen::TensorMap<Eigen::Tensor3dXf>(K.data(), S, head_split, num_heads);
Eigen::Tensor3dXf V_heads =
    Eigen::TensorMap<Eigen::Tensor3dXf>(V.data(), S, head_split, num_heads);

Eigen::MatrixXf cross_attn_out(T, C);

for (int h = 0; h < num_heads; ++h)
{
    // Extract the h-th head from Q_heads and K_heads
    Eigen::Tensor2dXf Q_head_tensor = Q_heads.chip(h, 2);
    Eigen::Tensor2dXf K_head_tensor = K_heads.chip(h, 2);
    Eigen::Tensor2dXf V_head_tensor = V_heads.chip(h, 2);

    // Reshape the tensors to matrices
    Eigen::Map<Eigen::MatrixXf> Q_head(Q_head_tensor.data(), T, head_split);
    Eigen::Map<Eigen::MatrixXf> K_head(K_head_tensor.data(), S, head_split);
    Eigen::Map<Eigen::MatrixXf> V_head(V_head_tensor.data(), S, head_split);

    // Compute the dot product of Q_head and K_head
    Eigen::MatrixXf dot_product =
        Q_head * K_head.transpose() / std::sqrt((float)head_split);

    // Apply softmax to the dot product
    Eigen::ArrayXf max_vals = dot_product.rowwise().maxCoeff();
    Eigen::MatrixXf max_vals_expanded = max_vals.replicate(1, S);
    Eigen::MatrixXf softmax_scores =
        (dot_product - max_vals_expanded).array().exp().matrix();
    Eigen::VectorXf row_sums = softmax_scores.rowwise().sum();
    Eigen::MatrixXf divisor = row_sums.replicate(1, S);
    softmax_scores = (softmax_scores.array() / divisor.array()).matrix();

    Eigen::MatrixXf cross_attn_head = softmax_scores * V_head;
    cross_attn_out.block(0, h * head_split, T, head_split) =
        cross_attn_head;
}

// Copy q into q_2d (Map q to 2D matrix)
Eigen::Map<Eigen::MatrixXf> q_2d(q.data(), T, C);

// Apply output projection with gamma1_scale
Eigen::MatrixXf out_proj = cross_attn_out * out_proj_weight.transpose();
out_proj.array().rowwise() += out_proj_bias.transpose().array();
out_proj = out_proj.array().rowwise() * gamma1_scale.transpose().array();

// Add to q
q_2d += out_proj;

// before feedforward, apply norm3 to x i.e. q
q_norm = demucscpp::layer_norm(q, norm3_weight, norm3_bias, eps);
q_norm_2d = Eigen::Map<const Eigen::MatrixXf>(q_norm.data(), T, C);

// Feedforward block
// Linear layer 1
Eigen::MatrixXf ff1 = q_norm_2d * linear1_weight.transpose();
ff1.rowwise() += linear1_bias.transpose();

ff1 = demucscpp::gelu(ff1);

// Linear layer 2
Eigen::MatrixXf ff2 = ff1 * linear2_weight.transpose();
ff2.rowwise() += linear2_bias.transpose();

// Apply gamma_2 scale directly on 2D matrix
ff2 = ff2.array().rowwise() * gamma2_scale.transpose().array();

// now x = x + self.gamma_2(self._ff_block(self.norm3(q))))
q_2d += ff2;

// Map the 2D data back into a 3D tensor with dimensions (T, B, C)
q = Eigen::TensorMap<Eigen::Tensor3dXf>(q_2d.data(), T, B, C);

// Swap the first and last dimensions to get a tensor with dimensions (B, C,
// T)
Eigen::array<int, 3> permute_dims_3 = {1, 2, 0};
Eigen::Tensor3dXf q_shuf = q.shuffle(permute_dims_3);

// Normalize the output with norm_out/MyGroupNorm
q = demucscpp::group_norm(q_shuf, norm_out_weight, norm_out_bias, 1, eps);

Eigen::array<int, 3> permute_dims_4 = {0, 2, 1};
q_shuf = q.shuffle(permute_dims_4);

q = q_shuf;

Conclusion

Translating Demucs to demucs.cpp was a very manual, hand-written process. There are more rigorous and generalized projects out there, for example frugally-deep which has Eigen code that runs from Keras model file definitions. However, it was an incredibly rewarding project, and having hand-written demucs.cpp inference to be hardcoded to the Demucs v4 Hybrid Transformer model gives me a lot of leeway for performance optimizations.

In a bittersweet twist, the scientist behind Demucs is no longer working on it, so I ended up having hardcoded fast C++ inference for what ends up being the final official variant of Demucs (for now?).

Comments