diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -584,6 +584,9 @@ } unsigned getNumLoops() { return step().size(); } }]; + + let hasCanonicalizer = 1; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1923,6 +1923,86 @@ static LogicalResult verify(TiledLoopOp op) { return success(); } +namespace { + +// Folds away TiledLoopOp output tensors when the following conditions are met: +// * result of `linalg.tiled_loop` has no uses +// * output tensor is the argument of `linalg.yield` +// +// Example: +// +// %0 = linalg.tiled_loop ... outs (%out, %out_buf:tensor<...>, memref<...>) { +// ... +// linalg.yield %out : tensor ... +// } +// +// Becomes +// +// linalg.tiled_loop ... outs (%out_buf:memref<...>) { +// ... +// linalg.yield +// } +struct TiledLoopResultsFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, + PatternRewriter &rewriter) const final { + if (tiledLoop.getNumResults() == 0) + return failure(); + + Block *block = tiledLoop.getBody(); + auto yieldOp = cast(block->getTerminator()); + + // Match the pattern and collect output buffers that will replace the output + // tensors and also the ops that will be ignored when cloning the body. + SmallVector newOutputOperands, newYieldArgs; + int resultId = 0; + for (Value out : tiledLoop.outputs()) { + if (!out.getType().isa()) { + newOutputOperands.push_back(out); + continue; + } + Value result = tiledLoop.getResult(resultId); + Value yieldArg = yieldOp.getOperand(resultId); + if (yieldArg != out || !result.use_empty()) { + newOutputOperands.push_back(out); + newYieldArgs.push_back(yieldArg); + } + ++resultId; + } + if (newOutputOperands.size() == tiledLoop.outputs().size()) + return failure(); + + Location loc = tiledLoop.getLoc(); + auto newTiledLoop = rewriter.create( + loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), + tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types()); + + // Clone the region ignoring the def-chain for linalg.yield args: + // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops. + BlockAndValueMapping bvm; + bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); + OpBuilder innerBuilder = OpBuilder::atBlockEnd(newTiledLoop.getBody()); + for (auto &op : tiledLoop.getBody()->without_terminator()) + innerBuilder.clone(op, bvm); + innerBuilder.create(loc, newYieldArgs); + rewriter.eraseOp(tiledLoop); + + return success(); + } +}; +} // namespace + +void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +LogicalResult TiledLoopOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} + /////// Operations corresponding to library calls defined with Tablegen //////// template diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -802,3 +802,45 @@ // CHECK: return return } + +// ----- + +#map0 = affine_map<(d0) -> (24, -d0 + 192)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> +#map2 = affine_map<(d0) -> (16, -d0 + 192)> + +func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, %C: memref<192x192xf32>) -> () + +func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>, + %C: memref<192x192xf32>, + %C_tensor: tensor<192x192xf32>) { + %cst = constant 0.000000e+00 : f32 + %c24 = constant 24 : index + %c16 = constant 16 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) + step (%c24, %c16) + ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>) + outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) { + call @foo(%A, %B, %C) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () + linalg.yield %C_tensor : tensor<192x192xf32> + } + return +} + +// CHECK-LABEL: func @fold_tiled_loop_results( +// CHECK-SAME: %[[A:.*]]: [[TY:.*]], %[[B:.*]]: [[TY]], %[[C:.*]]: [[TY]], +// CHECK-SAME: %[[C_TENSOR:.*]]: tensor<{{.*}}>) { +// CHECK: %[[C24:.*]] = constant 24 : index +// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C192:.*]] = constant 192 : index + +// CHECK-NOT: %{{.*}} = linalg.tiled_loop +// CHECK: linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) +// CHECK-SAME: ins (%[[A]], %[[B]]: memref<192x192xf32>, memref<192x192xf32>) +// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) { +// CHECK-NEXT: call @foo(%[[A]], %[[B]], %[[C]]) +// CHECK-NEXT: linalg.yield