Index: mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td =================================================================== --- mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -608,6 +608,92 @@ }]; } + +def TileReductionUsingScfOp : Op { + let description = [{ + Indicates that the given `target` op should be transformed with the + `tileReduction` transformation with the tile size provided as attribute. + + This transformation tiles the `target` along the reduction dimensions. It + creates a tensor initialized with the identity value. Then it creates nested + loops with a parallel version of `target` op inside. The parallel op + dimensions are less or equal to the tile size passed by user. + After the loop a merge operation is created to do a final reduction with the + partial reductions. + + #### Return modes + + This 3 returned handles point to: + - the fill op used to initialize the neutral element, + - the parallel tiled op and + - the result-combining op. + + #### Example: + + ``` + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.addf %arg7, %arg9 : f32 + linalg.yield %1 : f32 + } -> tensor + return %red : tensor + ``` + + is transformed into: + + ``` + %0 = tensor.empty(%dim_1) : tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %2 = scf.for %arg2 = %c0 to %dim_0 step %c5 iter_args(%arg3 = %1) -> (tensor) { + %extracted_slice = tensor.extract_slice %1[0, 0] [%dim, 5] [1, 1] : tensor to tensor + %extracted_slice_2 = tensor.extract_slice %arg0[0, %arg2] [%dim, 5] [1, 1] : tensor to tensor + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%extracted_slice_2 : tensor) + outs(%extracted_slice : tensor) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %in, %out : f32 + linalg.yield %5 : f32 + } -> tensor + %dim_3 = tensor.dim %1, %c0 : tensor + %inserted_slice = tensor.insert_slice %4 into %arg3[0, 0] [%dim_3, 5] [1, 1] : tensor into tensor + scf.yield %inserted_slice : tensor + } + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%2 : tensor) + outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %4 = arith.addf %in, %out : f32 + linalg.yield %4 : f32 + } -> tensor + ``` + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$tile_sizes); + let results = (outs PDL_Operation:$fill_op, + PDL_Operation:$split_linalg_op, + PDL_Operation:$combining_linalg_op); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def TileOp : Op, DeclareOpInterfaceMethods]> { Index: mlir/include/mlir/Dialect/Linalg/Utils/Utils.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -137,6 +137,10 @@ Optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +/// Return the identity numeric value associated to the give op. Return +/// llvm::None if there is no known neutral element. +Optional getNeutralElement(Operation *op); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h =================================================================== --- mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -136,6 +136,45 @@ FailureOr> lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); +/// Transformation information returned after reduction tiling. +struct SCFReductionTilingResult { + /// The tiled operation generated. + Operation *parallelTiledOp; + /// The tiled operation generated. + Operation *mergeOp; + /// Initial op + Operation *initialOp; + /// The `scf.for` operations that iterate over the tiles. + SmallVector loops; +}; + +/// Method to tile a reduction and generate a parallel op within a serial loop. +/// Each of the partial reductions are calculated in parallel. Then after the +/// loop all the partial reduction are mered into a final reduction. +/// For example, for the following sequence +/// +/// ```mlir +/// %0 = linalg.generic %in ["parallel", "reduction"] +/// : tensor<7x9xf32> -> tensor<7xf32> +/// ``` +/// into: +/// +/// ```mlir +/// %0 = linalg.fill ... : tensor<7x4xf32> +/// %1 = scf.for ... iter_args(%arg0 = %0) +/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32> +/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> +/// %4 = linalg.generic %2, %3 ["parallel", "parallel"] +/// : tensor<7x?xf32> -> tensor<7x?xf32> +/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32> +/// } +/// %6 = linalg.generic %1 ["parallel", "reduction"] +/// : tensor<7x4xf32> -> tensor<7xf32> +/// ``` +FailureOr +tileReductionUsingScf(PatternRewriter &b, PartialReductionOpInterface op, + ArrayRef tileSize); + } // namespace scf } // namespace mlir Index: mlir/include/mlir/Interfaces/TilingInterface.td =================================================================== --- mlir/include/mlir/Interfaces/TilingInterface.td +++ mlir/include/mlir/Interfaces/TilingInterface.td @@ -155,4 +155,72 @@ > ]; } + +def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { + let description = [{ + Interface for allowing operations to expose information needed to + tile partial reductions. This is complementary to TilingInterface to tile + reductions. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Methode to generate a tensor initalized with the identity value of the + operation reduction. The tensor shape is equal to operation result + shape with new dimension for each non zero tile size. + }], + /*retType=*/"FailureOr", + /*methodName=*/"generateInitialTensorForPartialReduction", + /*args=*/(ins + "OpBuilder &":$b, + "Location ":$loc, + "ArrayRef":$sizes, + "ArrayRef":$reductionDim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Methode to generate a tiled version of the operation where the tiled + reduction dimension are converted to parallel dimensions with a size + less or equal to the tile size. This is meant to be used with + `mergeReductions` methode which will combine the partial reductions. + }], + /*retType=*/"Operation*", + /*methodName=*/"tileToPartialReduction", + /*args=*/(ins + "OpBuilder &":$b, + "Location ":$loc, + "ValueRange":$init, + "ArrayRef":$offsets, + "ArrayRef":$sizes, + "ArrayRef":$reductionDim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return nullptr; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Methode to merge partial reductions for an operation that has been + tiled along the reduction dimensions. This will only apply the + reduction the operation. + }], + /*retType=*/"Operation*", + /*methodName=*/"mergeReductions", + /*args=*/(ins + "OpBuilder &":$b, + "Location ":$loc, + "ValueRange":$partialReduce, + "ArrayRef":$reductionDim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return nullptr; + }] + > + ]; +} #endif // MLIR_TILINGINTERFACE Index: mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1094,6 +1094,33 @@ return DiagnosedSilenceableFailure(success()); } +//===----------------------------------------------------------------------===// +// SplitReductionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( + linalg::LinalgOp target, SmallVectorImpl &results, + transform::TransformState &state) { + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); + SmallVector sizes; + for (int64_t size : tileSizes) { + sizes.push_back(rewriter.getIndexAttr(size)); + } + + FailureOr result = scf::tileReductionUsingScf( + rewriter, cast(target.getOperation()), + sizes); + + if (failed(result)) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + results.push_back(result->initialOp); + results.push_back(result->parallelTiledOp); + results.push_back(result->mergeOp); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -26,38 +26,6 @@ using namespace mlir; using namespace mlir::linalg; -/// Return the identity numeric value associated to the give op. -static Attribute getNeutralElement(Operation *op) { - // Builder only used as helper for attribute creation. - OpBuilder b(op->getContext()); - Type resultType = op->getResult(0).getType(); - if (auto floatType = resultType.dyn_cast()) { - const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); - if (isa(op)) - return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); - if (isa(op)) - return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); - if (isa(op)) - return b.getFloatAttr(resultType, - llvm::APFloat::getLargest(semantic, true)); - if (isa(op)) - return b.getFloatAttr(resultType, - llvm::APFloat::getLargest(semantic, true)); - return Attribute(); - } - if (isa(op)) - return b.getIntegerAttr(resultType, 0); - if (isa(op)) - return b.getIntegerAttr(resultType, -1); - if (isa(op)) - return b.getIntegerAttr(resultType, std::numeric_limits::min()); - if (isa(op)) - return b.getIntegerAttr(resultType, std::numeric_limits::max()); - if (isa(op)) - return b.getIntegerAttr(resultType, 1); - return Attribute(); -} - FailureOr mlir::linalg::splitReduction( PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { @@ -88,8 +56,8 @@ return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); Operation *reductionOp = combinerOps[0]; - Attribute identity = getNeutralElement(reductionOp); - if (!identity) + Optional identity = getNeutralElement(reductionOp); + if (!identity.has_value()) return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); Location loc = op->getLoc(); @@ -187,7 +155,7 @@ emptyOrAllocTensor = b.create( loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); } - Value constantOp = b.create(loc, identity); + Value constantOp = b.create(loc, *identity); Value identityTensor = b.create(op->getLoc(), constantOp, emptyOrAllocTensor) .getResult(0); @@ -309,10 +277,13 @@ if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) return b.notifyMatchFailure(op, "cannot match a reduction pattern"); - SmallVector neutralElements = llvm::to_vector<4>( - llvm::map_range(combinerOps, [&](Operation *reductionOp) { - return getNeutralElement(reductionOp); - })); + SmallVector neutralElements; + for (Operation *reductionOp : combinerOps) { + Optional neutralElement = getNeutralElement(reductionOp); + if (!neutralElement.has_value()) + return b.notifyMatchFailure(op, "cannot find neutral element."); + neutralElements.push_back(*neutralElement); + } if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; })) return b.notifyMatchFailure(op, "unknown reduction neutral"); Index: mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -240,11 +241,166 @@ } }; +//===----------------------------------------------------------------------===// +// External Model for implementing `PartialReductionInterface` for `LinalgOp`s. +//===----------------------------------------------------------------------===// + +/// External model implementation of PartialReductionInterface for LinalgOps. +template +struct LinalgOpPartialReductionInterface + : public PartialReductionOpInterface::ExternalModel< + LinalgOpPartialReductionInterface, LinalgOpTy> { + FailureOr generateInitialTensorForPartialReduction( + Operation *op, OpBuilder &b, Location loc, ArrayRef sizes, + ArrayRef reductionDims) const { + auto linalgOp = cast(op); + assert(reductionDims.size() == 1 && + "only support single reduction right now."); + if (linalgOp.hasBufferSemantics()) + return op->emitOpError("expected operation to have tensor semantics"); + // Insert the new parallel dimension based on the index of the reduction + // loop. This could be controlled by user for more flexibility. + unsigned insertSplitDimension = unsigned(reductionDims[0]); + + SmallVector combinerOps; + if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) || + combinerOps.size() != 1) + return op->emitOpError("Failed to anaysis the reduction operation."); + + Operation *reductionOp = combinerOps[0]; + Optional identity = getNeutralElement(reductionOp); + if (!identity.has_value()) + return op->emitOpError( + "Failed to get an identity value for the reduction operation."); + + // Calculate the new output map and shape, we insert the new dimension in + // the intermediate tensor based on the index returned by + // `controlSplitReductionFn`. + SmallVector newOutputShape; + ArrayRef oldShape = + linalgOp.getShape(linalgOp.getDpsInitOperand(0)); + SmallVector dynamicDims; + for (unsigned idx : llvm::seq(0, oldShape.size() + 1)) { + if (idx == insertSplitDimension) { + dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape, + ShapedType::kDynamicStrideOrOffset); + continue; + } + unsigned oldIdx = idx < insertSplitDimension ? idx : idx - 1; + int64_t dim = oldShape[oldIdx]; + newOutputShape.push_back(dim); + if (dim == ShapedType::kDynamicSize) + dynamicDims.push_back(b.createOrFold( + loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx)); + } + Value emptyTensor = b.create( + loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(), + dynamicDims); + Value constantOp = b.create(loc, *identity); + auto identityTensor = + b.create(loc, constantOp, emptyTensor); + return identityTensor.getOperation(); + } + + Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, + ValueRange init, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef reductionDims) const { + auto linalgOp = cast(op); + assert(reductionDims.size() == 1 && + "only support single reduction right now."); + unsigned insertSplitDimension = unsigned(reductionDims[0]); + + AffineMap oldOutputMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); + SmallVector outputExpr; + for (auto &[idx, expr] : llvm::enumerate(oldOutputMap.getResults())) { + if (idx == insertSplitDimension) { + outputExpr.push_back(b.getAffineDimExpr(reductionDims[0])); + } + outputExpr.push_back(expr); + } + if (insertSplitDimension == oldOutputMap.getNumResults()) + outputExpr.push_back(b.getAffineDimExpr(reductionDims[0])); + + SmallVector valuesToTile = linalgOp.getDpsInputOperands(); + SmallVector tiledOperands = + makeTiledShapes(b, loc, op, valuesToTile, offsets, sizes, {}, true); + + SmallVector strides(offsets.size(), b.getIndexAttr(1)); + SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); + Value out = b.create(loc, init[0], outOffsets, + sizes, strides); + + // Create the new op matching the original op where the reduction dimension + // is now parallel. + SmallVector newIteratorTypes = linalgOp.getIteratorTypesArray(); + newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName(); + SmallVector newMaps = linalgOp.getIndexingMapsArray(); + newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr, + linalgOp.getContext()); + auto genericOp = + b.create(loc, TypeRange({out.getType()}), tiledOperands, + ValueRange({out}), newMaps, newIteratorTypes); + BlockAndValueMapping mapping; + op->getRegion(0).cloneInto(&genericOp.getRegion(), + genericOp.getRegion().begin(), mapping); + return genericOp.getOperation(); + } + + Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc, + ValueRange partialReduce, + ArrayRef reductionDims) const { + auto linalgOp = cast(op); + assert(reductionDims.size() == 1 && + "only support single reduction right now."); + unsigned dimToMerge = reductionDims[0]; + + // Then create a new reduction that only reduce the newly added dimension + // from the previous op. + unsigned intermRank = + partialReduce[0].getType().cast().getRank(); + AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); + SmallVector reductionIteratorTypes; + SmallVector exprs; + for (unsigned i : llvm::seq(0, intermRank)) { + if (dimToMerge == i) { + reductionIteratorTypes.push_back(getReductionIteratorTypeName()); + } else { + exprs.push_back(b.getAffineDimExpr(i)); + reductionIteratorTypes.push_back(getParallelIteratorTypeName()); + } + } + AffineMap outputMap = + AffineMap::get(intermRank, 0, exprs, op->getContext()); + SmallVector reductionMaps = {inputMap, outputMap}; + + SmallVector combinerOps; + matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps); + Operation *reductionOp = combinerOps[0]; + + auto reduction = b.create( + loc, op->getResultTypes(), ValueRange({partialReduce[0]}), + SmallVector{linalgOp.getDpsInitOperands()}, reductionMaps, + reductionIteratorTypes, + [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { + Operation *clonedReductionOp = b.clone(*reductionOp); + clonedReductionOp->setOperand(0, inputs[0]); + clonedReductionOp->setOperand(1, inputs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + return reduction.getOperation(); + } +}; + } // namespace template static void registerOne(MLIRContext *ctx) { OpType::template attachInterface>(*ctx); + OpType::template attachInterface>( + *ctx); } /// Variadic helper function. Index: mlir/lib/Dialect/Linalg/Utils/Utils.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -948,13 +948,14 @@ SmallVector subShapeSizes = computeTileSizes(builder, loc, tileSizes, sizeBounds); - assert(static_cast(valuesToTile.size()) == + assert(static_cast(valuesToTile.size()) <= linalgOp->getNumOperands() && - "expected one value to tile for every operand"); + "more value to tile than operands."); SmallVector> allSliceParams; allSliceParams.reserve(valuesToTile.size()); - for (OpOperand &opOperand : linalgOp->getOpOperands()) { - Value shapedOp = valuesToTile[opOperand.getOperandNumber()]; + for (auto [opOperand, val] : + llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { + Value shapedOp = val; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); // Use `opOperand` as is if it is not tiled and not an output tensor. Having @@ -1059,5 +1060,37 @@ return reassociation; } +/// Return the identity numeric value associated to the give op. +Optional getNeutralElement(Operation *op) { + // Builder only used as helper for attribute creation. + OpBuilder b(op->getContext()); + Type resultType = op->getResult(0).getType(); + if (auto floatType = resultType.dyn_cast()) { + const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); + if (isa(op)) + return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); + if (isa(op)) + return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); + if (isa(op)) + return b.getFloatAttr(resultType, + llvm::APFloat::getLargest(semantic, true)); + if (isa(op)) + return b.getFloatAttr(resultType, + llvm::APFloat::getLargest(semantic, true)); + return Attribute(); + } + if (isa(op)) + return b.getIntegerAttr(resultType, 0); + if (isa(op)) + return b.getIntegerAttr(resultType, -1); + if (isa(op)) + return b.getIntegerAttr(resultType, std::numeric_limits::min()); + if (isa(op)) + return b.getIntegerAttr(resultType, std::numeric_limits::max()); + if (isa(op)) + return b.getIntegerAttr(resultType, 1); + return llvm::None; +} + } // namespace linalg } // namespace mlir Index: mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp =================================================================== --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -424,6 +424,95 @@ return tilingResult; } +FailureOr +mlir::scf::tileReductionUsingScf(PatternRewriter &b, + PartialReductionOpInterface op, + ArrayRef tileSize) { + Location loc = op.getLoc(); + auto tilingInterfaceOp = cast(op.getOperation()); + SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); + SmallVector tileSizeVector = + getValueOrCreateConstantIndexOp(b, loc, tileSize); + if (tileSizeVector.size() < iterationDomain.size()) { + auto zero = b.create(loc, 0); + tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); + } + if (op->getNumResults() != 1) + return b.notifyMatchFailure( + op, "don't support ops with multiple results for now"); + SmallVector iterators = + tilingInterfaceOp.getLoopIteratorTypes(); + Optional reductionDim; + for (auto &[idx, iteratorType] : + llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) { + // Only support cases with one reduction dimension for now. + if (reductionDim.has_value()) + return failure(); + reductionDim = idx; + continue; + } + if (!isConstantIntValue(tileSize[idx], 0)) + return b.notifyMatchFailure( + op, "only reduction dimensions can have non zero tile sizes"); + } + if (!reductionDim.has_value()) + return b.notifyMatchFailure(op, "doesn't have a reduction dimension"); + // 1. create the inital tensor value. + FailureOr identityTensor = + op.generateInitialTensorForPartialReduction(b, loc, tileSize, + *reductionDim); + + if (failed(identityTensor)) + return failure(); + // 2. Create the nested loops. + SmallVector offsets, sizes; + SmallVector loops = generateTileLoopNest( + b, loc, iterationDomain, tileSizeVector, offsets, sizes); + + // 3. Generate the tiled implementation within the inner most loop. + if (!loops.empty()) + b.setInsertionPoint(loops.back().getBody()->getTerminator()); + + b.setInsertionPoint(loops.back().getBody()->getTerminator()); + Operation *parallelOp = + op.tileToPartialReduction(b, loc, identityTensor.value()->getResults(), + offsets, sizes, *reductionDim); + + SmallVector resultSizesList; + for (size_t i = 0; i < offsets.size(); i++) + resultSizesList.push_back( + b.createOrFold(loc, parallelOp->getResult(0), i)); + SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); + FailureOr> replacementOr = yieldTiledValues( + b, identityTensor.value()->getResults(), parallelOp->getResults(), + outOffsets, resultSizesList, loops); + if (failed(replacementOr)) + return b.notifyMatchFailure(op, "failed to yield replacement"); + + if (auto dstOp = dyn_cast(parallelOp)) { + auto innerMostLoop = loops.back(); + SmallVector destinationTensors = dstOp.getDpsInitOperands(); + assert(destinationTensors.size() == + innerMostLoop.getRegionIterArgs().size() && + "unexpected number of outputs"); + updateDestinationOperandsForTiledOp(b, destinationTensors, + innerMostLoop.getRegionIterArgs()); + } + + // 4. Apply the merge reduction to combine all the partial values. + b.setInsertionPointAfter(*loops.begin()); + Operation *mergeOp = + op.mergeReductions(b, loc, replacementOr.value(), *reductionDim); + b.replaceOp(op, mergeOp->getResults()); + + SCFReductionTilingResult results; + results.initialOp = identityTensor.value(); + results.loops = std::move(loops); + results.parallelTiledOp = parallelOp; + results.mergeOp = mergeOp; + return results; +} //===----------------------------------------------------------------------===// // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/Linalg/transform-tile-reduction.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize | FileCheck %s + +func.func @reduction_tile(%arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.mulf %arg7, %arg7 : f32 + %2 = arith.addf %1, %arg9 : f32 + linalg.yield %2 : f32 + } -> tensor + return %red : tensor +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 5] } +} + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)> +// CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor +// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor +// CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor) { +// CHECK: %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]] +// CHECK: %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor to tensor +// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor to tensor +// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor) outs(%[[EXT]] : tensor) { +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor +// CHECK: %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor into tensor +// CHECK: scf.yield %[[INS]] : tensor +// CHECK: } +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor) outs(%[[ARG1]] : tensor) { +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor + +// ----- + +func.func @reduction_tile_transpose(%arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>], + iterator_types = ["reduction", "parallel"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %42 = arith.addf %arg7, %arg9 : f32 + linalg.yield %42 : f32 + } -> tensor + return %red : tensor +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [5, 0] } +} + +// CHECK: func @reduction_tile_transpose +// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32> +// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32> +// CHECK: scf.for +// CHECK: linalg.generic +// CHECK: %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor +// CHECK: %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor into tensor<5x?xf32> +// CHECK: scf.yield {{.*}} : tensor<5x?xf32> +// CHECK: } +// CHECK: linalg.generic +// CHECK: return