diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -591,9 +591,10 @@ } } - // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This - // leads to a much more efficient form when 'c' is a power of two, and in - // general a more compact and readable form. + // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where + // q may be a constant or symbolic expression. This leads to a much more + // efficient form when 'c' is a power of two, and in general a more compact + // and readable form. // Process '(expr floordiv c) * (-c)'. if (!rBinOpExpr) @@ -602,13 +603,33 @@ auto lrhs = rBinOpExpr.getLHS(); auto rrhs = rBinOpExpr.getRHS(); + AffineExpr llrhs, rlrhs; + + // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a + // symbolic expression. + auto lrhsBinOpExpr = lrhs.dyn_cast(); + // Check rrhsConstOpExpr = -1. + auto rrhsConstOpExpr = rrhs.dyn_cast(); + if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr && + lrhsBinOpExpr.getKind() == AffineExprKind::Mul) { + // Check llrhs = expr floordiv q. + llrhs = lrhsBinOpExpr.getLHS(); + // Check rlrhs = q. + rlrhs = lrhsBinOpExpr.getRHS(); + auto llrhsBinOpExpr = llrhs.dyn_cast(); + if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv) + return nullptr; + if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS()) + return lhs % rlrhs; + } + // Process lrhs, which is 'expr floordiv c'. AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast(); if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) return nullptr; - auto llrhs = lrBinOpExpr.getLHS(); - auto rlrhs = lrBinOpExpr.getRHS(); + llrhs = lrBinOpExpr.getLHS(); + rlrhs = lrBinOpExpr.getRHS(); if (lhs == llrhs && rlrhs == -rrhs) { return lhs % rlrhs; diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -533,3 +533,17 @@ // CHECK-NEXT: %[[RESULT0:.*]] = affine.apply #[[$PRODUCT]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]] // CHECK-NEXT: %[[RESULT1:.*]] = affine.apply #[[$SUM_OF_PRODUCTS]]()[%[[ARG3]], %[[ARG4]], %[[ARG0]], %[[ARG1]], %[[ARG2]]] // CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] + +// ----- + +// CHECK-DAG: #[[$SIMPLIFIED_MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> ((-s0 + s2 + s3) mod (s0 + s1))> +// CHECK-LABEL: func @semi_affine_simplification_euclidean_lemma +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index) +func @semi_affine_simplification_euclidean_lemma(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) { + %a = affine.apply affine_map<(d0, d1)[s0, s1] -> ((d0 + d1) - ((d0 + d1) floordiv (s0 - s1)) * (s0 - s1) - (d0 + d1) mod (s0 - s1))>(%arg0, %arg1)[%arg2, %arg3] + %b = affine.apply affine_map<(d0, d1)[s0, s1] -> ((d0 + d1 - s0) - ((d0 + d1 - s0) floordiv (s0 + s1)) * (s0 + s1))>(%arg0, %arg1)[%arg2, %arg3] + return %a, %b : index, index +} +// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[RESULT:.*]] = affine.apply #[[$SIMPLIFIED_MAP]]()[%[[ARG2]], %[[ARG3]], %[[ARG0]], %[[ARG1]]] +// CHECK-NEXT: return %[[ZERO]], %[[RESULT]]