diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -144,6 +144,12 @@ LogicalResult constantFold(ArrayRef operandConstants, SmallVectorImpl &results) const; + /// Folds the constant operands into this affine map. Returns a new map, which + /// may be equal to the old map if no folding happened. Operands are allowed + /// to be null, at which point they are treated as non-constant. This does not + /// change the number of symbols and dimensions. + AffineMap partialConstantFold(ArrayRef operandConstants) const; + /// Returns the AffineMap resulting from composing `this` with `map`. /// The resulting AffineMap has as many AffineDimExpr as `map` and as many /// AffineSymbolExpr as the concatenation of `this` and `map` (in which case diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2098,27 +2098,46 @@ // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) // +/// Populate "results" with constant results of "map". If some of the "map" +/// results are not constant affine expressions, resturn failure. +static LogicalResult extractIntegerResults(AffineMap map, + SmallVectorImpl &results) { + assert(results.empty()); + + results.reserve(map.getNumResults()); + for (AffineExpr expr : map.getResults()) { + if (auto cstExpr = expr.dyn_cast()) { + results.push_back(cstExpr.getValue()); + continue; + } + + results.clear(); + return failure(); + } + return success(); +} + OpFoldResult AffineMinOp::fold(ArrayRef operands) { // Fold the affine map. // TODO(andydavis, ntv) Fold more cases: partial static information, // min(some_affine, some_affine + constant, ...). - SmallVector results; - if (failed(map().constantFold(operands, results))) - return {}; - - // Compute and return min of folded map results. - int64_t min = std::numeric_limits::max(); - int minIndex = -1; - for (unsigned i = 0, e = results.size(); i < e; ++i) { - auto intAttr = results[i].cast(); - if (intAttr.getInt() < min) { - min = intAttr.getInt(); - minIndex = i; - } + auto foldedMap = map().partialConstantFold(operands); + + // If some of the map results are not constant, try changing the map in-place. + SmallVector results; + if (failed(extractIntegerResults(foldedMap, results))) { + // If the map is the same, report that folding did not happen. + if (foldedMap == map()) + return {}; + setAttr("map", AffineMapAttr::get(foldedMap)); + return getResult(); } - if (minIndex < 0) + + // Otherwise, completely fold the "min" into a constant. + auto minIt = std::min_element(results.begin(), results.end()); + if (minIt == results.end()) return {}; - return results[minIndex]; + return IntegerAttr::get(IndexType::get(getContext()), *minIt); } void AffineMinOp::getCanonicalizationPatterns( @@ -2137,23 +2156,23 @@ // Fold the affine map. // TODO(andydavis, ntv, ouhang) Fold more cases: partial static information, // max(some_affine, some_affine + constant, ...). - SmallVector results; - if (failed(map().constantFold(operands, results))) - return {}; + auto foldedMap = map().partialConstantFold(operands); + + // If some of the map results are not constant, try changing the map in-place. + SmallVector results; + if (failed(extractIntegerResults(foldedMap, results))) { + // If the map is the same, report that folding did not happen. + if (foldedMap == map()) + return {}; + setAttr("map", AffineMapAttr::get(foldedMap)); + return getResult(); + } // Compute and return max of folded map results. - int64_t max = std::numeric_limits::min(); - int maxIndex = -1; - for (unsigned i = 0, e = results.size(); i < e; ++i) { - auto intAttr = results[i].cast(); - if (intAttr.getInt() > max) { - max = intAttr.getInt(); - maxIndex = i; - } - } - if (maxIndex < 0) + auto maxIt = std::max_element(results.begin(), results.end()); + if (maxIt == results.end()) return {}; - return results[maxIndex]; + return IntegerAttr::get(IndexType::get(getContext()), *maxIt); } void AffineMaxOp::getCanonicalizationPatterns( diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -234,22 +234,44 @@ LogicalResult AffineMap::constantFold(ArrayRef operandConstants, SmallVectorImpl &results) const { + // Attempt partial folding. + AffineMap newMap = partialConstantFold(operandConstants); + + // If all expressions folded to a constant, populate results with attributes + // containing those constants. + results.reserve(newMap.getNumResults()); + for (AffineExpr expr : newMap.getResults()) { + auto cstExpr = expr.dyn_cast(); + if (!cstExpr) { + results.clear(); + return failure(); + } + + results.push_back( + IntegerAttr::get(IndexType::get(getContext()), cstExpr.getValue())); + } + + return success(); +} + +AffineMap +AffineMap::partialConstantFold(ArrayRef operandConstants) const { assert(getNumInputs() == operandConstants.size()); // Fold each of the result expressions. AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); - // Constant fold each AffineExpr in AffineMap and add to 'results'. + SmallVector results; + results.reserve(getNumResults()); + for (auto expr : getResults()) { + // If did not fold to a constant, keep the original expression. auto folded = exprFolder.constantFold(expr); - // If we didn't fold to a constant, then folding fails. - if (!folded) - return failure(); - - results.push_back(folded); + results.push_back( + folded ? getAffineConstantExpr(folded.getInt(), folded.getContext()) + : expr); } - assert(results.size() == getNumResults() && - "constant folding produced the wrong number of results"); - return success(); + + return get(getNumDims(), getNumSymbols(), results, getContext()); } /// Walk all of the AffineExpr's in this mapping. Each node in an expression diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -13,10 +13,12 @@ // TILE-002-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // TILE-234-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// TILE-2-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> -// TILE-02-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> -// TILE-002-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> -// TILE-234-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// TILE-2-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> +// TILE-02-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> +// TILE-002-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> +// TILE-234-DAG: #[[bound_map_2:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> +// TILE-234-DAG: #[[bound_map_3:.*]] = affine_map<(d0, d1, d2) -> (3, d1 - d2)> +// TILE-234-DAG: #[[bound_map_4:.*]] = affine_map<(d0, d1, d2) -> (4, d1 - d2)> // TILE-2-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // TILE-02-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> @@ -97,19 +99,19 @@ // TILE-234: loop.for %[[J:.*]] = %{{.*}}{{.*}} to %[[ubN]] step %{{.*}} { // TILE-234: loop.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[localK:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[localK]], %[[K]]) +// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[localK]], %[[K]]) // TILE-234: %[[sAik:.*]] = subview %{{.*}}[%[[I]], %[[K]]] [%[[szM]], %[[szK]]] [%[[C1]], %[[C1]]] : memref to memref // TILE-234: %[[localK:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[localK]], %[[K]]) +// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[localK]], %[[K]]) // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]]) +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) // TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%[[szK]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]]) +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) // TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref // // TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref @@ -230,15 +232,15 @@ // TILE-234: loop.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-234: loop.for %[[J:.*]] = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]]) +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) // TILE-234: %[[sAij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref // TILE-234: %[[localN:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]]) +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) // TILE-234: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [%[[C1]]] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref // // TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref, memref, memref @@ -274,10 +276,10 @@ // TILE-234: %[[ubK:.*]] = dim %{{.*}}, 0 : memref // TILE-234: loop.for %[[I:.*]] = %{{.*}} to %[[ubK]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[sAi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) // TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref // TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref, memref, memref diff --git a/mlir/test/Dialect/Linalg/tile_conv.mlir b/mlir/test/Dialect/Linalg/tile_conv.mlir --- a/mlir/test/Dialect/Linalg/tile_conv.mlir +++ b/mlir/test/Dialect/Linalg/tile_conv.mlir @@ -4,7 +4,7 @@ // TILE-23004-DAG: #[[S0x10p90:.*]] = affine_map<()[s0] -> (s0 * 10 + 90)> // TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> // TILE-23004-DAG: #[[strided4D_dynamic:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> -// TILE-23004-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// TILE-23004-DAG: #[[bound_map_4:.*]] = affine_map<(d0, d1, d2) -> (4, d1 - d2)> func @conv(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], strides = [30, 40]} : memref, memref, memref @@ -27,7 +27,7 @@ // TILE-23004: %[[Z0:.*]] = dim %{{.*}}, 0 : memref // TILE-23004: %[[Z1:.*]] = dim %{{.*}}, 1 : memref // TILE-23004: %[[Z2:.*]] = dim %{{.*}}, 2 : memref -// TILE-23004: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[Z2]], %[[ivK]]) +// TILE-23004: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[Z2]], %[[ivK]]) // TILE-23004: %[[K:.*]] = dim %{{.*}}, 3 : memref // TILE-23004: %[[FilterView:.*]] = subview %{{.*}}[%[[C0]], %[[C0]], %[[ivK]], %[[C0]]] [%[[Z0]], %[[Z1]], %[[szK]], %[[K]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref to memref // @@ -35,7 +35,7 @@ // T__ILE-23004: %[[I1pStep:.*]] = affine.apply #[[S0x10p90]]()[%[[I1]]] // TILE-23004: %[[SZ2:.*]] = dim %{{.*}}, 2 : memref // TILE-23004: %[[dim3:.*]] = dim %{{.*}}, 3 -// TILE-23004: %[[sz3:.*]] = affine.min #[[bound_map]](%[[C4]], %[[dim3]], %[[ivK]] +// TILE-23004: %[[sz3:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[dim3]], %[[ivK]] // TILE-23004: %[[InputView:.*]] = subview %{{.*}}[%[[ivI]], %[[J1]], %[[C0]], %[[ivK]]] [%{{.*}}, %{{.*}}, %[[SZ2]], %[[sz3]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref to memref // // TILE-23004: %[[X0:.*]] = dim %{{.*}}, 2 : memref diff --git a/mlir/test/Dialect/Linalg/tile_conv_padding.mlir b/mlir/test/Dialect/Linalg/tile_conv_padding.mlir --- a/mlir/test/Dialect/Linalg/tile_conv_padding.mlir +++ b/mlir/test/Dialect/Linalg/tile_conv_padding.mlir @@ -3,7 +3,7 @@ // TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> // TILE-20000-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> -// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> // TILE-20000-DAG: #[[subviewstride:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> func @conv_padding(%arg0: memref, %arg1: memref, %arg2: memref) {