Wednesday, January 1, 2020

MLIR Polyhedral: Tiling Experiments

Based on the present Structured Ops in MLIR (p.56 ~ p.60), this post did some experiments on matrix multiplication to test MLIR's multi-level tiling.

Input
#strided2D = (d0, d1)[s0] -> (d0 * s0 + d1)
linalg.matmul(%A, %B, %C) : memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>
Naïve Codegen
for (int i = 0..M)
  for (int j = 0..N)
    float dot = 0.f
    for (int k = 0..K)
      dot += A[i][k] * B[k][j];
    C[i][j] += dot
(Multi-level) Tiling Codegen
for (int i = 0..M, step = TileM)
  for (int j = 0..N, step = TileN)
    for (int k = 0..K, step = TileK)
      for (int ii = 0..TileM)
        for (int jj = 0..TileN)
          float dot = 0.f
          for (int kk = 0..TileK)
            dot += A[i + ii][k + kk] * B[k + kk][j + jj];
          C[i + ii][j + jj] += dot
mlir terms:
  • linalg: mlir linear algebra dialect
  • linalg.matmul: matrix multiplication operation
  • memref: memory reference
    • <?x?xf32>: 2D ranked-memref-type, element is f32 with unknow size.
    • memref<*xf32>: unranked-memref-type, can be any rank.
    • (d0, d1)[s0] -> (d0 * s0 + d1): Give index d0, d1, and the stride s0 (number elements of a row), the element[d0][d1] is gotten by *(element + d0 * s0 + d1).

Git Clone llvm-project and Build llvm/clang/mlir

git clone https://github.com/llvm/llvm-project.git
mkdir llvm-project/build && cd llvm-project/build
cmake -G Ninja ../llvm \
   -DLLVM_ENABLE_PROJECTS=”clang;mlir” \
   -DLLVM_BUILD_EXAMPLES=ON \
   -DLLVM_TARGETS_TO_BUILD="host" \
   -DCMAKE_BUILD_TYPE=Release \
   -DLLVM_ENABLE_ASSERTIONS=ON
cmake --build .
It's a bit different from mlir official getting started document, we also build clang and other llvm tools for our experiments.

Modify and Add Test Files

Add tile-test.mlir:
  • This file creates a matmul test, the sizes of A, B, C are all 1024x1024 with f32 type.
  • Matrices A, B, C are initialized with 2.0, 1.0, 10.0 respectively
  • Use linalg.matmul to compute matrix multiplication, and return MatrixC[6][7] as result
#strided1D = (d0) -> (d0)
#strided2D = (d0, d1)[s0] -> (d0 * s0 + d1)

// Creates and returns a 1-D buffer of size %s filled with the value %f
func @alloc_filled_f32(%s : index, %f : f32) -> memref<?xi8> {
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  %c4 = constant 4 : index
  %s4 = muli %s, %c4: index
  %buf = alloc(%s4) {alignment = 256} : memref<?xi8>
  %V = view %buf[%s][] : memref<?xi8> to memref<?xf32, #strided1D>
  linalg.fill(%V, %f) : memref<?xf32, #strided1D>, f32
  return %buf : memref<?xi8>
}

func @matmul() -> f32 {
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  %c6 = constant 6 : index
  %c7 = constant 7 : index

  %m = constant 1024 : index
  %k = constant 1024 : index
  %n = constant 1024 : index
  %mk = constant 1048576 : index
  %kn = constant 1048576 : index
  %mn = constant 1048576 : index

  %f1 = constant 1.00000e+00 : f32
  %f2 = constant 2.00000e+00 : f32
  %f10 = constant 10.00000e+00 : f32

  %bA = call @alloc_filled_f32(%mk, %f2) : (index, f32) -> (memref<?xi8>)
  %bB = call @alloc_filled_f32(%kn, %f1) : (index, f32) -> (memref<?xi8>)
  %bC = call @alloc_filled_f32(%mn, %f10) : (index, f32) -> (memref<?xi8>)

  %A = view %bA[][%m, %k] : memref<?xi8> to memref<?x?xf32, #strided2D>
  %B = view %bB[][%k, %n] : memref<?xi8> to memref<?x?xf32, #strided2D>
  %C = view %bC[][%m, %n] : memref<?xi8> to memref<?x?xf32, #strided2D>

  linalg.matmul(%A, %B, %C) : memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>
  %res = load %C[%c6, %c7] : memref<?x?xf32, #strided2D>

  dealloc %bC : memref<?xi8>
  dealloc %bB : memref<?xi8>
  dealloc %bA : memref<?xi8>

  return %res : f32
}
mlir terms:
  • (standard op) view: "converts a 1-D memref with i8 element type, to an N-D memref with arbitrary element type."
Modify mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td#L39:
Original:
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[2000, 3000, 4000], "L3"> $op),
          [(Constraint<Or<[HasNoLinalgTransformMarker,
                           HasLinalgTransformMarker<"MEM">]>> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[200, 300, 400], "L2"> $op),
          [(Constraint<HasLinalgTransformMarker<"L3">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[20, 30, 40], "L1"> $op),
          [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[2, 3, 4], "REG"> $op),
          [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
The original patterns create four level of tiles, but for simplicity, this post only creates three level of tiles, and uses the same sizes in each tile.
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[256, 256, 256], "L3"> $op),
          [(Constraint<Or<[HasNoLinalgTransformMarker,
                           HasLinalgTransformMarker<"MEM">]>> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[32, 32, 32], "L2"> $op),
          [(Constraint<HasLinalgTransformMarker<"L3">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[8, 8, 8], "L1"> $op),
          [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
The 3rd-level's tile size is 8x8, that means, we compute the sub-matmul of
  $C[m, n] += $A[m, k] x $B[k, n]

in a tile with size
  $C"[8x8] += $A"[8x8] x $B"[8x8].

The memory access footprint in a 8x8 tile is: 8*8*4 (float) * 3 (SubMatrix A", B", C") = 768 Bytes.

The 2nd-level is 32x32, total 12 KB, and the 1st-level is 256x256, total 768 KB.

We can dump the transformation result to see the difference:
Without tiling:
./bin/mlir-opt tile-test.mlir
  func @matmul() -> f32 {
    ...
    %c1024 = constant 1024 : index
    %c1024_0 = constant 1024 : index
    %c1024_1 = constant 1024 : index
    %c1048576 = constant 1048576 : index
    %c1048576_2 = constant 1048576 : index
    %c1048576_3 = constant 1048576 : index
    %cst = constant 1.000000e+00 : f32
    %cst_4 = constant 2.000000e+00 : f32
    %cst_5 = constant 1.000000e+01 : f32
    %0 = call @alloc_filled_f32(%c1048576, %cst_4) : (index, f32) -> memref<?xi8>
    %1 = call @alloc_filled_f32(%c1048576_2, %cst) : (index, f32) -> memref<?xi8>
    %2 = call @alloc_filled_f32(%c1048576_3, %cst_5) : (index, f32) -> memref<?xi8>
    %3 = std.view %0[][%c1024, %c1024_0] : memref<?xi8> to memref<?x?xf32, #map0>
    %4 = std.view %1[][%c1024_0, %c1024_1] : memref<?xi8> to memref<?x?xf32, #map0>
    %5 = std.view %2[][%c1024, %c1024_1] : memref<?xi8> to memref<?x?xf32, #map0>
    linalg.matmul(%3, %4, %5) : memref<?x?xf32, #map0>, memref<?x?xf32, #map0>, memref<?x?xf32, #map0>
    %6 = load %5[%c6, %c7] : memref<?x?xf32, #map0>
    ...

With tiling:
./bin/mlir-opt tile-test.mlir -test-linalg-transform-patterns
#map0 = (d0, d1)[s0] -> (d0 * s0 + d1)
#map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)
module {
  func @matmul() -> f32 {
    ...
    %3 = std.view %0[][%c1024, %c1024] : memref<?xi8> to memref<?x?xf32, #map0>
    %4 = std.view %1[][%c1024, %c1024] : memref<?xi8> to memref<?x?xf32, #map0>
    %5 = std.view %2[][%c1024, %c1024] : memref<?xi8> to memref<?x?xf32, #map0>
    loop.for %arg0 = %c0 to %c1024 step %c256 {
      loop.for %arg1 = %c0 to %c1024 step %c256 {
        loop.for %arg2 = %c0 to %c1024 step %c256 {
          %7 = std.subview %3[%arg0, %arg2][%c256, %c256][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map1>
          %8 = std.subview %4[%arg2, %arg1][%c256, %c256][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map1>
          %9 = std.subview %5[%arg0, %arg1][%c256, %c256][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map1>
          loop.for %arg3 = %c0 to %c256 step %c32 {
            loop.for %arg4 = %c0 to %c256 step %c32 {
              loop.for %arg5 = %c0 to %c256 step %c32 {
                %10 = std.subview %7[%arg3, %arg5][%c32, %c32][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1>
                %11 = std.subview %8[%arg5, %arg4][%c32, %c32][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1>
                %12 = std.subview %9[%arg3, %arg4][%c32, %c32][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1>
                loop.for %arg6 = %c0 to %c32 step %c8 {
                  loop.for %arg7 = %c0 to %c32 step %c8 {
                    loop.for %arg8 = %c0 to %c32 step %c8 {
                      %13 = std.subview %10[%arg6, %arg8][%c8, %c8][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1>
                      %14 = std.subview %11[%arg8, %arg7][%c8, %c8][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1>
                      %15 = std.subview %12[%arg6, %arg7][%c8, %c8][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1>
                      linalg.matmul(%13, %14, %15) : memref<?x?xf32, #map1>, memref<?x?xf32, #map1>, memref<?x?xf32, #map1>
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
    %6 = load %5[%c6, %c7] : memref<?x?xf32, #map0>
mlir terms:
Add main.cpp
  • Call tile-test.mlir's matmul function, and measure the execution time.
#include <chrono>
#include <iostream>

extern "C" float matmul();

int main(int argc, char *argv[]) {
  std::chrono::time_point<std::chrono::system_clock> startTime =
      std::chrono::system_clock::now();
  float v = matmul();
  std::chrono::duration<double> elapsedTime =
      std::chrono::system_clock::now() - startTime;

  std::cout << "matmul result = " << v << ", time = " << elapsedTime.count()
            << " seconds." << std::endl;
  return 0;
}

Add cblas.cpp
  • Implement linalg.fill (type: memref<?xf32, #strided1D>, f32), and
  • linalg.matmul (type: memref<?x?xf32, #strided2D>, .., ..)
#include "include/cblas.h"
#include <assert.h>

extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
                                          float f) {
  for (unsigned i = 0; i < X->sizes[0]; ++i)
    *(X->data + X->offset + i * X->strides[0]) = f;
}

static void sgemm(const enum CBLAS_ORDER Order,
                  const enum CBLAS_TRANSPOSE TransA,
                  const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
                  const int K, const float alpha, const float *A, const int lda,
                  const float *B, const int ldb, const float beta, float *C,
                  const int ldc) {
  assert(Order == CBLAS_ORDER::CblasRowMajor);
  assert(TransA == CBLAS_TRANSPOSE::CblasNoTrans);
  assert(TransB == CBLAS_TRANSPOSE::CblasNoTrans);
  for (int m = 0; m < M; ++m) {
    auto *pA = A + m * lda;
    auto *pC = C + m * ldc;
    for (int n = 0; n < N; ++n) {
      float c = pC[n];
      float res = 0.0f;
      for (int k = 0; k < K; ++k) {
        auto *pB = B + k * ldb;
        res += pA[k] * pB[n];
      }
      pC[n] = alpha * c + beta * res;
    }
  }
}

__attribute__((always_inline)) extern "C" void
linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
    StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
    StridedMemRefType<float, 2> *C) {
  //   std::endl; printMemRefMetaData(std::cerr, *A); std::cout << std::endl;
  //   printMemRefMetaData(std::cerr, *B);
  //   std::cout << std::endl;
  //   printMemRefMetaData(std::cerr, *C);
  //   std::cout << std::endl << std::endl;
  sgemm(CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans,
        CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1], A->sizes[1],
        1.0f, A->data + A->offset, A->strides[0], B->data + B->offset,
        B->strides[0], 1.0f, C->data + C->offset, C->strides[0]);
}
Notes:
  • linalg.matmul(%13, %14, %15) :
        memref<?x?xf32, #map1>,
        memref<?x?xf32, #map1>,
        memref<?x?xf32, #map1>

    name mangles into "linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32"
  • Use 'always_inline' attribute, force clang to inline 'linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32' when running llvm optimizations.

Build Test Program

Test program flow:
  1. main.cpp: main() -> tile-test.mlir: matmul()
  2. tile-test.mlir: matmul() -> cblas.cpp: linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32
cmake --build .
We changed the TestLinalgTransformPatterns.td, so build it again to refresh our changes.
./bin/clang++ -O3 -emit-llvm -S cblas.cpp -std=c++14 \
  -I ../mlir/test/mlir-cpu-runner \
  -I /Users/cycheng/llvm-project/libcxx/include \
  -o cblas.ll
Build cblas.cpp with O3, and generate llvm text bitcode.
./bin/mlir-opt tile-test.mlir -test-linalg-transform-patterns -convert-linalg-to-llvm \
       | ./bin/mlir-translate --mlir-to-llvmir > tile-test.ll
Apply the linalg transform patterns to converting/lowering part of linalg.matmul to loop dialect.
Then converting/lowering to llvm dialect.
Then translate llvm dialect to llvm ir.
./bin/llvm-link -S -o tile-test-all.ll cblas.ll tile-test.ll
Link cblas.ll and tile-test.ll into tile-test-all.ll
./bin/clang++ -emit-llvm -c -S -O3 tile-test-all.ll -o tile-test-all-o3.ll
Then doing O3 optimization against the linked file.
Clang will inline the call to linalg.matmul.
./bin/clang++ -O3 -I /Users/cycheng/llvm-project/libcxx/include main.cpp \
  tile-test-all-o3.ll -o tile-test
Finally, build the main.cpp, and execute tile-test program.
./tile-test
On my machine, I got
matmul result = 2058, time = 0.246217 seconds.

The test machine has 128KB-8Way L1D, 1MB-8Way L2 cache, 6MB-12Way Shared L3, running at 2.2 GHz
We can also test Naïve Codegen by removing '-test-linalg-transform-patterns' when compiling tile-test.mlir:
./bin/mlir-opt tile-test.mlir -convert-linalg-to-llvm \
       | ./bin/mlir-translate --mlir-to-llvmir > tile-test.ll
Finally I got
matmul result = 2058, time = 2.28546 seconds.

Result

This post has tested some different matrix sizes, result as follows:

tile (seconds)non-tile (seconds)non-tile / tile
512 x 512 x 5120.0266310.2234858.391911682
1024 x 512 x 5120.0540330.434978.050080506
1024 x 512 x 10240.1208660.8829287.305015472
1024 x 1024 x 10240.2391722.011868.411770609

Tiling version is much much faster than non-tiling version.

Next

Reference