diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -141,6 +141,23 @@ TileUsingSCFForOp tilingPattern; }; +/// Pattern to lower operations that implement the `TilingInterface` to +/// loops/scalar IR using `scf.for`. +struct LowerToLoopsUsingSCFForOp + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr> + returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } +}; + } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -48,7 +48,7 @@ /// Given an array of values, try to extract a constant Attribute from each /// value. If this fails, return the original value. -SmallVector getAsOpFoldResult(ArrayRef values); +SmallVector getAsOpFoldResult(ValueRange values); /// Convert `arrayAttr` to a vector of OpFoldResult. SmallVector getAsOpFoldResult(ArrayAttr arrayAttr); diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -167,6 +167,28 @@ /*defaultImplementation=*/[{ return failure(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Generates the scalar implementation of the operation. + + Given the list `ivs` that represent points in the iteration space + (as specified by `getIterationDomain()`) returns the scalar operations + that represent the computation at that point in the iteration space. + This method is typically used as the "exit path", i.e. once all + transformations are done, this method can be used to lower to scalar + code that can then be lowered to LLVM or SPIR-V dialects. + }], + /*retType=*/"LogicalResult", + /*methodName=*/"generateScalarImplementation", + /*args=*/(ins + "OpBuilder &":$b, + "Location ":$loc, + "ValueRange ":$ivs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -13,14 +13,68 @@ #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Interfaces/TilingInterface.h" using namespace mlir; using namespace mlir::linalg; -namespace { +//===----------------------------------------------------------------------===// +// Utility methods for implementation of Tiling Interface for Linalg ops +//===----------------------------------------------------------------------===// + +/// Return the SSA values that represent the data point accessed using a given +/// `indexingMap` for a given point in the iteration space represented by `ivs`. +static SmallVector getIndicesForAccess(OpBuilder &b, Location loc, + AffineMap indexingMap, + ValueRange ivs) { + SmallVector indices; + indices.reserve(indexingMap.getNumResults()); + for (auto result : indexingMap.getResults()) { + AffineMap m = AffineMap::get(indexingMap.getNumDims(), + indexingMap.getNumSymbols(), result); + Value v = b.create(loc, m, ivs); + indices.push_back(v); + } + return indices; +} + +/// Method to inline the payload of a `linalgOp` given the iteration space +/// point and values for the arguments of the payload. +static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, + ValueRange ivs, ValueRange argValues) { + Block *body = linalgOp.getBlock(); + BlockAndValueMapping map; + map.map(body->getArguments(), argValues); + for (auto &op : body->without_terminator()) { + if (auto indexOp = dyn_cast(&op)) { + map.map(indexOp.getResult(), ivs[indexOp.dim()]); + continue; + } + b.clone(op, map); + } + + Operation *terminator = body->getTerminator(); + Location loc = terminator->getLoc(); + for (auto operand : llvm::enumerate(terminator->getOperands())) { + Value toStore = map.lookupOrDefault(operand.value()); + OpOperand *storeInto = linalgOp.getOutputOperand(operand.index()); + auto indices = getIndicesForAccess( + b, loc, linalgOp.getTiedIndexingMap(storeInto), ivs); + b.create(loc, toStore, + linalgOp.getOutputOperand(operand.index())->get(), + indices); + } + return success(); +} +//===----------------------------------------------------------------------===// +// External Model for implementing `TilingInterface` for `LinalgOp`s. +//===----------------------------------------------------------------------===// + +namespace { /// External model implementation of TilingInterface for LinalgOps. An external /// model implementation is used for now till the use of `TilingInterface` is /// on-par with the current Linalg tiling + fusion patterns. Once it is @@ -167,6 +221,38 @@ return tiledOp[0]->getResult(resultNumber); } + + LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, + Location loc, + ValueRange ivs) const { + auto linalgOp = cast(op); + if (!linalgOp.hasBufferSemantics()) + return op->emitOpError("expected operation to have buffer semantics"); + + SmallVector indexedValues; + indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); + Location linalgOpLoc = op->getLoc(); + /// Load the data corresponding to the block arguments that + /// represent input operands. + for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) { + if (!linalgOp.payloadUsesValueFromOperand(operand)) { + indexedValues.push_back(nullptr); + continue; + } + if (linalgOp.isScalar(operand)) { + indexedValues.push_back(operand->get()); + continue; + } + SmallVector indices = getIndicesForAccess( + builder, linalgOpLoc, linalgOp.getTiedIndexingMap(operand), ivs); + Value load = + builder.create(linalgOpLoc, operand->get(), indices); + indexedValues.push_back(load); + } + + /// Inline the op payload and store the result. + return inlinePayload(builder, linalgOp, ivs, indexedValues); + } }; } // namespace diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -494,3 +494,41 @@ tileAndFuseResult.loops.back(), rewriter); return tileAndFuseResult; } + +//===----------------------------------------------------------------------===// +// LowerToLoopsUsingSCFForOp +//===----------------------------------------------------------------------===// + +FailureOr> +scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite( + TilingInterface op, PatternRewriter &rewriter) const { + SmallVector domain = op.getIterationDomain(rewriter); + + // TODO: Handle cases where the op has results if needed. + if (op->getNumResults() > 0) { + return rewriter.notifyMatchFailure( + op, "unable to lower to loops operations with return values"); + } + + SmallVector ivs; + SmallVector loops; + Location loc = op.getLoc(); + for (auto loopRange : domain) { + Value offsetVal = + getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); + Value sizeVal = + getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); + Value strideVal = + getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); + auto loop = rewriter.create(op.getLoc(), offsetVal, sizeVal, + strideVal, ValueRange{}); + loops.push_back(loop); + ivs.push_back(loop.getInductionVar()); + rewriter.setInsertionPoint(loop.getBody()->getTerminator()); + } + if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { + return failure(); + } + rewriter.eraseOp(op); + return loops; +} diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -62,7 +62,7 @@ /// Given an array of values, try to extract a constant Attribute from each /// value. If this fails, return the original value. -SmallVector getAsOpFoldResult(ArrayRef values) { +SmallVector getAsOpFoldResult(ValueRange values) { return llvm::to_vector<4>( llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); })); } diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s + +func.func @gemm(%arg0 : memref, %arg1 : memref, + %arg2 : memref) { + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} +// CHECK-LABEL: func @gemm +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]] +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]] +// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]] +// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]] +// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]] +// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]] +// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]] +// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]] + +// ----- + +func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>, + %arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) { + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : memref<200x300xi32>, memref<300xi16>, memref<200xi8>) + outs(%arg3 : memref<300x200xi64>) { + ^bb0(%b0 : i32, %b1 : i16, %b2 : i8, %b3 : i64): + %0 = linalg.index 0 : index + %1 = arith.index_cast %0 : index to i16 + %2 = arith.muli %b1, %1 : i16 + %3 = linalg.index 1 : index + %4 = arith.index_cast %3 : index to i8 + %5 = arith.muli %b2, %4 : i8 + %6 = arith.extsi %2 : i16 to i32 + %7 = arith.extsi %5 : i8 to i32 + %8 = arith.addi %6, %7 : i32 + %9 = arith.addi %8, %b0 : i32 + %10 = arith.extsi %9 : i32 to i64 + linalg.yield %10 : i64 + } + return +} +// CHECK-LABEL: func @indexed_generic +// CHECK-SAME: %[[ARG0:.+]]: memref<200x300xi32> +// CHECK-SAME: %[[ARG1:.+]]: memref<300xi16> +// CHECK-SAME: %[[ARG2:.+]]: memref<200xi8> +// CHECK-SAME: %[[ARG3:.+]]: memref<300x200xi64> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C200:.+]] = arith.constant 200 : index +// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C200]] step %[[C1]] +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C1]] +// CHECK-DAG: %[[B0:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[B1:.+]] = memref.load %[[ARG1]][%[[IV1]]] +// CHECK-DAG: %[[B2:.+]] = memref.load %[[ARG2]][%[[IV0]]] +// CHECK-DAG: %[[T1:.+]] = arith.index_cast %[[IV0]] +// CHECK-DAG: %[[T2:.+]] = arith.muli %[[B1]], %[[T1]] +// CHECK-DAG: %[[T4:.+]] = arith.index_cast %[[IV1]] +// CHECK-DAG: %[[T5:.+]] = arith.muli %[[B2]], %[[T4]] +// CHECK-DAG: %[[T6:.+]] = arith.extsi %[[T2]] +// CHECK-DAG: %[[T7:.+]] = arith.extsi %[[T5]] +// CHECK-DAG: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]] +// CHECK-DAG: %[[T9:.+]] = arith.addi %[[T8]], %[[B0]] +// CHECK-DAG: %[[T10:.+]] = arith.extsi %[[T9]] +// CHECK-DAG: memref.store %[[T10]], %[[ARG3]][%[[IV1]], %[[IV0]]] + +// ----- + +func.func @conv_strides_and_dilation(%arg0 : memref, %arg1 : memref, + %arg2 : memref) { + linalg.conv_2d_nhwc_hwcf { + strides = dense<[1, 2]> : tensor<2xi64>, + dilations = dense<[3, 4]> : tensor<2xi64>} + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)> +// CHECK: func @conv_strides_and_dilation( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]] +// CHECK-DAG: %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[F:.+]] = memref.dim %[[ARG1]], %[[C3]] +// CHECK-DAG: %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]] +// CHECK-DAG: %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]] +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]] +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]] +// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]] +// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[F]] step %[[C1]] +// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]] +// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]] +// CHECK: scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]] +// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]]) +// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]]) +// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]] +// CHECK-DAG: %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]] +// CHECK-DAG: %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] +// CHECK-DAG: %[[T12:.+]] = arith.mulf %[[T9]], %[[T10]] +// CHECK-DAG: %[[T13:.+]] = arith.addf %[[T11]], %[[T12]] +// CHECK: memref.store %[[T13]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] + +// ----- + +func.func @pool_strides_and_dilation(%arg0 : memref, %arg1 : memref, + %arg2 : memref) { + linalg.pooling_nhwc_max { + strides = dense<[1, 2]> : tensor<2xi64>, + dilations = dense<[3, 4]> : tensor<2xi64>} + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)> +// CHECK: func @pool_strides_and_dilation +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]] +// CHECK-DAG: %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]] +// CHECK-DAG: %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]] +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]] +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]] +// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]] +// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]] +// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]] +// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]] +// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]) +// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]) +// CHECK-DAG: %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]] +// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] +// CHECK-DAG: %[[T10:.+]] = arith.maxf %[[T9]], %[[T8]] +// CHECK: memref.store %[[T10]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -65,7 +65,7 @@ linalg::LinalgTransformationFilter filter; }; -/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern +/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` /// ops for iterating over the tiles) while using a `filter` to avoid recursive /// application. @@ -138,6 +138,12 @@ "with scf.for operations"), llvm::cl::init(false)}; + Option testLoweringToScalar{ + *this, "lower-to-scalar-using-scf-for", + llvm::cl::desc("Test lowering to scalar implementation using " + "TilingInterface with scf.for operations"), + llvm::cl::init(false)}; + void runOnOperation() override; private: @@ -199,6 +205,9 @@ context, patterns, "gemm_sequence_fusion", {10}); return; } + if (testLoweringToScalar) { + patterns.add(context); + } } void TestTilingInterfacePass::runOnOperation() {