Generalizing SIMD vector sizes

We've already seen the massive performance improvements in several real world scenarios in the previous posts. In this post I'd like to show a way of templating your code to run on varying vector length machines without unnecessary code duplication and effort. Not all your users are going to have the latest Intel Core i9 / Skylake X CPU which has AVX512, but you still want your code to run on as many machines as possible.

For this example, I'm going to implement a very simple General Matrix Multiplication (GeMM). An algorithm which is gaining popularity as its quite extensively used in various artificial intelligence algorithms, i.e. neural networks or deep learning networks.

I'm going to use some templating boilerplate code already written by an open source project called Embree. This code is easily resuable and has a fairly open license, so I encourage you to use it in your own projects. Since its written in C++, the rest of this example will use C++ as well. But there is no reason why you can't produce the same results with C using macros or any other language for that matter.
The formula is quite simple so I'll let you DuckDuckGo "Matrix multiplication formula" for more details on that.
Let's get straight to it.
// clang++-5.0 -std=c++11 -Iembree-git example.cpp  -mavx -Wall -g -O3

#include "common/simd/simd.h"

using namespace embree;

template <class V, class T>
void gemm(T *a, T *b, T *c, const int k, const int m, const int n)
{
    V vectorA;
    V vectorB;
    V vectorC;

    const int vectorSize = vectorC.size;

    for(int mi = 0; mi < m ; mi ++) 
    {
        for(int ki =0;ki < k; ki ++) 
        {
            vectorA = V::broadcast(a + (k*mi) + ki);

            for(int ni = 0; ni < n ; ni += vectorSize)
            {
                vectorC = V::loadu(c + (mi*n) + ni);
                vectorB = V::loadu(b + (ki*n) + ni);
    
                vectorC = madd(vectorA,vectorB,vectorC);

                V::storeu(c + (mi*n) + ni,vectorC);
            }   
        }   
    }   
}

From a performance standpoint, this isn't the fastest way to do a GeMM. Slicing up the matrices to fit into your cache size is critical. This alone can net a 5x speedup for very large matrices. Moreover if memory is cheap, eliminating the constant load+store by doing a transpose on Matrix B first is faster. Also notice how this code doesn't handle matrix sizes that are not multiples of the vectorSize. But this post is about templating so we'll keep the code as simple as possible.

As you can tell from my compiler line, I've cloned the Embree git tree into a subfolder called embree-git, then directly included their "common/simd/simd.h" header, pass -Iembree-git to the compiler and declare the embree namespace. Thats all I needed to do in order to use their code. Nice and clean!

With this one version of gemm(), I can now target a wide range of data types and machines without any code duplication. This one method can do a matrix multiplication for int32, int64, float or double type matrices, on 128bit SSE2, 256bit AVX and 512bit AVX512 machines. And can be easily extended to include other data types using the same templating features.

Here is an example of how to run our GEMM on a 128bit machine using float based matrices.
float *a = ...
float *b = ...
float *c = ...
gemm<vfloat<4>,float>(a,b,c,k,m,n);
And if you wanted to do double based matrices on a AVX512 machine, it would simply be;
double *a = ...
double *b = ...
double *c = ...
gemm<vdouble<8>,double>(a,b,c,k,m,n);
Look how easy that is!

Lets see how this ultra simple templated gemm() method scales across the different vector sizes on my 10 core i9 7900x Skylake-X machine.
There you have it. Free performance with a bit of C++ boiler plate to generalize SIMD programming. The Intel Skylake-X CPU is capable of more than 50 GFLOPS per core, so exercise to the reader - optimize it!

No comments:

Post a Comment

Generalizing SIMD vector sizes

We've already seen the massive performance improvements in several real world scenarios in the previous posts. In this post I'd like...