diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -110,11 +110,14 @@ operator Value() const /* implicit */ { return value; } ArrayRef getExprs() { return exprs; } + Type getType() { return value.getType(); } private: StructuredIndexed(Value v, ArrayRef indexings) : value(v), exprs(indexings.begin(), indexings.end()) { - assert(v.getType().isa() && "MemRefType expected"); + assert((v.getType().isa() || + v.getType().isa()) && + "MemRef or RankedTensor expected"); } StructuredIndexed(ValueHandle v, ArrayRef indexings) : StructuredIndexed(v.getValue(), indexings) {} @@ -125,9 +128,21 @@ inline void defaultRegionBuilder(ArrayRef args) {} +/// Build a `linalg.generic` op with the specified inputs, outputs and region. +/// +/// `otherValues` and `otherAttributes` may be passed and will be appended as +/// operands and attributes respectively. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, - ArrayRef outputs, + ArrayRef outputs, ArrayRef resultTensorTypes = {}, function_ref)> regionBuilder = defaultRegionBuilder, ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); @@ -167,32 +182,77 @@ /// with in-place semantics and parallelism. /// Unary pointwise operation (with broadcast) entry point. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. using UnaryPointwiseOpBuilder = function_ref; Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, - StructuredIndexed I, StructuredIndexed O); + StructuredIndexed I, StructuredIndexed O, + ArrayRef resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = tanh(I)`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. -Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O, + ArrayRef resultTensorTypes = {}); /// Binary pointwise operation (with broadcast) entry point. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. using BinaryPointwiseOpBuilder = function_ref; Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = I1 + I2`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = max(I!, I2)`. The client is responsible for specifying the /// proper indexings when creating the StructuredIndexed. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef resultTensorTypes = {}); // TODO(ntv): Implement more useful pointwise operations on a per-need basis. @@ -203,11 +263,23 @@ /// | /// | C(m, n) += A(m, k) * B(k, n) /// ``` -Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, + ArrayRef resultTensorTypes = {}); -template Operation *linalg_matmul(Container values) { +template +Operation *linalg_matmul(Container values, + ArrayRef resultTensorTypes = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_matmul(values[0], values[1], values[2]); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); + return linalg_matmul(values[0], values[1], values[2], resultTensorTypes); } /// Build a linalg.generic, under the current ScopedContext, at the current @@ -231,16 +303,28 @@ /// /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +// // TODO(ntv) Extend convolution rank with some template magic. Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef resultTensorTypes = {}, ArrayRef strides = {}, ArrayRef dilations = {}); template -Operation *linalg_conv_nhwc(Container values, ArrayRef strides = {}, - ArrayRef dilations = {}) { +Operation * +linalg_conv_nhwc(Container values, ArrayRef resultTensorTypes = {}, + ArrayRef strides = {}, ArrayRef dilations = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); + return linalg_conv_nhwc(values[0], values[1], values[2], resultTensorTypes, + strides, dilations); } /// Build a linalg.generic, under the current ScopedContext, at the current @@ -249,7 +333,7 @@ /// (batch, dm, c, [h, w, ...], [kh, kw, ...]) = /// | (par, par, par, [par, par, ...], [red, red, ...]) /// | -/// | O(batch, [h, w, ...], c * depth_multiplier) += +/// | O(batch, [h, w, ...], c * depthMultiplier) += /// | I(batch, /// | [ /// | stride[0] * h + dilations[0] * kh, @@ -257,26 +341,40 @@ /// ], /// | c) /// | * -/// | W([kh, kw, ...], c, depth_multiplier) +/// | W([kh, kw, ...], c, depthMultiplier) /// ``` /// If `dilations` or `strides` are left empty, the default value of `1` is used /// along each relevant dimension. /// /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +// // TODO(ntv) Extend convolution rank with some template magic. Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, - ValueHandle vO, int depth_multiplier = 1, + ValueHandle vO, + ArrayRef resultTensorTypes = {}, + int depthMultiplier = 1, ArrayRef strides = {}, ArrayRef dilations = {}); template -Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier, +Operation *linalg_dilated_conv_nhwc(Container values, + ArrayRef resultTensorTypes = {}, + int depthMultiplier = 1, ArrayRef strides = {}, ArrayRef dilations = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); return linalg_dilated_conv_nhwc(values[0], values[1], values[2], - depth_multiplier, strides, dilations); + resultTensorTypes, depthMultiplier, strides, + dilations); } } // namespace ops diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -128,16 +128,20 @@ Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, - ArrayRef outputs, + ArrayRef outputBuffers, ArrayRef resultTensorTypes, function_ref)> regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { + assert( + llvm::all_of(llvm::make_range(outputBuffers.begin(), outputBuffers.end()), + [](Value v) { return v.getType().isa(); }) && + "output operands must all be buffers."); auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); - unsigned nOutputs = outputs.size(); + unsigned nOutputs = outputBuffers.size() + resultTensorTypes.size(); unsigned maxPos = 0; getMaxDimIndex(inputs, maxPos); - getMaxDimIndex(outputs, maxPos); + getMaxDimIndex(outputBuffers, maxPos); // maxPos is 0 indexed, need to turn this into a count (i.e. +1) unsigned nDims = maxPos + 1; @@ -146,7 +150,7 @@ for (auto in : inputs) maps.push_back( AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs())); - for (auto out : outputs) + for (auto out : outputBuffers) maps.push_back( AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs())); @@ -154,7 +158,7 @@ SmallVector values; values.reserve(nViews); values.append(inputs.begin(), inputs.end()); - values.append(outputs.begin(), outputs.end()); + values.append(outputBuffers.begin(), outputBuffers.end()); auto iteratorStrTypes = functional::map(toString, iteratorTypes); // clang-format off @@ -162,7 +166,7 @@ edsc::ScopedContext::getBuilder() .create( edsc::ScopedContext::getLocation(), - ArrayRef{}, // TODO(ntv): support tensors + resultTensorTypes, values, IntegerAttr::get(IntegerType::get(64, ctx), nInputs), IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), @@ -207,7 +211,8 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, - StructuredIndexed O) { + StructuredIndexed O, + ArrayRef resultTensorTypes) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); auto fun = [&unaryOp](ArrayRef args) { @@ -215,22 +220,30 @@ ValueHandle a(args[0]); linalg_yield(unaryOp(a)); }; - return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); + + // Distinguish between tensor and buffer semantics. + if (O.getType().isa()) { + assert(resultTensorTypes.empty()); + return makeGenericLinalgOp(iterTypes, {I}, {O}, {}, fun); + } + return makeGenericLinalgOp(iterTypes, {I, O}, {}, resultTensorTypes, fun); } -Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O, + ArrayRef resultTensorTypes) { ; using edsc::intrinsics::tanh; UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); }); - return linalg_pointwise(unOp, I, O); + return linalg_pointwise(unOp, I, O, resultTensorTypes); } /// Binary pointwise operation (with broadcast) entry point. Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O) { + StructuredIndexed O, + ArrayRef resultTensorTypes) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); auto fun = [&binaryOp](ArrayRef args) { @@ -238,45 +251,62 @@ ValueHandle a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); }; - return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); + // Distinguish between tensor and buffer semantics. + if (O.getType().isa()) { + assert(resultTensorTypes.empty()); + return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, {}, fun); + } + return makeGenericLinalgOp(iterTypes, {I1, I2, O}, {}, resultTensorTypes, + fun); } -Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, + StructuredIndexed I2, StructuredIndexed O, + ArrayRef resultTensorTypes) { using edsc::op::operator+; BinaryPointwiseOpBuilder binOp( [](ValueHandle a, ValueHandle b) -> Value { return a + b; }); - return linalg_pointwise(binOp, I1, I2, O); + return linalg_pointwise(binOp, I1, I2, O, resultTensorTypes); } -Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, + StructuredIndexed I2, StructuredIndexed O, + ArrayRef resultTensorTypes) { BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value { using edsc::intrinsics::select; using edsc::op::operator>; return select(a > b, a, b).getValue(); }); - return linalg_pointwise(binOp, I1, I2, O); + return linalg_pointwise(binOp, I1, I2, O, resultTensorTypes); } Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC) { - // clang-format off + ValueHandle vC, + ArrayRef resultTensorTypes) { AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC); + + assert(!C.getType().isa() || resultTensorTypes.empty()); + StructuredIndexed allIndexed[3]{A({m, k}), B({k, n}), C({m, n})}; + ArrayRef inputs = + (C.getType().isa()) + ? ArrayRef{allIndexed, allIndexed + 2} + : ArrayRef{allIndexed, allIndexed + 3}; + ArrayRef outputs = + (C.getType().isa()) + ? ArrayRef{allIndexed + 2, allIndexed + 3} + : ArrayRef{}; return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, - {A({m, k}), B({k, n})}, - {C({m, n})}, - macRegionBuilder); - // clang-format on + {IterType::Parallel, IterType::Parallel, IterType::Reduction}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef resultTensorTypes, ArrayRef strides, ArrayRef dilations) { MLIRContext *ctx = ScopedContext::getContext(); @@ -294,23 +324,33 @@ bindDims(ctx, b, f, h, w, kh, kw, c); unsigned numDims = c.cast().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); + + assert(!O.getType().isa() || resultTensorTypes.empty()); + // Roundtrip to flattened form to serve as canonicalization and ensure + // consistent ordering of subexpressions. // clang-format off - return makeGenericLinalgOp( - {par, par, par, par, red, red, red}, { + StructuredIndexed allIndexed[3] = { I({b, - // Roundtrip to flattened form to serve as canonicalization and ensure - // consistent ordering of subexpressions. simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), c}), - W({kh, kw, c, f})}, { - O({b, h, w, f})}, - macRegionBuilder); + W({kh, kw, c, f}), + O({b, h, w, f})}; // clang-format on + auto inputs = (O.getType().isa()) + ? ArrayRef{allIndexed, allIndexed + 2} + : ArrayRef{allIndexed, allIndexed + 3}; + ArrayRef outputs = + (O.getType().isa()) + ? ArrayRef{allIndexed + 2, allIndexed + 3} + : ArrayRef{}; + return makeGenericLinalgOp({par, par, par, par, red, red, red}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc( - ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier, + ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef resultTensorTypes, int depthMultiplier, ArrayRef strides, ArrayRef dilations) { MLIRContext *ctx = ScopedContext::getContext(); // TODO(ntv) some template magic to make everything rank-polymorphic. @@ -328,16 +368,26 @@ bindDims(ctx, b, dm, c, h, w, kh, kw); unsigned numDims = kw.cast().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); - return makeGenericLinalgOp( - {par, par, par, par, par, red, red}, { + // Roundtrip to flattened form to serve as canonicalization and ensure + // consistent ordering of subexpressions. + // clang-format off + StructuredIndexed allIndexed[3] = { I({b, // Roundtrip to flattened form to serve as canonicalization and ensure // consistent ordering of subexpressions. simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), c}), - W({kh, kw, c, dm})}, { - O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})}, - macRegionBuilder); + W({kh, kw, c, dm}), + O({b, h, w, simplifyAffineExpr(c * depthMultiplier + dm, numDims, 0)})}; // clang-format on + auto inputs = (O.getType().isa()) + ? ArrayRef{allIndexed, allIndexed + 2} + : ArrayRef{allIndexed, allIndexed + 3}; + ArrayRef outputs = + (O.getType().isa()) + ? ArrayRef{allIndexed + 2, allIndexed + 3} + : ArrayRef{}; + return makeGenericLinalgOp({par, par, par, par, par, red, red}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -467,7 +467,8 @@ auto i1Type = IntegerType::get(1, &globalContext()); auto i8Type = IntegerType::get(8, &globalContext()); auto memrefType = MemRefType::get({}, i1Type, {}, 0); - auto f = makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); + auto f = + makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); @@ -795,10 +796,12 @@ } // CHECK-LABEL: func @affine_if_op -// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}} +// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}} +// CHECK-SAME: [[s0:.*]], [[s1:.*]]{{\]}} // CHECK-NOT: else -// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}} -// CHECK-NEXT: } else { +// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}} +// CHECK-SAME: [[s0:.*]], [[s1:.*]]{{\]}} +// CHECK-NEXT: } else { TEST_FUNC(affine_if_op) { using namespace edsc; using namespace edsc::intrinsics; @@ -900,6 +903,36 @@ } // clang-format off +// CHECK-LABEL: func @linalg_matmul_mixed_tensors +// CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} +/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 +// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 +// CHECK: linalg.yield %[[a4]] : f32 +// CHECK: }: tensor, memref, tensor -> tensor +// clang-format on +TEST_FUNC(linalg_matmul_mixed_tensors_test) { + using namespace edsc; + using namespace edsc::ops; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); + auto tensorType = RankedTensorType::get({-1, -1}, f32Type); + auto f = makeFunction("linalg_matmul_mixed_tensors", {}, + {tensorType, memrefType, tensorType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())), + tensorType); + + f.print(llvm::outs()); + f.erase(); +} + +// clang-format off // CHECK-LABEL: func @linalg_conv_nhwc // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6)>, @@ -923,7 +956,7 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), + linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), {}, /*strides=*/{3, 4}, /*dilations=*/{5, 6}); f.print(llvm::outs()); @@ -956,6 +989,7 @@ ScopedContext scope(builder, f.getLoc()); linalg_dilated_conv_nhwc( makeValueHandles(llvm::to_vector<3>(f.getArguments())), + /*outputTensorTypes=*/{}, /*depth_multiplier=*/7, /*strides=*/{3, 4}, /*dilations=*/{5, 6});