diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRSCFOpsIncGen LINK_LIBS PUBLIC + MLIRAffineDialect MLIRArithmeticDialect MLIRBufferizationDialect MLIRControlFlowDialect diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -715,6 +717,24 @@ } }; +/// Util function that tries to compute a constant diff between u and l. +/// Returns llvm::None when the difference between two AffineValueMap is +/// dynamic. +static Optional computeConstDiff(Value l, Value u) { + auto alb = l.getDefiningOp(); + auto aub = u.getDefiningOp(); + // ID map: (d0)->d0 + auto id = AffineMap::getMultiDimIdentityMap(1, l.getContext()); + auto lb = alb ? alb.getAffineValueMap() : AffineValueMap(id, l); + auto ub = aub ? aub.getAffineValueMap() : AffineValueMap(id, u); + + AffineValueMap diffMap; + AffineValueMap::difference(ub, lb, &diffMap); + if (auto constDiff = diffMap.getResult(0).dyn_cast()) + return constDiff.getValue(); + return llvm::None; +} + /// Rewriting pattern that erases loops that are known not to iterate, replaces /// single-iteration loops with their bodies, and removes empty loops that /// iterate at least once and only return values defined outside of the loop. @@ -730,15 +750,13 @@ return success(); } - auto lb = op.getLowerBound().getDefiningOp(); - auto ub = op.getUpperBound().getDefiningOp(); - if (!lb || !ub) + Optional diff = + computeConstDiff(op.getLowerBound(), op.getUpperBound()); + if (!diff) return failure(); // If the loop is known to have 0 iterations, remove it. - llvm::APInt lbValue = lb.getValue().cast().getValue(); - llvm::APInt ubValue = ub.getValue().cast().getValue(); - if (lbValue.sge(ubValue)) { + if (*diff <= 0) { rewriter.replaceOp(op, op.getIterOperands()); return success(); } @@ -750,7 +768,7 @@ // If the loop is known to have 1 iteration, inline its body and remove the // loop. llvm::APInt stepValue = step.getValue().cast().getValue(); - if ((lbValue + stepValue).sge(ubValue)) { + if (stepValue.sge(*diff)) { SmallVector blockArgs; blockArgs.reserve(op.getNumIterOperands() + 1); blockArgs.push_back(op.getLowerBound()); diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -324,7 +324,7 @@ // CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]] // CHECK: scf.if {{.*}} { // CHECK: "test.op"() : () -> () -// CHECK: } +// CHECK: } // CHECK: return [[V0]] : index // ----- @@ -547,7 +547,7 @@ scf.yield %0, %1 : i32, f32 } } - return + return } // CHECK-LABEL: @merge_yielding_nested_if_nv2 @@ -557,7 +557,7 @@ // CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32 // CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] // CHECK: %[[RES:.*]] = arith.select %[[ARG0]], %[[PRE0]], %[[PRE1]] -// CHECK: scf.if %[[COND]] +// CHECK: scf.if %[[COND]] // CHECK: "test.run"() : () -> () // CHECK: } // CHECK: return %[[RES]] @@ -719,6 +719,27 @@ return } + +// ----- + +#map = affine_map<(d0) -> (d0 + 1)> +// CHECK-LABEL: func @replace_single_iteration_const_diff( +// CHECK-SAME: %[[A0:.*]]: index) +func.func @replace_single_iteration_const_diff(%arg0 : index) { + // CHECK-NEXT: %[[CST:.*]] = arith.constant 2 + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %5 = affine.apply #map(%arg0) + // CHECK-NOT: scf.for + scf.for %arg2 = %arg0 to %5 step %c1 { + // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[A0]], %[[CST]] + %7 = arith.muli %c2, %arg2 : index + // CHECK-NEXT: "test.consume"(%[[MUL]]) + "test.consume"(%7) : (index) -> () + } + return +} + // ----- // CHECK-LABEL: @remove_empty_parallel_loop @@ -986,7 +1007,7 @@ // CHECK-NEXT: %[[cmp:.+]] = "test.condition"() : () -> i1 // CHECK-NEXT: scf.condition(%[[cmp]]) %[[cmp]] : i1 // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%arg0: i1): +// CHECK-NEXT: ^bb0(%arg0: i1): // CHECK-NEXT: "test.use"(%[[true]]) : (i1) -> () // CHECK-NEXT: scf.yield // CHECK-NEXT: } @@ -1009,7 +1030,7 @@ // CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1 // CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32 // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[post:.+]]: i32): +// CHECK-NEXT: ^bb0(%[[post:.+]]: i32): // CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32 // CHECK-NEXT: scf.yield %[[next]] : i32 // CHECK-NEXT: } @@ -1105,7 +1126,7 @@ // CHECK-NEXT: %{{.*}} = "test.get_some_value"() : () -> i64 // CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32 // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[arg:.*]]: i32): +// CHECK-NEXT: ^bb0(%[[arg:.*]]: i32): // CHECK-NEXT: "test.use"(%[[arg]]) : (i32) -> () // CHECK-NEXT: scf.yield // CHECK-NEXT: } @@ -1133,7 +1154,7 @@ // CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32 // CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32 // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%arg1: i32): +// CHECK-NEXT: ^bb0(%arg1: i32): // CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> () // CHECK-NEXT: scf.yield // CHECK-NEXT: } @@ -1160,7 +1181,7 @@ // CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %arg0, %[[val]] : i32 // CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32 // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%arg1: i32): +// CHECK-NEXT: ^bb0(%arg1: i32): // CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> () // CHECK-NEXT: scf.yield // CHECK-NEXT: } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2678,6 +2678,7 @@ ), includes = ["include"], deps = [ + ":AffineDialect", ":ArithmeticDialect", ":ArithmeticUtils", ":BufferizationDialect",