diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -25,7 +25,7 @@ class AffineApplyOp; class AffineBound; class AffineValueMap; -class IRRewriter; +class RewriterBase; /// TODO: These should be renamed if they are on the mlir namespace. /// Ideally, they should go in a mlir::affine:: namespace. @@ -381,13 +381,18 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e, ValueRange values); +/// Returns an AffineMinOp obtained by composing `map` and `operands` with +/// AffineApplyOps supplying those operands. +Value makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, + ValueRange operands); + /// Returns the values obtained by applying `map` to the list of values. SmallVector applyMapToValues(OpBuilder &b, Location loc, AffineMap map, ValueRange values); /// Returns the values obtained by applying `map` to the list of values, which /// may be known constants. -SmallVector applyMapToValues(IRRewriter &b, Location loc, +SmallVector applyMapToValues(RewriterBase &b, Location loc, AffineMap map, ArrayRef values); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -588,7 +588,7 @@ /// AffineSymbolExpr@[pos - dims.size()] is replaced. /// Mutate `map`,`dims` and `syms` in place as follows: /// 1. `dims` and `syms` are only appended to. -/// 2. `map` dim and symbols are gradually shifted to higer positions. +/// 2. `map` dim and symbols are gradually shifted to higher positions. /// 3. Old `dim` and `sym` entries are replaced by nullptr /// This avoids the need for any bookkeeping. static LogicalResult replaceDimOrSym(AffineMap *map, @@ -722,6 +722,75 @@ values); } +Value mlir::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, + ValueRange operands) { + // 1. Compose maps and operands for individual expressions. Store results as + // single-expression maps in order to keep track of the numbers of dimensions + // and symbols used in each. + SmallVector, 2> exprOperands; + SmallVector exprs; + exprOperands.reserve(map.getNumResults()); + exprs.reserve(map.getNumResults()); + for (unsigned i : llvm::seq(0, map.getNumResults())) { + AffineMap &exprMap = exprs.emplace_back(map.getSubMap({i})); + fullyComposeAffineMapAndOperands( + &exprMap, &exprOperands.emplace_back(llvm::to_vector(operands))); + assert(exprMap && exprMap.getNumResults() == 1 && "could not compose maps"); + } + + // 2. Collect unique dimension and symbol operands used across individual + // expressions. We don't expect these sets to overlap because composition + // implies canonicalization that would have promoted the same operand to be a + // symbol based on its affine category (that only depends on where the operand + // value is defined) identically across all expressions. + SetVector normalizedDimOperandSet; + SetVector normalizedSymbolOperandSet; + for (unsigned i : llvm::seq(0, map.getNumResults())) { + auto *it = exprOperands[i].begin() + exprs[i].getNumDims(); + normalizedDimOperandSet.insert(exprOperands[i].begin(), it); + normalizedSymbolOperandSet.insert(it, exprOperands[i].end()); + } + + // 3. Create a single list of unique operands containing dimensions followed + // by symbols. + unsigned numDims = normalizedDimOperandSet.size(); + unsigned numSymbols = normalizedSymbolOperandSet.size(); + SmallVector normalizedOperands = + llvm::to_vector(normalizedDimOperandSet); + llvm::append_range(normalizedOperands, normalizedSymbolOperandSet); + + // 4. For each composed expression, remap its dimension and symbols to other + // dimensions and symbols that correspond to the positions of the operands + // specific to this expression in the normalized operand list. + SmallVector normalizedExprs; + normalizedExprs.reserve(exprs.size()); + for (unsigned i : llvm::seq(0, exprs.size())) { + SmallVector dimReplacements; + SmallVector symReplacements; + for (const auto &en : llvm::enumerate(exprOperands[i])) { + Value operand = en.value(); + unsigned position = std::distance( + normalizedOperands.begin(), llvm::find(normalizedOperands, operand)); + AffineExpr replacement = position < numDims + ? b.getAffineDimExpr(position) + : b.getAffineSymbolExpr(position - numDims); + if (en.index() < exprs[i].getNumDims()) + dimReplacements.push_back(replacement); + else + symReplacements.push_back(replacement); + } + normalizedExprs.push_back(exprs[i].getResult(0).replaceDimsAndSymbols( + dimReplacements, symReplacements)); + } + + // 5. Construct an affine.min with normalized expressions and operands. Note + // that it may even fold to a constant thanks to normalization. + auto normalizedMap = + AffineMap::get(numDims, numSymbols, normalizedExprs, b.getContext()); + return b.createOrFold(loc, b.getIndexType(), normalizedMap, + normalizedOperands); +} + /// Fully compose map with operands and canonicalize the result. /// Return the `createOrFold`'ed AffineApply op. static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, @@ -748,13 +817,10 @@ return res; } -SmallVector -mlir::applyMapToValues(IRRewriter &b, Location loc, AffineMap map, - ArrayRef values) { - // Materialize constants and keep track of produced operations so we can clean - // them up later. - SmallVector constants; - SmallVector actualValues; +static void materializeConstants(OpBuilder &b, Location loc, + ArrayRef values, + SmallVectorImpl &constants, + SmallVectorImpl &actualValues) { actualValues.reserve(values.size()); auto *dialect = b.getContext()->getLoadedDialect(); for (OpFoldResult ofr : values) { @@ -766,6 +832,50 @@ b.getIndexType(), loc)); actualValues.push_back(constants.back()->getResult(0)); } +} + +template +static std::enable_if_t(), + OpFoldResult> +createOrFold(RewriterBase &b, Location loc, ValueRange operands, + Args &&...leadingOperands) { + // Identify the constant operands and extract their values as attributes. + // Note that we cannot use the original values directly because the list of + // operands may have changed due to canonicalization and composition. + SmallVector constantOperands; + constantOperands.reserve(operands.size()); + for (Value operand : operands) { + IntegerAttr attr; + if (matchPattern(operand, m_Constant(&attr))) + constantOperands.push_back(attr); + else + constantOperands.push_back(nullptr); + } + + // Create the operation and immediately attempt to fold it. On success, + // delete the operation and prepare the (unmaterialized) value for being + // returned. On failure, return the operation result value. + // TODO: arguably, the main folder (createOrFold) API should support this use + // case instead of indiscriminately materializing constants. + OpTy op = + b.create(loc, std::forward(leadingOperands)..., operands); + SmallVector foldResults; + if (succeeded(op->fold(constantOperands, foldResults)) && + !foldResults.empty()) { + b.eraseOp(op); + return foldResults.front(); + } + return op->getResult(0); +} + +SmallVector +mlir::applyMapToValues(RewriterBase &b, Location loc, AffineMap map, + ArrayRef values) { + // Materialize constants and keep track of produced operations so we can clean + // them up later. + SmallVector constants; + SmallVector actualValues; + materializeConstants(b, loc, values, constants, actualValues); // Compose, fold and construct maps for each result independently because they // may simplify more effectively. @@ -777,35 +887,9 @@ SmallVector operands = actualValues; fullyComposeAffineMapAndOperands(&submap, &operands); canonicalizeMapAndOperands(&submap, &operands); - - // Identify the constant operands and extract their values as attributes. - // Note that we cannot use the original values directly because the list of - // operands may have changed due to canonicalization and composition. - SmallVector constantOperands; - constantOperands.reserve(operands.size()); - for (Value operand : operands) { - IntegerAttr attr; - if (matchPattern(operand, m_Constant(&attr))) - constantOperands.push_back(attr); - else - constantOperands.push_back(nullptr); - } - - // Create an apply operation and immediately attempt to fold it. On sucess, - // delete the operation and prepare the (unmaterialized) value for being - // returned. On failure, return the function result. - // TODO: arguably, the main folder (createOrFold) API should support this - // use case instead of indiscriminately materializing constants. - auto apply = b.create(loc, submap, operands); - SmallVector foldResult; - if (succeeded(apply->fold(constantOperands, foldResult))) { - assert(foldResult.size() == 1 && "expected single-result map"); - b.eraseOp(apply); - results.push_back(foldResult.front()); - } else { - results.push_back(apply.getResult()); + results.push_back(createOrFold(b, loc, operands, submap)); + if (!results.back().is()) foldedAll = false; - } } // If the entire map could be folded, remove the constants that were used in diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -294,11 +294,11 @@ return emitSilenceableError() << "could not generate tile size computation"; } + AffineExpr s0 = builder.getAffineSymbolExpr(0); + AffineExpr s1 = builder.getAffineSymbolExpr(1); Operation *splitPoint = - builder - .createOrFold(target.getLoc(), spec->lowTileSize, - spec->lowTripCount) - .getDefiningOp(); + makeComposedAffineApply(builder, target.getLoc(), s0 * s1, + {spec->lowTileSize, spec->lowTripCount}); Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); Operation *highTileSize = spec->highTileSize.getDefiningOp(); assert(lowTileSize && highTileSize && splitPoint && diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "llvm/ADT/STLExtras.h" @@ -42,8 +43,15 @@ continue; } - SmallVector sizes = - applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace); + SmallVector sizes; + sizes.reserve(indexing.getNumResults()); + for (AffineExpr dimIndexing : indexing.getResults()) { + AffineApplyOp sizeOp = makeComposedAffineApply( + builder, op.getLoc(), dimIndexing, splitIterationSpace); + assert(sizeOp->getNumResults() == 1 && + "expected single-result affine map"); + sizes.push_back(sizeOp.getResult()); + } SmallVector offsets(type.getRank(), builder.getIndexAttr(0)); SmallVector strides(type.getRank(), builder.getIndexAttr(1)); @@ -107,8 +115,8 @@ // `op`. Adjust the size if necessary to prevent overflows. Insert the partial // results back. Value splitPointValue = materializeOpFoldResult(builder, splitPoint); - splitPointValue = builder.createOrFold( - builder.getIndexType(), + splitPointValue = makeComposedAffineMin( + builder, builder.getLoc(), AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()), ValueRange({splitPointValue, iterationSpaceShapes[dimension]})); SmallVector splitIterationSpace = diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir @@ -60,15 +60,16 @@ // CHECK-LABEL: @one_d_static_overflow // CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { - // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> - // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> - // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic - // CHECK: ins(%[[IN_SLICE_LOW]] - // CHECK: outs(%[[OUT_SLICE_LOW]] + // Due to overflow, the first part of the split computes everything and the + // insert/extract slices are folded away by the canonicalizer. + // CHECK: %[[RES_PARTIAL:.+]] = linalg.generic + // CHECK: ins(%[[IN]] + // CHECK: outs(%[[OUT]] // CHECK: linalg.index 0 // CHECK: func.call @elem - // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1] // + // The second part operates on zero-sized slices that are not currently + // folded away. // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic @@ -118,13 +119,13 @@ func.func private @get_size() -> index -// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<(d0, d1) -> (d0, 100)> -// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)> +// CHECK-DAG: #[[$MAP_MIN_100:.+]] = affine_map<()[s0] -> (100, s0)> +// CHECK-DAG: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)> // CHECK-LABEL: @dynamic func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { // CHECK: %[[SPLIT:.+]] = call @get_size - // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]](%[[SPLIT]] + // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]]] // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic @@ -133,14 +134,13 @@ // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1] // // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] + // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_1]]] [1] : tensor<100xf32> to tensor // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] - // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor - // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] - // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic // CHECK: ins(%[[IN_SLICE_HIGH]] // CHECK: outs(%[[OUT_SLICE_HIGH]] - // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] + // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] %0 = func.call @get_size() : () -> index %1 = linalg.generic { indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], @@ -148,7 +148,8 @@ } ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { ^bb0(%3: f32, %4: f32): - linalg.yield %3 : f32 + %5 = arith.addf %3, %4 : f32 + linalg.yield %5 : f32 } -> tensor<100xf32> return %1 : tensor<100xf32> }