diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -584,6 +584,15 @@ ]; let extraClassDeclaration = [{ + /// Number of loops + unsigned getNumLoops() { return step().size(); } + + /// Number of input operands + unsigned getNumInputs() { return inputs().size(); } + + /// Number of output operands + unsigned getNumOutputs() { return outputs().size(); } + /// Number of operands controlling the loop: lbs, ubs, steps unsigned getNumControlOperands() { return 3 * getNumLoops(); } @@ -597,7 +606,6 @@ return getBody()->getArguments().take_back(outputs().size()); } - void setLowerBounds(ValueRange lowerBounds) { unsigned numLoops = getNumLoops(); assert(lowerBounds.size() == numLoops && @@ -622,6 +630,16 @@ setOperand(pos, steps[i]); } + /// Block argument that corresponds to the `input` or `output` operand. + BlockArgument getTiedBlockArgument(OpOperand& operand) { + auto operandIndex = operand.getOperandNumber(); + assert( + operandIndex >= getNumControlOperands() && + operandIndex < getNumOperands() && + "tied block arg is defined only for `input` and `output` arguments"); + return getBody()->getArgument(operandIndex - 2 * getNumLoops()); + } + /// Result that corresponds to the `outputs` argument of tensor type. OpResult getTiedOpResult(OpOperand& opOperand) { // No result can correspond to a memref argument. @@ -642,7 +660,76 @@ return getOperation()->getResult(tensorId); } - unsigned getNumLoops() { return step().size(); } + /// Append `operand` to the `input` arguments. + OpOperand& appendInputOperand(OpBuilder& builder, Value operand) { + int numLoops = getNumLoops(); + int numInputs = getNumInputs(); + int numOutputs = getNumOutputs(); + + getOperation()->insertOperands(getNumControlOperands() + numInputs, + operand); + getBody()->insertArgument(numLoops + numInputs, operand.getType()); + getOperation()->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {numLoops, numLoops, numLoops, numInputs + 1, numOutputs})); + return getOperation()->getOpOperand(getNumControlOperands() + numInputs); + } + + /// Append `operand` to the `output` arguments. + OpOperand& appendOutputOperand(OpBuilder& builder, Value operand) { + int numLoops = getNumLoops(); + int numInputs = getNumInputs(); + int numOutputs = getNumOutputs(); + + getOperation()->insertOperands( + getNumControlOperands() + numInputs + numOutputs, operand); + getBody()->insertArgument(numLoops + numInputs + numOutputs, + operand.getType()); + getOperation()->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {numLoops, numLoops, numLoops, numInputs, numOutputs + 1})); + return getOperation()->getOpOperand(getNumControlOperands() + numInputs + + numOutputs); + } + + /// Erase `operand` from the `input` or `output` arguments. + void eraseOperand(OpBuilder& builder, OpOperand& operand) { + int numInputs = getNumInputs(); + int numLoops = getNumLoops(); + int numOutputs = getNumOutputs(); + int numControlOperands = getNumControlOperands(); + + auto operandIndex = operand.getOperandNumber(); + assert(operandIndex >= numControlOperands && + operandIndex < getNumOperands() && + "Can erase only `input` or `output` operand"); + + if (operandIndex >= numControlOperands + numInputs) + --numOutputs; + else + --numInputs; + + getOperation()->eraseOperand(operandIndex); + getBody()->eraseArgument(operandIndex - 2 * numLoops); + getOperation()->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {numLoops, numLoops, numLoops, numInputs, numOutputs})); + } + + OpOperand* findInputOperand(Value value) { + OperandRange::iterator it = llvm::find(inputs(), value); + if (it == inputs().end()) return nullptr; + return it.getBase(); + } + + OpOperand* findOutputOperand(Value value) { + OperandRange::iterator it = llvm::find(outputs(), value); + if (it == outputs().end()) return nullptr; + return it.getBase(); + } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -107,6 +107,66 @@ llvm_unreachable("Expect to be able to extract a shape defining loop range"); } +// Return tiled operands for the fused producer op. When fusing into +// `linalg.tiled_loop` one has to update `input` and `output` arguments of the +// loop correspondingly. +// Each input tensor of the producer op has to be added to `inputs` of the +// `tiled_loop` if it is not present there already. Each output tensor has to +// be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending +// on whether the correponding result is an input or an output to the loop. +// +// NOTE: This way of updating the arguments of the `tiled_loop` assumes that the +// intermediate result is not used by any other operation but the consumer. A +// more generic way is to append all missing output tensors of the producer to +// the tiled loop outputs and hence modify the number of the results, since we +// would need to add the intermediate results to `linalg.yield`. After that a +// canonicalization pass would move the unused output args of the `tiled_loop` +// to the `input` section. +static SmallVector getTiledOperands(OpBuilder &b, LinalgOp producer) { + auto tiledLoop = dyn_cast(b.getBlock()->getParentOp()); + if (!tiledLoop) + return llvm::to_vector<4>(producer.getShapedOperands()); + + SmallVector tiledOperands; + assert(producer.hasTensorSemantics() && + "only fusion on tensors is currently supported for TiledLinalgOp"); + + for (auto producerInput : producer.getInputTensors()) { + OpOperand *addedInput = tiledLoop.findInputOperand(producerInput); + if (addedInput == nullptr) + addedInput = &tiledLoop.appendInputOperand(b, producerInput); + BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput); + tiledOperands.push_back(addedBlockArg); + } + for (auto &en : llvm::enumerate(producer.getOutputTensors())) { + Value producerOutput = en.value(); + + Value result = producer->getResult(en.index()); + OpOperand *resultInputOperand = tiledLoop.findInputOperand(result); + OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result); + assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) && + "The result should be present in `input` or `output` args of " + "`tiled_loop"); + + bool isInput = resultInputOperand; + int opNumber = isInput ? resultInputOperand->getOperandNumber() + : resultOutputOperand->getOperandNumber(); + + OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput); + if (addedOutput == nullptr) + addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput) + : &tiledLoop.appendOutputOperand(b, producerOutput); + + OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber); + auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput); + auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand); + resultOperandBlockArg.replaceAllUsesWith(addedBlockArg); + tiledLoop.eraseOperand(b, resultOperand); + tiledOperands.push_back(addedBlockArg); + } + return tiledOperands; +} + /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` /// provides the loop range information for the fused loops. The rest are /// obtained from the producer itself, since they are not tiled + fused. @@ -143,8 +203,8 @@ clonedShapes.reserve(producer.getNumShapedOperands()); // Compute subranges for all tensor input/output operands. - auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands()); - clonedShapes.append(makeTiledShapes(b, loc, producer, tiledOperands, ivs, + clonedShapes.append(makeTiledShapes(b, loc, producer, + getTiledOperands(b, producer), ivs, tileSizes, sizeBounds)); // Append the other operands. @@ -808,7 +868,7 @@ origOpToFusedOp[origOp.getOperation()] = fusedOp; fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; - // Prepare the b for the next insertion point. + // Prepare the builder for the next insertion point. auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); }); if (!origOp.hasTensorSemantics()) continue; @@ -844,16 +904,16 @@ // 2. encode destructive updates that may be inplaceable by bufferization. // To keep the second type of information while letting the unfused op die // unused, we need to forward the producer output operand. - for (auto &operand : - cast(tiledLinalgOp.loops.front()).getIterOpOperands()) - if (auto opResult = operand.get().dyn_cast()) - if (opResult.getOwner() == origOp) - operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]); + if (auto forOp = dyn_cast(tiledLinalgOp.loops.front())) { + for (auto &operand : forOp.getIterOpOperands()) + if (auto opResult = operand.get().dyn_cast()) + if (opResult.getOwner() == origOp) + operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]); + } } return fusedOps; } -template static Optional tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, @@ -928,11 +988,9 @@ const LinalgTilingOptions &tilingOptions) { switch (tilingOptions.loopType) { case LinalgTilingLoopType::Loops: - return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, - tilingOptions); case LinalgTilingLoopType::ParallelLoops: - return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, - tilingOptions); + case LinalgTilingLoopType::TiledLoops: + return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions); default:; } return llvm::None; diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -1,15 +1,16 @@ -// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP module { - func @matmul_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor, %arg3: tensor, - %arg4: tensor) -> tensor { - %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor // - %1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} - ins(%0, %arg3 : tensor, tensor) - outs(%arg4 : tensor) -> tensor // - return %1 : tensor + func @matmul_fusion(%A: tensor, %B: tensor, + %AB_init: tensor, %C: tensor, + %ABC_init: tensor) -> tensor { + %AB = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%AB_init : tensor) -> tensor // + %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} + ins(%AB, %C : tensor, tensor) + outs(%ABC_init : tensor) -> tensor // + return %ABC : tensor } } // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)> @@ -90,6 +91,64 @@ // CHECK: } // CHECK: return %[[RESULT]] +// TLOOP-LABEL: func @matmul_fusion( +// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[AB_INIT:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[C:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[ABC_INIT:[a-zA-Z0-9_]+]]: tensor) -> tensor { + +// TLOOP: %[[C32:.*]] = constant 32 : index +// TLOOP: %[[C64:.*]] = constant 64 : index +// TLOOP: %[[C16:.*]] = constant 16 : index +// TLOOP: %[[C0:.*]] = constant 0 : index +// TLOOP: %[[C1:.*]] = constant 1 : index + +// TLOOP: %[[DIM_A0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]] + +// TLOOP: %[[ABC:.*]] = linalg.tiled_loop (%[[IV0:.*]]) = (%[[C0]]) +// TLOOP-SAME: to (%[[DIM_A0]]) step (%[[C32]]) +// TLOOP-SAME: ins (%[[C_:.*]] = %[[C]]: tensor, +// TLOOP-SAME: %[[A_:.*]] = %[[A]]: tensor, +// TLOOP-SAME: %[[B_:.*]] = %[[B]]: tensor, +// TLOOP-SAME: %[[AB_INIT_:.*]] = %[[AB_INIT]]: tensor) +// TLOOP-SAME: outs (%[[ABC_INIT_:.*]] = %[[ABC_INIT]]: tensor) { + +// TLOOP: %[[ABC_INIT_SUB:.*]] = subtensor %[[ABC_INIT_]][%[[IV0]], 0] +// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0] +// TLOOP: %[[AB_INIT_SUB:.*]] = subtensor %[[AB_INIT_]][%[[IV0]], 0] + +// TLOOP: %[[AB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]] + +// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B_]], %[[C1]] : [[TY]] +// TLOOP: %[[DIM_C_1:.*]] = memref.dim %[[C_]], %[[C1]] : [[TY]] + +// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = +// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]]) +// TLOOP-SAME: step (%[[C64]], %[[C16]]) +// TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]], +// TLOOP-SAME: %[[C__:.*]] = %[[C_]]: [[TY]]) +// TLOOP-SAME: outs (%[[ABC_INIT_SUB_:.*]] = %[[ABC_INIT_SUB]]: [[TY]]) +// TLOOP-SAME: iterators["parallel", "reduction"] { + +// TLOOP: %[[AB_SUB_SUB:.*]] = subtensor %[[AB_SUB_]][0, %[[IV2]]] +// TLOOP: %[[C__SUB:.*]] = subtensor %[[C__]][%[[IV2]], %[[IV1]]] +// TLOOP: %[[ABS_INIT_SUB_SUB:.*]] = subtensor %[[ABC_INIT_SUB_]][0, %[[IV1]]] + +// TLOOP: %[[ABC_SUB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]]) +// TLOOP-SAME: outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]] + +// TLOOP: %[[RES0:.*]] = subtensor_insert %[[ABC_SUB_SUB]] +// TLOOP-SAME: into %[[ABC_INIT_SUB_]][0, %[[IV1]]] +// TLOOP: linalg.yield %[[RES0]] : [[TY]] +// TLOOP: } +// TLOOP: %[[RES1:.*]] = subtensor_insert %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0] +// TLOOP: linalg.yield %[[RES1]] : [[TY]] +// TLOOP: } +// TLOOP: return %[[ABC]] : [[TY]] + // ----- module { @@ -144,6 +203,48 @@ // CHECK: scf.yield %[[YIELD]] // CHECK: return %[[RESULT]] +// TLOOP-LABEL: func @matmul_plus_matmul +// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[AB:[a-zA-Z0-9_]+]]: tensor + +// TLOOP: %[[C32:.*]] = constant 32 : index +// TLOOP: %[[C64:.*]] = constant 64 : index +// TLOOP: %[[C0:.*]] = constant 0 : index +// TLOOP: %[[C1:.*]] = constant 1 : index + +// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]] +// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]] + +// TLOOP: %[[INIT:.*]] = linalg.init_tensor [%[[DIM_A_0]], %[[DIM_B_1]]] + +// TLOOP: %[[RESULT:.*]] = linalg.tiled_loop (%[[IV0:.*]], %[[IV1:.*]]) = +// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]]) +// TLOOP-SAME: step (%[[C32]], %[[C64]]) +// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]], +// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]], +// TLOOP-SAME: %[[AB_:.*]] = %[[AB]]: [[TY]]) +// TLOOP-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: [[TY]]) { + +// TLOOP: %[[INIT_SUB:.*]] = subtensor %[[INIT_]][%[[IV0]], %[[IV1]]] +// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0] +// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[IV1]]] +// TLOOP: %[[AB_SUB_INIT:.*]] = subtensor %[[AB_]][%[[IV0]], %[[IV1]]] + +// TLOOP: %[[AB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[A_SUB]], %[[B_SUB]] : [[TY]], [[TY]]) +// TLOOP-SAME: outs(%[[AB_SUB_INIT]] : [[TY]]) + +// TLOOP: %[[DOUBLE_AB:.*]] = linalg.generic +// TLOOP-SAME: ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]]) + +// TLOOP: %[[RESULT_SUB:.*]] = subtensor_insert +// TLOOP-SAME: %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]] + +// TLOOP: linalg.yield %[[RESULT_SUB]] : [[TY]] +// TLOOP: } +// TLOOP: return %[[RESULT]] : [[TY]] + // ----- module { @@ -174,3 +275,53 @@ // CHECK: scf.yield %[[ST_MM]] : tensor // CHECK: %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}} // CHECK: scf.yield %[[MM]] : tensor + + +// TLOOP-LABEL: func @matmul_out_fusion( +// TLOOP-SAME: %[[OUT:[a-zA-Z0-9_]+]]: tensor +// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor +// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor + +// TLOOP-DAG: %[[C0_F32:.*]] = constant 0.0 +// TLOOP-DAG: %[[C32:.*]] = constant 32 : index +// TLOOP-DAG: %[[C64:.*]] = constant 64 : index +// TLOOP-DAG: %[[C16:.*]] = constant 16 : index +// TLOOP-DAG: %[[C0:.*]] = constant 0 : index +// TLOOP-DAG: %[[C1:.*]] = constant 1 : index + +// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]] +// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]] + +// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = +// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]]) +// TLOOP-SAME: step (%[[C32]], %[[C64]]) +// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]], +// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]]) +// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) { + +// TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]] +// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0] +// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]] +// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]] +// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]]) + +// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) +// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]]) +// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]], +// TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]]) +// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]]) +// TLOOP-SAME: iterators["reduction"] { + +// TLOOP: %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]] +// TLOOP: %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0] + +// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]]) +// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]] +// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]] +// TLOOP: } +// TLOOP: %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]] +// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]] +// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]] +// TLOOP: } +// TLOOP: return %[[AB]] : [[TY]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -278,6 +278,13 @@ "Test Linalg on tensor fusion transformation " "patterns by applying them greedily."); } +void registerTestLinalgTiledLoopFusionTransforms() { + PassRegistration> + testTiledLoopFusionTransformsPass( + "test-linalg-tiled-loop-fusion-transform-patterns", + "Test Linalg on tensor fusion transformation " + "patterns by applying them greedily."); +} void registerTestLinalgGreedyFusion() { PassRegistration testFusionTransformsPass( "test-linalg-greedy-fusion", diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -81,6 +81,7 @@ void registerTestPushExpandingReshape(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); +void registerTestLinalgTiledLoopFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); void registerTestLinalgTileAndFuseSequencePass(); @@ -159,6 +160,7 @@ test::registerTestPushExpandingReshape(); test::registerTestLinalgFusionTransforms(); test::registerTestLinalgTensorFusionTransforms(); + test::registerTestLinalgTiledLoopFusionTransforms(); test::registerTestLinalgGreedyFusion(); test::registerTestLinalgHoisting(); test::registerTestLinalgTileAndFuseSequencePass();