diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -144,6 +144,16 @@ LogicalResult constantFold(ArrayRef operandConstants, SmallVectorImpl &results) const; + /// Propagates the constant operands into this affine map. Operands are + /// allowed to be null, at which point they are treated as non-constant. This + /// does not change the number of symbols and dimensions. Returns a new map, + /// which may be equal to the old map if no folding happened. If `results` is + /// provided and if all expressions in the map were folded to constants, + /// `results` will contain the values of these constants. + AffineMap + partialConstantFold(ArrayRef operandConstants, + SmallVectorImpl *results = nullptr) const; + /// Returns the AffineMap resulting from composing `this` with `map`. /// The resulting AffineMap has as many AffineDimExpr as `map` and as many /// AffineSymbolExpr as the concatenation of `this` and `map` (in which case diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2089,6 +2089,38 @@ parser.addTypeToList(indexType, result.types)); } +/// Fold an affine min or max operation with the given operands. The operand +/// list may contain nulls, which are interpreted as the operand not being a +/// constant. +template +OpFoldResult foldMinMaxOp(T op, ArrayRef operands) { + static_assert(llvm::is_one_of::value, + "expected affine min or max op"); + + // Fold the affine map. + // TODO(andydavis, ntv) Fold more cases: + // min(some_affine, some_affine + constant, ...), etc. + SmallVector results; + auto foldedMap = op.map().partialConstantFold(operands, &results); + + // If some of the map results are not constant, try changing the map in-place. + if (results.empty()) { + // If the map is the same, report that folding did not happen. + if (foldedMap == op.map()) + return {}; + op.setAttr("map", AffineMapAttr::get(foldedMap)); + return op.getResult(); + } + + // Otherwise, completely fold the op into a constant. + auto resultIt = std::is_same::value + ? std::min_element(results.begin(), results.end()) + : std::max_element(results.begin(), results.end()); + if (resultIt == results.end()) + return {}; + return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt); +} + //===----------------------------------------------------------------------===// // AffineMinOp //===----------------------------------------------------------------------===// @@ -2097,26 +2129,7 @@ // OpFoldResult AffineMinOp::fold(ArrayRef operands) { - // Fold the affine map. - // TODO(andydavis, ntv) Fold more cases: partial static information, - // min(some_affine, some_affine + constant, ...). - SmallVector results; - if (failed(map().constantFold(operands, results))) - return {}; - - // Compute and return min of folded map results. - int64_t min = std::numeric_limits::max(); - int minIndex = -1; - for (unsigned i = 0, e = results.size(); i < e; ++i) { - auto intAttr = results[i].cast(); - if (intAttr.getInt() < min) { - min = intAttr.getInt(); - minIndex = i; - } - } - if (minIndex < 0) - return {}; - return results[minIndex]; + return foldMinMaxOp(*this, operands); } void AffineMinOp::getCanonicalizationPatterns( @@ -2132,26 +2145,7 @@ // OpFoldResult AffineMaxOp::fold(ArrayRef operands) { - // Fold the affine map. - // TODO(andydavis, ntv, ouhang) Fold more cases: partial static information, - // max(some_affine, some_affine + constant, ...). - SmallVector results; - if (failed(map().constantFold(operands, results))) - return {}; - - // Compute and return max of folded map results. - int64_t max = std::numeric_limits::min(); - int maxIndex = -1; - for (unsigned i = 0, e = results.size(); i < e; ++i) { - auto intAttr = results[i].cast(); - if (intAttr.getInt() > max) { - max = intAttr.getInt(); - maxIndex = i; - } - } - if (maxIndex < 0) - return {}; - return results[maxIndex]; + return foldMinMaxOp(*this, operands); } void AffineMaxOp::getCanonicalizationPatterns( diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -234,22 +234,51 @@ LogicalResult AffineMap::constantFold(ArrayRef operandConstants, SmallVectorImpl &results) const { + // Attempt partial folding. + SmallVector integers; + partialConstantFold(operandConstants, &integers); + + // If all expressions folded to a constant, populate results with attributes + // containing those constants. + if (integers.empty()) + return failure(); + + auto range = llvm::map_range(integers, [this](int64_t i) { + return IntegerAttr::get(IndexType::get(getContext()), i); + }); + results.append(range.begin(), range.end()); + return success(); +} + +AffineMap +AffineMap::partialConstantFold(ArrayRef operandConstants, + SmallVectorImpl *results) const { assert(getNumInputs() == operandConstants.size()); // Fold each of the result expressions. AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); - // Constant fold each AffineExpr in AffineMap and add to 'results'. + SmallVector exprs; + exprs.reserve(getNumResults()); + for (auto expr : getResults()) { auto folded = exprFolder.constantFold(expr); - // If we didn't fold to a constant, then folding fails. - if (!folded) - return failure(); - - results.push_back(folded); + // If did not fold to a constant, keep the original expression, and clear + // the integer results vector. + if (folded) { + exprs.push_back( + getAffineConstantExpr(folded.getInt(), folded.getContext())); + if (results) + results->push_back(folded.getInt()); + } else { + exprs.push_back(expr); + if (results) { + results->clear(); + results = nullptr; + } + } } - assert(results.size() == getNumResults() && - "constant folding produced the wrong number of results"); - return success(); + + return get(getNumDims(), getNumSymbols(), exprs, getContext()); } /// Walk all of the AffineExpr's in this mapping. Each node in an expression diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -13,10 +13,12 @@ // TILE-002-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // TILE-234-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// TILE-2-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> -// TILE-02-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> -// TILE-002-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> -// TILE-234-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// TILE-2-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> +// TILE-02-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> +// TILE-002-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> +// TILE-234-DAG: #[[bound_map_2:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> +// TILE-234-DAG: #[[bound_map_3:.*]] = affine_map<(d0, d1, d2) -> (3, d1 - d2)> +// TILE-234-DAG: #[[bound_map_4:.*]] = affine_map<(d0, d1, d2) -> (4, d1 - d2)> // TILE-2-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // TILE-02-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> @@ -97,19 +99,19 @@ // TILE-234: loop.for %[[J:.*]] = %{{.*}}{{.*}} to %[[ubN]] step %{{.*}} { // TILE-234: loop.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[localK:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[localK]], %[[K]]) +// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[localK]], %[[K]]) // TILE-234: %[[sAik:.*]] = subview %{{.*}}[%[[I]], %[[K]]] [%[[szM]], %[[szK]]] [%[[C1]], %[[C1]]] : memref to memref // TILE-234: %[[localK:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[localK]], %[[K]]) +// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[localK]], %[[K]]) // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]]) +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) // TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%[[szK]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]]) +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) // TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref // // TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref @@ -230,15 +232,15 @@ // TILE-234: loop.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-234: loop.for %[[J:.*]] = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]]) +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) // TILE-234: %[[sAij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref // TILE-234: %[[localN:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]]) +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) // TILE-234: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [%[[C1]]] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref // // TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref, memref, memref @@ -274,10 +276,10 @@ // TILE-234: %[[ubK:.*]] = dim %{{.*}}, 0 : memref // TILE-234: loop.for %[[I:.*]] = %{{.*}} to %[[ubK]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[sAi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref // TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref, memref, memref diff --git a/mlir/test/Dialect/Linalg/tile_conv.mlir b/mlir/test/Dialect/Linalg/tile_conv.mlir --- a/mlir/test/Dialect/Linalg/tile_conv.mlir +++ b/mlir/test/Dialect/Linalg/tile_conv.mlir @@ -4,7 +4,7 @@ // TILE-23004-DAG: #[[S0x10p90:.*]] = affine_map<()[s0] -> (s0 * 10 + 90)> // TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> // TILE-23004-DAG: #[[strided4D_dynamic:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> -// TILE-23004-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// TILE-23004-DAG: #[[bound_map_4:.*]] = affine_map<(d0, d1, d2) -> (4, d1 - d2)> func @conv(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], strides = [30, 40]} : memref, memref, memref @@ -27,7 +27,7 @@ // TILE-23004: %[[Z0:.*]] = dim %{{.*}}, 0 : memref // TILE-23004: %[[Z1:.*]] = dim %{{.*}}, 1 : memref // TILE-23004: %[[Z2:.*]] = dim %{{.*}}, 2 : memref -// TILE-23004: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[Z2]], %[[ivK]]) +// TILE-23004: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[Z2]], %[[ivK]]) // TILE-23004: %[[K:.*]] = dim %{{.*}}, 3 : memref // TILE-23004: %[[FilterView:.*]] = subview %{{.*}}[%[[C0]], %[[C0]], %[[ivK]], %[[C0]]] [%[[Z0]], %[[Z1]], %[[szK]], %[[K]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref to memref // @@ -35,7 +35,7 @@ // T__ILE-23004: %[[I1pStep:.*]] = affine.apply #[[S0x10p90]]()[%[[I1]]] // TILE-23004: %[[SZ2:.*]] = dim %{{.*}}, 2 : memref // TILE-23004: %[[dim3:.*]] = dim %{{.*}}, 3 -// TILE-23004: %[[sz3:.*]] = affine.min #[[bound_map]](%[[C4]], %[[dim3]], %[[ivK]] +// TILE-23004: %[[sz3:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[dim3]], %[[ivK]] // TILE-23004: %[[InputView:.*]] = subview %{{.*}}[%[[ivI]], %[[J1]], %[[C0]], %[[ivK]]] [%{{.*}}, %{{.*}}, %[[SZ2]], %[[sz3]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref to memref // // TILE-23004: %[[X0:.*]] = dim %{{.*}}, 2 : memref diff --git a/mlir/test/Dialect/Linalg/tile_conv_padding.mlir b/mlir/test/Dialect/Linalg/tile_conv_padding.mlir --- a/mlir/test/Dialect/Linalg/tile_conv_padding.mlir +++ b/mlir/test/Dialect/Linalg/tile_conv_padding.mlir @@ -3,7 +3,7 @@ // TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> // TILE-20000-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> -// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> // TILE-20000-DAG: #[[subviewstride:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> func @conv_padding(%arg0: memref, %arg1: memref, %arg2: memref) {