diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -18,6 +18,9 @@ #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" namespace mlir { class AffineForOp; @@ -127,7 +130,11 @@ // EDSC builders for linalg generic operations. //===----------------------------------------------------------------------===// -/// Build the body of a region to compute a multiply-accumulate, under the +/// Build the body of a region to compute a scalar multiply, under the current +/// ScopedContext, at the current insert point. +void mulRegionBuilder(ArrayRef args); + +/// Build the body of a region to compute a scalr multiply-accumulate, under the /// current ScopedContext, at the current insert point. void macRegionBuilder(ArrayRef args); @@ -182,6 +189,8 @@ // TODO(ntv): Implement more useful pointwise operations on a per-need basis. +using MatmulRegionBuilder = function_ref args)>; + /// Build a linalg.generic, under the current ScopedContext, at the current /// insert point, that computes: /// ``` @@ -189,7 +198,8 @@ /// | /// | C(m, n) += A(m, k) * B(k, n) /// ``` -Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, + MatmulRegionBuilder regionBuilder = macRegionBuilder); /// Build a linalg.generic, under the current ScopedContext, at the current /// insert point, that computes: @@ -199,7 +209,8 @@ /// | C(m, n) = sum_k(A(m, k) * B(k, n)) /// ``` /// and returns the tensor `C`. -Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC); +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC, + MatmulRegionBuilder regionBuilder = mulRegionBuilder); /// Build a linalg.generic, under the current ScopedContext, at the current /// insert point, that computes: @@ -210,11 +221,14 @@ /// ``` /// and returns the tensor `D`. Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, - RankedTensorType tD); + RankedTensorType tD, + MatmulRegionBuilder regionBuilder = macRegionBuilder); -template Operation *linalg_matmul(Container values) { +template +Operation *linalg_matmul(Container values, + MatmulRegionBuilder regionBuilder = macRegionBuilder) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_matmul(values[0], values[1], values[2]); + return linalg_matmul(values[0], values[1], values[2], regionBuilder); } /// Build a linalg.generic, under the current ScopedContext, at the current diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -212,6 +212,14 @@ linalg_yield((a * b).getValue()); } +void mlir::edsc::ops::mulRegionBuilder(ArrayRef args) { + using edsc::op::operator+; + using edsc::op::operator*; + assert(args.size() == 2 && "expected 3 block arguments"); + ValueHandle a(args[0]), b(args[1]); + linalg_yield((a * b).getValue()); +} + void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { using edsc::op::operator+; using edsc::op::operator*; @@ -291,7 +299,8 @@ } Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC) { + ValueHandle vC, + MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); @@ -300,12 +309,13 @@ {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n})}, {C({m, n})}, - macRegionBuilder); + regionBuilder); // clang-format on } Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - RankedTensorType tC) { + RankedTensorType tC, + MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); @@ -314,12 +324,13 @@ {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n})}, {C({m, n})}, - mulRegionBuilder); + regionBuilder); // clang-format on } Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC, RankedTensorType tD) { + ValueHandle vC, RankedTensorType tD, + MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); @@ -328,7 +339,7 @@ {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n}), C({m, n})}, {D({m, n})}, - macRegionBuilder); + regionBuilder); // clang-format on } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -982,17 +982,20 @@ f.erase(); } -// CHECK-LABEL: func @vector_matmul_test( -// CHECK-SAME: %[[A:.*]]: vector<4x16xf32>, -// CHECK-SAME: %[[B:.*]]: vector<16x8xf32>, -// CHECK-SAME: %[[C:.*]]: vector<4x8xf32>) -// CHECK: vector.contract {{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>, -// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>, -// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>], -// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"] -// CHECK-SAME: %[[A]], %[[B]], %[[C]] -// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32> -TEST_FUNC(vector_matmul_test) { +// CHECK-LABEL: func @memref_vector_matmul_test( +// CHECK-SAME: %[[A:.*]]: memref>, +// CHECK-SAME: %[[B:.*]]: memref>, +// CHECK-SAME: %[[C:.*]]: memref>) +// CHECK: linalg.generic {{.*}} %[[A]], %[[B]], %[[C]] +// CHECK: vector.contract{{.*}}[affine_map<(d0, d1, d2) -> (d0, +// d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>], +// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"] +// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32> +// CHECK: memref>, memref>, +// CHECK-SAME: memref> +TEST_FUNC(memref_vector_matmul_test) { using namespace edsc; using namespace edsc::ops; @@ -1001,13 +1004,29 @@ auto mkVectorType = VectorType::get({M, K}, f32Type); auto knVectorType = VectorType::get({K, N}, f32Type); auto mnVectorType = VectorType::get({M, N}, f32Type); - auto f = makeFunction("vector_matmul_test", {}, - {mkVectorType, knVectorType, mnVectorType}); + auto typeA = + MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize}, + mkVectorType, {}, 0); + auto typeB = + MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize}, + knVectorType, {}, 0); + auto typeC = + MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize}, + mnVectorType, {}, 0); + auto f = makeFunction("memref_vector_matmul_test", {}, {typeA, typeB, typeC}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); - vector_matmul(A, B, C); + auto contractionBuilder = [](ArrayRef args) { + assert(args.size() == 3 && "expected 3 + block arguments"); + (linalg_yield(vector_matmul(args[0], + args[1], + args[2]))); + }; + linalg_matmul(A, B, C, contractionBuilder); + f.print(llvm::outs()); f.erase(); }