diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -179,6 +179,40 @@ let hasCustomAssemblyFormat = 1; } +def TileMultiSizeOp : Op]> { + let description = [{ + Indicates that the multi-sized tiling transformation shuold be applied to + the given target. This transformation partitions each dimension of the + structured op's iteration space into two parts such that the shape of each + part is perfectly divisible by some value not exceeding the corresponding + "target size". The value itself is divisble by the corresponding "target + size divisor" (defaults to 1 if not provided). This produces a tree of + imprefectly nested loops, at the leaves of which 2^n smaller-sized + structured operations are created, where n is the rank of original iteration + space. + + The op requires at least as many target sizes as the target op has iteration + space dimensions. Extra sizes are ignored. Target size divisors may be + provided for either all or none of the tile sizes, and do not need to evenly + divide the provided sizes. Note that zero tile sizes indicating the absence + of tiling along the given dimension are **not** currently supported, + therefore the tile sizes must be strictly positive. + + This op returns a handle to the flattened list of tiled ops, grouped by + target op. For each op, the group of tiled ops covers the parts of the + original iteration space in the lexicographical order of dimensions. + }]; + + let arguments = (ins PDL_Operation:$target, + I64ArrayAttr:$target_sizes, + DefaultValuedAttr:$target_size_divisors); + let results = (outs PDL_Operation:$tiled_linalg_ops); + let assemblyFormat = "$target attr-dict"; + let hasVerifier = 1; +} + def VectorizeOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -129,6 +129,28 @@ FailureOr tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options); +/// Perform multi-size tiling of a single LinalgOp with desired tile sizes +/// `targetSizes`. Multi-size tiling dynamically computes, for each desired tile +/// size, two smaller tile sizes T1, T2 such that some linear combination of +/// *full* tiles covers the entire iteration space dimension. That is, given the +/// iteration space dimension D, T1*n + T2*m == D where n, m are some integer +/// values. If `targetSizeDivisors` are provided, the computed tile sizes will +/// be divisible by the corresponding value; this is useful for vectorization. +/// Multi-size tiling produces an imperfectly-nested loop structure with two +/// loops at each tiled dimension, each with different tile size. The result +/// contains a list of LinalgOp instances operating on smaller data tiles, in +/// their order of appearnace in the IR, which corresponds to the +/// lexicographical order of tile indices. It also contains the tensor-valued +/// outputs that were used to replace the original operation. +struct MultiSizedTilingResult { + SmallVector tiledOps; + ValueRange tensorResults; +}; +FailureOr +multiSizeTileLinalgOp(RewriterBase &b, LinalgOp linalgOp, + ValueRange targetSizes, + ValueRange targetSizeDivisors = {}); + /// Peel and canonicalize 'loops'. void peelLoops(RewriterBase &rewriter, ArrayRef loops); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -170,7 +170,7 @@ /// Note: The returned tile sizes are closed intervals. SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange tileSizes, - ArrayRef sizeBounds); + ValueRange sizeBounds); /// Creates an extract_slice/subview op for a single `valueToTile` with /// `builder`. This new operation extracts a tile of `valueToTile`, starting @@ -193,12 +193,10 @@ /// Note that a constant zero in `tileSizes` means no tiling at that implicit /// loop. The number of non-zero values in `tileSizes` should be equal to the /// number of values in `ivs`. -SmallVector makeTiledShapes(OpBuilder &builder, Location loc, - LinalgOp linalgOp, - ArrayRef valuesToTile, - ValueRange ivs, ValueRange tileSizes, - ArrayRef sizeBounds, - bool omitPartialTileCheck); +SmallVector +makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, + ValueRange valuesToTile, ValueRange ivs, ValueRange tileSizes, + ValueRange sizeBounds, bool omitPartialTileCheck); /// Add the tile loop induction variables `ivs` to the IndexOp results found in /// the body of the `tiledOp` to account for the tile offset. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRLinalgTransformOpsIncGen LINK_LIBS PUBLIC + MLIRArithmeticDialect MLIRIR MLIRLinalgDialect MLIRLinalgTransforms diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -430,6 +431,80 @@ p.printOptionalAttrDict((*this)->getAttrs()); } +//===----------------------------------------------------------------------===// +// TileMultiSizeOp +//===----------------------------------------------------------------------===// + +/// Emits the arithmetic constant operations defining index-typed values for the +/// given list of constants using the provided builder and location. +static SmallVector emitIndexConstants(OpBuilder &b, Location loc, + ArrayRef values) { + return llvm::to_vector(llvm::map_range(values, [&](int64_t value) -> Value { + return b.create(loc, value); + })); +} + +DiagnosedSilenceableFailure +transform::TileMultiSizeOp::apply(TransformResults &results, + TransformState &state) { + SmallVector tileSizes = extractI64Array(getTargetSizes()); + SmallVector tileSizeDivisors = + extractI64Array(getTargetSizeDivisors()); + ArrayRef targets = state.getPayloadOps(getTarget()); + + // Each tiled dimension doubles the number of linalg ops produced. + SmallVector flattenedResults; + flattenedResults.reserve((2 << tileSizes.size()) * targets.size()); + for (Operation *target : targets) { + auto linalgOp = dyn_cast(target); + if (!linalgOp) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "only applies to Linalg structured ops"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + // Multi-sized tiling is designed for dynamic tile sizes provided as values, + // so emit the constants as operations and use them to configure the + // transformation. + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + SmallVector dynamicTileSizes = + emitIndexConstants(rewriter, target->getLoc(), tileSizes); + SmallVector dynamicTileSizeDivisors = + emitIndexConstants(rewriter, target->getLoc(), tileSizeDivisors); + FailureOr result = multiSizeTileLinalgOp( + rewriter, linalgOp, dynamicTileSizes, dynamicTileSizeDivisors); + if (failed(result)) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "failed to apply"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + llvm::append_range(flattenedResults, result->tiledOps); + } + results.set(getTiledLinalgOps().cast(), flattenedResults); + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::TileMultiSizeOp::verify() { + SmallVector tileSizes = extractI64Array(getTargetSizes()); + SmallVector tileSizeDivisors = + extractI64Array(getTargetSizeDivisors()); + if (tileSizes.size() != tileSizeDivisors.size() && + !tileSizeDivisors.empty()) { + return emitOpError() << "expects as many divisors as tile sizes or none"; + } + + auto is_non_positive = [](int64_t value) { return value <= 0; }; + if (llvm::any_of(tileSizes, is_non_positive)) + return emitOpError() << "expects tile sizes to be strictly positive"; + if (llvm::any_of(tileSizeDivisors, is_non_positive)) + return emitOpError() << "expects divisors to be strictly positive"; + return success(); +} + //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// @@ -466,13 +541,13 @@ //===----------------------------------------------------------------------===// namespace { -/// Registers new ops and declares PDL as dependent dialect since the additional -/// ops are using PDL types for operands and results. +/// Registers new ops and declares dependent dialects. class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { public: LinalgTransformDialectExtension() { + declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -26,13 +26,16 @@ #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::scf; #define DEBUG_TYPE "linalg-tiling" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") static bool isZero(Value v) { if (auto cst = v.getDefiningOp()) @@ -82,10 +85,10 @@ addTileLoopIvsToIndexOpResults(b, op, allIvs); } -// Insert a tile `source` into the destination tensor `dest`. The position at -// which the tile is inserted (as well as size of tile) is taken from a given -// ExtractSliceOp `sliceOp`. -static Value insertSliceIntoTensor(RewriterBase &b, Location loc, +/// Insert a tile `source` into the destination tensor `dest`. The position at +/// which the tile is inserted (as well as size of tile) is taken from a given +/// ExtractSliceOp `sliceOp`. +static Value insertSliceIntoTensor(OpBuilder &b, Location loc, tensor::ExtractSliceOp sliceOp, Value source, Value dest) { return b.create( @@ -94,6 +97,41 @@ sliceOp.static_sizes(), sliceOp.static_strides()); } +/// Insert the result slices produced by the `tiled` op back into output tensor +/// operands in case these operands are produced by slice extraction. +static scf::ValueVector insertSlicesBack(OpBuilder &b, Location loc, + LinalgOp tiled, + ValueRange tiledOperands) { + scf::ValueVector tensorResults; + unsigned resultIdx = 0; + for (OpOperand *opOperand : tiled.getOutputTensorOperands()) { + // TODO: use an interface/adaptor to avoid leaking position in + // `tiledOperands`. + Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; + if (auto sliceOp = outputTensor.getDefiningOp()) { + tensorResults.push_back(insertSliceIntoTensor( + b, loc, sliceOp, tiled->getResult(resultIdx), sliceOp.source())); + } else { + tensorResults.push_back(tiled->getResult(resultIdx)); + } + ++resultIdx; + } + return tensorResults; +} + +/// Clone the given `op` while adusting its result types to match those of +/// values taken as output tensor operands. +static LinalgOp cloneWithSubshapeOperands(OpBuilder &b, Location loc, + LinalgOp op, ValueRange operands) { + // TODO: use an interface/adaptor to avoid leaking position in `operands`. + SmallVector resultTensorTypes; + for (OpOperand *opOperand : op.getOutputTensorOperands()) + resultTensorTypes.push_back( + operands[opOperand->getOperandNumber()].getType()); + + return op.clone(b, loc, resultTensorTypes, operands); +} + template static FailureOr tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes, @@ -181,34 +219,10 @@ SmallVector tiledOperands = makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds, /*omitPartialTileCheck=*/false); + res = cloneWithSubshapeOperands(b, loc, op, tiledOperands); - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. - SmallVector resultTensorTypes; - for (OpOperand *opOperand : op.getOutputTensorOperands()) - resultTensorTypes.push_back( - tiledOperands[opOperand->getOperandNumber()].getType()); - - res = op.clone(b, loc, resultTensorTypes, tiledOperands); - - // Insert a insert_slice for each output tensor. - unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputTensorOperands()) { - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. - Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; - // TODO: Propagate RewriterBase everywhere. - IRRewriter rewriter(b); - if (auto sliceOp = outputTensor.getDefiningOp()) { - tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp, - res->getResult(resultIdx), - sliceOp.source())); - } else { - tensorResults.push_back(res->getResult(resultIdx)); - } - ++resultIdx; - } - return scf::ValueVector(tensorResults.begin(), tensorResults.end()); + // Insert an insert_slice for each output tensor. + return insertSlicesBack(b, loc, res, tiledOperands); }; GenerateLoopNest::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, tiledLoopBodyBuilder, options.distribution, @@ -278,6 +292,272 @@ return failure(); } +namespace { +/// A description of a multi-size tiling comprising tile sizes and numbers of +/// tiles, expressed as Values which may or may not be constant. Multi-size +/// currently means two-size. +struct MultiSizeSpecification { + /// Tile sizes. + Value lowTileSize, highTileSize; + /// Number of tiles associated with each size. + Value lowTripCount, highTripCount; +}; +} // namespace + +/// Emit the IR copmuting the multi-sized tiling specification with two tile +/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such that +/// there exist numbers of tiles with these sizes that fully cover the +/// `originalTripCount` iterations. +/// +/// The computation is as follows: +/// +/// b = originalTripCount floordiv sizeDivisor +/// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor +/// d = (b + t - 1) floordiv t +/// s = (b floordiv d) * sizeDivisor +/// v = b % d +/// u = d - v +/// +/// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of +/// the corresponding tiles are `u` and `v`, respectively. All four values are +/// returned. +static MultiSizeSpecification +computeMultiSizeLoopTripCounts(ImplicitLocOpBuilder &b, Value targetSize, + Value sizeDivisor, Value originalTripCount, + Value one) { + assert(targetSize.getType() == originalTripCount.getType() && + sizeDivisor.getType() == originalTripCount.getType() && + "expected all types to match"); + Value dividedBound = + b.create(originalTripCount, sizeDivisor); + Value targetPlusDivisor = b.create(targetSize, sizeDivisor); + Value targetPlusDivisorSubOne = + b.create(targetPlusDivisor, one); + Value roundedTarget = + b.create(targetPlusDivisorSubOne, sizeDivisor); + Value boundPlusTargetRounded = + b.create(dividedBound, roundedTarget); + Value boundPlusTargetRoundedSubOne = + b.create(boundPlusTargetRounded, one); + Value divisorRounded = b.create( + boundPlusTargetRoundedSubOne, roundedTarget); + Value unscaledLowTileSize = + b.create(dividedBound, divisorRounded); + Value lowTileSize = b.create(unscaledLowTileSize, sizeDivisor); + Value highTileSize = b.create(lowTileSize, sizeDivisor); + Value highTripCount = b.create(dividedBound, divisorRounded); + Value lowTripCount = b.create(divisorRounded, highTripCount); + return {lowTileSize, highTileSize, lowTripCount, highTripCount}; +} + +static ValueRange createMultiSizedTillingLoops( + ImplicitLocOpBuilder &b, LinalgOp op, Value zero, Value one, + ValueRange sizeBounds, ArrayRef specs, + ValueRange initArgs, SmallVectorImpl &adjustedIterators, + SmallVectorImpl &tileSizes, SmallVectorImpl &tiledOps); + +/// Emit the IR for one of the loops produced by multi-sized tiling, including +/// all nested loops recursively. The recursion is bounded by the number of +/// dimensions being tiled, which is known to be small (usualy <10). `isLow` +/// indicates whether the part being emitted is the first (lower indices) or +/// the last (higher indices), which affects the index adjustment. `op` is the +/// operation being tiled. `zero` and `one` correspond to index-typed constants +/// visible in the loop. `sizeBounds` contains the dimensions of the original +/// iteration space; `specs` contains the tile sizes and numbers of tiles yet to +/// be generated, starting with the current one. `initArgs` are the values that +/// are passed as `iter_args` of the loop being emitted, and typically +/// correspond to results of the operation being tiled partially updated by +/// previous parts, if any. `adjustedIterators` is a mutable list of values that +/// replace the indices of the original iteration space. `tileSizes` is a +/// mutable list of tile sizes used by the parent (already emitted) loops. +/// `tiledOps` is a mutable list of smaller instances of `op` produce by tiling. +/// All mutable lists are updated before entering the recursion. +/// +/// The loop structure resembles: +/// +/// %partial = scf.for %i = 0 to %numLowTiles step 1 +/// iter_args(%iter_args = %result_inits) { +/// %adjusted = %i * %lowTileSize +/// %slices... = extractslice %original_inputs[...] +/// %out_slices... = extractslice %iter_args[%adjusted, ...] +/// %res = linlalg.op ins(%slices) outs(%out_slices) +/// %loop_res = insertslice %res into %iter_args[%adjusted, ...] +/// scf.yield %loop_res +/// } +/// +/// where "linalg.op" can be further decomposed into pairs of loops implementing +/// the multi-size tiling on deeper dimensions. +static ValueRange createOneMultiSizedPart( + ImplicitLocOpBuilder &b, LinalgOp op, Value zero, Value one, + ValueRange sizeBounds, bool isLow, ArrayRef specs, + ValueRange initArgs, SmallVectorImpl &adjustedIterators, + SmallVectorImpl &tileSizes, SmallVectorImpl &tiledOps) { + // Create the loop itself. + auto loop = b.create( + zero, isLow ? specs[0].lowTripCount : specs[0].highTripCount, one, + initArgs); + + // Emit IR recovering the indices in the original iteration space. For the + // lower part, this accounts for the tile size. For the higher part, this + // accounts for the tile size and for the pieces computed by the lower part. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(loop.getBody()); + if (isLow) { + adjustedIterators.push_back( + b.create(loop.getInductionVar(), specs[0].lowTileSize)); + tileSizes.push_back(specs[0].lowTileSize); + } else { + Value previousProduct = + b.create(specs[0].lowTripCount, specs[0].lowTileSize); + Value currentProduct = + b.create(loop.getInductionVar(), specs[0].highTileSize); + adjustedIterators.push_back( + b.create(previousProduct, currentProduct)); + tileSizes.push_back(specs[0].highTileSize); + } + auto scope = llvm::make_scope_exit([&] { + adjustedIterators.pop_back(); + tileSizes.pop_back(); + }); + + // If this is not the innermost loop, recurse. + if (specs.size() > 1) { + ValueRange yielded = createMultiSizedTillingLoops( + b, op, zero, one, sizeBounds, specs.drop_front(), + loop.getRegionIterArgs(), adjustedIterators, tileSizes, tiledOps); + b.create(yielded); + return loop->getResults(); + } + + // Emit tiled operations in the innermost loop. By construction, both parts in + // the multi-sized case have full tiles so omit the partial tile check code + // generation. + auto operandsToTile = llvm::to_vector( + llvm::map_range(op.getInputOperands(), + [](OpOperand *opOperand) { return opOperand->get(); })); + llvm::append_range(operandsToTile, loop.getRegionIterArgs()); + SmallVector tiledOperands = + makeTiledShapes(b, b.getLoc(), op, operandsToTile, adjustedIterators, + tileSizes, sizeBounds, + /*omitPartialTileCheck=*/true); + + // Create the tiled operation. + LinalgOp tiled = cloneWithSubshapeOperands(b, b.getLoc(), op, tiledOperands); + addTileLoopIvsToIndexOpResults(b, tiled, adjustedIterators); + tiledOps.push_back(tiled); + + // Insert partial results into tensors and yield them. + scf::ValueVector results = + insertSlicesBack(b, b.getLoc(), tiled, tiledOperands); + b.create(results); + return loop->getResults(); +} + +/// Emit the IR for loops representing the lower and the higher part of the +/// multi-sized tiling of the given operation `op`. Operates recursively per +/// dimension. See createOneMultiSizedPart for documentation on arguments. The +/// IR structured resembles the following: +/// +/// %partial = scf.for %i = 0 to %numLowTiles step 1 iter_args(%init_args) { +/// %adjusted = %i * %lowTileSize +/// // see createOneMultiSizedPart for loop body +/// } +/// %full = scf.for %i = 0 to %numHighTiles step 1 iter_args(%partial) { +/// %adjusted = %i * %highTileSize + %numLowTiles * %lowTileSize +/// // see createOneMultiSizedPart for loop body +/// } +static ValueRange createMultiSizedTillingLoops( + ImplicitLocOpBuilder &b, LinalgOp op, Value zero, Value one, + ValueRange sizeBounds, ArrayRef specs, + ValueRange initArgs, SmallVectorImpl &adjustedIterators, + SmallVectorImpl &tileSizes, SmallVectorImpl &tiledOps) { + assert(!specs.empty() && "expected a non-empty list of tile specs"); + + ValueRange lowResults = createOneMultiSizedPart( + b, op, zero, one, sizeBounds, /*isLow=*/true, specs, initArgs, + adjustedIterators, tileSizes, tiledOps); + return createOneMultiSizedPart(b, op, zero, one, sizeBounds, /*isLow=*/false, + specs, lowResults, adjustedIterators, + tileSizes, tiledOps); +} + +FailureOr +mlir::linalg::multiSizeTileLinalgOp(RewriterBase &b, LinalgOp linalgOp, + ValueRange targetSizes, + ValueRange targetSizeDivisors) { + if (linalgOp.getNumLoops() > targetSizes.size()) { + LLVM_DEBUG( + DBGS() << "NYI: multi-size tiling only applies to all dimensions\n"); + return failure(); + } + + if (linalgOp.getNumWindowLoops() != 0) { + LLVM_DEBUG( + DBGS() << "NYI: multi-size tiling does not support window loops\n"); + return failure(); + } + assert((targetSizeDivisors.empty() || + targetSizeDivisors.size() == targetSizes.size()) && + "expected the same number of divisors as target sizes"); + + // No tiling is required. + if (targetSizes.empty()) { + MultiSizedTilingResult result; + result.tiledOps.push_back(linalgOp); + result.tensorResults = linalgOp->getResults(); + return result; + } + + // Set up target divisors if necessary. + ImplicitLocOpBuilder builder(linalgOp.getLoc(), b); + SmallVector updatedTargetSizeDivisors; + Value one = nullptr; + if (targetSizeDivisors.empty()) { + one = builder.create(1); + updatedTargetSizeDivisors.resize(targetSizes.size(), one); + targetSizeDivisors = llvm::makeArrayRef(updatedTargetSizeDivisors); + } + + // Compute multi-sized tiling specifications. This includes tile sizes and the + // number of tiles, the latter serve as trip counts for the produced loops. + Location loc = linalgOp.getLoc(); + SmallVector allShapes = + linalgOp.createFlatListOfOperandDims(b, loc); + AffineMap shapesToLoopsMap = linalgOp.getShapesToLoopsMap(); + if (!shapesToLoopsMap) { + LLVM_DEBUG(DBGS() << "the op does not provide the shapes-to-loops map"); + return failure(); + } + SmallVector loopTripCounts = + applyMapToValues(b, loc, shapesToLoopsMap, allShapes); + unsigned numTiledDims = std::min(loopTripCounts.size(), targetSizes.size()); + one = one == nullptr ? builder.create(1) : one; + SmallVector specs; + specs.reserve(numTiledDims); + for (unsigned i = 0; i < numTiledDims; ++i) { + specs.push_back(computeMultiSizeLoopTripCounts(builder, targetSizes[i], + targetSizeDivisors[i], + loopTripCounts[i], one)); + } + + // Generate loops recursively for each iteration space dimension. + SmallVector sizeBounds = + applyMapToValues(b, loc, shapesToLoopsMap, allShapes); + Value zero = builder.create(0); + auto tensorOutputs = llvm::to_vector( + llvm::map_range(linalgOp.getOutputOperands(), + [](OpOperand *operand) { return operand->get(); })); + SmallVector adjustedIterators; + SmallVector tileSizes; + MultiSizedTilingResult result; + result.tensorResults = createMultiSizedTillingLoops( + builder, linalgOp, zero, one, sizeBounds, specs, tensorOutputs, + adjustedIterators, tileSizes, result.tiledOps); + + b.replaceOp(linalgOp, result.tensorResults); + return result; +} + /// Generate a loop nest around a given tensor::PadOp (for tiling). `newPadOp` /// and `loopNest` are output parameters that return the new (tiled) /// tensor::PadOp and the loop nest. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -895,7 +895,7 @@ SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange tileSizes, - ArrayRef sizeBounds) { + ValueRange sizeBounds) { SmallVector sizes; for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { bool isTiled = !isZero(tileSizes[idx]); @@ -908,12 +908,10 @@ return sizes; } -SmallVector makeTiledShapes(OpBuilder &b, Location loc, - LinalgOp linalgOp, - ArrayRef valuesToTile, - ValueRange ivs, ValueRange tileSizes, - ArrayRef sizeBounds, - bool omitPartialTileCheck) { +SmallVector +makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, + ValueRange valuesToTile, ValueRange ivs, ValueRange tileSizes, + ValueRange sizeBounds, bool omitPartialTileCheck) { assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), [](Value v) { return !isZero(v); })) && diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -180,6 +180,26 @@ return [IntegerAttr(element).value for element in attr] +class TileMultiSizeOp: + """Specialization for TileMultiSizeOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + target_sizes: Union[ArrayAttr, IntOrAttrList], + target_size_divisors: OptionalIntList = None, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + target_sizes=_get_int_array_attr(target_sizes), + target_size_divisors=_get_int_array_attr(target_size_divisors), + loc=loc, + ip=ip) + + class VectorizeOp: """Specialization for VectorizeOp class.""" diff --git a/mlir/test/Dialect/Linalg/transform-op-tile-multisize.mlir b/mlir/test/Dialect/Linalg/transform-op-tile-multisize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-tile-multisize.mlir @@ -0,0 +1,261 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter --canonicalize --split-input-file --verify-diagnostics | FileCheck %s --check-prefix=CANON + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{expects tile sizes to be strictly positive}} + transform.structured.tile.multisize %arg0 { target_sizes = [0, 10] } +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{expects divisors to be strictly positive}} + transform.structured.tile.multisize %arg0 { + target_sizes = [10, 10], + target_size_divisors = [1, -1] + } +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{expects as many divisors as tile sizes or none}} + transform.structured.tile.multisize %arg0 { + target_sizes = [10, 10], + target_size_divisors = [3] + } +} + +// ----- + +// +// Checking the successful transformation. +// + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + transform.structured.tile.multisize %0 { target_sizes = [3, 10] } + } +} + +// CHECK-DAG: #[[$MINUS_ONE:.+]] = affine_map<()[s0] -> (s0 - 1)> +// CHECK-DAG: #[[$ID_1D:.+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$ID_2D:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$SUM:.+]] = affine_map<(d0, d1) -> (d0 + d1)> + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + +// CHECK-LABEL: @one_d +// CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> +// CANON-LABEL: @one_d +// CANON-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> +func.func @one_d(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { + // CHECK: %[[SIZE1:.+]] = arith.constant 3 + // CHECK: %[[SIZE2:.+]] = arith.constant 10 + // CHECK: %[[ONE:.+]] = arith.constant 1 : + // These are emitted by createOrFold on dims. + // CHECK-COUNT-2: constant 10 + // CHECK: %[[SHAPE:.+]] = arith.constant 10 + // CHECK: %[[DIVIDED_BOUND:.+]] = arith.floordivsi %[[SHAPE]], %[[ONE]] + // CHECK: %[[TARGET_PLUS_DIVISOR:.+]] = arith.addi %[[SIZE1]], %[[ONE]] + // CHECK: %[[TARGET_PLUS_DIVISOR_SUB_ONE:.+]] = arith.subi %[[TARGET_PLUS_DIVISOR]], %[[ONE]] + // CHECK: %[[ROUNDED_TARGET:.+]] = arith.floordivsi %[[TARGET_PLUS_DIVISOR_SUB_ONE]], %[[ONE]] + // CHECK: %[[BOUND_PLUS_TARGET_ROUNDED:.+]] = arith.addi %[[DIVIDED_BOUND]], %[[ROUNDED_TARGET]] + // CHECK: %[[BOUND_PLUS_TARGET_ROUNDED_SUB_ONE:.+]] = arith.subi %[[BOUND_PLUS_TARGET_ROUNDED]], %[[ONE]] + // CHECK: %[[DIVISOR_ROUNDED:.+]] = arith.floordivsi %[[BOUND_PLUS_TARGET_ROUNDED_SUB_ONE]], %[[ROUNDED_TARGET]] + // CHECK: %[[UNSCALED_LOW_TILE_SIZE:.+]] = arith.floordivsi %[[DIVIDED_BOUND]], %[[DIVISOR_ROUNDED]] + // CHECK: %[[LOW_TILE_SIZE:.+]] = arith.muli %[[UNSCALED_LOW_TILE_SIZE]], %[[ONE]] + // CHECK: %[[HIGH_TILE_SIZE:.+]] = arith.addi %[[LOW_TILE_SIZE]], %[[ONE]] + // CHECK: %[[HIGH_TRIP_COUNT:.+]] = arith.remsi %[[DIVIDED_BOUND]], %[[DIVISOR_ROUNDED]] + // CHECK: %[[LOW_TRIP_COUNT:.+]] = arith.subi %[[DIVISOR_ROUNDED]], %[[HIGH_TRIP_COUNT]] + + // CHECK: %[[ZERO:.+]] = arith.constant 0 + // CHECK: %[[PARTIAL:.+]] = scf.for %[[I:.+]] = %[[ZERO]] to %[[LOW_TRIP_COUNT]] step %[[ONE]] iter_args(%[[ITER_ARG_I:.+]] = %[[OUT]]) + // CHECK: %[[ADJUSTED_I:.+]] = arith.muli %[[I]], %[[LOW_TILE_SIZE]] + // CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]][%[[ADJUSTED_I]]] [%[[LOW_TILE_SIZE]]] [1] + // CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG_I]][%[[ADJUSTED_I]]] [%[[LOW_TILE_SIZE]]] [1] + // CHECK: %[[RES_SLICE:.+]] = linalg.generic {{.*}} ins(%[[IN_SLICE]] : tensor) outs(%[[OUT_SLICE]] : tensor) + // CHECK: %[[INDEX_0:.+]] = linalg.index 0 + // CHECK: %[[ADJUSTED_INDEX_0:.+]] = affine.apply #[[$SUM]](%[[INDEX_0]], %[[ADJUSTED_I]]) + // CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[RES_SLICE]] into %[[ITER_ARG_I]][%[[ADJUSTED_I]]] [%[[LOW_TILE_SIZE]]] [1] + // CHECK: scf.yield %[[YIELD]] + + + // CHECK: scf.for %[[I:.+]] = %[[ZERO]] to %[[HIGH_TRIP_COUNT]] step %[[ONE]] iter_args(%[[ITER_ARG_I:.+]] = %[[PARTIAL]]) + // CHECK: %[[LOW_PART:.+]] = arith.muli %[[LOW_TRIP_COUNT]], %[[LOW_TILE_SIZE]] + // CHECK: %[[SCALED_I:.+]] = arith.muli %[[I]], %[[HIGH_TILE_SIZE]] + // CHECK: %[[ADJUSTED_I:.+]] = arith.addi %[[LOW_PART]], %[[SCALED_I]] + // CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]][%[[ADJUSTED_I]]] [%[[HIGH_TILE_SIZE]]] [1] + // CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG_I]][%[[ADJUSTED_I]]] [%[[HIGH_TILE_SIZE]]] [1] + // CHECK: %[[RES_SLICE:.+]] = linalg.generic {{.*}} ins(%[[IN_SLICE]] : tensor) outs(%[[OUT_SLICE]] : tensor) + // CHECK: %[[INDEX_0:.+]] = linalg.index 0 + // CHECK: %[[ADJUSTED_INDEX_0:.+]] = affine.apply #[[$SUM]](%[[INDEX_0]], %[[ADJUSTED_I]]) + // CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[RES_SLICE]] into %[[ITER_ARG_I]][%[[ADJUSTED_I]]] [%[[HIGH_TILE_SIZE]]] [1] + // CHECK: scf.yield %[[YIELD]] + + // Check that canonicalization is able to recover static shapes. + // CANON-DAG: %[[C0:.+]] = arith.constant 0 + // CANON-DAG: %[[C1:.+]] = arith.constant 1 + // CANON-DAG: %[[C2:.+]] = arith.constant 2 + // CANON-DAG: %[[C3:.+]] = arith.constant 3 + // CANON-DAG: %[[C4:.+]] = arith.constant 4 + // CANON: scf.for %{{.*}} = %[[C0]] to %[[C2]] step %[[C1]] + // CANON: tensor.extract_slice %[[IN]][%{{.*}}] [2] [1] : tensor<10xf32> to tensor<2xf32> + // CANON: tensor.extract_slice %{{.*}}[%{{.*}}] [2] [1] : tensor<10xf32> to tensor<2xf32> + // CANON: scf.for %{{.*}} = %[[C0]] to %[[C2]] step %[[C1]] + // CANON: tensor.extract_slice %[[IN]][%{{.*}}] [3] [1] : tensor<10xf32> to tensor<3xf32> + // CANON: tensor.extract_slice %{{.*}}[%{{.*}}] [3] [1] : tensor<10xf32> to tensor<3xf32> + %0 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<10xf32> + return %0 : tensor<10xf32> +} + +// CHECK-LABEL: @two_d +// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32> +// CANON-LABEL: @two_d +// CANON-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32> +func.func @two_d(%arg0: tensor<10x34xf32>, %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { + // Only check the overall nesting and computation structure. + // CHECK: %[[PARTIAL_I:.+]] = scf.for %{{.*}} iter_args(%[[ITER_ARG_I:.+]] = %[[OUT]]) + // CHECK: %[[ADJUSTED_I:.+]] = arith.muli + // + // CHECK: %[[PARTIAL_J:.+]] = scf.for %{{.*}} iter_args(%[[ITER_ARG_J:.+]] = %[[ITER_ARG_I]]) + // CHECK: %[[ADJUSTED_J:.+]] = arith.muli + // CHECK: tensor.extract_slice %[[IN]][%[[ADJUSTED_I]], %[[ADJUSTED_J]]] + // CHECK: tensor.extract_slice %[[ITER_ARG_J]][%[[ADJUSTED_I]], %[[ADJUSTED_J]]] + // CHECK: %[[RES:.+]] = linalg.generic + // CHECK: %[[INSERTED_J:.+]] = tensor.insert_slice %[[RES]] into %[[ITER_ARG_J]] + // CHECK: scf.yield %[[INSERTED_J]] + // + // CHECK: %[[FULL_J:.+]] = scf.for %{{.*}} iter_args(%[[ITER_ARG_J:.+]] = %[[PARTIAL_J]]) + // CHECK: arith.muli + // CHECK: arith.muli + // CHECK: %[[ADJUSTED_J:.+]] = arith.addi + // CHECK: tensor.extract_slice %[[IN]][%[[ADJUSTED_I]], %[[ADJUSTED_J]]] + // CHECK: tensor.extract_slice %[[ITER_ARG_J]][%[[ADJUSTED_I]], %[[ADJUSTED_J]]] + // CHECK: %[[RES:.+]] = linalg.generic + // CHECK: %[[INSERTED_J:.+]] = tensor.insert_slice %[[RES]] into %[[ITER_ARG_J]] + // CHECK: scf.yield %[[INSERTED_J]] + // + // CHECK: %{{.+}} = scf.for %{{.*}} iter_args(%[[ITER_ARG_I:.+]] = %[[PARTIAL_I]]) + // CHECK: arith.muli + // CHECK: arith.muli + // CHECK: %[[ADJUSTED_I:.+]] = arith.addi + // + // CHECK: %[[PARTIAL_J:.+]] = scf.for %{{.*}} iter_args(%[[ITER_ARG_J:.+]] = %[[ITER_ARG_I]]) + // CHECK: %[[ADJUSTED_J:.+]] = arith.muli + // CHECK: tensor.extract_slice %[[IN]][%[[ADJUSTED_I]], %[[ADJUSTED_J]]] + // CHECK: tensor.extract_slice %[[ITER_ARG_J]][%[[ADJUSTED_I]], %[[ADJUSTED_J]]] + // CHECK: %[[RES:.+]] = linalg.generic + // CHECK: %[[INSERTED_J:.+]] = tensor.insert_slice %[[RES]] into %[[ITER_ARG_J]] + // CHECK: scf.yield %[[INSERTED_J]] + // + // CHECK: %[[FULL_J:.+]] = scf.for %{{.*}} iter_args(%[[ITER_ARG_J:.+]] = %[[PARTIAL_J]]) + // CHECK: arith.muli + // CHECK: arith.muli + // CHECK: %[[ADJUSTED_J:.+]] = arith.addi + // CHECK: tensor.extract_slice %[[IN]][%[[ADJUSTED_I]], %[[ADJUSTED_J]]] + // CHECK: tensor.extract_slice %[[ITER_ARG_J]][%[[ADJUSTED_I]], %[[ADJUSTED_J]]] + // CHECK: %[[RES:.+]] = linalg.generic + // CHECK: %[[INSERTED_J:.+]] = tensor.insert_slice %[[RES]] into %[[ITER_ARG_J]] + // CHECK: scf.yield %[[INSERTED_J]] + + // Check that canonicalization is able to recover static shapes. + // CANON-COUNT-2: tensor.extract_slice {{.*}} : tensor<10x34xf32> to tensor<2x8xf32> + // CANON-COUNT-2: tensor.extract_slice {{.*}} : tensor<10x34xf32> to tensor<2x9xf32> + // CANON-COUNT-2: tensor.extract_slice {{.*}} : tensor<10x34xf32> to tensor<3x8xf32> + // CANON-COUNT-2: tensor.extract_slice {{.*}} : tensor<10x34xf32> to tensor<3x9xf32> + %0 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"] + } + ins(%arg0: tensor<10x34xf32>) + outs(%arg1: tensor<10x34xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<10x34xf32> + return %0 : tensor<10x34xf32> +} + +// ----- + +// expected-note @below {{target op}} +module { + transform.sequence { + ^bb1(%arg1: !pdl.operation): + // expected-error @below {{only applies to Linalg structured ops}} + transform.structured.tile.multisize %arg1 { target_sizes = [3, 10] } + } +} + +// ----- + +// +// Check failure to apply due to the insufficient number of target sizes. +// +// TODO: this should have a better error message, but it needs to be surfaced +// from the transformation code itself to avoid duplication. +// + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + +func.func @two_d(%arg0: tensor<10x34xf32>, %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { + // expected-note @below {{target op}} + %0 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"] + } + ins(%arg0: tensor<10x34xf32>) + outs(%arg1: tensor<10x34xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<10x34xf32> + return %0 : tensor<10x34xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + // expected-error @below {{failed to apply}} + transform.structured.tile.multisize %0 { target_sizes = [3] } + } +} + diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -114,6 +114,18 @@ # CHECK-DAG: sizes = [4, 8] +@run +def testTileMultiSize(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.TileMultiSizeOp(sequence.bodyTarget, target_sizes=[3, 7]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileMultiSize + # CHECK: transform.sequence + # CHECK: structured.tile.multisize + # CHECK: target_sizes = [3, 7] + + @run def testTileZero(): sequence = transform.SequenceOp() diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7398,6 +7398,7 @@ ], includes = ["include"], deps = [ + ":ArithmeticDialect", ":IR", ":LinalgDialect", ":LinalgTransformOpsIncGen",