diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -174,6 +174,15 @@ return constraints.addBound(type, pos, alignedMap); } +/// Add `val` to each result of `map`. +static AffineMap addConstToResults(AffineMap map, int64_t val) { + SmallVector newResults; + for (AffineExpr r : map.getResults()) + newResults.push_back(r + val); + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, + map.getContext()); +} + /// This function tries to canonicalize min/max operations by proving that their /// value is bounded by the same lower and upper bound. In that case, the /// operation can be folded away. @@ -228,18 +237,23 @@ // isMin: op <= expr_i, !isMin: op >= expr_i auto boundType = isMin ? FlatAffineConstraints::UB : FlatAffineConstraints::LB; - if (failed(alignAndAddBound(constraints, boundType, dimOp, map, operands))) + // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.) + AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map; + if (failed( + alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands))) return failure(); // Try to compute a lower/upper bound for op, expressed in terms of the other // `dims` and extra symbols. SmallVector opLb(1), opUb(1); constraints.getSliceBounds(dimOp, 1, rewriter.getContext(), &opLb, &opUb); - AffineMap boundMap = isMin ? opUb[0] : opLb[0]; + AffineMap sliceBound = isMin ? opUb[0] : opLb[0]; // TODO: `getSliceBounds` may return multiple bounds at the moment. This is // a TODO of `getSliceBounds` and not handled here. - if (!boundMap || boundMap.getNumResults() != 1) + if (!sliceBound || sliceBound.getNumResults() != 1) return failure(); // No or multiple bounds found. + // Recover the inclusive UB in the case of an `affine.min`. + AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound; // Add an equality: Set dimOpBound to computed bound. // Add back dimension for op. (Was removed by `getSliceBounds`.)