Input
#strided2D = (d0, d1)[s0] -> (d0 * s0 + d1) linalg.matmul(%A, %B, %C) : memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>Naïve Codegen
for (int i = 0..M) for (int j = 0..N) float dot = 0.f for (int k = 0..K) dot += A[i][k] * B[k][j]; C[i][j] += dot(Multi-level) Tiling Codegen
for (int i = 0..M, step = TileM) for (int j = 0..N, step = TileN) for (int k = 0..K, step = TileK) for (int ii = 0..TileM) for (int jj = 0..TileN) float dot = 0.f for (int kk = 0..TileK) dot += A[i + ii][k + kk] * B[k + kk][j + jj]; C[i + ii][j + jj] += dotmlir terms:
- linalg: mlir linear algebra dialect
- linalg.matmul: matrix multiplication operation
- memref: memory reference
- <?x?xf32>: 2D ranked-memref-type, element is f32 with unknow size.
- memref<*xf32>: unranked-memref-type, can be any rank.
- (d0, d1)[s0] -> (d0 * s0 + d1): Give index d0, d1, and the stride s0 (number elements of a row), the element[d0][d1] is gotten by *(element + d0 * s0 + d1).
Git Clone llvm-project and Build llvm/clang/mlir
git clone https://github.com/llvm/llvm-project.git mkdir llvm-project/build && cd llvm-project/build cmake -G Ninja ../llvm \ -DLLVM_ENABLE_PROJECTS=”clang;mlir” \ -DLLVM_BUILD_EXAMPLES=ON \ -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON cmake --build .It's a bit different from mlir official getting started document, we also build clang and other llvm tools for our experiments.
Modify and Add Test Files
Add tile-test.mlir:- This file creates a matmul test, the sizes of A, B, C are all 1024x1024 with f32 type.
- Matrices A, B, C are initialized with 2.0, 1.0, 10.0 respectively
- Use linalg.matmul to compute matrix multiplication, and return MatrixC[6][7] as result
#strided1D = (d0) -> (d0) #strided2D = (d0, d1)[s0] -> (d0 * s0 + d1) // Creates and returns a 1-D buffer of size %s filled with the value %f func @alloc_filled_f32(%s : index, %f : f32) -> memref<?xi8> { %c0 = constant 0 : index %c1 = constant 1 : index %c4 = constant 4 : index %s4 = muli %s, %c4: index %buf = alloc(%s4) {alignment = 256} : memref<?xi8> %V = view %buf[%s][] : memref<?xi8> to memref<?xf32, #strided1D> linalg.fill(%V, %f) : memref<?xf32, #strided1D>, f32 return %buf : memref<?xi8> } func @matmul() -> f32 { %c0 = constant 0 : index %c1 = constant 1 : index %c6 = constant 6 : index %c7 = constant 7 : index %m = constant 1024 : index %k = constant 1024 : index %n = constant 1024 : index %mk = constant 1048576 : index %kn = constant 1048576 : index %mn = constant 1048576 : index %f1 = constant 1.00000e+00 : f32 %f2 = constant 2.00000e+00 : f32 %f10 = constant 10.00000e+00 : f32 %bA = call @alloc_filled_f32(%mk, %f2) : (index, f32) -> (memref<?xi8>) %bB = call @alloc_filled_f32(%kn, %f1) : (index, f32) -> (memref<?xi8>) %bC = call @alloc_filled_f32(%mn, %f10) : (index, f32) -> (memref<?xi8>) %A = view %bA[][%m, %k] : memref<?xi8> to memref<?x?xf32, #strided2D> %B = view %bB[][%k, %n] : memref<?xi8> to memref<?x?xf32, #strided2D> %C = view %bC[][%m, %n] : memref<?xi8> to memref<?x?xf32, #strided2D> linalg.matmul(%A, %B, %C) : memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D> %res = load %C[%c6, %c7] : memref<?x?xf32, #strided2D> dealloc %bC : memref<?xi8> dealloc %bB : memref<?xi8> dealloc %bA : memref<?xi8> return %res : f32 }mlir terms:
- (standard op) view: "converts a 1-D memref with i8 element type, to an N-D memref with arbitrary element type."
Original:
def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[2000, 3000, 4000], "L3"> $op), [(Constraint<Or<[HasNoLinalgTransformMarker, HasLinalgTransformMarker<"MEM">]>> $op)]>; def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[200, 300, 400], "L2"> $op), [(Constraint<HasLinalgTransformMarker<"L3">> $op)]>; def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[20, 30, 40], "L1"> $op), [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>; def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[2, 3, 4], "REG"> $op), [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;The original patterns create four level of tiles, but for simplicity, this post only creates three level of tiles, and uses the same sizes in each tile.
def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[256, 256, 256], "L3"> $op), [(Constraint<Or<[HasNoLinalgTransformMarker, HasLinalgTransformMarker<"MEM">]>> $op)]>; def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[32, 32, 32], "L2"> $op), [(Constraint<HasLinalgTransformMarker<"L3">> $op)]>; def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[8, 8, 8], "L1"> $op), [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;The 3rd-level's tile size is 8x8, that means, we compute the sub-matmul of
$C[m, n] += $A[m, k] x $B[k, n]
in a tile with size
$C"[8x8] += $A"[8x8] x $B"[8x8].
The memory access footprint in a 8x8 tile is: 8*8*4 (float) * 3 (SubMatrix A", B", C") = 768 Bytes.
The 2nd-level is 32x32, total 12 KB, and the 1st-level is 256x256, total 768 KB.
We can dump the transformation result to see the difference:
Without tiling:
./bin/mlir-opt tile-test.mlir
func @matmul() -> f32 { ... %c1024 = constant 1024 : index %c1024_0 = constant 1024 : index %c1024_1 = constant 1024 : index %c1048576 = constant 1048576 : index %c1048576_2 = constant 1048576 : index %c1048576_3 = constant 1048576 : index %cst = constant 1.000000e+00 : f32 %cst_4 = constant 2.000000e+00 : f32 %cst_5 = constant 1.000000e+01 : f32 %0 = call @alloc_filled_f32(%c1048576, %cst_4) : (index, f32) -> memref<?xi8> %1 = call @alloc_filled_f32(%c1048576_2, %cst) : (index, f32) -> memref<?xi8> %2 = call @alloc_filled_f32(%c1048576_3, %cst_5) : (index, f32) -> memref<?xi8> %3 = std.view %0[][%c1024, %c1024_0] : memref<?xi8> to memref<?x?xf32, #map0> %4 = std.view %1[][%c1024_0, %c1024_1] : memref<?xi8> to memref<?x?xf32, #map0> %5 = std.view %2[][%c1024, %c1024_1] : memref<?xi8> to memref<?x?xf32, #map0> linalg.matmul(%3, %4, %5) : memref<?x?xf32, #map0>, memref<?x?xf32, #map0>, memref<?x?xf32, #map0> %6 = load %5[%c6, %c7] : memref<?x?xf32, #map0> ...
With tiling:
./bin/mlir-opt tile-test.mlir -test-linalg-transform-patterns
#map0 = (d0, d1)[s0] -> (d0 * s0 + d1) #map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2) module { func @matmul() -> f32 { ... %3 = std.view %0[][%c1024, %c1024] : memref<?xi8> to memref<?x?xf32, #map0> %4 = std.view %1[][%c1024, %c1024] : memref<?xi8> to memref<?x?xf32, #map0> %5 = std.view %2[][%c1024, %c1024] : memref<?xi8> to memref<?x?xf32, #map0> loop.for %arg0 = %c0 to %c1024 step %c256 { loop.for %arg1 = %c0 to %c1024 step %c256 { loop.for %arg2 = %c0 to %c1024 step %c256 { %7 = std.subview %3[%arg0, %arg2][%c256, %c256][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map1> %8 = std.subview %4[%arg2, %arg1][%c256, %c256][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map1> %9 = std.subview %5[%arg0, %arg1][%c256, %c256][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map1> loop.for %arg3 = %c0 to %c256 step %c32 { loop.for %arg4 = %c0 to %c256 step %c32 { loop.for %arg5 = %c0 to %c256 step %c32 { %10 = std.subview %7[%arg3, %arg5][%c32, %c32][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1> %11 = std.subview %8[%arg5, %arg4][%c32, %c32][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1> %12 = std.subview %9[%arg3, %arg4][%c32, %c32][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1> loop.for %arg6 = %c0 to %c32 step %c8 { loop.for %arg7 = %c0 to %c32 step %c8 { loop.for %arg8 = %c0 to %c32 step %c8 { %13 = std.subview %10[%arg6, %arg8][%c8, %c8][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1> %14 = std.subview %11[%arg8, %arg7][%c8, %c8][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1> %15 = std.subview %12[%arg6, %arg7][%c8, %c8][%c1, %c1] : memref<?x?xf32, #map1> to memref<?x?xf32, #map1> linalg.matmul(%13, %14, %15) : memref<?x?xf32, #map1>, memref<?x?xf32, #map1>, memref<?x?xf32, #map1> } } } } } } } } } %6 = load %5[%c6, %c7] : memref<?x?xf32, #map0>mlir terms:
- loop.for %arg0 = %c0 to %c1024 step %c256 {
- loop dialect's for operation.
- Express in C
=> for (int %arg0 = 0; %arg0 < 1024; %arg0 += 256) - std.subview base_memref[offsets...][sizes...][strides...] : memref<...> to memref<...>
- For example:
%13 = std.subview %10[%arg6, %arg8][%c8, %c8][%c1, %c1] :
memref<?x?xf32, #map1> to memref<?x?xf32, #map1>
Roughly Express in C =>
typedef f32 _2d_f32_ty[8][8];
_2d_f32_ty *%13 = (_2d_f32_ty *)&%10[%arg6][%arg8];
(with additional mapping and stride information)
- Call tile-test.mlir's matmul function, and measure the execution time.
#include <chrono> #include <iostream> extern "C" float matmul(); int main(int argc, char *argv[]) { std::chrono::time_point<std::chrono::system_clock> startTime = std::chrono::system_clock::now(); float v = matmul(); std::chrono::duration<double> elapsedTime = std::chrono::system_clock::now() - startTime; std::cout << "matmul result = " << v << ", time = " << elapsedTime.count() << " seconds." << std::endl; return 0; }
Add cblas.cpp
- Implement linalg.fill (type: memref<?xf32, #strided1D>, f32), and
- linalg.matmul (type: memref<?x?xf32, #strided2D>, .., ..)
#include "include/cblas.h" #include <assert.h> extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f) { for (unsigned i = 0; i < X->sizes[0]; ++i) *(X->data + X->offset + i * X->strides[0]) = f; } static void sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const float alpha, const float *A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc) { assert(Order == CBLAS_ORDER::CblasRowMajor); assert(TransA == CBLAS_TRANSPOSE::CblasNoTrans); assert(TransB == CBLAS_TRANSPOSE::CblasNoTrans); for (int m = 0; m < M; ++m) { auto *pA = A + m * lda; auto *pC = C + m * ldc; for (int n = 0; n < N; ++n) { float c = pC[n]; float res = 0.0f; for (int k = 0; k < K; ++k) { auto *pB = B + k * ldb; res += pA[k] * pB[n]; } pC[n] = alpha * c + beta * res; } } } __attribute__((always_inline)) extern "C" void linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32( StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B, StridedMemRefType<float, 2> *C) { // std::endl; printMemRefMetaData(std::cerr, *A); std::cout << std::endl; // printMemRefMetaData(std::cerr, *B); // std::cout << std::endl; // printMemRefMetaData(std::cerr, *C); // std::cout << std::endl << std::endl; sgemm(CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1], A->sizes[1], 1.0f, A->data + A->offset, A->strides[0], B->data + B->offset, B->strides[0], 1.0f, C->data + C->offset, C->strides[0]); }Notes:
- linalg.matmul(%13, %14, %15) :
memref<?x?xf32, #map1>,
memref<?x?xf32, #map1>,
memref<?x?xf32, #map1>
name mangles into "linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32" - Use 'always_inline' attribute, force clang to inline 'linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32' when running llvm optimizations.
Build Test Program
Test program flow:
- main.cpp: main() -> tile-test.mlir: matmul()
- tile-test.mlir: matmul() -> cblas.cpp: linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32
cmake --build .
We changed the TestLinalgTransformPatterns.td, so build it again to refresh our changes../bin/clang++ -O3 -emit-llvm -S cblas.cpp -std=c++14 \ -I ../mlir/test/mlir-cpu-runner \ -I /Users/cycheng/llvm-project/libcxx/include \ -o cblas.llBuild cblas.cpp with O3, and generate llvm text bitcode.
./bin/mlir-opt tile-test.mlir -test-linalg-transform-patterns -convert-linalg-to-llvm \ | ./bin/mlir-translate --mlir-to-llvmir > tile-test.llApply the linalg transform patterns to converting/lowering part of linalg.matmul to loop dialect.
Then converting/lowering to llvm dialect.
Then translate llvm dialect to llvm ir.
./bin/llvm-link -S -o tile-test-all.ll cblas.ll tile-test.llLink cblas.ll and tile-test.ll into tile-test-all.ll
./bin/clang++ -emit-llvm -c -S -O3 tile-test-all.ll -o tile-test-all-o3.ll
Then doing O3 optimization against the linked file.
Clang will inline the call to linalg.matmul.
./bin/clang++ -O3 -I /Users/cycheng/llvm-project/libcxx/include main.cpp \ tile-test-all-o3.ll -o tile-test
Finally, build the main.cpp, and execute tile-test program.
./tile-test
On my machine, I got
matmul result = 2058, time = 0.246217 seconds.
The test machine has 128KB-8Way L1D, 1MB-8Way L2 cache, 6MB-12Way Shared L3, running at 2.2 GHz
We can also test Naïve Codegen by removing '-test-linalg-transform-patterns' when compiling tile-test.mlir:
./bin/mlir-opt tile-test.mlir -convert-linalg-to-llvm \ | ./bin/mlir-translate --mlir-to-llvmir > tile-test.ll
Finally I got
matmul result = 2058, time = 2.28546 seconds.
Result
This post has tested some different matrix sizes, result as follows:
Tiling version is much much faster than non-tiling version.
Next
- There are other interesting patterns in TestLinalgTransformPatterns.td, it would be interesting to test it.
- Spend time on tracing the tiling transformation source code.
- Read more polyhedral model related techniques.
- Read paper: Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code
- http://tiramisu-compiler.org/
(I learned tiramisu compiler from https://zhuanlan.zhihu.com/p/65356717)
Reference
- Structured Ops in MLIR Compiling Loops, Libraries and DSLs
- Highly recommended!!!!
- MLIR toy tutorials
https://github.com/llvm/llvm-project/tree/master/mlir/docs/Tutorials/Toy - MLIR open design weekly meeting
https://docs.google.com/document/d/1y_9f1AbfgcoVdJh4_aM6-BaSHvrHl8zuA5G4jv_94K8/edit#