diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -147,17 +147,16 @@ passed as additional SSA operands to the "scf.for" following the 3 loop control SSA values mentioned above (lower bound, upper bound and step). The operation region has an argument for the induction variable, followed by - one argument for each loop-carried variable, representing he value of the + one argument for each loop-carried variable, representing the value of the variable at the current iteration. The region must terminate with a "scf.yield" that passes the current - values of loop-carried variables to the next iteration, or to the "scf.for" - result, if at the last iteration. The type (static or dynamic) of a - loop-carried variable may not change with iterations. E.g., it is illegal - to pass a tensor of larger size to the next iteration; even if the tensor's - dimensions are dynamic (i.e., same static type). Note, that when the - loop-carried variables are present, calling ForOp::build will not insert the - terminator implicitly. The caller must insert "scf.yield" in that case. + values of all loop-carried variables to the next iteration, or to the + "scf.for" result, if at the last iteration. The static type of a + loop-carried variable may not change with iterations; its runtime type is + allowed to change. Note, that when the loop-carried variables are present, + calling ForOp::build will not insert the terminator implicitly. The caller + must insert "scf.yield" in that case. "scf.for" results hold the final values after the last iteration. For example, to sum-reduce a memref: diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::scf; @@ -44,10 +45,44 @@ /// ... /// } /// ``` +/// +/// Note: Dim ops are folded only if it can be proven that the runtime type of +/// the iter arg does not change with loop iterations. template struct DimOfIterArgFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + /// A simple, conservative analysis to determine if the loop is shape + /// conserving. I.e., the type of the arg-th yielded value is the same as the + /// type of the corresponding basic block argument of the loop. + /// Note: This function handles only simple cases. Expand as needed. + static bool isShapePreserving(ForOp forOp, int64_t arg) { + auto yieldOp = cast(forOp.getBody()->getTerminator()); + assert(arg < static_cast(yieldOp.results().size()) && + "arg is out of bounds"); + Value value = yieldOp.results()[arg]; + while (value) { + if (value == forOp.getRegionIterArgs()[arg]) + return true; + OpResult opResult = value.dyn_cast(); + if (!opResult) + return false; + + using tensor::InsertSliceOp; + value = + llvm::TypeSwitch(opResult.getOwner()) + .template Case( + [&](InsertSliceOp op) { return op.dest(); }) + .template Case([&](ForOp forOp) { + return isShapePreserving(forOp, opResult.getResultNumber()) + ? forOp.getIterOperands()[opResult.getResultNumber()] + : Value(); + }) + .Default([&](auto op) { return Value(); }); + } + return false; + } + LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { auto blockArg = dimOp.source().template dyn_cast(); @@ -56,6 +91,8 @@ auto forOp = dyn_cast(blockArg.getParentBlock()->getParentOp()); if (!forOp) return failure(); + if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1)) + return failure(); Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); rewriter.updateRootInPlace( diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir --- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir +++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir @@ -242,3 +242,74 @@ } return %1 : index } + +// ----- + +// CHECK-LABEL: func @tensor_dim_of_iter_arg_insertslice( +// CHECK-SAME: %[[t:.*]]: tensor, +// CHECK: scf.for +// CHECK: tensor.dim %[[t]] +func @tensor_dim_of_iter_arg_insertslice(%t : tensor, + %t2 : tensor) -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0) + -> (tensor, index) { + %dim = tensor.dim %arg0, %c0 : tensor + %2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1] + : tensor into tensor + %3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1] + : tensor into tensor + scf.yield %3, %dim : tensor, index + } + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @tensor_dim_of_iter_arg_nested_for( +// CHECK-SAME: %[[t:.*]]: tensor, +// CHECK: scf.for +// CHECK: scf.for +// CHECK: tensor.dim %[[t]] +func @tensor_dim_of_iter_arg_nested_for(%t : tensor, + %t2 : tensor) -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0) + -> (tensor, index) { + %2, %3 = scf.for %j = %c0 to %c10 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) + -> (tensor, index) { + %dim = tensor.dim %arg2, %c0 : tensor + %4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1] + : tensor into tensor + scf.yield %4, %dim : tensor, index + } + scf.yield %2, %3 : tensor, index + } + return %1 : index +} + +// ----- + +// A test case that should not canonicalize because the loop is not shape +// conserving. + +// CHECK-LABEL: func @tensor_dim_of_iter_arg_no_canonicalize( +// CHECK-SAME: %[[t:.*]]: tensor, +// CHECK: scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[t]] +// CHECK: tensor.dim %[[arg0]] +func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor, + %t2 : tensor) -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0) + -> (tensor, index) { + %dim = tensor.dim %arg0, %c0 : tensor + scf.yield %t2, %dim : tensor, index + } + return %1 : index +}