diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -105,7 +105,7 @@ } def AffineForOp : Affine_Op<"for", - [AutomaticAllocationScope, ImplicitAffineTerminator, RecursivelySpeculatable, + [AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable, RecursiveMemoryEffects, DeclareOpInterfaceMethods, @@ -340,6 +340,9 @@ /// Returns true if both the lower and upper bound have the same operand /// lists (same operands in the same order). bool matchingBoundOperandList(); + + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -119,6 +119,7 @@ [AutomaticAllocationScope, DeclareOpInterfaceMethods, + ConditionallySpeculatable, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects]> { @@ -330,6 +331,12 @@ /// induction variable. LoopOp only has one region, so 0 is the only valid /// value for `index`. OperandRange getSuccessorEntryOperands(Optional index); + + /// Returns the step as an `APInt` if it is constant. + Optional getConstantStep(); + + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2293,6 +2293,16 @@ return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound())); } +Speculation::Speculatability AffineForOp::getSpeculatability() { + // `affine.for (I = Start; I < End; I += 1)` terminates for all values of + // Start and End. + // + // For Step != 1, the loop may not terminate. We can add more smarts here if + // needed. + return getStep() == 1 ? Speculation::RecursivelySpeculatable + : Speculation::NotSpeculatable; +} + /// Returns true if the provided value is the induction variable of a /// AffineForOp. bool mlir::isForInductionVar(Value val) { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -762,13 +762,13 @@ return success(); } - IntegerAttr step; - if (!matchPattern(op.getStep(), m_Constant(&step))) + llvm::Optional maybeStepValue = op.getConstantStep(); + if (!maybeStepValue) return failure(); // If the loop is known to have 1 iteration, inline its body and remove the // loop. - llvm::APInt stepValue = step.getValue(); + llvm::APInt stepValue = *maybeStepValue; if (stepValue.sge(*diff)) { SmallVector blockArgs; blockArgs.reserve(op.getNumIterOperands() + 1); @@ -1065,6 +1065,25 @@ LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context); } +Optional ForOp::getConstantStep() { + IntegerAttr step; + if (matchPattern(getStep(), m_Constant(&step))) + return step.getValue(); + return {}; +} + +Speculation::Speculatability ForOp::getSpeculatability() { + // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start + // and End. + if (auto constantStep = getConstantStep()) + if (*constantStep == 1) + return Speculation::RecursivelySpeculatable; + + // For Step != 1, the loop may not terminate. We can add more smarts here if + // needed. + return Speculation::NotSpeculatable; +} + //===----------------------------------------------------------------------===// // ForeachThreadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -103,7 +103,7 @@ %m = memref.alloc() : memref<10xf32> %cf8 = arith.constant 8.0 : f32 affine.for %arg0 = 0 to 10 { - affine.for %arg1 = 0 to 10 { + affine.for %arg1 = 0 to 20 { affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%arg0, %arg0) { %cf9 = arith.addf %cf8, %cf8 : f32 } @@ -112,7 +112,7 @@ // CHECK: memref.alloc() : memref<10xf32> // CHECK-NEXT: %[[CST:.*]] = arith.constant 8.000000e+00 : f32 - // CHECK-NEXT: affine.for %[[ARG:.*]] = 0 to 10 { + // CHECK-NEXT: affine.for %[[ARG:.*]] = 0 to 20 { // CHECK-NEXT: } // CHECK-NEXT: affine.for %[[ARG:.*]] = 0 to 10 { // CHECK-NEXT: affine.if #set(%[[ARG]], %[[ARG]]) { @@ -124,6 +124,96 @@ // ----- +func.func @hoist_affine_for_with_unknown_trip_count(%lb: index, %ub: index) { + affine.for %arg0 = 0 to 10 { + affine.for %arg1 = %lb to %ub { + } + } + + // CHECK: @hoist_affine_for_with_unknown_trip_count(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) { + // CHECK-NEXT: affine.for %[[ARG2:.*]] = %[[ARG0]] to %[[ARG1]] { + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 10 { + // CHECK-NEXT: } + + return +} + +// ----- + +func.func @hoist_affine_for_with_unknown_trip_count_non_unit_step(%lb: index, %ub: index) { + affine.for %arg0 = 0 to 10 { + affine.for %arg1 = %lb to %ub step 2 { + } + } + + // CHECK: @hoist_affine_for_with_unknown_trip_count_non_unit_step(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) { + // CHECK-NEXT: affine.for %[[ARG2:.*]] = 0 to 10 { + // CHECK-NEXT: affine.for %[[ARG3:.*]] = %[[ARG0]] to %[[ARG1]] step 2 { + // CHECK-NEXT: } + // CHECK-NEXT: } + + return +} + +// ----- + +func.func @hoist_scf_for_with_unknown_trip_count_unit_step(%lb: index, %ub: index) { + %c1 = arith.constant 1 : index + scf.for %arg0 = %lb to %ub step %c1 { + scf.for %arg1 = %lb to %ub step %c1 { + } + } + + // CHECK: @hoist_scf_for_with_unknown_trip_count_unit_step(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) { + // CHECK: scf.for %[[ARG2:.*]] = %[[ARG0]] to %[[ARG1]] + // CHECK-NEXT: } + // CHECK-NEXT: scf.for %[[ARG3:.*]] = %[[ARG0]] to %[[ARG1]] + // CHECK-NEXT: } + + return +} + +// ----- + +func.func @hoist_scf_for_with_unknown_trip_count_non_unit_constant_step(%lb: index, %ub: index) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + scf.for %arg0 = %lb to %ub step %c1 { + scf.for %arg1 = %lb to %ub step %c2 { + } + } + + // CHECK: @hoist_scf_for_with_unknown_trip_count_non_unit_constant_step(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) { + // CHECK: scf.for %[[ARG2:.*]] = %[[ARG0]] to %[[ARG1]] + // CHECK-NEXT: scf.for %[[ARG3:.*]] = %[[ARG0]] to %[[ARG1]] + // CHECK-NEXT: } + // CHECK-NEXT: } + + return +} + +// ----- + +func.func @hoist_scf_for_with_unknown_trip_count_unknown_step(%lb: index, %ub: index, %step: index) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + scf.for %arg0 = %lb to %ub step %c1 { + scf.for %arg1 = %lb to %ub step %step { + } + } + + // CHECK: @hoist_scf_for_with_unknown_trip_count_unknown_step(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[STEP:.*]]: index) { + // CHECK: scf.for %[[ARG2:.*]] = %[[ARG0]] to %[[ARG1]] + // CHECK-NEXT: scf.for %[[ARG3:.*]] = %[[ARG0]] to %[[ARG1]] step %[[STEP]] + // CHECK-NEXT: } + // CHECK-NEXT: } + + return +} + +// ----- + func.func @invariant_affine_if2() { %m = memref.alloc() : memref<10xf32> %cf8 = arith.constant 8.0 : f32