diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -188,6 +188,7 @@ public: using ImplType = detail::AffineConstantExprStorage; /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr); + AffineConstantExpr() : AffineConstantExpr(nullptr) {} int64_t getValue() const; }; diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -14,6 +14,8 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::detail; @@ -314,6 +316,39 @@ return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); } + // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr". + // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their + // respective multiplicands. + Optional rLhsConst, rRhsConst; + AffineExpr firstExpr, secondExpr; + AffineConstantExpr rLhsConstExpr; + auto lBinOpExpr = lhs.dyn_cast(); + if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul && + (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast())) { + rLhsConst = rLhsConstExpr.getValue(); + firstExpr = lBinOpExpr.getLHS(); + } else { + rLhsConst = 1; + firstExpr = lhs; + } + + auto rBinOpExpr = rhs.dyn_cast(); + AffineConstantExpr rRhsConstExpr; + if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul && + (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast())) { + rRhsConst = rRhsConstExpr.getValue(); + secondExpr = rBinOpExpr.getLHS(); + } else { + rRhsConst = 1; + secondExpr = rhs; + } + + if (rLhsConst && rRhsConst && firstExpr == secondExpr) + return getAffineBinaryOpExpr( + AffineExprKind::Mul, firstExpr, + getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(), + lhs.getContext())); + // When doing successive additions, bring constant to the right: turn (d0 + 2) // + d1 into (d0 + d1) + 2. if (lBin && lBin.getKind() == AffineExprKind::Add) { @@ -327,7 +362,6 @@ // general a more compact and readable form. // Process '(expr floordiv c) * (-c)'. - AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast(); if (!rBinOpExpr) return nullptr; diff --git a/mlir/test/Dialect/AffineOps/canonicalize.mlir b/mlir/test/Dialect/AffineOps/canonicalize.mlir --- a/mlir/test/Dialect/AffineOps/canonicalize.mlir +++ b/mlir/test/Dialect/AffineOps/canonicalize.mlir @@ -448,7 +448,7 @@ // ----- // CHECK-DAG: [[LBMAP:#map[0-9]+]] = affine_map<()[s0] -> (0, s0)> -// CHECK-DAG: [[UBMAP:#map[0-9]+]] = affine_map<()[s0] -> (1024, s0 + s0)> +// CHECK-DAG: [[UBMAP:#map[0-9]+]] = affine_map<()[s0] -> (1024, s0 * 2)> // CHECK-LABEL: func @canonicalize_bounds // CHECK-SAME: [[M:%.*]]: index, diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir --- a/mlir/test/IR/affine-map.mlir +++ b/mlir/test/IR/affine-map.mlir @@ -33,7 +33,7 @@ // The following reduction should be unique'd out too but such expression // simplification is not performed for IR parsing, but only through analyses // and transforms. -// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d1 - d0 + (d0 - d1 + 1) * 2 + d1 - 1, d1 + d1 + d1 + d1 + 2)> +// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d1 - d0 + (d0 - d1 + 1) * 2 + d1 - 1, d1 * 4 + 2)> #map3l = affine_map<(i, j) -> ((j - i) + 2*(i - j + 1) + j - 1 + 0, j + j + 1 + j + j + 1)> // CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 + 2, d1)> @@ -183,6 +183,12 @@ // CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0, d0 * 2 + d1 * 4 + 2, 1, 2, (d0 * 4) mod 8)> #map56 = affine_map<(d0, d1) -> ((4*d0 + 2) floordiv 4, (4*d0 + 8*d1 + 5) floordiv 2, (2*d0 + 4*d1 + 3) mod 2, (3*d0 - 4) mod 3, (4*d0 + 8*d1) mod 8)> +// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d1, d0, 0)> +#map57 = affine_map<(d0, d1) -> (d0 - d0 + d1, -d0 + d0 + d0, (1 + d0 + d1 floordiv 4) - (d0 + d1 floordiv 4 + 1))> + +// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 * 3, (d0 + d1) * 2, d0 mod 2)> +#map58 = affine_map<(d0, d1) -> (4*d0 - 2*d0 + d0, (d0 + d1) + (d0 + d1), 2 * (d0 mod 2) - d0 mod 2)> + // Single identity maps are removed. // CHECK: func @f0(memref<2x4xi8, 1>) func @f0(memref<2x4xi8, #map0, 1>) @@ -361,3 +367,9 @@ // CHECK: func @f56(memref<1x1xi8, #map{{[0-9]+}}>) func @f56(memref<1x1xi8, #map56>) + +// CHECK: "f57"() {map = #map{{[0-9]+}}} : () -> () +"f57"() {map = #map57} : () -> () + +// CHECK: "f58"() {map = #map{{[0-9]+}}} : () -> () +"f58"() {map = #map58} : () -> ()