diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -524,6 +524,7 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { auto lhsConst = lhs.dyn_cast(); auto rhsConst = rhs.dyn_cast(); + // Fold if both LHS, RHS are a constant. if (lhsConst && rhsConst) return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), @@ -591,9 +592,10 @@ } } - // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This - // leads to a much more efficient form when 'c' is a power of two, and in - // general a more compact and readable form. + // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where + // q may be a constant or symbolic expression. This leads to a much more + // efficient form when 'c' is a power of two, and in general a more compact + // and readable form. // Process '(expr floordiv c) * (-c)'. if (!rBinOpExpr) @@ -602,13 +604,34 @@ auto lrhs = rBinOpExpr.getLHS(); auto rrhs = rBinOpExpr.getRHS(); + AffineExpr llrhs, rlrhs; + + // Check lrhsBinOpExpr = (expr floordiv q) * q, where q is a symbolic + // expression. + auto lrhsBinOpExpr = lrhs.dyn_cast(); + // Check rrhsConstOpExpr = -1. + auto rrhsConstOpExpr = rrhs.dyn_cast(); + if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr && + lrhsBinOpExpr.getKind() == AffineExprKind::Mul) { + // Check llrhs = expr floordiv q. + llrhs = lrhsBinOpExpr.getLHS(); + // Check rlrhs = q. + rlrhs = lrhsBinOpExpr.getRHS(); + auto llrhsBinOpExpr = llrhs.dyn_cast(); + if (!llrhsBinOpExpr || (llrhsBinOpExpr && llrhsBinOpExpr.getKind() != + AffineExprKind::FloorDiv)) + return nullptr; + if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS()) + return lhs % rlrhs; + } + // Process lrhs, which is 'expr floordiv c'. AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast(); if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) return nullptr; - auto llrhs = lrBinOpExpr.getLHS(); - auto rlrhs = lrBinOpExpr.getRHS(); + llrhs = lrBinOpExpr.getLHS(); + rlrhs = lrBinOpExpr.getRHS(); if (lhs == llrhs && rlrhs == -rrhs) { return lhs % rlrhs; diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -483,29 +483,34 @@ // ----- // Test simplification of semi affine expressions. - // CHECK-DAG: #[[MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)> - // CHECK-DAG: #[[FLOORDIV:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 + s3 + (s0 - s1) floordiv s2)> - // CHECK-DAG: #[[PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)> - // CHECK-DAG: #[[SIMPLIFIED_FLOORDIV_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 floordiv (s2 - s0 * s1))> - // CHECK-DAG: #[[SIMPLIFIED_CEILDIV_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 ceildiv (s2 - s0 * s1))> - // CHECK-DAG: #[[SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))> - // CHECK-DAG: #[[SORTED_INDICES_EXPR:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 * s0 + s2 + s3 * s0 + s3 * s1 + s3 + s4 * s1 + s4)> - // CHECK: func @semiaffine_simplification(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index) - func @semiaffine_simplification(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index, index, index, index, index, index) { - %a = affine.apply affine_map<(d0, d1)[s0, s1, s2, s3] -> ((-(d1 * s0 - (s0 - s1) mod s2) + s3) + (d0 * s1 + d1 * s0))>(%arg0, %arg1)[%arg2, %arg3, %arg4, %arg5] - %b = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> ((-(d0 * s1 - (s0 - s1) floordiv s2) + s3) + (d0 * s1 + s0))>(%arg0)[%arg1, %arg2, %arg3, %arg4] - %c = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> ((-(s0 - (s0 - s1) * s2) + s3) + (d0 + s0))>(%arg0)[%arg1, %arg2, %arg3, %arg4] - %d = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 floordiv (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] - %e = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 ceildiv (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] - %f = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 mod (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] - %g = affine.apply affine_map<(d0, d1, d2)[s0, s1] -> (d0 + d1 * s1 + d1 + d0 * s0 + d1 * s0 + d2 * s1 + d2)>(%arg0, %arg1, %arg2)[%arg3, %arg4] - return %a, %b, %c, %d, %e, %f, %g : index, index, index, index, index, index, index +// CHECK-DAG: #[[MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)> +// CHECK-DAG: #[[FLOORDIV:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 + s3 + (s0 - s1) floordiv s2)> +// CHECK-DAG: #[[PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)> +// CHECK-DAG: #[[MOD_TRANSFORM:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s2 - s3) mod (s0 + s1))> +// CHECK-DAG: #[[SIMPLIFIED_FLOORDIV_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 floordiv (s2 - s0 * s1))> +// CHECK-DAG: #[[SIMPLIFIED_CEILDIV_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 ceildiv (s2 - s0 * s1))> +// CHECK-DAG: #[[SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))> +// CHECK-DAG: #[[SORTED_INDICES_EXPR:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 * s0 + s2 + s3 * s0 + s3 * s1 + s3 + s4 * s1 + s4)> +// CHECK: func @semiaffine_simplification(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index) +func @semiaffine_simplification(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index, index, index, index, index, index, index, index) { + %a = affine.apply affine_map<(d0, d1)[s0, s1, s2, s3] -> ((-(d1 * s0 - (s0 - s1) mod s2) + s3) + (d0 * s1 + d1 * s0))>(%arg0, %arg1)[%arg2, %arg3, %arg4, %arg5] + %b = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> ((-(d0 * s1 - (s0 - s1) floordiv s2) + s3) + (d0 * s1 + s0))>(%arg0)[%arg1, %arg2, %arg3, %arg4] + %c = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> ((-(s0 - (s0 - s1) * s2) + s3) + (d0 + s0))>(%arg0)[%arg1, %arg2, %arg3, %arg4] + %d = affine.apply affine_map<(d0, d1)[s0, s1] -> ((d0 + d1) - ((d0 + d1) floordiv (s0 - s1)) * (s0 - s1) - (d0 + d1) mod (s0 - s1))>(%arg0, %arg1)[%arg2, %arg3] + %e = affine.apply affine_map<(d0, d1)[s0, s1] -> ((d0 - d1) - ((d0 - d1) floordiv (s0 + s1)) * (s0 + s1))>(%arg0, %arg1)[%arg2, %arg3] + %f = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 floordiv (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] + %g = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 ceildiv (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] + %h = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 mod (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] + %i = affine.apply affine_map<(d0, d1, d2)[s0, s1] -> (d0 + d1 * s1 + d1 + d0 * s0 + d1 * s0 + d2 * s1 + d2)>(%arg0, %arg1, %arg2)[%arg3, %arg4] + return %a, %b, %c, %d, %e, %f, %g, %h, %i : index, index, index, index, index, index, index, index, index } - // CHECK-DAG: %[[RESULT0:.*]] = affine.apply #[[MOD]]()[%[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG0]]] - // CHECK-DAG: %[[RESULT1:.*]] = affine.apply #[[FLOORDIV]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]] - // CHECK-DAG: %[[RESULT2:.*]] = affine.apply #[[PRODUCT]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]] - // CHECK-DAG: %[[RESULT3:.*]] = affine.apply #[[SIMPLIFIED_FLOORDIV_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] - // CHECK-DAG: %[[RESULT4:.*]] = affine.apply #[[SIMPLIFIED_CEILDIV_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] - // CHECK-DAG: %[[RESULT5:.*]] = affine.apply #[[SIMPLIFIED_MOD_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] - // CHECK-DAG: %[[RESULT6:.*]] = affine.apply #[[SORTED_INDICES_EXPR]]()[%[[ARG3]], %[[ARG4]], %[[ARG0]], %[[ARG1]], %[[ARG2]]] - // CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]], %[[RESULT2]], %[[RESULT3]], %[[RESULT4]], %[[RESULT5]], %[[RESULT6]] +// CHECK-NEXT: %[[ZERO:.*]] = constant 0 : index +// CHECK-DAG: %[[RESULT0:.*]] = affine.apply #[[MOD]]()[%[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG0]]] +// CHECK-DAG: %[[RESULT1:.*]] = affine.apply #[[FLOORDIV]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]] +// CHECK-DAG: %[[RESULT2:.*]] = affine.apply #[[PRODUCT]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]] +// CHECK-DAG: %[[RESULT3:.*]] = affine.apply #[[MOD_TRANSFORM]]()[%[[ARG2]], %[[ARG3]], %[[ARG0]], %[[ARG1]]] +// CHECK-DAG: %[[RESULT4:.*]] = affine.apply #[[SIMPLIFIED_FLOORDIV_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] +// CHECK-DAG: %[[RESULT5:.*]] = affine.apply #[[SIMPLIFIED_CEILDIV_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] +// CHECK-DAG: %[[RESULT6:.*]] = affine.apply #[[SIMPLIFIED_MOD_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] +// CHECK-DAG: %[[RESULT7:.*]] = affine.apply #[[SORTED_INDICES_EXPR]]()[%[[ARG3]], %[[ARG4]], %[[ARG0]], %[[ARG1]], %[[ARG2]]] +// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]], %[[RESULT2]], %[[ZERO]], %[[RESULT3]], %[[RESULT4]], %[[RESULT5]], %[[RESULT6]], %[[RESULT7]] \ No newline at end of file