diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -584,6 +584,15 @@ ]; let extraClassDeclaration = [{ + /// Number of loops + unsigned getNumLoops() { return step().size(); } + + /// Number of input operands + unsigned getNumInputs() { return inputs().size(); } + + /// Number of output operands + unsigned getNumOutputs() { return outputs().size(); } + /// Number of operands controlling the loop: lbs, ubs, steps unsigned getNumControlOperands() { return 3 * getNumLoops(); } @@ -597,7 +606,6 @@ return getBody()->getArguments().take_back(outputs().size()); } - void setLowerBounds(ValueRange lowerBounds) { unsigned numLoops = getNumLoops(); assert(lowerBounds.size() == numLoops && @@ -622,6 +630,16 @@ setOperand(pos, steps[i]); } + /// Block argument that corresponds to the `input` or `output` operand. + BlockArgument getTiedBlockArgument(OpOperand& operand) { + auto operandIndex = operand.getOperandNumber(); + assert( + operandIndex >= getNumControlOperands() && + operandIndex < getNumOperands() && + "tied block arg is defined only for `input` and `output` arguments"); + return getBody()->getArgument(operandIndex - 2 * getNumLoops()); + } + /// Result that corresponds to the `outputs` argument of tensor type. OpResult getTiedOpResult(OpOperand& opOperand) { // No result can correspond to a memref argument. @@ -642,7 +660,76 @@ return getOperation()->getResult(tensorId); } - unsigned getNumLoops() { return step().size(); } + /// Append `operand` to the `input` arguments. + OpOperand& appendInputOperand(OpBuilder& builder, Value operand) { + int numLoops = getNumLoops(); + int numInputs = getNumInputs(); + int numOutputs = getNumOutputs(); + + getOperation()->insertOperands(getNumControlOperands() + numInputs, + operand); + getBody()->insertArgument(numLoops + numInputs, operand.getType()); + getOperation()->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {numLoops, numLoops, numLoops, numInputs + 1, numOutputs})); + return getOperation()->getOpOperand(getNumControlOperands() + numInputs); + } + + /// Append `operand` to the `output` arguments. + OpOperand& appendOutputOperand(OpBuilder& builder, Value operand) { + int numLoops = getNumLoops(); + int numInputs = getNumInputs(); + int numOutputs = getNumOutputs(); + + getOperation()->insertOperands( + getNumControlOperands() + numInputs + numOutputs, operand); + getBody()->insertArgument(numLoops + numInputs + numOutputs, + operand.getType()); + getOperation()->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {numLoops, numLoops, numLoops, numInputs, numOutputs + 1})); + return getOperation()->getOpOperand(getNumControlOperands() + numInputs + + numOutputs); + } + + /// Erase `operand` from the `input` or `output` arguments. + void eraseOperand(OpBuilder& builder, OpOperand& operand) { + int numInputs = getNumInputs(); + int numLoops = getNumLoops(); + int numOutputs = getNumOutputs(); + int numControlOperands = getNumControlOperands(); + + auto operandIndex = operand.getOperandNumber(); + assert(operandIndex >= numControlOperands && + operandIndex < getNumOperands() && + "Can erase only `input` or `output` operand"); + + if (operandIndex >= numControlOperands + numInputs) + --numOutputs; + else + --numInputs; + + getOperation()->eraseOperand(operandIndex); + getBody()->eraseArgument(operandIndex - 2 * numLoops); + getOperation()->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {numLoops, numLoops, numLoops, numInputs, numOutputs})); + } + + + OpOperand* findInputOperand(Value value) { + OperandRange::iterator it = llvm::find(inputs(), value); + if (it == inputs().end()) return nullptr; + return it.getBase(); + } + OpOperand* findOutputOperand(Value value) { + OperandRange::iterator it = llvm::find(outputs(), value); + if (it == outputs().end()) return nullptr; + return it.getBase(); + } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -110,6 +110,66 @@ llvm_unreachable("Expect to be able to extract a shape defining loop range"); } +// Return tiled operands for the fused producer op. When fusing into +// `linalg.tiled_loop` one has to update `input` and `output` arguments of the +// loop correspondingly. +// Each input tensor of the producer op has to be added to `inputs` of the +// `tiled_loop` if it is not present there already. Each output tensor has to +// be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending +// on whether the correponding result is an input or an output to the loop. +// +// NOTE: This way of updating the arguments of the `tiled_loop` assumes that the +// intermediate result is not used by any other operation but the consumer. A +// more generic way is to append all missing output tensors of the producer to +// the tiled loop outputs and hence modify the number of the results, since we +// would need to add the intermediate results to `linalg.yield`. After that a +// canonicalization pass would move the unused output args of the `tiled_loop` +// to the `input` section. +static SmallVector getTiledOperands(OpBuilder &b, LinalgOp producer) { + auto tiledLoop = dyn_cast(b.getBlock()->getParentOp()); + if (!tiledLoop) + return llvm::to_vector<4>(producer.getShapedOperands()); + + SmallVector tiledOperands; + assert(producer.hasTensorSemantics() && + "only fusion on tensors is currently supported for TiledLinalgOp"); + + for (auto producerInput : producer.getInputTensors()) { + OpOperand *addedInput = tiledLoop.findInputOperand(producerInput); + if (addedInput == nullptr) + addedInput = &tiledLoop.appendInputOperand(b, producerInput); + BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput); + tiledOperands.push_back(addedBlockArg); + } + for (auto &en : llvm::enumerate(producer.getOutputTensors())) { + Value producerOutput = en.value(); + + Value result = producer->getResult(en.index()); + OpOperand *resultInputOperand = tiledLoop.findInputOperand(result); + OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result); + assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) && + "The result should be present in `input` or `output` args of " + "`tiled_loop"); + + bool isInput = resultInputOperand; + int opNumber = isInput ? resultInputOperand->getOperandNumber() + : resultOutputOperand->getOperandNumber(); + + OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput); + if (addedOutput == nullptr) + addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput) + : &tiledLoop.appendOutputOperand(b, producerOutput); + + OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber); + auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput); + auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand); + resultOperandBlockArg.replaceAllUsesWith(addedBlockArg); + tiledLoop.eraseOperand(b, resultOperand); + tiledOperands.push_back(addedBlockArg); + } + return tiledOperands; +} + /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` /// provides the loop range information for the fused loops. The rest are /// obtained from the producer itself, since they are not tiled + fused. @@ -117,6 +177,7 @@ const DenseMap &fusedLoopsAndRanges) { SmallVector ivs, tileSizes, sizeBounds; SmallVector loopRanges; + Location loc = producer.getLoc(); auto zero = b.create(loc, 0); auto one = b.create(loc, 1); @@ -146,8 +207,8 @@ clonedShapes.reserve(producer.getNumShapedOperands()); // Compute subranges for all tensor input/output operands. - auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands()); - clonedShapes.append(makeTiledShapes(b, loc, producer, tiledOperands, ivs, + clonedShapes.append(makeTiledShapes(b, loc, producer, + getTiledOperands(b, producer), ivs, tileSizes, sizeBounds)); // Append the other operands. @@ -770,18 +831,17 @@ /// Tile the fused loops in the root operation, by setting the tile sizes for /// all other loops to zero (those will be tiled later). -static Optional -tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef tileSizeVector, - const LinalgTilingOptions &options, - const std::set &fusedLoops) { +static Optional tileRootOperation( + OpBuilder &builder, LinalgOp op, ArrayRef tileSizeVector, + const LinalgTilingOptions &options, const std::set &fusedLoops) { SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); - auto zero = b.create(op.getLoc(), 0); + auto zero = std_constant_index(0); for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) if (!fusedLoops.count(i)) tileSizes[i] = zero; LinalgTilingOptions tileFusedLoopsOptions = options; tileFusedLoopsOptions.setTileSizes(tileSizes); - return tileLinalgOp(b, op, tileFusedLoopsOptions); + return tileLinalgOp(builder, op, tileFusedLoopsOptions); } /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected @@ -789,19 +849,19 @@ /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of /// `tiledOp`. static SmallVector -fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp, +fuseOperations(OpBuilder &builder, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp, ArrayRef fusionCandidates, const FusableOpDependencesTy &fusableDependences, const std::set &fusedLoops) { LinalgOp tiledOp = tiledLinalgOp.op; - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(tiledOp); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(tiledOp); DenseMap fusedLoopsAndRanges; for (unsigned loop : fusedLoops) { ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); fusedLoopsAndRanges[loop] = getRangeFromOperandShape( - b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); + builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); } SmallVector fusedOps(fusionCandidates.size()); @@ -809,12 +869,13 @@ origOpToFusedOp[rootOp.getOperation()] = tiledOp; for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { LinalgOp origOp = candidate.value(); - LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges); + LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges); origOpToFusedOp[origOp.getOperation()] = fusedOp; fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; - // Prepare the b for the next insertion point. - auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); }); + // Prepare the builder for the next insertion point. + auto guard = + llvm::make_scope_exit([&]() { builder.setInsertionPoint(fusedOp); }); if (!origOp.hasTensorSemantics()) continue; @@ -849,18 +910,18 @@ // 2. encode destructive updates that may be inplaceable by bufferization. // To keep the second type of information while letting the unfused op die // unused, we need to forward the producer output operand. - for (auto &operand : - cast(tiledLinalgOp.loops.front()).getIterOpOperands()) - if (auto opResult = operand.get().dyn_cast()) - if (opResult.getOwner() == origOp) - operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]); + if (auto forOp = dyn_cast(tiledLinalgOp.loops.front())) { + for (auto &operand : forOp.getIterOpOperands()) + if (auto opResult = operand.get().dyn_cast()) + if (opResult.getOwner() == origOp) + operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]); + } } return fusedOps; } -template static Optional -tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef ops, +tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions) { if (ops.size() < 2) @@ -884,9 +945,9 @@ return llvm::None; } - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(rootOp); - ScopedContext scope(b, rootOp.getLoc()); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(rootOp); + ScopedContext scope(builder, rootOp.getLoc()); // Find all the producers. LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n"); @@ -911,9 +972,9 @@ // Tile the fused loops in the last operation in the list. SmallVector tileSizeVector = - tilingOptions.tileSizeComputationFunction(b, rootOp); + tilingOptions.tileSizeComputationFunction(builder, rootOp); Optional tiledRootOp = tileRootOperation( - b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); + builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); if (!tiledRootOp) { rootOp.emitRemark("failed to tile the fused loops"); return llvm::None; @@ -922,23 +983,23 @@ ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(), - fusableDependences, ret.fusedLoopDims); + ret.fusedProducers = + fuseOperations(builder, rootOp, *tiledRootOp, ops.drop_back(), + fusableDependences, ret.fusedLoopDims); return ret; } Optional -mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef ops, +mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions) { switch (tilingOptions.loopType) { case LinalgTilingLoopType::Loops: - return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, - tilingOptions); case LinalgTilingLoopType::ParallelLoops: - return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, - tilingOptions); + case LinalgTilingLoopType::TiledLoops: + return tileAndFuseLinalgOpsImpl(builder, ops, dependenceGraph, + tilingOptions); default:; } return llvm::None; diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -1,15 +1,16 @@ -// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP module { - func @matmul_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor, %arg3: tensor, - %arg4: tensor) -> tensor { - %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor // - %1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} - ins(%0, %arg3 : tensor, tensor) - outs(%arg4 : tensor) -> tensor // - return %1 : tensor + func @matmul_fusion(%A: tensor, %B: tensor, + %AB_init: tensor, %C: tensor, + %ABC_init: tensor) -> tensor { + %AB = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%AB_init : tensor) -> tensor // + %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} + ins(%AB, %C : tensor, tensor) + outs(%ABC_init : tensor) -> tensor // + return %ABC : tensor } } // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)> @@ -90,6 +91,64 @@ // CHECK: } // CHECK: return %[[RESULT]] +// TLOOP-LABEL: func @matmul_fusion( +// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[AB_INIT:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[C:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[ABC_INIT:[a-zA-Z0-9_]+]]: tensor) -> tensor { + +// TLOOP: %[[C32:.*]] = constant 32 : index +// TLOOP: %[[C64:.*]] = constant 64 : index +// TLOOP: %[[C16:.*]] = constant 16 : index +// TLOOP: %[[C0:.*]] = constant 0 : index +// TLOOP: %[[C1:.*]] = constant 1 : index + +// TLOOP: %[[DIM_A0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]] + +// TLOOP: %[[ABC:.*]] = linalg.tiled_loop (%[[IV0:.*]]) = (%[[C0]]) +// TLOOP-SAME: to (%[[DIM_A0]]) step (%[[C32]]) +// TLOOP-SAME: ins (%[[C_:.*]] = %[[C]]: tensor, +// TLOOP-SAME: %[[A_:.*]] = %[[A]]: tensor, +// TLOOP-SAME: %[[B_:.*]] = %[[B]]: tensor, +// TLOOP-SAME: %[[AB_INIT_:.*]] = %[[AB_INIT]]: tensor) +// TLOOP-SAME: outs (%[[ABC_INIT_:.*]] = %[[ABC_INIT]]: tensor) { + +// TLOOP: %[[ABC_INIT_SUB:.*]] = subtensor %[[ABC_INIT_]][%[[IV0]], 0] +// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0] +// TLOOP: %[[AB_INIT_SUB:.*]] = subtensor %[[AB_INIT_]][%[[IV0]], 0] + +// TLOOP: %[[AB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]] + +// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B_]], %[[C1]] : [[TY]] +// TLOOP: %[[DIM_C_1:.*]] = memref.dim %[[C_]], %[[C1]] : [[TY]] + +// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = +// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]]) +// TLOOP-SAME: step (%[[C64]], %[[C16]]) +// TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]], +// TLOOP-SAME: %[[C__:.*]] = %[[C_]]: [[TY]]) +// TLOOP-SAME: outs (%[[ABC_INIT_SUB_:.*]] = %[[ABC_INIT_SUB]]: [[TY]]) +// TLOOP-SAME: iterators["parallel", "reduction"] { + +// TLOOP: %[[AB_SUB_SUB:.*]] = subtensor %[[AB_SUB_]][0, %[[IV2]]] +// TLOOP: %[[C__SUB:.*]] = subtensor %[[C__]][%[[IV2]], %[[IV1]]] +// TLOOP: %[[ABS_INIT_SUB_SUB:.*]] = subtensor %[[ABC_INIT_SUB_]][0, %[[IV1]]] + +// TLOOP: %[[ABC_SUB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]]) +// TLOOP-SAME: outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]] + +// TLOOP: %[[RES0:.*]] = subtensor_insert %[[ABC_SUB_SUB]] +// TLOOP-SAME: into %[[ABC_INIT_SUB_]][0, %[[IV1]]] +// TLOOP: linalg.yield %[[RES0]] : [[TY]] +// TLOOP: } +// TLOOP: %[[RES1:.*]] = subtensor_insert %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0] +// TLOOP: linalg.yield %[[RES1]] : [[TY]] +// TLOOP: } +// TLOOP: return %[[ABC]] : [[TY]] + // ----- module { @@ -144,6 +203,48 @@ // CHECK: scf.yield %[[YIELD]] // CHECK: return %[[RESULT]] +// TLOOP-LABEL: func @matmul_plus_matmul +// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor, +// TLOOP-SAME: %[[AB:[a-zA-Z0-9_]+]]: tensor + +// TLOOP: %[[C32:.*]] = constant 32 : index +// TLOOP: %[[C64:.*]] = constant 64 : index +// TLOOP: %[[C0:.*]] = constant 0 : index +// TLOOP: %[[C1:.*]] = constant 1 : index + +// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]] +// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]] + +// TLOOP: %[[INIT:.*]] = linalg.init_tensor [%[[DIM_A_0]], %[[DIM_B_1]]] + +// TLOOP: %[[RESULT:.*]] = linalg.tiled_loop (%[[IV0:.*]], %[[IV1:.*]]) = +// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]]) +// TLOOP-SAME: step (%[[C32]], %[[C64]]) +// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]], +// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]], +// TLOOP-SAME: %[[AB_:.*]] = %[[AB]]: [[TY]]) +// TLOOP-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: [[TY]]) { + +// TLOOP: %[[INIT_SUB:.*]] = subtensor %[[INIT_]][%[[IV0]], %[[IV1]]] +// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0] +// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[IV1]]] +// TLOOP: %[[AB_SUB_INIT:.*]] = subtensor %[[AB_]][%[[IV0]], %[[IV1]]] + +// TLOOP: %[[AB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[A_SUB]], %[[B_SUB]] : [[TY]], [[TY]]) +// TLOOP-SAME: outs(%[[AB_SUB_INIT]] : [[TY]]) + +// TLOOP: %[[DOUBLE_AB:.*]] = linalg.generic +// TLOOP-SAME: ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]]) + +// TLOOP: %[[RESULT_SUB:.*]] = subtensor_insert +// TLOOP-SAME: %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]] + +// TLOOP: linalg.yield %[[RESULT_SUB]] : [[TY]] +// TLOOP: } +// TLOOP: return %[[RESULT]] : [[TY]] + // ----- module { @@ -174,3 +275,53 @@ // CHECK: scf.yield %[[ST_MM]] : tensor // CHECK: %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}} // CHECK: scf.yield %[[MM]] : tensor + + +// TLOOP-LABEL: func @matmul_out_fusion( +// TLOOP-SAME: %[[OUT:[a-zA-Z0-9_]+]]: tensor +// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor +// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor + +// TLOOP-DAG: %[[C0_F32:.*]] = constant 0.0 +// TLOOP-DAG: %[[C32:.*]] = constant 32 : index +// TLOOP-DAG: %[[C64:.*]] = constant 64 : index +// TLOOP-DAG: %[[C16:.*]] = constant 16 : index +// TLOOP-DAG: %[[C0:.*]] = constant 0 : index +// TLOOP-DAG: %[[C1:.*]] = constant 1 : index + +// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]] +// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]] + +// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = +// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]]) +// TLOOP-SAME: step (%[[C32]], %[[C64]]) +// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]], +// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]]) +// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) { + +// TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]] +// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0] +// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]] +// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]] +// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]]) + +// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) +// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]]) +// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]], +// TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]]) +// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]]) +// TLOOP-SAME: iterators["reduction"] { + +// TLOOP: %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]] +// TLOOP: %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0] + +// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]]) +// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]] +// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]] +// TLOOP: } +// TLOOP: %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]] +// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]] +// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]] +// TLOOP: } +// TLOOP: return %[[AB]] : [[TY]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -278,6 +278,13 @@ "Test Linalg on tensor fusion transformation " "patterns by applying them greedily."); } +void registerTestLinalgTiledLoopFusionTransforms() { + PassRegistration> + testTiledLoopFusionTransformsPass( + "test-linalg-tiled-loop-fusion-transform-patterns", + "Test Linalg on tensor fusion transformation " + "patterns by applying them greedily."); +} void registerTestLinalgGreedyFusion() { PassRegistration testFusionTransformsPass( "test-linalg-greedy-fusion", diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -81,6 +81,7 @@ void registerTestPushExpandingReshape(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); +void registerTestLinalgTiledLoopFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); void registerTestLinalgTileAndFuseSequencePass(); @@ -159,6 +160,7 @@ test::registerTestPushExpandingReshape(); test::registerTestLinalgFusionTransforms(); test::registerTestLinalgTensorFusionTransforms(); + test::registerTestLinalgTiledLoopFusionTransforms(); test::registerTestLinalgGreedyFusion(); test::registerTestLinalgHoisting(); test::registerTestLinalgTileAndFuseSequencePass();