diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -25,6 +25,37 @@ using namespace mlir; using namespace mlir::scf; +/// A simple, conservative analysis to determine if the loop is shape +/// conserving. I.e., the type of the arg-th yielded value is the same as the +/// type of the corresponding basic block argument of the loop. +/// Note: This function handles only simple cases. Expand as needed. +static bool isShapePreserving(ForOp forOp, int64_t arg) { + auto yieldOp = cast(forOp.getBody()->getTerminator()); + assert(arg < static_cast(yieldOp.results().size()) && + "arg is out of bounds"); + Value value = yieldOp.results()[arg]; + while (value) { + if (value == forOp.getRegionIterArgs()[arg]) + return true; + OpResult opResult = value.dyn_cast(); + if (!opResult) + return false; + + using tensor::InsertSliceOp; + value = + llvm::TypeSwitch(opResult.getOwner()) + .template Case( + [&](InsertSliceOp op) { return op.dest(); }) + .template Case([&](ForOp forOp) { + return isShapePreserving(forOp, opResult.getResultNumber()) + ? forOp.getIterOperands()[opResult.getResultNumber()] + : Value(); + }) + .Default([&](auto op) { return Value(); }); + } + return false; +} + namespace { /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: /// @@ -52,37 +83,6 @@ struct DimOfIterArgFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - /// A simple, conservative analysis to determine if the loop is shape - /// conserving. I.e., the type of the arg-th yielded value is the same as the - /// type of the corresponding basic block argument of the loop. - /// Note: This function handles only simple cases. Expand as needed. - static bool isShapePreserving(ForOp forOp, int64_t arg) { - auto yieldOp = cast(forOp.getBody()->getTerminator()); - assert(arg < static_cast(yieldOp.results().size()) && - "arg is out of bounds"); - Value value = yieldOp.results()[arg]; - while (value) { - if (value == forOp.getRegionIterArgs()[arg]) - return true; - OpResult opResult = value.dyn_cast(); - if (!opResult) - return false; - - using tensor::InsertSliceOp; - value = - llvm::TypeSwitch(opResult.getOwner()) - .template Case( - [&](InsertSliceOp op) { return op.dest(); }) - .template Case([&](ForOp forOp) { - return isShapePreserving(forOp, opResult.getResultNumber()) - ? forOp.getIterOperands()[opResult.getResultNumber()] - : Value(); - }) - .Default([&](auto op) { return Value(); }); - } - return false; - } - LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { auto blockArg = dimOp.source().template dyn_cast(); @@ -102,6 +102,48 @@ }; }; +/// Fold dim ops of loop results to dim ops of their respective init args. E.g.: +/// +/// ``` +/// %0 = ... : tensor +/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor) { +/// ... +/// } +/// %1 = tensor.dim %r, %c0 : tensor +/// ``` +/// +/// is folded to: +/// +/// ``` +/// %0 = ... : tensor +/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor) { +/// ... +/// } +/// %1 = tensor.dim %0, %c0 : tensor +/// ``` +/// +/// Note: Dim ops are folded only if it can be proven that the runtime type of +/// the iter arg does not change with loop iterations. +template +struct DimOfLoopResultFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy dimOp, + PatternRewriter &rewriter) const override { + auto forOp = dimOp.source().template getDefiningOp(); + if (!forOp) + return failure(); + auto opResult = dimOp.source().template cast(); + unsigned resultNumber = opResult.getResultNumber(); + if (!isShapePreserving(forOp, resultNumber)) + return failure(); + rewriter.updateRootInPlace(dimOp, [&]() { + dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]); + }); + return success(); + } +}; + /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for /// and scf.parallel loops with a known range. template @@ -152,11 +194,12 @@ void mlir::scf::populateSCFForLoopCanonicalizationPatterns( RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); - patterns - .insert, - AffineOpSCFCanonicalizationPattern, - DimOfIterArgFolder, - DimOfIterArgFolder>(ctx); + patterns.insert< + AffineOpSCFCanonicalizationPattern, + AffineOpSCFCanonicalizationPattern, + DimOfIterArgFolder, DimOfIterArgFolder, + DimOfLoopResultFolder, + DimOfLoopResultFolder>(ctx); } std::unique_ptr mlir::createSCFForLoopCanonicalizationPass() { diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir --- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir +++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir @@ -313,3 +313,38 @@ } return %1 : index } + +// ----- + +// CHECK-LABEL: func @tensor_dim_of_loop_result( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: tensor.dim %[[t]] +func @tensor_dim_of_loop_result(%t : tensor) -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t) + -> (tensor) { + scf.yield %arg0 : tensor + } + %dim = tensor.dim %0, %c0 : tensor + return %dim : index +} + +// ----- + +// CHECK-LABEL: func @tensor_dim_of_loop_result_no_canonicalize( +// CHECK: %[[loop:.*]]:2 = scf.for +// CHECK: tensor.dim %[[loop]]#1 +func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor, + %u : tensor) -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %u) + -> (tensor, tensor) { + scf.yield %arg0, %u : tensor, tensor + } + %dim = tensor.dim %1, %c0 : tensor + return %dim : index +}