diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2286,6 +2286,44 @@ } }; +} // namespace + +/// A simple, conservative analysis to determine if the loop is shape +/// conserving. I.e., the type of the arg-th yielded value is the same as the +/// type of the corresponding basic block argument of the loop. +/// Note: This function handles only simple cases. Expand as needed. +static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) { + auto yieldOp = cast(loopOp.getLoopBody().front().getTerminator()); + if (yieldOp.values().empty()) + // Tiled loop either has no outputs or is a "memref-based version". In + // either case, the loop is shape conserving. + return true; + assert(arg < static_cast(yieldOp.values().size()) && + "arg is out of bounds"); + Value value = yieldOp.values()[arg]; + while (value) { + if (value == loopOp.getRegionOutputArgs()[arg]) + return true; + OpResult opResult = value.dyn_cast(); + if (!opResult) + return false; + + using tensor::InsertSliceOp; + value = llvm::TypeSwitch(opResult.getOwner()) + .template Case( + [&](InsertSliceOp op) { return op.dest(); }) + .template Case([&](TiledLoopOp loopOp) { + return isShapePreserving(loopOp, opResult.getResultNumber()) + ? loopOp.outputs()[opResult.getResultNumber()] + : Value(); + }) + .Default([&](auto op) { return Value(); }); + } + return false; +} + +namespace { + /// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block /// to dim(y) where `y` is the initial input/output value of the argument. /// @@ -2307,40 +2345,6 @@ struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - /// A simple, conservative analysis to determine if the loop is shape - /// conserving. I.e., the type of the arg-th yielded value is the same as the - /// type of the corresponding basic block argument of the loop. - /// Note: This function handles only simple cases. Expand as needed. - static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) { - auto yieldOp = cast(loopOp.getLoopBody().front().getTerminator()); - if (yieldOp.values().empty()) - // Tiled loop either has no outputs or is a "memref-based version". In - // either case, the loop is shape conserving. - return true; - assert(arg < static_cast(yieldOp.values().size()) && - "arg is out of bounds"); - Value value = yieldOp.values()[arg]; - while (value) { - if (value == loopOp.getRegionOutputArgs()[arg]) - return true; - OpResult opResult = value.dyn_cast(); - if (!opResult) - return false; - - using tensor::InsertSliceOp; - value = llvm::TypeSwitch(opResult.getOwner()) - .template Case( - [&](InsertSliceOp op) { return op.dest(); }) - .template Case([&](TiledLoopOp loopOp) { - return isShapePreserving(loopOp, opResult.getResultNumber()) - ? loopOp.outputs()[opResult.getResultNumber()] - : Value(); - }) - .Default([&](auto op) { return Value(); }); - } - return false; - } - LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const final { auto src = dimOp.source().template dyn_cast(); @@ -2380,6 +2384,45 @@ } }; +/// Fold dim(r) where `r` is the result of a TiledLoopOp to dim(y) where `y` +/// is the initial output value of the loop. +/// +/// E.g.: +/// %y = ... : tensor<...> +/// %r = linalg.tiled_loop ... outs(%i = %y : tensor<...>) { +/// ... +/// } +/// %0 = tensor.dim %r, %c0 : tensor<...> +/// +/// is folded to: +/// %y = ... : tensor<...> +/// linalg.tiled_loop ... outs(%i = %y : tensor<...>) { +/// ... +/// } +/// %0 = tensor.dim %y, %c0 : tensor<...> +/// +/// Note: Dim ops are folded only if it can be proven that the runtime type of +/// the yielded value (in case of outputs) does not change with loop iterations. +template +struct DimOfTiledLoopResultFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy dimOp, + PatternRewriter &rewriter) const final { + auto loopOp = dimOp.source().template getDefiningOp(); + if (!loopOp) + return failure(); + auto opResult = dimOp.source().template cast(); + unsigned resultNumber = opResult.getResultNumber(); + if (!isShapePreserving(loopOp, resultNumber)) + return failure(); + rewriter.updateRootInPlace(dimOp, [&]() { + dimOp.sourceMutable().assign(loopOp.outputs()[resultNumber]); + }); + return success(); + } +}; + // Folds away TiledLoopOp output tensors when the following conditions are met: // * result of `linalg.tiled_loop` has no uses // * output tensor is the argument of `linalg.yield` @@ -2485,7 +2528,9 @@ MLIRContext *context) { results.insert, - DimOfTiledLoopInsOutsFolder>(context); + DimOfTiledLoopInsOutsFolder, + DimOfTiledLoopResultFolder, + DimOfTiledLoopResultFolder>(context); } LogicalResult TiledLoopOp::fold(ArrayRef, diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -956,3 +956,51 @@ } return %r : tensor } + +// ----- + +// CHECK-LABEL: func @dim_of_tiled_loop_result( +// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK: tensor.dim %[[arg2]], %[[c0]] +func @dim_of_tiled_loop_result(%arg0: tensor, %arg1: tensor, %arg2: tensor, %s: index) + -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0) + to (%d0, %d1) step (%c1, %c1) + ins (%in0 = %arg0 : tensor, %in1 = %arg1 : tensor) + outs (%out1 = %arg2 : tensor) { + %1 = tensor.insert_slice %arg0 into %out1 [0, 0] [%s, %s] [1, 1] : tensor into tensor + linalg.yield %1 : tensor + } + %r2 = tensor.dim %r, %c0 : tensor + return %r2 : index +} + +// ----- + +// CHECK-LABEL: func @dim_of_tiled_loop_result_no_canonicalize( +// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK: %[[r:.*]] = linalg.tiled_loop +// CHECK: tensor.dim %[[r]], %[[c0]] +func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor, %arg1: tensor, %arg2: tensor, %s: index) + -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0) + to (%d0, %d1) step (%c1, %c1) + ins (%in0 = %arg0 : tensor, %in1 = %arg1 : tensor) + outs (%out1 = %arg2 : tensor) { + %1 = tensor.insert_slice %arg0 into %arg1 [0, 0] [%s, %s] [1, 1] : tensor into tensor + linalg.yield %1 : tensor + } + %r2 = tensor.dim %r, %c0 : tensor + return %r2 : index +} +