Index: mlir/lib/IR/AffineExpr.cpp =================================================================== --- mlir/lib/IR/AffineExpr.cpp +++ mlir/lib/IR/AffineExpr.cpp @@ -314,6 +314,30 @@ return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); } + // Detect "expr + expr * (-1)" and "expr * -1 + expr" as 0. + { + // expr + expr * -1 = 0. + auto lBinOpExpr = lhs.dyn_cast(); + if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul) { + auto llhs = lBinOpExpr.getLHS(); + auto rlhs = lBinOpExpr.getRHS(); + auto rlhsConst = rlhs.dyn_cast(); + if (llhs == rhs && rlhsConst && rlhsConst.getValue() == -1) + return getAffineConstantExpr(0, lhs.getContext()); + } + } + auto rBinOpExpr = rhs.dyn_cast(); + { + // expr * -1 + expr = 0. + if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul) { + auto lrhs = rBinOpExpr.getLHS(); + auto rrhs = rBinOpExpr.getRHS(); + auto rrhsConst = rrhs.dyn_cast(); + if (lrhs == lhs && rrhsConst && rrhsConst.getValue() == -1) + return getAffineConstantExpr(0, lhs.getContext()); + } + } + // When doing successive additions, bring constant to the right: turn (d0 + 2) // + d1 into (d0 + d1) + 2. if (lBin && lBin.getKind() == AffineExprKind::Add) { @@ -327,7 +351,6 @@ // general a more compact and readable form. // Process '(expr floordiv c) * (-c)'. - AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast(); if (!rBinOpExpr) return nullptr; Index: mlir/test/IR/affine-map.mlir =================================================================== --- mlir/test/IR/affine-map.mlir +++ mlir/test/IR/affine-map.mlir @@ -183,6 +183,9 @@ // CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0, d0 * 2 + d1 * 4 + 2, 1, 2, (d0 * 4) mod 8)> #map56 = affine_map<(d0, d1) -> ((4*d0 + 2) floordiv 4, (4*d0 + 8*d1 + 5) floordiv 2, (2*d0 + 4*d1 + 3) mod 2, (3*d0 - 4) mod 3, (4*d0 + 8*d1) mod 8)> +// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d1, d0, 0)> +#map57 = affine_map<(d0, d1) -> (d0 - d0 + d1, -d0 + d0 + d0, (1 + d0 + d1 floordiv 4) - (d0 + d1 floordiv 4 + 1))> + // Single identity maps are removed. // CHECK: func @f0(memref<2x4xi8, 1>) func @f0(memref<2x4xi8, #map0, 1>) @@ -361,3 +364,6 @@ // CHECK: func @f56(memref<1x1xi8, #map{{[0-9]+}}>) func @f56(memref<1x1xi8, #map56>) + +// CHECK: "f57"() {map = #map{{[0-9]+}}} : () -> () +"f57"() {map = #map57} : () -> ()