Index: mlir/lib/Dialect/Affine/IR/AffineOps.cpp =================================================================== --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1657,6 +1657,16 @@ } namespace { +/// Returns constant trip count in trivial cases. +static Optional getTrivialConstantTripCount(AffineForOp forOp) { + int64_t step = forOp.getStep(); + if (!forOp.hasConstantBounds() || step <= 0) + return None; + int64_t lb = forOp.getConstantLowerBound(); + int64_t ub = forOp.getConstantUpperBound(); + return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step; +} + /// This is a pattern to fold trivially empty loop bodies. /// TODO: This should be moved into the folding hook. struct AffineForEmptyLoopFolder : public OpRewritePattern { @@ -1667,8 +1677,46 @@ // Check that the body only contains a yield. if (!llvm::hasSingleElement(*forOp.getBody())) return failure(); - // The initial values of the iteration arguments would be the op's results. - rewriter.replaceOp(forOp, forOp.getIterOperands()); + if (forOp.getNumResults() == 0) + return success(); + Optional tripCount = getTrivialConstantTripCount(forOp); + if (tripCount.hasValue() && tripCount.getValue() == 0) { + // The initial values of the iteration arguments would be the op's + // results. + rewriter.replaceOp(forOp, forOp.getIterOperands()); + return success(); + } + SmallVector replacements; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + auto iterArgs = forOp.getRegionIterArgs(); + bool hasValDefinedOutsideLoop = false; + bool iterArgsNotInOrder = false; + for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) { + Value val = yieldOp.getOperand(i); + auto iterArgIt = llvm::find(iterArgs, val); + if (iterArgIt == iterArgs.end()) { + // `val` is defined outside of the loop. + assert(forOp.isDefinedOutsideOfLoop(val) && + "must be defined outside of the loop"); + hasValDefinedOutsideLoop = true; + replacements.push_back(val); + } else { + unsigned pos = std::distance(iterArgs.begin(), iterArgIt); + if (pos != i) + iterArgsNotInOrder = true; + replacements.push_back(forOp.getIterOperands()[pos]); + } + } + // Bail out when the trip count is unknown and the loop returns any value + // defined outside of the loop or any iterArg out of order. + if (!tripCount.hasValue() && + (hasValDefinedOutsideLoop || iterArgsNotInOrder)) + return failure(); + // Bail out when the loop iterates more than once and it returns any iterArg + // out of order. + if (tripCount.hasValue() && tripCount.getValue() >= 2 && iterArgsNotInOrder) + return failure(); + rewriter.replaceOp(forOp, replacements); return success(); } }; @@ -1681,11 +1729,10 @@ /// Returns true if the affine.for has zero iterations in trivial cases. static bool hasTrivialZeroTripCount(AffineForOp op) { - if (!op.hasConstantBounds()) - return false; - int64_t lb = op.getConstantLowerBound(); - int64_t ub = op.getConstantUpperBound(); - return ub - lb <= 0; + Optional tripCount = getTrivialConstantTripCount(op); + if (tripCount.hasValue() && tripCount.getValue() == 0) + return true; + return false; } LogicalResult AffineForOp::fold(ArrayRef operands, Index: mlir/test/Dialect/Affine/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Affine/canonicalize.mlir +++ mlir/test/Dialect/Affine/canonicalize.mlir @@ -475,6 +475,112 @@ // ----- +// CHECK-LABEL: func @fold_empty_loop() +func @fold_empty_loop() -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) { + affine.yield %c2, %arg1 : index, index + } + // CHECK-DAG: %[[one:.*]] = arith.constant 1 + // CHECK-DAG: %[[two:.*]] = arith.constant 2 + // CHECK-NEXT: return %[[two]], %[[one]] + return %res#0, %res#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @fold_empty_loops_trip_count_1() +func @fold_empty_loops_trip_count_1() -> (index, index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %res1:2 = affine.for %i = 0 to 1 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) { + affine.yield %c1, %arg0 : index, index + } + %res2:2 = affine.for %i = 0 to 2 step 3 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) { + affine.yield %arg1, %arg0 : index, index + } + // CHECK-DAG: %[[zero:.*]] = arith.constant 0 + // CHECK-DAG: %[[one:.*]] = arith.constant 1 + // CHECK-DAG: %[[two:.*]] = arith.constant 2 + // CHECK-NEXT: return %[[one]], %[[two]], %[[zero]], %[[two]] + return %res1#0, %res1#1, %res2#0, %res2#1 : index, index, index, index +} + +// ----- + +// CHECK-LABEL: func @fold_empty_loop_trip_count_0() +func @fold_empty_loop_trip_count_0() -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %res:2 = affine.for %i = 0 to 0 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) { + affine.yield %c1, %arg0 : index, index + } + // CHECK-DAG: %[[zero:.*]] = arith.constant 0 + // CHECK-DAG: %[[two:.*]] = arith.constant 2 + // CHECK-NEXT: return %[[two]], %[[zero]] + return %res#0, %res#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @fold_empty_loop_trip_count_unknown +func @fold_empty_loop_trip_count_unknown(%in : index) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) { + affine.yield %arg0, %arg1 : index, index + } + // CHECK-DAG: %[[zero:.*]] = arith.constant 0 + // CHECK-DAG: %[[one:.*]] = arith.constant 1 + // CHECK-NEXT: return %[[zero]], %[[one]] + return %res#0, %res#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @empty_loops_not_folded_1 +func @empty_loops_not_folded_1(%in : index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: affine.for + %res = affine.for %i = 0 to %in iter_args(%arg = %c0) -> index { + affine.yield %c1 : index + } + return %res : index +} + +// ----- + +// CHECK-LABEL: func @empty_loops_not_folded_2 +func @empty_loops_not_folded_2(%in : index) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: affine.for + %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) { + affine.yield %arg1, %arg0 : index, index + } + return %res#0, %res#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @empty_loops_not_folded_3 +func @empty_loops_not_folded_3() -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: affine.for + %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) { + affine.yield %arg1, %arg0 : index, index + } + return %res#0, %res#1 : index, index +} + +// ----- + // CHECK-LABEL: func @fold_zero_iter_loops // CHECK-SAME: %[[ARG:.*]]: index func @fold_zero_iter_loops(%in : index) -> index {