diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -136,7 +136,7 @@ /// This is a common class used for patterns of the form /// ``` -/// someop(memrefcast) -> someop +/// someop(memrefcast(%src)) -> someop(%src) /// ``` /// It folds the source of the memref.cast into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { @@ -151,6 +151,44 @@ return success(folded); } +/// This is a specialization of `foldMemRefCast` used for patterns of the form +/// ``` +/// tiled_loop(memrefcast(%src)) -> tiled_loop(%src) +/// ``` +/// It folds the source of the memref.cast into the root operation directly. +static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) { + bool folded = false; + Location loc = op->getLoc(); + + Block *body = op.getBody(); + OpBuilder b = OpBuilder::atBlockBegin(body); + + // Update `input` and `output` operands and block arguments if necessary. + // Operands list: [lbs, ubs, steps, inputs, outputs]. + // Block args list: [ivs, inputs, outputs]. + for (size_t operandIndex = op.getNumControlOperands(), + bbArgIndex = op.getNumLoops(), e = op.getNumOperands(); + operandIndex < e; ++operandIndex, ++bbArgIndex) { + OpOperand &operand = op->getOpOperand(operandIndex); + + auto castOp = operand.get().getDefiningOp(); + if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { + operand.set(castOp.getOperand()); + auto newBbArg = + body->insertArgument(bbArgIndex, castOp.getOperand().getType()); + auto oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1); + + // Insert memref.cast back to the original type. + oldBbArg.replaceAllUsesWith( + b.create(loc, oldBbArg.getType(), newBbArg)); + body->eraseArgument(oldBbArg.getArgNumber()); + + folded = true; + } + } + return success(folded); +} + //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. @@ -2054,6 +2092,63 @@ namespace { +static constexpr int64_t kNoMatch = -1; + +// Folds away TiledLoopOp input tensors if they have no uses within the body. +// +// Example: +// +// %0 = linalg.tiled_loop ... ins (%in_ = %in: tensor<...>, +// %in_buf_ = %in_buf: memref<...>) {...} +// Becomes +// +// linalg.tiled_loop ... ins (%in_buf_ = %in_buf: memref<...>) {...} +struct TiledLoopInputsFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, + PatternRewriter &rewriter) const final { + SmallVector newInputs, regionInputTensorArgs; + // Store ids of the corresponding old and new input operands. + SmallVector oldInputIdToNew(tiledLoop.inputs().size(), + kNoMatch); + for (auto en : llvm::enumerate( + llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) { + Value in, bbArg; + size_t index = en.index(); + std::tie(in, bbArg) = en.value(); + if (!in.getType().isa() || !bbArg.use_empty()) { + oldInputIdToNew[index] = newInputs.size(); + newInputs.push_back(in); + continue; + } + } + if (newInputs.size() == tiledLoop.inputs().size()) + return failure(); + Location loc = tiledLoop.getLoc(); + auto newTiledLoop = rewriter.create( + loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), + newInputs, tiledLoop.outputs(), tiledLoop.iterator_types()); + + // Clone the region. + BlockAndValueMapping bvm; + bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); + bvm.map(tiledLoop.getRegionOutputArgs(), + newTiledLoop.getRegionOutputArgs()); + for (const auto &en : llvm::enumerate(oldInputIdToNew)) + if (en.value() != kNoMatch) + bvm.map(tiledLoop.getRegionInputArgs()[en.index()], + newTiledLoop.getRegionInputArgs()[en.value()]); + OpBuilder innerBuilder = + OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); + for (auto &op : *tiledLoop.getBody()) + innerBuilder.clone(op, bvm); + rewriter.eraseOp(tiledLoop); + + return success(); + } +}; + // Folds away TiledLoopOp output tensors when the following conditions are met: // * result of `linalg.tiled_loop` has no uses // * output tensor is the argument of `linalg.yield` @@ -2085,27 +2180,26 @@ // Match the pattern and collect output buffers that will replace the output // tensors and also the ops that will be ignored when cloning the body. - SmallVector newOutputOperands, newYieldArgs, - regionOutputTensorArgs; + SmallVector newOutputOperands, newYieldArgs; int resultId = 0; // Store ids of the corresponding old and new output operands. - SmallVector, 2> old_out_id_to_new; - for (auto item : llvm::enumerate( + SmallVector oldOutputIdToNew(tiledLoop.outputs().size(), + kNoMatch); + for (auto en : llvm::enumerate( llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) { - size_t index = item.index(); - Value out = std::get<0>(item.value()); - Value outRegionArg = std::get<1>(item.value()); + size_t index = en.index(); + Value out = std::get<0>(en.value()); + Value outRegionArg = std::get<1>(en.value()); if (!out.getType().isa()) { - old_out_id_to_new.push_back({index, newOutputOperands.size()}); + oldOutputIdToNew[index] = newOutputOperands.size(); newOutputOperands.push_back(out); - regionOutputTensorArgs.push_back(outRegionArg); continue; } Value result = tiledLoop.getResult(resultId); Value yieldArg = yieldOp.getOperand(resultId); if (yieldArg != outRegionArg || !result.use_empty()) { - old_out_id_to_new.push_back({index, newOutputOperands.size()}); + oldOutputIdToNew[index] = newOutputOperands.size(); newOutputOperands.push_back(out); newYieldArgs.push_back(yieldArg); } @@ -2119,14 +2213,18 @@ loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types()); - // Clone the region ignoring the def-chain for linalg.yield args: - // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops. + // Clone the region. BlockAndValueMapping bvm; bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs()); - for (const auto &item : old_out_id_to_new) - bvm.map(tiledLoop.getRegionOutputArgs()[item.first], - newTiledLoop.getRegionOutputArgs()[item.second]); + for (const auto &en : llvm::enumerate(oldOutputIdToNew)) { + if (en.value() != kNoMatch) + bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], + newTiledLoop.getRegionOutputArgs()[en.value()]); + else + bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], + tiledLoop.outputs()[en.index()]); + } OpBuilder innerBuilder = OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : tiledLoop.getBody()->without_terminator()) @@ -2141,12 +2239,12 @@ void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } LogicalResult TiledLoopOp::fold(ArrayRef, SmallVectorImpl &) { - return foldMemRefCast(*this); + return foldMemRefCastInTiledLoopOp(*this); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -18,6 +18,31 @@ // ----- +#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK-LABEL: func @memref_cast_into_tiled_loop( +func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>) { + %0 = memref.cast %arg0 + : memref<192xf32> to memref<192xf32, #map> + %cst = constant 0.000000e+00 : f32 + %c24 = constant 24 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + // CHECK: linalg.tiled_loop + // CHECK-SAME: outs (%{{.*}} = %{{.*}}: memref<192xf32>) + linalg.tiled_loop (%arg3) = (%c0) to (%c192) step (%c24) + outs (%out = %0: memref<192xf32, #map>) { + %14 = affine.min affine_map<(d0) -> (-d0 + 192, 24)>(%arg3) + %16 = memref.subview %out[%arg3] [%14] [1] + : memref<192xf32, #map> to memref + linalg.fill(%16, %cst) : memref, f32 + linalg.yield + } + return +} + +// ----- + func @collapsing_tensor_reshapes(%arg0 : tensor) -> tensor { %0 = linalg.tensor_reshape %arg0 @@ -889,6 +914,30 @@ // ----- +#map0 = affine_map<(d0) -> (24, -d0 + 192)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> +#map2 = affine_map<(d0) -> (16, -d0 + 192)> + +func private @foo(%A: memref<192xf32>) -> () + +func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>) { + %c0 = constant 0 : index + %c24 = constant 24 : index + %c192 = constant 192 : index + linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24) + ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) { + call @foo(%A_) : (memref<192xf32>)-> () + linalg.yield + } + return +} + +// CHECK-LABEL: func @fold_tiled_loop_inputs +// CHECK: linalg.tiled_loop +// CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>) + +// ----- + func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index, %arg3: f32) -> (index, index, index) {