diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -3,6 +3,11 @@ C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); } +ods_def: +def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) { + C(n, m) = std_addf(std_mulf(A(k, m), B(n, k))); +} + ods_def: def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { x(m) = std_addf(std_mulf(A(m, n), y(n))); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -143,6 +143,10 @@ }]; let verifier = [{ return ::verify(*this); }]; + let assemblyFormat = [{ + `(` operands `)` attr-dict `:` type(operands) + }]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir @@ -0,0 +1,99 @@ +// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \ +// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul register-tile-sizes=12,32,16 vectorize" | \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \ + +// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \ +// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \ +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// Activate to dump assembly +// R_UN: -dump-object-file -object-filename=/tmp/a.o \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \ +// Use tee to both print to stderr and FileCheck +// RUN: tee -a /dev/stderr | FileCheck %s + + +!row_major_A = type memref<${M}x${K}xf32> +!row_major_B = type memref<${K}x${N}xf32> +!row_major_C = type memref<${M}x${N}xf32> + +func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C) +// TODO: activate manually for now. +// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} +{ + linalg.matmul ins(%a, %b : !row_major_A, !row_major_B) + outs(%c: !row_major_C) + return +} + +func @print_perf(%iters: index, %total_time: f64) { + %c2 = constant 2 : index + %cM = constant ${M} : index + %cN = constant ${N} : index + %cK = constant ${K} : index + + %mn = muli %cM, %cN : index + %mnk = muli %mn, %cK : index + + // 2*M*N*K. + %flops_per_iter = muli %c2, %mnk : index + %flops = muli %iters, %flops_per_iter : index + %flops_i64 = index_cast %flops : index to i64 + %flops_f = sitofp %flops_i64 : i64 to f64 + %flops_per_s = divf %flops_f, %total_time : f64 + vector.print %flops_per_s : f64 + + return +} + +func @main() { + %f0 = constant 0.0 : f32 + %f1 = constant 1.0 : f32 + + %A = alloc() : !row_major_A + %B = alloc() : !row_major_B + %C = alloc() : !row_major_C + + linalg.fill(%A, %f1) : !row_major_A, f32 + linalg.fill(%B, %f1) : !row_major_B, f32 + linalg.fill(%C, %f0) : !row_major_C, f32 + + %c0 = constant 0: index + %c1 = constant 1: index + %iters = constant ${ITERS}: index + + /// Run and dump performance for matmul. + /// Preheating run: + scf.for %arg0 = %c0 to %iters step %c1 { + linalg.fill(%C, %f0) : !row_major_C, f32 + call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> () + } + %t_start_matmul = call @rtclock() : () -> f64 + scf.for %arg0 = %c0 to %iters step %c1 { + // linalg.matmul writes %C in place, need to reset it to zero every time. + // This is accounts for about 10-15% perf hit on small sizes. + // Once linalg on tensors is ready, fusing fill at teh register level will + // be easy. + linalg.fill(%C, %f0) : !row_major_C, f32 + call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> () + } + %t_end_matmul = call @rtclock() : () -> f64 + %tmatmul = subf %t_end_matmul, %t_start_matmul: f64 + call @print_perf(%iters, %tmatmul) : (index, f64) -> () + + %res = load %C[%c0, %c0]: !row_major_C + // CHECK: 64 + vector.print %res: f32 + + dealloc %A : !row_major_A + dealloc %B : !row_major_B + dealloc %C : !row_major_C + + return +} + +func private @rtclock() -> f64 + +// TODO: init with random, run and check output. +// func private @fill_random_f32(memref<*xf32>) diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir @@ -0,0 +1,98 @@ +// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \ +// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_column_major register-tile-sizes=16,0,32 vectorize" | \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,16 vectorize" | \ + +// TODO: linalg.copy vectorization in the presence of permutation map fails. Enable when addressed. +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,16 vectorize" | \ + +// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \ +// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \ +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// Activate to dump assembly +// R_UN: -dump-object-file -object-filename=/tmp/a.o \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \ +// Use tee to both print to stderr and FileCheck +// RUN: tee -a /dev/stderr | FileCheck %s + +!row_major_A = type memref<${M}x${K}xf32> +!row_major_B = type memref<${K}x${N}xf32> +!row_major_C = type memref<${M}x${N}xf32> +!column_major_A = type memref<${K}x${M}xf32> +!column_major_B = type memref<${N}x${K}xf32> +!column_major_C = type memref<${N}x${M}xf32> + +func @matmul_column_major(%a: !column_major_A, %b: !column_major_B, %c: !column_major_C) +// TODO: activate manually for now. +// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} +{ + linalg.matmul_column_major ins(%a, %b : !column_major_A, !column_major_B) + outs(%c: !column_major_C) + return +} + +func @print_perf(%iters: index, %total_time: f64) { + %c2 = constant 2 : index + %cM = constant ${M} : index + %cN = constant ${N} : index + %cK = constant ${K} : index + + %mn = muli %cM, %cN : index + %mnk = muli %mn, %cK : index + + // 2*M*N*K. + %flops_per_iter = muli %c2, %mnk : index + %flops = muli %iters, %flops_per_iter : index + %flops_i64 = index_cast %flops : index to i64 + %flops_f = sitofp %flops_i64 : i64 to f64 + %flops_per_s = divf %flops_f, %total_time : f64 + vector.print %flops_per_s : f64 + + return +} + +func @main() { + %f0 = constant 0.0 : f32 + %f1 = constant 1.0 : f32 + + %cA = alloc() : !column_major_A + %cB = alloc() : !column_major_B + %cC = alloc() : !column_major_C + + linalg.fill(%cA, %f1) : !column_major_A, f32 + linalg.fill(%cB, %f1) : !column_major_B, f32 + linalg.fill(%cC, %f0) : !column_major_C, f32 + + %c0 = constant 0: index + %c1 = constant 1: index + %iters = constant ${ITERS}: index + + /// Run and dump performance for matmul_column_major. + %t_start_matmul_column_major = call @rtclock() : () -> f64 + scf.for %arg0 = %c0 to %iters step %c1 { + // linalg.matmul writes %C in place, need to reset it to zero every time. + // This is accounts for about 10-15% perf hit on small sizes. + // Once linalg on tensors is ready, fusing fill at teh register level will + // be easy. + linalg.fill(%cC, %f0) : !column_major_C, f32 + call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> () + } + %t_end_matmul_column_major = call @rtclock() : () -> f64 + %tmatmul_column_major = subf %t_end_matmul_column_major, %t_start_matmul_column_major: f64 + call @print_perf(%iters, %tmatmul_column_major) : (index, f64) -> () + + %res = load %cC[%c0, %c0]: !column_major_C + // CHECK: 64 + vector.print %res: f32 + + dealloc %cA : !column_major_A + dealloc %cB : !column_major_B + dealloc %cC : !column_major_C + + return +} + +func private @rtclock() -> f64 + +// TODO: init with random, run and check output. +// func private @fill_random_f32(memref<*xf32>) diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir @@ -0,0 +1,116 @@ +// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \ +// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_column_major register-tile-sizes=16,0,32 vectorize" | \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul register-tile-sizes=12,32,16 vectorize" | \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,16 vectorize" | \ + +// TODO: linalg.copy vectorization in the presence of permutation map fails. Enable when addressed. +// R_UN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,16 vectorize" | \ + +// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \ +// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \ +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// Activate to dump assembly +// R_UN: -dump-object-file -object-filename=/tmp/a.o \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \ +// Use tee to both print to stderr and FileCheck +// RUN: tee -a /dev/stderr | FileCheck %s + +!row_major_A = type memref<${M}x${K}xf32> +!row_major_B = type memref<${K}x${N}xf32> +!row_major_C = type memref<${M}x${N}xf32> +!column_major_A = type memref<${K}x${M}xf32> +!column_major_B = type memref<${N}x${K}xf32> +!column_major_C = type memref<${N}x${M}xf32> + +func @matmul_column_major_as_row_major( + %ca: !column_major_A, %cb: !column_major_B, %cc: !column_major_C, + %a: !row_major_A, %b: !row_major_B, %c: !row_major_C) +// TODO: activate manually for now. +// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} +{ + linalg.copy(%ca, %a) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !column_major_A, !row_major_A + linalg.copy(%cb, %b) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !column_major_B, !row_major_B + linalg.matmul ins(%a, %b : !row_major_A, !row_major_B) + outs(%c: !row_major_C) + linalg.copy(%c, %cc) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !row_major_C, !column_major_C + return +} + +func @print_perf(%iters: index, %total_time: f64) { + %c2 = constant 2 : index + %cM = constant ${M} : index + %cN = constant ${N} : index + %cK = constant ${K} : index + + %mn = muli %cM, %cN : index + %mnk = muli %mn, %cK : index + + // 2*M*N*K. + %flops_per_iter = muli %c2, %mnk : index + %flops = muli %iters, %flops_per_iter : index + %flops_i64 = index_cast %flops : index to i64 + %flops_f = sitofp %flops_i64 : i64 to f64 + %flops_per_s = divf %flops_f, %total_time : f64 + vector.print %flops_per_s : f64 + + return +} + +func @main() { + %f0 = constant 0.0 : f32 + %f1 = constant 1.0 : f32 + + %cA = alloc() : !column_major_A + %cB = alloc() : !column_major_B + %cC = alloc() : !column_major_C + + linalg.fill(%cA, %f1) : !column_major_A, f32 + linalg.fill(%cB, %f1) : !column_major_B, f32 + linalg.fill(%cC, %f0) : !column_major_C, f32 + + %c0 = constant 0: index + %c1 = constant 1: index + %iters = constant ${ITERS}: index + + /// Run and dump performance for matmul_column_major as a row-major + %A = alloc() : !row_major_A + %B = alloc() : !row_major_B + %C = alloc() : !row_major_C + %t_start_matmul_column_major_as_row_major = call @rtclock() : () -> f64 + scf.for %arg0 = %c0 to %iters step %c1 { + // linalg.matmul writes %C in place, need to reset it to zero every time. + // This is accounts for about 10-15% perf hit on small sizes. + // Once linalg on tensors is ready, fusing fill at teh register level will + // be easy. + linalg.fill(%C, %f0) : !row_major_C, f32 + call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) : + (!column_major_A, !column_major_B, !column_major_C, + !row_major_A, !row_major_B, !row_major_C) -> () + } + %t_end_matmul_column_major_as_row_major = call @rtclock() : () -> f64 + %tmatmul_column_major_as_row_major = subf %t_end_matmul_column_major_as_row_major, %t_start_matmul_column_major_as_row_major: f64 + call @print_perf(%iters, %tmatmul_column_major_as_row_major) : (index, f64) -> () + + %res = load %cC[%c0, %c0]: !column_major_C + // CHECK: 64 + vector.print %res: f32 + %res2 = load %C[%c0, %c0]: !row_major_C + // CHECK: 64 + vector.print %res2: f32 + + dealloc %A : !row_major_A + dealloc %B : !row_major_B + dealloc %C : !row_major_C + + dealloc %cA : !column_major_A + dealloc %cB : !column_major_B + dealloc %cC : !column_major_C + + return +} + +func private @rtclock() -> f64 + +// TODO: init with random, run and check output. +// func private @fill_random_f32(memref<*xf32>) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -51,10 +51,11 @@ // Some of these may be too aggressive as a stage 3 that is applied on each // stage 1 application and may have to be split out to post staged patterns // application (in which case they could just be passes, TBD). - PassManager pm(op->getContext()); - pm.addPass(createLoopInvariantCodeMotionPass()); - if (failed(pm.run(op->getParentOfType()))) - llvm_unreachable("Unexpected failure in cleanup pass pipeline."); + op->walk([&](LoopLikeOpInterface loopLike) { + LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n")); + if (failed(moveLoopInvariantCode(loopLike))) + llvm_unreachable("unexpected LICM failure"); + }); promoteSingleIterationLoops(cast(op)); hoistViewAllocOps(cast(op)); hoistRedundantVectorTransfers(cast(op)); @@ -67,13 +68,11 @@ // Post staged patterns transforms //===--------------------------------------------------------------------===// - ModuleOp module = func->getParentOfType(); - // Programmatic splitting of slow/fast path vector transfers. OwningRewritePatternList patterns; patterns.insert( context, vectorTransformsOptions); - applyPatternsAndFoldGreedily(module, std::move(patterns)); + applyPatternsAndFoldGreedily(func, std::move(patterns)); // Programmatic controlled lowering of vector.contract only. OwningRewritePatternList vectorContractLoweringPatterns; @@ -81,17 +80,16 @@ .insert( vectorTransformsOptions, context); - applyPatternsAndFoldGreedily(module, - std::move(vectorContractLoweringPatterns)); + applyPatternsAndFoldGreedily(func, std::move(vectorContractLoweringPatterns)); // Programmatic controlled lowering of vector.transfer only. OwningRewritePatternList vectorToLoopsPatterns; populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, vectorToSCFOptions); - applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns)); + applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns)); // Ensure we drop the marker in the end. - module.walk([](LinalgOp op) { + func.walk([](LinalgOp op) { op.removeAttr(LinalgTransforms::kLinalgTransformMarker); }); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -68,8 +68,8 @@ // TODO: Should be Tablegen'd from a single source that generates the op itself. static LogicalResult isContraction(Operation *op) { // TODO: interface for named ops. - if (isa(op)) + if (isa(op)) return success(); auto genericOp = dyn_cast(op); diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -1,7 +1,6 @@ add_subdirectory(Bindings) add_subdirectory(CAPI) add_subdirectory(EDSC) -add_subdirectory(mlir-cpu-runner) add_subdirectory(SDBM) add_subdirectory(lib) @@ -54,8 +53,6 @@ mlir-sdbm-api-test mlir-tblgen mlir-translate - mlir_test_cblas - mlir_test_cblas_interface mlir_runner_utils mlir_c_runner_utils mlir_async_runtime diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s -// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER // CHECK-LABEL: func @matmul( // OUTER-LABEL: func @matmul( diff --git a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp @@ -42,6 +42,9 @@ // clang-format on } + template + void applyStrategyToNamedLinalgOp(); + void runOnFunction() override; ListOption tileSizes{*this, "tile-sizes", @@ -91,11 +94,21 @@ *this, "unroll-vector-transfers", llvm::cl::desc("Enable full unrolling of vector.transfer operations"), llvm::cl::init(false)}; + Option anchorOpName{ + *this, "anchor-op", + llvm::cl::desc( + "Which single linalg op is the anchor for the codegen strategy to " + "latch on:\n" + "\tlinalg.matmul: anchor on linalg.matmul\n" + "\tlinalg.matmul_column_major: anchor on linalg.matmul_column_major\n" + "\tlinalg.copy: anchor on linalg.copy\n" + "\tlinalg.fill: anchor on linalg.fill\n"), + llvm::cl::init("")}; }; } // end anonymous namespace -/// Apply transformations specified as patterns. -void TestLinalgCodegenStrategy::runOnFunction() { +template +void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() { LinalgTilingOptions tilingOptions; if (!tileSizes.empty()) tilingOptions = tilingOptions.setTileSizes(tileSizes); @@ -121,27 +134,42 @@ .Default(vector::VectorTransferSplit::None); CodegenStrategy strategy; - strategy.tileIf(!tileSizes.empty(), tilingOptions) - .promoteIf(promote, - LinalgPromotionOptions() - .setAlignment(16) - .setUseFullTileBuffersByDefault(promoteFullTile)) - .tileIf(!registerTileSizes.empty(), registerTilingOptions) - .promoteIf(registerPromote, LinalgPromotionOptions() - .setAlignment(16) - .setUseFullTileBuffersByDefault( - registerPromoteFullTile)) - .vectorizeIf(vectorize) + strategy.template tileIf(!tileSizes.empty(), tilingOptions) + .template promoteIf( + promote, LinalgPromotionOptions() + .setAlignment(16) + .setUseFullTileBuffersByDefault(promoteFullTile)) + .template tileIf(!registerTileSizes.empty(), + registerTilingOptions) + .template promoteIf( + registerPromote, + LinalgPromotionOptions() + .setAlignment(16) + .setUseFullTileBuffersByDefault(registerPromoteFullTile)) + .template vectorizeIf(vectorize) .setVectorTransformsOptions( vector::VectorTransformsOptions() .setVectorTransformsOptions(vectorContractLowering) .setVectorTransferSplit(vectorTransferSplit)) .setVectorTransferToSCFOptions( VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); - strategy.transform(getFunction()); } +/// Apply transformations specified as patterns. +void TestLinalgCodegenStrategy::runOnFunction() { + if (anchorOpName == MatmulOp::getOperationName()) + applyStrategyToNamedLinalgOp(); + else if (anchorOpName == MatmulColumnMajorOp::getOperationName()) + applyStrategyToNamedLinalgOp(); + else if (anchorOpName == CopyOp::getOperationName()) + applyStrategyToNamedLinalgOp(); + else if (anchorOpName == FillOp::getOperationName()) + applyStrategyToNamedLinalgOp(); + else + llvm_unreachable("Unsupported anchor op"); +} + namespace mlir { namespace test { void registerTestLinalgCodegenStrategy() { diff --git a/mlir/test/mlir-cpu-runner/CMakeLists.txt b/mlir/test/mlir-cpu-runner/CMakeLists.txt deleted file mode 100644 --- a/mlir/test/mlir-cpu-runner/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -set(LLVM_OPTIONAL_SOURCES - mlir_test_cblas.cpp - mlir_test_cblas_interface.cpp - ) - -add_llvm_library(mlir_test_cblas SHARED mlir_test_cblas.cpp) -target_compile_definitions(mlir_test_cblas PRIVATE mlir_test_cblas_EXPORTS) - -add_llvm_library(mlir_test_cblas_interface SHARED mlir_test_cblas_interface.cpp) -target_link_libraries(mlir_test_cblas_interface PRIVATE mlir_test_cblas) -target_compile_definitions(mlir_test_cblas_interface PRIVATE mlir_test_cblas_interface_EXPORTS) - diff --git a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h b/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h deleted file mode 100644 --- a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h +++ /dev/null @@ -1,49 +0,0 @@ -//===- mlir_test_cblas.h - Simple Blas subset -----------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_ -#define MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_ - -#include "mlir/ExecutionEngine/RunnerUtils.h" - -#ifdef _WIN32 -#ifndef MLIR_TEST_CBLAS_EXPORT -#ifdef mlir_test_cblas_EXPORTS -// We are building this library -#define MLIR_TEST_CBLAS_EXPORT __declspec(dllexport) -#else -// We are using this library -#define MLIR_TEST_CBLAS_EXPORT __declspec(dllimport) -#endif // mlir_test_cblas_EXPORTS -#endif // MLIR_TEST_CBLAS_EXPORT -#else -#define MLIR_TEST_CBLAS_EXPORT -#endif // _WIN32 - -/// This reproduces a minimal subset of mlir_test_cblas to allow integration -/// testing without explicitly requiring a dependence on an external library. -/// Without loss of generality, various mlir_test_cblas implementations may be -/// swapped in by including the proper headers and linking with the proper -/// library. -enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 }; -enum CBLAS_TRANSPOSE { - CblasNoTrans = 111, - CblasTrans = 112, - CblasConjTrans = 113 -}; - -extern "C" MLIR_TEST_CBLAS_EXPORT float -mlir_test_cblas_sdot(const int N, const float *X, const int incX, - const float *Y, const int incY); - -extern "C" MLIR_TEST_CBLAS_EXPORT void mlir_test_cblas_sgemm( - const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, - const float alpha, const float *A, const int lda, const float *B, - const int ldb, const float beta, float *C, const int ldc); - -#endif // MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_ diff --git a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h b/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h deleted file mode 100644 --- a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h +++ /dev/null @@ -1,59 +0,0 @@ -//===- mlir_test_cblas_interface.h - Simple Blas subset interface ---------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_ -#define MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_ - -#include "mlir/ExecutionEngine/RunnerUtils.h" - -#ifdef _WIN32 -#ifndef MLIR_TEST_CBLAS_INTERFACE_EXPORT -#ifdef mlir_test_cblas_interface_EXPORTS -// We are building this library -#define MLIR_TEST_CBLAS_INTERFACE_EXPORT __declspec(dllexport) -#else -// We are using this library -#define MLIR_TEST_CBLAS_INTERFACE_EXPORT __declspec(dllimport) -#endif // mlir_test_cblas_interface_EXPORTS -#endif // MLIR_TEST_CBLAS_INTERFACE_EXPORT -#else -#define MLIR_TEST_CBLAS_INTERFACE_EXPORT -#endif // _WIN32 - -extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void -_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType *X, float f); - -extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void -_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType *X, float f); - -extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void -_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType *X, - float f); - -extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void -_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType *I, - StridedMemRefType *O); - -extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void -_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType *I, - StridedMemRefType *O); - -extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void -_mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32( - StridedMemRefType *I, StridedMemRefType *O); - -extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void -_mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32( - StridedMemRefType *X, StridedMemRefType *Y, - StridedMemRefType *Z); - -extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void -_mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32( - StridedMemRefType *A, StridedMemRefType *B, - StridedMemRefType *C); - -#endif // MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_ diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir deleted file mode 100644 --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ /dev/null @@ -1,99 +0,0 @@ -// RUN: mlir-opt %s -convert-linalg-to-std -convert-linalg-to-llvm \ -// RUN: | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \ -// RUN: | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -convert-linalg-to-std -convert-linalg-to-llvm \ -// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \ -// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \ -// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-std -convert-linalg-to-llvm \ -// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \ -// RUN: | FileCheck %s - -// Creates and returns a 1-D buffer of size %s filled with the value %f -func @alloc_filled_f32(%s : index, %f : f32) -> memref { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c4 = constant 4 : index - %s4 = muli %s, %c4: index - %buf = alloc(%s4) {alignment = 256} : memref - %V = view %buf[%c0][%s] : memref to memref - linalg.fill(%V, %f) : memref, f32 - return %buf : memref -} - -// Test for linalg.dot. -func @dot() -> f32 { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c16 = constant 16 : index - %f10 = constant 10.00000e+00 : f32 - %f1 = constant 1.00000e+00 : f32 - %f2 = constant 2.00000e+00 : f32 - - %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (memref) - %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (memref) - %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (memref) - - %A = view %bA[%c0][%c16] : memref to memref - %B = view %bB[%c0][%c16] : memref to memref - %C = view %bC[%c0][] : memref to memref - - linalg.dot ins(%A, %B : memref, memref) - outs(%C : memref) - %res = load %C[] : memref - - dealloc %bC : memref - dealloc %bB : memref - dealloc %bA : memref - - return %res : f32 -} - -// Test for linalg.matmul. -func @matmul() -> f32 { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c6 = constant 6 : index - %c7 = constant 7 : index - %c2 = constant 2 : index - %c16 = constant 16 : index - %c4 = constant 4 : index - %c32 = constant 32 : index - %f1 = constant 1.00000e+00 : f32 - %f2 = constant 2.00000e+00 : f32 - %f10 = constant 10.00000e+00 : f32 - - %bA = call @alloc_filled_f32(%c32, %f2) : (index, f32) -> (memref) - %bB = call @alloc_filled_f32(%c32, %f1) : (index, f32) -> (memref) - %bC = call @alloc_filled_f32(%c4, %f10) : (index, f32) -> (memref) - - %A = view %bA[%c0][%c2, %c16] : memref to memref - %B = view %bB[%c0][%c16, %c2] : memref to memref - %C = view %bC[%c0][%c2, %c2] : memref to memref - - linalg.matmul ins(%A, %B : memref, memref) - outs(%C : memref) - %res = load %C[%c0, %c1] : memref - - dealloc %bC : memref - dealloc %bB : memref - dealloc %bA : memref - - return %res : f32 -} - -// All tests return this value -// CHECK: 4.2{{0+}}e+01 diff --git a/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp b/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp deleted file mode 100644 --- a/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp +++ /dev/null @@ -1,46 +0,0 @@ -//===- mlir_test_cblas.cpp - Simple Blas subset implementation ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Simple Blas subset implementation. -// -//===----------------------------------------------------------------------===// - -#include "include/mlir_test_cblas.h" -#include - -extern "C" float mlir_test_cblas_sdot(const int N, const float *X, - const int incX, const float *Y, - const int incY) { - float res = 0.0f; - for (int i = 0; i < N; ++i) - res += X[i * incX] * Y[i * incY]; - return res; -} - -extern "C" void mlir_test_cblas_sgemm( - const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, - const float alpha, const float *A, const int lda, const float *B, - const int ldb, const float beta, float *C, const int ldc) { - assert(Order == CBLAS_ORDER::CblasRowMajor); - assert(TransA == CBLAS_TRANSPOSE::CblasNoTrans); - assert(TransB == CBLAS_TRANSPOSE::CblasNoTrans); - for (int m = 0; m < M; ++m) { - auto *pA = A + m * lda; - auto *pC = C + m * ldc; - for (int n = 0; n < N; ++n) { - float c = pC[n]; - float res = 0.0f; - for (int k = 0; k < K; ++k) { - auto *pB = B + k * ldb; - res += pA[k] * pB[n]; - } - pC[n] = alpha * c + beta * res; - } - } -} diff --git a/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp b/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp deleted file mode 100644 --- a/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp +++ /dev/null @@ -1,107 +0,0 @@ -//===- mlir_test_cblas_interface.cpp - Simple Blas subset interface -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Simple Blas subset interface implementation. -// -//===----------------------------------------------------------------------===// - -#include "include/mlir_test_cblas_interface.h" -#include "include/mlir_test_cblas.h" -#include -#include - -extern "C" void -_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType *X, float f) { - X->data[X->offset] = f; -} - -extern "C" void -_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType *X, - float f) { - for (unsigned i = 0; i < X->sizes[0]; ++i) - *(X->data + X->offset + i * X->strides[0]) = f; -} - -extern "C" void -_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType *X, - float f) { - for (unsigned i = 0; i < X->sizes[0]; ++i) - for (unsigned j = 0; j < X->sizes[1]; ++j) - *(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f; -} - -extern "C" void -_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType *I, - StridedMemRefType *O) { - O->data[O->offset] = I->data[I->offset]; -} - -extern "C" void -_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType *I, - StridedMemRefType *O) { - if (I->sizes[0] != O->sizes[0]) { - std::cerr << "Incompatible strided memrefs\n"; - printMemRefMetaData(std::cerr, *I); - printMemRefMetaData(std::cerr, *O); - return; - } - for (unsigned i = 0; i < I->sizes[0]; ++i) - O->data[O->offset + i * O->strides[0]] = - I->data[I->offset + i * I->strides[0]]; -} - -extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32( - StridedMemRefType *I, StridedMemRefType *O) { - if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) { - std::cerr << "Incompatible strided memrefs\n"; - printMemRefMetaData(std::cerr, *I); - printMemRefMetaData(std::cerr, *O); - return; - } - auto so0 = O->strides[0], so1 = O->strides[1]; - auto si0 = I->strides[0], si1 = I->strides[1]; - for (unsigned i = 0; i < I->sizes[0]; ++i) - for (unsigned j = 0; j < I->sizes[1]; ++j) - O->data[O->offset + i * so0 + j * so1] = - I->data[I->offset + i * si0 + j * si1]; -} - -extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32( - StridedMemRefType *X, StridedMemRefType *Y, - StridedMemRefType *Z) { - if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) { - std::cerr << "Incompatible strided memrefs\n"; - printMemRefMetaData(std::cerr, *X); - printMemRefMetaData(std::cerr, *Y); - printMemRefMetaData(std::cerr, *Z); - return; - } - Z->data[Z->offset] += - mlir_test_cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0], - Y->data + Y->offset, Y->strides[0]); -} - -extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32( - StridedMemRefType *A, StridedMemRefType *B, - StridedMemRefType *C) { - if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] || - A->strides[1] != 1 || A->sizes[0] < A->strides[1] || - B->sizes[0] < B->strides[1] || C->sizes[0] < C->strides[1] || - C->sizes[0] != A->sizes[0] || C->sizes[1] != B->sizes[1] || - A->sizes[1] != B->sizes[0]) { - printMemRefMetaData(std::cerr, *A); - printMemRefMetaData(std::cerr, *B); - printMemRefMetaData(std::cerr, *C); - return; - } - mlir_test_cblas_sgemm( - CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans, - CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1], A->sizes[1], - 1.0f, A->data + A->offset, A->strides[0], B->data + B->offset, - B->strides[0], 1.0f, C->data + C->offset, C->strides[0]); -}