diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -8,7 +8,6 @@ MLIRSCFOpsIncGen LINK_LIBS PUBLIC - MLIRAffineDialect MLIRArithmeticDialect MLIRBufferizationDialect MLIRControlFlowDialect diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -7,8 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -721,17 +719,21 @@ /// Returns llvm::None when the difference between two AffineValueMap is /// dynamic. static Optional computeConstDiff(Value l, Value u) { - auto alb = l.getDefiningOp(); - auto aub = u.getDefiningOp(); - // ID map: (d0)->d0 - auto id = AffineMap::getMultiDimIdentityMap(1, l.getContext()); - auto lb = alb ? alb.getAffineValueMap() : AffineValueMap(id, l); - auto ub = aub ? aub.getAffineValueMap() : AffineValueMap(id, u); - - AffineValueMap diffMap; - AffineValueMap::difference(ub, lb, &diffMap); - if (auto constDiff = diffMap.getResult(0).dyn_cast()) - return constDiff.getValue(); + auto clb = l.getDefiningOp(); + auto cub = u.getDefiningOp(); + if (cub && clb) { + llvm::APInt lbValue = clb.getValue().cast().getValue(); + llvm::APInt ubValue = cub.getValue().cast().getValue(); + return (ubValue - lbValue).getSExtValue(); + } + + // Else a simple pattern match for x + c or c + x + llvm::APInt diff; + if (matchPattern( + u, m_Op(matchers::m_Val(l), m_ConstantInt(&diff))) || + matchPattern( + u, m_Op(m_ConstantInt(&diff), matchers::m_Val(l)))) + return diff.getSExtValue(); return llvm::None; } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -722,14 +722,13 @@ // ----- -#map = affine_map<(d0) -> (d0 + 1)> // CHECK-LABEL: func @replace_single_iteration_const_diff( // CHECK-SAME: %[[A0:.*]]: index) func.func @replace_single_iteration_const_diff(%arg0 : index) { // CHECK-NEXT: %[[CST:.*]] = arith.constant 2 %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %5 = affine.apply #map(%arg0) + %5 = arith.addi %arg0, %c1 : index // CHECK-NOT: scf.for scf.for %arg2 = %arg0 to %5 step %c1 { // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[A0]], %[[CST]] diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2678,7 +2678,6 @@ ), includes = ["include"], deps = [ - ":AffineDialect", ":ArithmeticDialect", ":ArithmeticUtils", ":BufferizationDialect",