diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -934,6 +934,118 @@ } } +/// Simplify the expressions in `map` while making use of lower or upper bounds +/// of its operands. If `isMax` is true, the map is to be treated as a max of +/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 + +/// d1) can be simplified to (8) if the operands are respectively lower bounded +/// by 2 and 0 (the second expression can't be lower than 8). +static void simplifyMinOrMaxExprWithOperands(AffineMap &map, + ArrayRef operands, + bool isMax) { + // Can't simplify. + if (operands.empty()) + return; + + // Get the upper or lower bound on an affine.for op IV using its range. + // Get the constant lower or upper bounds on the operands. + SmallVector> constLowerBounds, constUpperBounds; + constLowerBounds.reserve(operands.size()); + constUpperBounds.reserve(operands.size()); + for (Value operand : operands) { + constLowerBounds.push_back(getLowerBound(operand)); + constUpperBounds.push_back(getUpperBound(operand)); + } + + // We will compute the lower and upper bounds on each of the expressions + // Then, we will check (depending on max or min) as to whether a specific + // bound is redundant by checking if its highest (in case of max) and its + // lowest (in the case of min) value is already lower than (or higher than) + // the lower bound (or upper bound in the case of min) of another bound. + SmallVector, 4> lowerBounds, upperBounds; + lowerBounds.reserve(map.getNumResults()); + upperBounds.reserve(map.getNumResults()); + for (AffineExpr e : map.getResults()) { + if (auto constExpr = e.dyn_cast()) { + lowerBounds.push_back(constExpr.getValue()); + upperBounds.push_back(constExpr.getValue()); + } else { + lowerBounds.push_back(getBoundForExpr(e, map.getNumDims(), + map.getNumSymbols(), + constLowerBounds, constUpperBounds, + /*isUpper=*/false)); + upperBounds.push_back(getBoundForExpr(e, map.getNumDims(), + map.getNumSymbols(), + constLowerBounds, constUpperBounds, + /*isUpper=*/true)); + } + } + + // Collect expressions that are not redundant. + SmallVector irredundantExprs; + for (auto exprEn : llvm::enumerate(map.getResults())) { + AffineExpr e = exprEn.value(); + unsigned i = exprEn.index(); + // Some expressions can be turned into constants. + if (lowerBounds[i] && upperBounds[i] && lowerBounds[i] == *upperBounds[i]) + e = getAffineConstantExpr(*lowerBounds[i], e.getContext()); + + // Check if the expression is redundant. + if (isMax) { + if (!upperBounds[i]) { + irredundantExprs.push_back(e); + continue; + } + // If there exists another expression such that its lower bound is greater + // than this expression's upper bound, it's redundant. + if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) { + auto otherLowerBound = en.value(); + unsigned pos = en.index(); + if (pos == i || !otherLowerBound) + return false; + if (*otherLowerBound > *upperBounds[i]) + return true; + if (*otherLowerBound < *upperBounds[i]) + return false; + // Equality case. When both expressions are considered redundant, we + // don't want to get both of them. We keep the one that appears + // first. + if (upperBounds[pos] && lowerBounds[i] && + lowerBounds[i] == upperBounds[i] && + otherLowerBound == *upperBounds[pos] && i < pos) + return false; + return true; + })) + irredundantExprs.push_back(e); + } else { + if (!lowerBounds[i]) { + irredundantExprs.push_back(e); + continue; + } + // Likewise for the `min` case. Use the complement of the condition above. + if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) { + auto otherUpperBound = en.value(); + unsigned pos = en.index(); + if (pos == i || !otherUpperBound) + return false; + if (*otherUpperBound < *lowerBounds[i]) + return true; + if (*otherUpperBound > *lowerBounds[i]) + return false; + if (lowerBounds[pos] && upperBounds[i] && + lowerBounds[i] == upperBounds[i] && + otherUpperBound == lowerBounds[pos] && i < pos) + return false; + return true; + })) + irredundantExprs.push_back(e); + } + } + + // Create the map without the redundant expressions. + map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs, + map.getContext()); +} + /// Simplify the map while exploiting information on the values in `operands`. // Use "unused attribute" marker to silence warning stemming from the inability // to see through the template expansion. @@ -2210,6 +2322,8 @@ composeAffineMapAndOperands(&lbMap, &lbOperands); canonicalizeMapAndOperands(&lbMap, &lbOperands); + simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true); + simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false); lbMap = removeDuplicateExprs(lbMap); composeAffineMapAndOperands(&ubMap, &ubOperands); diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1264,6 +1264,127 @@ // ----- +#map0 = affine_map<(d0) -> (32, d0 * -32 + 32)> +#map1 = affine_map<(d0) -> (32, d0 * -32 + 64)> +#map3 = affine_map<(d0) -> (16, d0 * -16 + 32)> + +// CHECK-DAG: #[[$SIMPLE_MAP:.*]] = affine_map<()[s0] -> (3, s0)> +// CHECK-DAG: #[[$SIMPLE_MAP_MAX:.*]] = affine_map<()[s0] -> (5, s0)> +// CHECK-DAG: #[[$SIMPLIFIED_MAP:.*]] = affine_map<(d0, d1) -> (-9, d0 * 4 - d1 * 4)> +// CHECK-DAG: #[[$FLOORDIV:.*]] = affine_map<(d0) -> (d0 floordiv 2)> + +// CHECK-LABEL: func @simplify_min_max_bounds +func.func @simplify_min_max_bounds(%M: index) { + + // CHECK-NEXT: affine.for %{{.*}} = 0 to min #[[$SIMPLE_MAP]] + affine.for %i = 0 to min affine_map<(d0) -> (3, 5, d0)>(%M) { + "test.foo"() : () -> () + } + + // CHECK: affine.for %{{.*}} = 0 to min #[[$SIMPLE_MAP]] + affine.for %i = 0 to min affine_map<(d0) -> (3, 3, d0)>(%M) { + "test.foo"() : () -> () + } + + // CHECK: affine.for %{{.*}} = max #[[$SIMPLE_MAP_MAX]] + affine.for %i = max affine_map<(d0) -> (3, 5, d0)>(%M) to 10 { + "test.foo"() : () -> () + } + + // CHECK: affine.for %{{.*}} = max #[[$SIMPLE_MAP_MAX]] + affine.for %i = max affine_map<(d0) -> (5, 5, d0)>(%M) to 10 { + "test.foo"() : () -> () + } + + affine.for %arg5 = 0 to 1 { + affine.for %arg6 = 0 to 2 { + affine.for %arg8 = 0 to min #map0(%arg5) step 16 { + affine.for %arg9 = 0 to min #map1(%arg6) step 16 { + affine.for %arg10 = 0 to 2 { + affine.for %arg12 = 0 to min #map3(%arg10) step 16 { + "test.foo"() : () -> () + } + } + } + } + } + } + // CHECK: affine.for + // CHECK-NEXT: affine.for + // CHECK-NEXT: affine.for %{{.*}} = 0 to 32 step 16 + // CHECK-NEXT: affine.for %{{.*}} = 0 to 32 step 16 + // CHECK-NEXT: affine.for %{{.*}} = 0 to 2 + // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 step 16 + + + // Lower bound max. + // CHECK: affine.for + affine.for %i = 0 to 2 { + // CHECK: affine.for %{{.*}} = 5 to + affine.for %j = max affine_map<(d0) -> (5, 4 * d0)> (%i) to affine_map<(d0) -> (4 * d0 + 3)>(%i) { + "test.foo"() : () -> () + } + } + + // Expressions with multiple operands. + // CHECK: affine.for + affine.for %i = 0 to 2 { + // CHECK: affine.for + affine.for %j = 0 to 4 { + // The first upper bound expression will not be lower than -9. So, it's redundant. + // CHECK-NEXT: affine.for %{{.*}} = -10 to -9 + affine.for %k = -10 to min affine_map<(d0, d1) -> (4 * d0 - 3 * d1, -9)>(%i, %j) { + "test.foo"() : () -> () + } + } + } + + // One expression is redundant but not the others. + // CHECK: affine.for + affine.for %i = 0 to 2 { + // CHECK: affine.for + affine.for %j = 0 to 4 { + // The first upper bound expression will not be lower than -9. So, it's redundant. + // CHECK-NEXT: affine.for %{{.*}} = -10 to min #[[$SIMPLIFIED_MAP]] + affine.for %k = -10 to min affine_map<(d0, d1) -> (4 * d0 - 3 * d1, -9, 4 * d0 - 4 * d1)>(%i, %j) { + "test.foo"() : () -> () + } + } + } + + // CHECK: affine.for %{{.*}} = 0 to 1 + affine.for %i = 0 to 2 { + affine.for %j = max affine_map<(d0) -> (d0 floordiv 2, 0)>(%i) to 1 { + "test.foo"() : () -> () + } + } + + // The constant bound is redundant here. + // CHECK: affine.for %{{.*}} = #[[$FLOORDIV]](%{{.*}} to 10 + affine.for %i = 0 to 8 { + affine.for %j = max affine_map<(d0) -> (d0 floordiv 2, 0)>(%i) to 10 { + "test.foo"() : () -> () + } + } + + // Negative test cases. + // CHECK: affine.for + affine.for %i = 0 to 4 { + // CHECK-NEXT: affine.for %{{.*}} = 0 to min + affine.for %j = 0 to min affine_map<(d0) -> (2 * d0, 2)>(%i) { + "test.foo"() : () -> () + } + // CHECK: affine.for %{{.*}} = 0 to min {{.*}}(%{{.*}})[%{{.*}}] + affine.for %j = 0 to min affine_map<(d0)[s0] -> (d0, s0)>(%i)[%M] { + "test.foo"() : () -> () + } + } + + return +} + +// ----- + // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))> // CHECK-BOTTOM-UP: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))> // CHECK-LABEL: func @regression_do_not_perform_invalid_replacements