diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -231,6 +231,27 @@ /// ``` Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (m, n, k) = (par, par, seq) +/// | +/// | C(m, n) = sum_k(A(m, k) * B(k, n)) +/// ``` +/// and returns the tensor `C`. +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC); + +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (m, n, k) = (par, par, seq) +/// | +/// | D(m, n) = C(m, n) + sum_k(A(m, k) * B(k, n)) +/// ``` +/// and returns the tensor `D`. +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, + RankedTensorType tD); + template Operation *linalg_matmul(Container values) { assert(values.size() == 3 && "Expected exactly 3 values"); return linalg_matmul(values[0], values[1], values[2]); diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -46,8 +46,8 @@ enter(body, /*prev=*/1); } -ValueHandle mlir::edsc::LoopRangeBuilder:: -operator()(std::function fun) { +ValueHandle +mlir::edsc::LoopRangeBuilder::operator()(std::function fun) { if (fun) fun(); exit(); @@ -77,8 +77,8 @@ : LoopNestRangeBuilder( ivs, SmallVector(ranges.begin(), ranges.end())) {} -ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder:: -operator()(std::function fun) { +ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( + std::function fun) { if (fun) fun(); for (auto &lit : reverse(loops)) { @@ -205,6 +205,13 @@ return op; } +static void mulRegionBuilder(ArrayRef args) { + using edsc::op::operator*; + assert(args.size() == 2 && "expected 2 block arguments"); + ValueHandle a(args[0]), b(args[1]); + linalg_yield((a * b).getValue()); +} + void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { using edsc::op::operator+; using edsc::op::operator*; @@ -298,6 +305,34 @@ // clang-format on } +Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, + RankedTensorType tC) { + // clang-format off + AffineExpr m, n, k; + bindDims(ScopedContext::getContext(), m, n, k); + StructuredIndexed A(vA), B(vB), C(tC); + return makeGenericLinalgOp( + {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {A({m, k}), B({k, n})}, + {C({m, n})}, + mulRegionBuilder); + // clang-format on +} + +Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, + ValueHandle vC, RankedTensorType tD) { + // clang-format off + AffineExpr m, n, k; + bindDims(ScopedContext::getContext(), m, n, k); + StructuredIndexed A(vA), B(vB), C(vC), D(tD); + return makeGenericLinalgOp( + {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {A({m, k}), B({k, n}), C({m, n})}, + {D({m, n})}, + macRegionBuilder); + // clang-format on +} + Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, ArrayRef strides, diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -872,49 +872,6 @@ } // clang-format off -// CHECK-LABEL: func @linalg_pointwise_mixed_tensors -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, -// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], -// CHECK-SAME: iterator_types = ["parallel", "parallel"]} -// CHECK: addf -// CHECK: }: tensor, memref -> tensor -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, -// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], -// CHECK-SAME: iterator_types = ["parallel", "parallel"]} -// CHECK: cmpf "ogt" -// CHECK: select -// CHECK: }: tensor, memref -> tensor -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, -// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], -// CHECK-SAME: iterator_types = ["parallel", "parallel"]} -// CHECK: tanh -// CHECK: }: tensor -> tensor -// clang-format on -TEST_FUNC(linalg_pointwise_mixed_tensors_test) { - using namespace edsc; - using namespace edsc::ops; - - auto f32Type = FloatType::getF32(&globalContext()); - auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); - auto tensorType = RankedTensorType::get({-1, -1}, f32Type); - auto f = makeFunction("linalg_pointwise_mixed_tensors", {}, - {tensorType, memrefType}); - - OpBuilder builder(f.getBody()); - ScopedContext scope(builder, f.getLoc()); - ValueHandle A(f.getArgument(0)), B(f.getArgument(1)); - AffineExpr i, j; - bindDims(&globalContext(), i, j); - StructuredIndexed SA(A), SB(B), SC(tensorType); - linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); - linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); - linalg_pointwise_tanh(SA({i, j}), SC({i, j})); - - f.print(llvm::outs()); - f.erase(); -} - -// clang-format off // CHECK-LABEL: func @linalg_matmul // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], @@ -1032,6 +989,66 @@ f.erase(); } +// clang-format off +// CHECK-LABEL: func @linalg_tensors +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: addf +// CHECK: }: tensor, memref -> tensor +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: cmpf "ogt" +// CHECK: select +// CHECK: }: tensor, memref -> tensor +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: tanh +// CHECK: }: tensor -> tensor +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK: mulf +// CHECK: }: tensor, memref -> tensor +// CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK: mulf +// CHECK: addf +// CHECK: }: tensor, memref, tensor -> tensor +// clang-format on +TEST_FUNC(linalg_tensors_test) { + using namespace edsc; + using namespace edsc::ops; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); + auto tensorType = RankedTensorType::get({-1, -1}, f32Type); + auto f = makeFunction("linalg_tensors", {}, {tensorType, memrefType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle A(f.getArgument(0)), B(f.getArgument(1)); + AffineExpr i, j; + bindDims(&globalContext(), i, j); + StructuredIndexed SA(A), SB(B), SC(tensorType); + linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_pointwise_tanh(SA({i, j}), SC({i, j})); + Value o1 = linalg_matmul(A, B, tensorType)->getResult(0); + linalg_matmul(A, B, ValueHandle(o1), tensorType); + + f.print(llvm::outs()); + f.erase(); +} + int main() { RUN_TESTS(); return 0;