Index: mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td =================================================================== --- mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -608,6 +608,91 @@ }]; } +def TileReductionUsingScfOp : Op { + let description = [{ + Indicates that the given `target` op should be transformed with the + `tileReduction` transformation with the tile size provided as attribute. + + This transformation tiles the `target` along the reduction dimensions. It + creates a tensor initialized with the identity value. Then it creates nested + loops with a parallel version of `target` op inside. The parallel op + dimensions are less or equal to the tile size passed by user. + After the loop a merge operation is created to do a final reduction with the + partial reductions. + + #### Return modes + + This 3 returned handles point to: + - the fill op used to initialize the neutral element, + - the parallel tiled op and + - the result-combining op. + + #### Example: + + ``` + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.addf %arg7, %arg9 : f32 + linalg.yield %1 : f32 + } -> tensor + return %red : tensor + ``` + + is transformed into: + + ``` + %0 = tensor.empty(%dim_1) : tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %2 = scf.for %arg2 = %c0 to %dim_0 step %c5 iter_args(%arg3 = %1) -> (tensor) { + %extracted_slice = tensor.extract_slice %1[0, 0] [%dim, 5] [1, 1] : tensor to tensor + %extracted_slice_2 = tensor.extract_slice %arg0[0, %arg2] [%dim, 5] [1, 1] : tensor to tensor + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%extracted_slice_2 : tensor) + outs(%extracted_slice : tensor) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %in, %out : f32 + linalg.yield %5 : f32 + } -> tensor + %dim_3 = tensor.dim %1, %c0 : tensor + %inserted_slice = tensor.insert_slice %4 into %arg3[0, 0] [%dim_3, 5] [1, 1] : tensor into tensor + scf.yield %inserted_slice : tensor + } + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%2 : tensor) + outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %4 = arith.addf %in, %out : f32 + linalg.yield %4 : f32 + } -> tensor + ``` + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$tile_sizes); + let results = (outs PDL_Operation:$fill_op, + PDL_Operation:$split_linalg_op, + PDL_Operation:$combining_linalg_op); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def TileOp : Op, DeclareOpInterfaceMethods]> { Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1027,6 +1027,45 @@ const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc = false); +/// Transformation information returned after reduction tiling. +struct ReductionTilingResult { + /// The tiled operation generated. + Operation *parallelTiledOp; + /// The tiled operation generated. + Operation *mergeOp; + /// Initial op + Operation *initialOp; + /// The `scf.for` operations that iterate over the tiles. + SmallVector loops; +}; + +/// Method to tile a reduction and generate a parallel op within a serial loop. +/// Each of the partial reductions are calculated in parallel. Then after the +/// loop all the partial reduction are mered into a final reduction. +/// For example, for the following sequence +/// +/// ```mlir +/// %0 = linalg.generic %in ["parallel", "reduction"] +/// : tensor<7x9xf32> -> tensor<7xf32> +/// ``` +/// into: +/// +/// ```mlir +/// %0 = linalg.fill ... : tensor<7x4xf32> +/// %1 = scf.for ... iter_args(%arg0 = %0) +/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32> +/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> +/// %4 = linalg.generic %2, %3 ["parallel", "parallel"] +/// : tensor<7x?xf32> -> tensor<7x?xf32> +/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32> +/// } +/// %6 = linalg.generic %1 ["parallel", "reduction"] +/// : tensor<7x4xf32> -> tensor<7xf32> +/// ``` +FailureOr +tileReductionUsingScf(PatternRewriter &b, LinalgOp op, + ArrayRef tileSize); + } // namespace linalg } // namespace mlir Index: mlir/include/mlir/Dialect/Linalg/Utils/Utils.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -137,6 +137,10 @@ Optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +/// Return the identity numeric value associated to the give op. Return +/// llvm::None if there is no known neutral element. +Optional getNeutralElement(Operation *op); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/SCF/Utils/Utils.h =================================================================== --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -176,6 +176,75 @@ void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, scf::ForOp root); +/// Generate an empty loop nest that represents the tiled loop nest shell. +/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. +/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. +/// - In `offsets` and `sizes` return the multi-dimensional offset and size of +/// the +/// tile processed within the inner most loop. +SmallVector generateTileLoopNest(OpBuilder &builder, Location loc, + ArrayRef loopRanges, + ArrayRef tileSizeVals, + SmallVector &offsets, + SmallVector &sizes); + +/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, +/// construct the destructive update pattern that inserts the yielded +/// value into a destination tensor provided by `initValue` at offset +/// `tileOffsets` and size `tileSizes`. For example, +/// +/// ```mlir +/// scf.for %iv0 = ... { +/// %0 = tiled_op +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %arg +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. +FailureOr> +yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, + ValueRange yieldedValues, + ArrayRef> tileOffsetsList, + ArrayRef> tileSizesList, + MutableArrayRef loops); + +/// If the tiled operation is destination passing style, update the +/// slice of the destination used (which refers to the untiled destination) +/// to use the corresponding region argument of the innermost loop. +/// +/// ```mlir +/// %0 = +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %0 +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %arg +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +void updateDestinationOperandsForTiledOp(OpBuilder &builder, + ValueRange tiledOpDestinationValues, + ValueRange bbArgsList); + } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ Index: mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1076,6 +1076,32 @@ return DiagnosedSilenceableFailure(success()); } +//===----------------------------------------------------------------------===// +// SplitReductionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( + linalg::LinalgOp target, SmallVectorImpl &results, + transform::TransformState &state) { + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); + SmallVector sizes; + for (int64_t size : tileSizes) { + sizes.push_back(rewriter.getIndexAttr(size)); + } + + FailureOr result = + linalg::tileReductionUsingScf(rewriter, target, sizes); + + if (failed(result)) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + results.push_back(result->initialOp); + results.push_back(result->parallelTiledOp); + results.push_back(result->mergeOp); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -26,38 +27,6 @@ using namespace mlir; using namespace mlir::linalg; -/// Return the identity numeric value associated to the give op. -static Attribute getNeutralElement(Operation *op) { - // Builder only used as helper for attribute creation. - OpBuilder b(op->getContext()); - Type resultType = op->getResult(0).getType(); - if (auto floatType = resultType.dyn_cast()) { - const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); - if (isa(op)) - return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); - if (isa(op)) - return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); - if (isa(op)) - return b.getFloatAttr(resultType, - llvm::APFloat::getLargest(semantic, true)); - if (isa(op)) - return b.getFloatAttr(resultType, - llvm::APFloat::getLargest(semantic, true)); - return Attribute(); - } - if (isa(op)) - return b.getIntegerAttr(resultType, 0); - if (isa(op)) - return b.getIntegerAttr(resultType, -1); - if (isa(op)) - return b.getIntegerAttr(resultType, std::numeric_limits::min()); - if (isa(op)) - return b.getIntegerAttr(resultType, std::numeric_limits::max()); - if (isa(op)) - return b.getIntegerAttr(resultType, 1); - return Attribute(); -} - FailureOr mlir::linalg::splitReduction( PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { @@ -88,8 +57,8 @@ return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); Operation *reductionOp = combinerOps[0]; - Attribute identity = getNeutralElement(reductionOp); - if (!identity) + Optional identity = getNeutralElement(reductionOp); + if (!identity.has_value()) return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); Location loc = op->getLoc(); @@ -187,7 +156,7 @@ emptyOrAllocTensor = b.create( loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); } - Value constantOp = b.create(loc, identity); + Value constantOp = b.create(loc, *identity); Value identityTensor = b.create(op->getLoc(), constantOp, emptyOrAllocTensor) .getResult(0); @@ -309,10 +278,13 @@ if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) return b.notifyMatchFailure(op, "cannot match a reduction pattern"); - SmallVector neutralElements = llvm::to_vector<4>( - llvm::map_range(combinerOps, [&](Operation *reductionOp) { - return getNeutralElement(reductionOp); - })); + SmallVector neutralElements; + for (Operation *reductionOp : combinerOps) { + Optional neutralElement = getNeutralElement(reductionOp); + if (!neutralElement.has_value()) + return b.notifyMatchFailure(op, "cannot find neutral element."); + neutralElements.push_back(*neutralElement); + } if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; })) return b.notifyMatchFailure(op, "unknown reduction neutral"); @@ -455,6 +427,232 @@ results.front()}; } +static FailureOr generateInitialTensorForPartialReduction( + Operation *op, OpBuilder &b, Location loc, ArrayRef sizes, + ArrayRef reductionDims) { + auto linalgOp = cast(op); + assert(reductionDims.size() == 1 && + "only support single reduction right now."); + if (linalgOp.hasBufferSemantics()) + return op->emitOpError("expected operation to have tensor semantics"); + // Insert the new parallel dimension based on the index of the reduction + // loop. This could be controlled by user for more flexibility. + unsigned insertSplitDimension = unsigned(reductionDims[0]); + + SmallVector combinerOps; + if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) || + combinerOps.size() != 1) + return op->emitOpError("Failed to anaysis the reduction operation."); + + Operation *reductionOp = combinerOps[0]; + Optional identity = getNeutralElement(reductionOp); + if (!identity.has_value()) + return op->emitOpError( + "Failed to get an identity value for the reduction operation."); + + // Calculate the new output map and shape, we insert the new dimension in + // the intermediate tensor based on the index returned by + // `controlSplitReductionFn`. + SmallVector newOutputShape; + ArrayRef oldShape = linalgOp.getShape(linalgOp.getOutputOperand(0)); + SmallVector dynamicDims; + for (unsigned idx : llvm::seq(0, oldShape.size() + 1)) { + if (idx == insertSplitDimension) { + dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape, + ShapedType::kDynamicStrideOrOffset); + continue; + } + unsigned oldIdx = idx < insertSplitDimension ? idx : idx - 1; + int64_t dim = oldShape[oldIdx]; + newOutputShape.push_back(dim); + if (dim == ShapedType::kDynamicSize) + dynamicDims.push_back(b.createOrFold( + loc, linalgOp.getOutputOperand(0)->get(), oldIdx)); + } + Value emptyTensor = b.create( + loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(), + dynamicDims); + Value constantOp = b.create(loc, *identity); + auto identityTensor = b.create(loc, constantOp, emptyTensor); + return identityTensor.getOperation(); +} + +static Operation *tileToPartialReduction(Operation *op, OpBuilder &b, + Location loc, ValueRange init, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef reductionDims) { + auto linalgOp = cast(op); + assert(reductionDims.size() == 1 && + "only support single reduction right now."); + unsigned insertSplitDimension = unsigned(reductionDims[0]); + + AffineMap oldOutputMap = + linalgOp.getMatchingIndexingMap(linalgOp.getOutputOperand(0)); + SmallVector outputExpr; + for (auto &[idx, expr] : llvm::enumerate(oldOutputMap.getResults())) { + if (idx == insertSplitDimension) { + outputExpr.push_back(b.getAffineDimExpr(reductionDims[0])); + } + outputExpr.push_back(expr); + } + if (insertSplitDimension == oldOutputMap.getNumResults()) + outputExpr.push_back(b.getAffineDimExpr(reductionDims[0])); + + SmallVector valuesToTile = linalgOp.getInputOperands(); + SmallVector tiledOperands = + makeTiledShapes(b, loc, op, valuesToTile, offsets, sizes, {}, true); + + SmallVector strides(offsets.size(), b.getIndexAttr(1)); + SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); + Value out = b.create(loc, init[0], outOffsets, sizes, + strides); + + // Create the new op matching the original op where the reduction dimension + // is now parallel. + SmallVector newIteratorTypes = linalgOp.getIteratorTypesArray(); + newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName(); + SmallVector newMaps = linalgOp.getIndexingMapsArray(); + newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr, + linalgOp.getContext()); + auto genericOp = + b.create(loc, TypeRange({out.getType()}), tiledOperands, + ValueRange({out}), newMaps, newIteratorTypes); + BlockAndValueMapping mapping; + op->getRegion(0).cloneInto(&genericOp.getRegion(), + genericOp.getRegion().begin(), mapping); + return genericOp.getOperation(); +} + +static Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc, + ValueRange partialReduce, + ArrayRef reductionDims) { + auto linalgOp = cast(op); + assert(reductionDims.size() == 1 && + "only support single reduction right now."); + unsigned dimToMerge = reductionDims[0]; + + // Then create a new reduction that only reduce the newly added dimension + // from the previous op. + unsigned intermRank = partialReduce[0].getType().cast().getRank(); + AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); + SmallVector reductionIteratorTypes; + SmallVector exprs; + for (unsigned i : llvm::seq(0, intermRank)) { + if (dimToMerge == i) { + reductionIteratorTypes.push_back(getReductionIteratorTypeName()); + } else { + exprs.push_back(b.getAffineDimExpr(i)); + reductionIteratorTypes.push_back(getParallelIteratorTypeName()); + } + } + AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op->getContext()); + SmallVector reductionMaps = {inputMap, outputMap}; + + SmallVector combinerOps; + matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps); + Operation *reductionOp = combinerOps[0]; + + auto reduction = b.create( + loc, op->getResultTypes(), ValueRange({partialReduce[0]}), + SmallVector{linalgOp.getOutputOperands()}, reductionMaps, + reductionIteratorTypes, + [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { + Operation *clonedReductionOp = b.clone(*reductionOp); + clonedReductionOp->setOperand(0, inputs[0]); + clonedReductionOp->setOperand(1, inputs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + return reduction.getOperation(); +} + +FailureOr +mlir::linalg::tileReductionUsingScf(PatternRewriter &b, LinalgOp op, + ArrayRef tileSize) { + Location loc = op.getLoc(); + auto tilingInterfaceOp = cast(op.getOperation()); + SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); + SmallVector tileSizeVector = + getValueOrCreateConstantIndexOp(b, loc, tileSize); + if (tileSizeVector.size() < iterationDomain.size()) { + auto zero = b.create(loc, 0); + tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); + } + if (op->getNumResults() != 1) + return b.notifyMatchFailure( + op, "don't support ops with multiple results for now"); + Optional reductionDim; + for (auto &[idx, iteratorType] : + llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) { + // Only support cases with one reduction dimension for now. + if (reductionDim.has_value()) + return failure(); + reductionDim = idx; + continue; + } + if (!isConstantIntValue(tileSize[idx], 0)) + return b.notifyMatchFailure( + op, "only reduction dimensions can have non zero tile sizes"); + } + if (!reductionDim.has_value()) + return b.notifyMatchFailure(op, "doesn't have a reduction dimension"); + // 1. create the inital tensor value. + FailureOr identityTensor = + generateInitialTensorForPartialReduction(op, b, loc, tileSize, + *reductionDim); + + if (failed(identityTensor)) + return failure(); + // 2. Create the nested loops. + SmallVector offsets, sizes; + SmallVector loops = generateTileLoopNest( + b, loc, iterationDomain, tileSizeVector, offsets, sizes); + + // 3. Generate the tiled implementation within the inner most loop. + if (!loops.empty()) + b.setInsertionPoint(loops.back().getBody()->getTerminator()); + + b.setInsertionPoint(loops.back().getBody()->getTerminator()); + Operation *parallelOp = + tileToPartialReduction(op, b, loc, identityTensor.value()->getResults(), + offsets, sizes, *reductionDim); + + SmallVector resultSizesList; + for (size_t i = 0; i < offsets.size(); i++) + resultSizesList.push_back( + b.createOrFold(loc, parallelOp->getResult(0), i)); + SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); + FailureOr> replacementOr = yieldTiledValues( + b, identityTensor.value()->getResults(), parallelOp->getResults(), + outOffsets, resultSizesList, loops); + if (failed(replacementOr)) + return b.notifyMatchFailure(op, "failed to yield replacement"); + + if (auto dstOp = dyn_cast(parallelOp)) { + auto innerMostLoop = loops.back(); + SmallVector destinationTensors = dstOp.getOutputOperands(); + assert(destinationTensors.size() == + innerMostLoop.getRegionIterArgs().size() && + "unexpected number of outputs"); + updateDestinationOperandsForTiledOp(b, destinationTensors, + innerMostLoop.getRegionIterArgs()); + } + + // 4. Apply the merge reduction to combine all the partial values. + b.setInsertionPointAfter(*loops.begin()); + Operation *mergeOp = + mergeReductions(op, b, loc, replacementOr.value(), *reductionDim); + b.replaceOp(op, mergeOp->getResults()); + + ReductionTilingResult results; + results.initialOp = identityTensor.value(); + results.loops = std::move(loops); + results.parallelTiledOp = parallelOp; + results.mergeOp = mergeOp; + return results; +} + namespace { struct LinalgSplitReduction : public OpInterfaceRewritePattern { Index: mlir/lib/Dialect/Linalg/Utils/Utils.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -948,13 +948,14 @@ SmallVector subShapeSizes = computeTileSizes(builder, loc, tileSizes, sizeBounds); - assert(static_cast(valuesToTile.size()) == + assert(static_cast(valuesToTile.size()) <= linalgOp->getNumOperands() && - "expected one value to tile for every operand"); + "more value to tile than operands."); SmallVector> allSliceParams; allSliceParams.reserve(valuesToTile.size()); - for (OpOperand &opOperand : linalgOp->getOpOperands()) { - Value shapedOp = valuesToTile[opOperand.getOperandNumber()]; + for (auto [opOperand, val] : + llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { + Value shapedOp = val; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); // Use `opOperand` as is if it is not tiled and not an output tensor. Having @@ -1059,5 +1060,37 @@ return reassociation; } +/// Return the identity numeric value associated to the give op. +Optional getNeutralElement(Operation *op) { + // Builder only used as helper for attribute creation. + OpBuilder b(op->getContext()); + Type resultType = op->getResult(0).getType(); + if (auto floatType = resultType.dyn_cast()) { + const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); + if (isa(op)) + return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); + if (isa(op)) + return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); + if (isa(op)) + return b.getFloatAttr(resultType, + llvm::APFloat::getLargest(semantic, true)); + if (isa(op)) + return b.getFloatAttr(resultType, + llvm::APFloat::getLargest(semantic, true)); + return Attribute(); + } + if (isa(op)) + return b.getIntegerAttr(resultType, 0); + if (isa(op)) + return b.getIntegerAttr(resultType, -1); + if (isa(op)) + return b.getIntegerAttr(resultType, std::numeric_limits::min()); + if (isa(op)) + return b.getIntegerAttr(resultType, std::numeric_limits::max()); + if (isa(op)) + return b.getIntegerAttr(resultType, 1); + return llvm::None; +} + } // namespace linalg } // namespace mlir Index: mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp =================================================================== --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -91,177 +91,6 @@ // tileUsingSCFForOp implementation. //===----------------------------------------------------------------------===// -// Check if `stride` evenly divides the trip count `size - offset`. -static bool tileDividesIterationDomain(Range loopRange) { - Optional offsetAsInt = getConstantIntValue(loopRange.offset); - if (!offsetAsInt) - return false; - Optional sizeAsInt = getConstantIntValue(loopRange.size); - if (!sizeAsInt) - return false; - Optional strideAsInt = getConstantIntValue(loopRange.stride); - if (!strideAsInt) - return false; - return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); -} - -/// Generate an empty loop nest that represents the tiled loop nest shell. -/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. -/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. -/// - In `offsets` and `sizes` return the multi-dimensional offset and size of -/// the -/// tile processed within the inner most loop. -static SmallVector -generateTileLoopNest(OpBuilder &builder, Location loc, - ArrayRef loopRanges, ArrayRef tileSizeVals, - SmallVector &offsets, - SmallVector &sizes) { - assert(!loopRanges.empty() && "expected at least one loop range"); - assert(loopRanges.size() == tileSizeVals.size() && - "expected as many tile sizes as loop ranges"); - OpBuilder::InsertionGuard guard(builder); - SmallVector loops; - offsets.resize(loopRanges.size()); - sizes.resize(loopRanges.size()); - - // The tile size to use (to avoid out of bounds access) is minimum of - // `tileSize` and `ub - iv`, where `iv` is the induction variable - // of the tiled loop. - AffineExpr s0, s1, d0; - bindDims(builder.getContext(), d0); - bindSymbols(builder.getContext(), s0, s1); - AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); - - for (auto loopRange : llvm::enumerate(loopRanges)) { - Value offset = - getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); - Value size = - getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); - // No loops if tile size is zero. Set offset and size to the loop - // offset and size. - if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { - offsets[loopRange.index()] = offset; - sizes[loopRange.index()] = size; - continue; - } - - auto loop = builder.create( - loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, - [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, - ValueRange /*iterArgs*/) { - bool canAvoidMap = tileDividesIterationDomain( - Range{loopRange.value().offset, loopRange.value().size, - tileSizeVals[loopRange.index()]}); - Value boundedTileSize = - (canAvoidMap) - ? tileSizeVals[loopRange.index()] - : builder.create( - bodyLoc, minMap, - ValueRange{iv, tileSizeVals[loopRange.index()], size}); - sizes[loopRange.index()] = boundedTileSize; - builder.create(loc); - }); - offsets[loopRange.index()] = loop.getInductionVar(); - loops.push_back(loop); - builder.setInsertionPoint(loop.getBody()->getTerminator()); - } - return loops; -} - -/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, -/// construct the destructive update pattern that inserts the yielded -/// value into a destination tensor provided by `initValue` at offset -/// `tileOffsets` and size `tileSizes`. For example, -/// -/// ```mlir -/// scf.for %iv0 = ... { -/// %0 = tiled_op -/// } -/// ``` -/// -/// is transformed to -/// -/// ```mlir -/// scf.for %iv0 = ... iter_args(%arg = %0) { -/// %1 = tensor.extract_slice %arg -/// %2 = tiled_op -/// %3 = tensor.insert_slice %2 into %arg -/// scf.yield %3 -/// } -/// ``` -/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. -static FailureOr> -yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, - ValueRange yieldedValues, - ArrayRef> tileOffsetsList, - ArrayRef> tileSizesList, - MutableArrayRef loops) { - NewYieldValueFn yieldValueFn = - [&](OpBuilder &b, Location loc, - ArrayRef newBBArgs) -> SmallVector { - SmallVector inserts; - for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { - ArrayRef tileOffsets = - tileOffsetsList[yieldedValue.index()]; - ArrayRef tileSizes = tileSizesList[yieldedValue.index()]; - SmallVector tileStrides(tileOffsets.size(), - b.getIndexAttr(1)); - Value insert = b.create( - loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], - tileOffsets, tileSizes, tileStrides); - inserts.push_back(insert); - } - return inserts; - }; - - SmallVector newLoops = - replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, - /*replaceIterOperandsUsesInLoop =*/false); - for (const auto &loop : llvm::enumerate(loops)) { - rewriter.eraseOp(loop.value()); - loops[loop.index()] = newLoops[loop.index()]; - } - return llvm::to_vector(llvm::map_range( - loops.front().getResults().take_back(yieldedValues.size()), - [](OpResult r) -> Value { return r; })); -} - -/// If the tiled operation is destination passing style, update the -/// slice of the destination used (which refers to the untiled destination) -/// to use the corresponding region argument of the innermost loop. -/// -/// ```mlir -/// %0 = -/// scf.for %iv0 = ... iter_args(%arg = %0) { -/// %1 = tensor.extract_slice %0 -/// %2 = tiled_op -/// %3 = tensor.insert_slice %2 into %arg -/// scf.yield %3 -/// } -/// ``` -/// -/// is transformed to -/// -/// ```mlir -/// scf.for %iv0 = ... iter_args(%arg = %0) { -/// %1 = tensor.extract_slice %arg -/// %2 = tiled_op -/// %3 = tensor.insert_slice %2 into %arg -/// scf.yield %3 -/// } -/// ``` -static void -updateDestinationOperandsForTiledOp(OpBuilder &builder, - ValueRange tiledOpDestinationValues, - ValueRange bbArgsList) { - for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { - auto sliceOp = destValue.value().getDefiningOp(); - if (!sliceOp) - continue; - sliceOp.setOperand(0, bbArgsList[destValue.index()]); - } -} - /// Implementation of tiling transformation of `op` that implements the /// `TilingInterface` using `scf.for` to iterate over the tiles. FailureOr Index: mlir/lib/Dialect/SCF/Utils/Utils.cpp =================================================================== --- mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -12,9 +12,12 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" @@ -1009,3 +1012,166 @@ return tileLoops; } + +// Check if `stride` evenly divides the trip count `size - offset`. +static bool tileDividesIterationDomain(Range loopRange) { + Optional offsetAsInt = getConstantIntValue(loopRange.offset); + if (!offsetAsInt) + return false; + Optional sizeAsInt = getConstantIntValue(loopRange.size); + if (!sizeAsInt) + return false; + Optional strideAsInt = getConstantIntValue(loopRange.stride); + if (!strideAsInt) + return false; + return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); +} + +SmallVector mlir::generateTileLoopNest( + OpBuilder &builder, Location loc, ArrayRef loopRanges, + ArrayRef tileSizeVals, SmallVector &offsets, + SmallVector &sizes) { + assert(!loopRanges.empty() && "expected at least one loop range"); + assert(loopRanges.size() == tileSizeVals.size() && + "expected as many tile sizes as loop ranges"); + OpBuilder::InsertionGuard guard(builder); + SmallVector loops; + offsets.resize(loopRanges.size()); + sizes.resize(loopRanges.size()); + + // The tile size to use (to avoid out of bounds access) is minimum of + // `tileSize` and `ub - iv`, where `iv` is the induction variable + // of the tiled loop. + AffineExpr s0, s1, d0; + bindDims(builder.getContext(), d0); + bindSymbols(builder.getContext(), s0, s1); + AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); + + for (auto loopRange : llvm::enumerate(loopRanges)) { + Value offset = + getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); + Value size = + getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); + // No loops if tile size is zero. Set offset and size to the loop + // offset and size. + if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { + offsets[loopRange.index()] = offset; + sizes[loopRange.index()] = size; + continue; + } + + auto loop = builder.create( + loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, + [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, + ValueRange /*iterArgs*/) { + bool canAvoidMap = tileDividesIterationDomain( + Range{loopRange.value().offset, loopRange.value().size, + tileSizeVals[loopRange.index()]}); + Value boundedTileSize = + (canAvoidMap) + ? tileSizeVals[loopRange.index()] + : builder.create( + bodyLoc, minMap, + ValueRange{iv, tileSizeVals[loopRange.index()], size}); + sizes[loopRange.index()] = boundedTileSize; + builder.create(loc); + }); + offsets[loopRange.index()] = loop.getInductionVar(); + loops.push_back(loop); + builder.setInsertionPoint(loop.getBody()->getTerminator()); + } + return loops; +} + +/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, +/// construct the destructive update pattern that inserts the yielded +/// value into a destination tensor provided by `initValue` at offset +/// `tileOffsets` and size `tileSizes`. For example, +/// +/// ```mlir +/// scf.for %iv0 = ... { +/// %0 = tiled_op +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %arg +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. +FailureOr> +mlir::yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, + ValueRange yieldedValues, + ArrayRef> tileOffsetsList, + ArrayRef> tileSizesList, + MutableArrayRef loops) { + NewYieldValueFn yieldValueFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) -> SmallVector { + SmallVector inserts; + for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { + ArrayRef tileOffsets = + tileOffsetsList[yieldedValue.index()]; + ArrayRef tileSizes = tileSizesList[yieldedValue.index()]; + SmallVector tileStrides(tileOffsets.size(), + b.getIndexAttr(1)); + Value insert = b.create( + loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], + tileOffsets, tileSizes, tileStrides); + inserts.push_back(insert); + } + return inserts; + }; + + SmallVector newLoops = + replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, + /*replaceIterOperandsUsesInLoop =*/false); + for (const auto &loop : llvm::enumerate(loops)) { + rewriter.eraseOp(loop.value()); + loops[loop.index()] = newLoops[loop.index()]; + } + return llvm::to_vector(llvm::map_range( + loops.front().getResults().take_back(yieldedValues.size()), + [](OpResult r) -> Value { return r; })); +} + +/// If the tiled operation is destination passing style, update the +/// slice of the destination used (which refers to the untiled destination) +/// to use the corresponding region argument of the innermost loop. +/// +/// ```mlir +/// %0 = +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %0 +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %arg +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +void mlir::updateDestinationOperandsForTiledOp( + OpBuilder &builder, ValueRange tiledOpDestinationValues, + ValueRange bbArgsList) { + for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { + auto sliceOp = destValue.value().getDefiningOp(); + if (!sliceOp) + continue; + sliceOp.setOperand(0, bbArgsList[destValue.index()]); + } +}