diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -20,6 +20,8 @@ std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); +std::unique_ptr> createLinalgFuseTensorOpsPass(); + std::unique_ptr createLinalgElementwiseOpFusionPass(); std::unique_ptr createFoldReshapeOpsByLinearizationPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -56,6 +56,17 @@ ]; } +def LinalgFuseTensorOps : FunctionPass<"linalg-fuse-tensor-ops"> { + let summary = "Fuse the producers of tiled operations on tensors"; + let constructor = "mlir::createLinalgFuseTensorOpsPass()"; + let options = [ + ListOption<"tileLoops", "tile-loops", "int64_t", + "The tiled consumer loop dimensions from outer to inner", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ]; + let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"]; +} + def LinalgElementwiseOpFusion : Pass<"linalg-fuse-elementwise-ops"> { let summary = "Fuse elementwise operations on tensors"; let constructor = "mlir::createLinalgElementwiseOpFusionPass()"; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -108,6 +108,11 @@ ValueRange ivs, ValueRange tileSizes, ArrayRef sizeBounds); +/// Add the tile loop induction variables `ivs` to the index op results found in +/// the body of the `tiledOp` to account for the tile offset. +void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp, + ArrayRef ivs); + using FusableOpDependencesTy = llvm::MapVector< Operation *, SmallVector>; @@ -148,6 +153,40 @@ OpResult producerOpResult, OpOperand &consumerOpOperand); +//===----------------------------------------------------------------------===// +// Fusion on tensor utilities +//===----------------------------------------------------------------------===// + +/// A struct storing the state computed by `isProducerFusable` and consumed by +/// `fuseProducer`. +struct FusionState { + /// Result to fuse. + OpResult producerOpResult; + /// Iteration argument of the outermost tile loop if one exists. + OpOperand *tileLoopIterArg; + /// All fusable producer loop dimensions. + SmallVector producerLoopsToFuse; + /// A producer result shape dimension per fusable producer loop. + SmallVector producerShapeDimsToFuse; + /// An ExtractSliceOp per fusable producer loop if one exists. + SmallVector sliceOpsToFuse; +}; + +/// Verify if the producer of `consumerOpOperand` is fusable in place of an +/// ExtractSliceOp part of the use-def chain connecting consumer and producer. +/// Use `tileLoopDims` to map the tile loops from outer to inner to the tiled +/// consumer loop dimension. In case of transitive fusion, tile loop dimensions +/// may be set to none if the loop does not tile a previously fused operation. +FailureOr +isProducerFusable(OpOperand *consumerOpOperand, + ArrayRef> tileLoopDims, + function_ref notifyFailure); + +/// Fuse the producer in place of the ExtractSliceOp found by +/// `isProducerFusable`. Clones the producer to replace the ExtractSliceOp and +/// returns the cloned producer. +LinalgOp fuseProducer(OpBuilder &b, FusionState &fusionState); + //===----------------------------------------------------------------------===// // Distribution utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ ElementwiseOpFusion.cpp ElementwiseToLinalg.cpp Fusion.cpp + FusionOnTensors.cpp Generalization.cpp Hoisting.cpp InlineScalarOperands.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -0,0 +1,452 @@ +//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements linalg fusion on tensors +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace linalg; + +/// Walk the use-def chain starting from the consumer and collect the +/// ExtracSliceOps and the ForOps found. Stop the search as soon as an operation +/// of a different kind is found. At the same time, collect the parent tile +/// loops of type ForOp and add them to `tileLoops`. For every tile loop, add +/// the ExtractSliceOp or ForOp found to `sliceOps` and `bbArgs`, respectively. +/// Otherwise, set the `sliceOps` or `bbArgs` entries to nullptr. Finally, +/// return the back of the use-def chain. +// TODO: Support additional loop types and control flow operations. +static Value getTileLoopNest(OpOperand *consumerOpOperand, + SmallVectorImpl &tileLoops, + SmallVectorImpl &sliceOps, + SmallVectorImpl &bbArgs) { + // Get the initial value of the use-def chain and the innermost tile loop. + Value current = consumerOpOperand->get(); + LinalgOp consumerOp = consumerOpOperand->getOwner(); + auto tileLoop = dyn_cast(consumerOp->getParentOp()); + + // Walk the tile loop nest. + while (tileLoop) { + // Advance to the next tile loop if the current value of the use-def chain + // is defined outside of the loop. + if (current.getParentBlock()->getParentOp() != tileLoop) { + tileLoops.push_back(tileLoop); + sliceOps.push_back(nullptr); + bbArgs.push_back(nullptr); + tileLoop = dyn_cast(tileLoop->getParentOp()); + continue; + } + // Search an ExtractSliceOp part of the tile loop level. + auto sliceOp = current.getDefiningOp(); + if (sliceOp && sliceOp->getParentOp() == tileLoop) { + sliceOps.push_back(sliceOp); + current = sliceOp.source(); + } + // Search a ForOp part of the tile loop level. + if (auto bbArg = current.dyn_cast()) { + Operation *parentOp = bbArg.getParentBlock()->getParentOp(); + auto forOp = dyn_cast(parentOp); + if (forOp && forOp == tileLoop) { + bbArgs.push_back(bbArg); + current = forOp.getOpOperandForRegionIterArg(bbArg).get(); + } + } + // Exit if the current value is not an ExtractSliceOp or a ForOp and defined + // inside the tile loop since we may have found the producer. + if (sliceOps.size() == tileLoops.size() && + bbArgs.size() == tileLoops.size()) + return current; + // Advance to the next tile loop level and append nullptr to the collections + // if needed. + tileLoops.push_back(tileLoop); + sliceOps.resize(tileLoops.size(), nullptr); + bbArgs.resize(tileLoops.size(), nullptr); + tileLoop = dyn_cast(tileLoop->getParentOp()); + } + return current; +} + +/// Relate the producer to the consumer loop iterations that access the same +/// producer result element: +/// consumerToProducerLoops = +/// inverse(producerIndexingMap).compose(consumerIndexingMap). +/// Return `consumerToProducerLoops` or none if the inversion fails. +static Optional +getConsumerToProducerLoopsMap(AffineMap producerIndexingMap, + AffineMap consumerIndexingMap) { + assert(consumerIndexingMap.getNumResults() == + producerIndexingMap.getNumResults() && + "expect the number of indexing map results to match"); + // Ensure the producer indexing map is a projected permutation. + if (!producerIndexingMap.isProjectedPermutation()) + return None; + AffineMap inverseIndexingMap = + inverseAndBroadcastProjectedPermuation(producerIndexingMap); + return inverseIndexingMap.compose(consumerIndexingMap); +} + +/// Return the producer loop dimension mapped to the given consumer tile loop +/// dimension or none if no mapping exists. +static Optional getProducerLoopDim(int64_t tileLoopDim, + AffineMap consumerToProducerLoops) { + assert(count_if(consumerToProducerLoops.getResults(), + [&](AffineExpr expr) { + return expr.isFunctionOfDim(tileLoopDim); + }) <= 1 && + "expect the tile loop to tile at most one producer loop"); + for (auto en : enumerate(consumerToProducerLoops.getResults())) + if (en.value().isFunctionOfDim(tileLoopDim)) + return en.index(); + return None; +} + +/// Return the bound for the given producer loop dimension. +static Value getLoopBound(OpBuilder &b, LinalgOp producerOp, int64_t dim) { + Location loc = producerOp.getLoc(); + for (OpOperand *opOperand : producerOp.getInputAndOutputOperands()) { + AffineMap indexingMap = producerOp.getTiedIndexingMap(opOperand); + for (auto en : enumerate(indexingMap.getResults())) { + auto dimExpr = en.value().dyn_cast(); + if (dimExpr && dim == static_cast(dimExpr.getPosition())) + return createOrFoldDimOp(b, loc, opOperand->get(), en.index()); + } + } + return nullptr; +} + +// Tile the producer operands given an ExtractSliceOp part of the use-def chain. +static SmallVector +getTiledOperands(OpBuilder &b, LinalgOp producerOp, + tensor::ExtractSliceOp sliceOp, ArrayRef valuesToTile, + ArrayRef producerLoopsToFuse, + ArrayRef producerShapeDimsToFuse, + SmallVectorImpl &sizeBounds, + SmallVectorImpl &allIvs, FusionState &fusionState) { + Location loc = producerOp.getLoc(); + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(sliceOp); + + // Get the offsets and sizes extracted by the ExtractSliceOp. + SmallVector ranges = sliceOp.getOrCreateRanges(b, loc); + + // Get the induction variables and tile sizes for the fused producer loops. + auto zero = b.create(loc, 0); + SmallVector tileIvs(producerOp.getNumLoops(), nullptr); + SmallVector tileSizes(producerOp.getNumLoops(), zero); + for (auto it : zip(producerLoopsToFuse, producerShapeDimsToFuse)) { + int64_t loopDim, shapeDim; + std::tie(loopDim, shapeDim) = it; + tileIvs[loopDim] = ranges[shapeDim].offset; + tileSizes[loopDim] = ranges[shapeDim].size; + allIvs[loopDim] = tileIvs[loopDim]; + } + erase_value(tileIvs, nullptr); + + // Tile the producer operands given the ivs and tile sizes / size bounds. + SmallVector tiledOperands = makeTiledShapes( + b, loc, producerOp, valuesToTile, tileIvs, tileSizes, sizeBounds); + + // Update the size bounds for the tiled dimensions. + for (int64_t loopDim : producerLoopsToFuse) + sizeBounds[loopDim] = tileSizes[loopDim]; + + return tiledOperands; +} + +FailureOr +mlir::linalg::isProducerFusable(OpOperand *consumerOpOperand, + ArrayRef> tileLoopDims, + function_ref notifyFailure) { + // Call the notify failure callback and return failure. + auto handleFailure = [&](StringRef message) { + notifyFailure(message); + return failure(); + }; + + // Get the consumer and check it has tensor semantics. + auto consumerOp = dyn_cast(consumerOpOperand->getOwner()); + if (!consumerOp || !consumerOp.hasTensorSemantics()) + return handleFailure("expect consumer to be a linalg op on tensors"); + + // Collect the tile loops plus the ExtractSliceOps and BlockArguments for + // every tile loop level or nullptr if they do not exists. + SmallVector tileLoops; + SmallVector sliceOps; + SmallVector bbArgs; + Value producerResult = + getTileLoopNest(consumerOpOperand, tileLoops, sliceOps, bbArgs); + + // Check there are tile loops and the tile loop dimensions are known. + if (tileLoops.empty() || tileLoops.size() > tileLoopDims.size()) + return handleFailure("expect >0 and <=tileLoopDims.size() tile loops"); + + // Check the producer is a LinalgOp and has tensor semantics. + auto producerOp = producerResult.getDefiningOp(); + if (!producerOp || !producerOp.hasTensorSemantics()) + return handleFailure("expect producer to be a linalg op on tensors"); + auto producerOpResult = producerResult.cast(); + + // Check the parents of producer result and outermost tile loop match. + scf::ForOp outermostTileLoop = tileLoops.back(); + if (producerResult.getParentBlock()->getParentOp() != + outermostTileLoop->getParentOp()) + return handleFailure("expect producer and tile loop parents to match"); + + // Compute the consumer to producer loops mapping and exit on failure. + AffineMap producerIndexinMap = producerOp.getTiedIndexingMap( + producerOp.getOutputOperand(producerOpResult.getResultNumber())); + AffineMap consumerIndexinMap = + consumerOp.getTiedIndexingMap(consumerOpOperand); + Optional consumerToProducerLoops = + getConsumerToProducerLoopsMap(producerIndexinMap, consumerIndexinMap); + if (!consumerToProducerLoops.hasValue()) + return handleFailure("cannot compute consumer to producer loop map"); + + // Reverse `sliceOps` and `bbArgs` since they are processed from outer to + // inner and to match the `tileLoopDims` order. + std::reverse(sliceOps.begin(), sliceOps.end()); + std::reverse(bbArgs.begin(), bbArgs.end()); + + // Search the innermost fusable ExtractSliceOp. + int64_t consumerOpDepth = tileLoops.size(); + int64_t sliceOpDepth = 0; + for (auto en : enumerate(tileLoopDims.take_front(consumerOpDepth))) { + // Stop fusion if an output operand passed by BlockArgument into a + // non-parallel tile loop. + if (bbArgs[en.index()] && en.value().hasValue()) { + Attribute iteratorType = + consumerOp.iterator_types().getValue()[en.value().getValue()]; + if (!isParallelIterator(iteratorType)) + break; + } + // Update depth to the innermost ExtractSliceOp found so far. + if (sliceOps[en.index()]) + sliceOpDepth = en.index() + 1; + } + if (sliceOpDepth == 0) + return handleFailure("expect to find a slice op to replace"); + + // Keep only the fusable ExtractSliceOps and BlockArguments. + sliceOps.resize(sliceOpDepth); + bbArgs.resize(sliceOpDepth); + int64_t bbArgsCount = sliceOpDepth - count(bbArgs, nullptr); + if (bbArgsCount != 0 && bbArgsCount != sliceOpDepth) + return handleFailure("expect one block argument per tile loop or none"); + + // If the producer of a consumer output is fused into a tile loop nest, fusion + // sets the iteration argument of the outermost tile loop to the producer + // output instead of its result. This transformation is only valid if all + // values along the use-def chain between the outermost tile loop iteration + // argument and the ExtractSliceOp are solely used by the consumer. We thus + // ensure these values either have one use or are used by ExtractSliceOp + // InsertSliceOp pairs exclusively. + if (bbArgsCount != 0) { + // Ensure all slice ops except for the last have one use. + for (auto sliceOp : sliceOps) { + if (sliceOp && !(sliceOp->hasOneUse() || sliceOp == sliceOps.back())) + return handleFailure("expect slice op to have one use"); + } + // Ensure all block arguments have one use or are used by one ExtractSliceOp + // InsertSliceOp pair except for possible dim accesses. + for (auto bbArg : bbArgs) { + int64_t defaultCount = 0, extractCount = 0, insertCount = 0; + for (Operation *op : bbArg.getUsers()) { + TypeSwitch(op) + .Case([&](auto) { extractCount++; }) + .Case([&](auto) { insertCount++; }) + .Case([&](auto) {}) + .Default([&](auto) { defaultCount++; }); + } + if (!bbArg.hasOneUse() && + !(extractCount == 1 && insertCount == 1 && defaultCount == 0)) + return handleFailure("expect one use or one extract/insert pair"); + } + } + + // Search the producer loops to fuse and exit if there are none. + SmallVector producerLoopsToFuse; + SmallVector sliceOpsToFuse; + for (auto en : enumerate(tileLoopDims.take_front(consumerOpDepth) + .drop_back(consumerOpDepth - sliceOpDepth))) { + // Search the the tiled producer loop and add it if one exists. + if (en.value().hasValue()) { + Optional producerTileLoopDim = getProducerLoopDim( + en.value().getValue(), consumerToProducerLoops.getValue()); + if (producerTileLoopDim.hasValue()) + producerLoopsToFuse.push_back(producerTileLoopDim.getValue()); + } + // Add a ExtractSliceOp to tile if there are producer loops to fuse. + int64_t numLoopsToTile = producerLoopsToFuse.size() - sliceOpsToFuse.size(); + if (sliceOps[en.index()] && numLoopsToTile != 0) { + sliceOpsToFuse.insert(sliceOpsToFuse.end(), numLoopsToTile, nullptr); + sliceOpsToFuse.back() = sliceOps[en.index()]; + } + } + if (producerLoopsToFuse.empty()) + return handleFailure("expect to find producer loops to fuse"); + + // Search the producer shape dimensions to fuse. + SmallVector producerShapeDimsToFuse(producerLoopsToFuse.size()); + for (auto en : enumerate(producerLoopsToFuse)) { + auto *it = find_if(producerIndexinMap.getResults(), [&](AffineExpr expr) { + AffineDimExpr dimExpr = expr.dyn_cast(); + return dimExpr.getPosition() == en.value(); + }); + assert(it != producerIndexinMap.getResults().end() && + "expect to find the loop in the indexing map"); + producerShapeDimsToFuse[en.index()] = + std::distance(producerIndexinMap.getResults().begin(), it); + } + + // Compute the tile loop iteration argument of the outermost tile loop. + OpOperand *tileLoopIterArg = nullptr; + if (bbArgsCount != 0) + tileLoopIterArg = + &outermostTileLoop.getOpOperandForRegionIterArg(bbArgs.front()); + + return FusionState{producerOpResult, tileLoopIterArg, producerLoopsToFuse, + producerShapeDimsToFuse, sliceOpsToFuse}; +} + +LinalgOp mlir::linalg::fuseProducer(OpBuilder &b, FusionState &fusionState) { + LinalgOp producerOp = fusionState.producerOpResult.getOwner(); + + // Set the insertion point to the producer to compute its bounds. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(producerOp); + + // Compute the size bounds for all producer loop dimensions. + SmallVector sizeBounds(producerOp.getNumLoops()); + for (auto &en : enumerate(sizeBounds)) { + en.value() = getLoopBound(b, producerOp, en.index()); + assert(en.value() && "cannot compute producer loop bound"); + } + + // Tile the producer operands at every tile loop level associated to an + // ExtractSliceOp. Tiling the producer operands level-by-level may unlock + // additional fusion opportunities in case of transitive fusion. + assert(fusionState.producerLoopsToFuse.size() == + fusionState.producerShapeDimsToFuse.size() && + "expect one shape dimension per producer loop dimension"); + assert(fusionState.producerLoopsToFuse.size() == + fusionState.sliceOpsToFuse.size() && + "expect one slice op entry per producer loop dimension"); + int64_t startIndex = 0; + SmallVector allIvs(producerOp.getNumLoops(), nullptr); + SmallVector tiledOperands = producerOp.getInputAndOutputOperands(); + for (auto en : enumerate(fusionState.sliceOpsToFuse)) { + // Skip the tile loop if there is no ExtractSliceOp. + if (!en.value()) + continue; + + // Get the producer loops and shapes tiled by the current tile loop level. + SmallVector producerLoopsToFuse( + fusionState.producerLoopsToFuse.begin() + startIndex, + fusionState.producerLoopsToFuse.begin() + en.index() + 1); + SmallVector producerShapeDimsToFuse( + fusionState.producerShapeDimsToFuse.begin() + startIndex, + fusionState.producerShapeDimsToFuse.begin() + en.index() + 1); + startIndex = en.index() + 1; + assert(!producerLoopsToFuse.empty() && !producerShapeDimsToFuse.empty() && + "expect the slice op tiles at least one producer loop"); + + tiledOperands = getTiledOperands( + b, producerOp, en.value(), tiledOperands, producerLoopsToFuse, + producerShapeDimsToFuse, sizeBounds, allIvs, fusionState); + } + + // Replace the producer result iteration argument of the outermost loop by + // the tied producer output and use the existing ExtractSliceOp result instead + // of the tiled producer output. + if (fusionState.tileLoopIterArg) { + OpOperand *outputOperand = producerOp.getOutputOperand( + fusionState.producerOpResult.getResultNumber()); + fusionState.tileLoopIterArg->set(outputOperand->get()); + tiledOperands[outputOperand->getOperandNumber()] = + fusionState.sliceOpsToFuse.back().getResult(); + } + + // Set the insertion point to after the ExtractSliceOp since the cloned + // producer may access its result. + b.setInsertionPointAfter(fusionState.sliceOpsToFuse.back()); + + // Clone the producer using the tiled producer operands. + Location loc = producerOp.getLoc(); + TypeRange resultTypes = ValueRange(tiledOperands) + .take_back(producerOp.getNumOutputs()) + .getTypes(); + LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); + + // Shift all index op results by the tile offset. + addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); + + // Cast the cloned op result to gap type mismatches before canonicalizations. + Type operandType = fusionState.sliceOpsToFuse.back().getResult().getType(); + int64_t resultNumber = fusionState.producerOpResult.getResultNumber(); + Value result = clonedOp->getResult(resultNumber); + if (result.getType() != operandType) + result = b.create(loc, operandType, result).getResult(); + + // Replace all ExtractSliceOp uses except for possible uses by cloned op. + fusionState.sliceOpsToFuse.back().getResult().replaceAllUsesExcept(result, + clonedOp); + return clonedOp; +} + +namespace { + +struct LinalgFuseTensorOps + : public LinalgFuseTensorOpsBase { + + void runOnFunction() override { + FuncOp funcOp = getFunction(); + // Search all tiled ops. + SmallVector tiledOps; + funcOp.walk([&](LinalgOp linalgOp) { + if (isa(linalgOp->getParentOp())) + tiledOps.push_back(linalgOp); + }); + // Try to fuse all producers. + OpBuilder b(funcOp.getContext()); + SmallVector> tileLoopDims(tileLoops.begin(), + tileLoops.end()); + for (auto tiledOp : tiledOps) { + for (OpOperand *consumerOpOperand : tiledOp.getInputAndOutputOperands()) { + auto notifyFailure = [&](StringRef message) { + llvm::errs() << " - LinalgFusionOnTensors (" << tiledOp->getName() + << " operand #" << consumerOpOperand->getOperandNumber() + << "): " << message << "\n"; + }; + FailureOr fusionState = + isProducerFusable(consumerOpOperand, tileLoopDims, notifyFailure); + if (failed(fusionState)) + continue; + fuseProducer(b, fusionState.getValue()); + } + } + } +}; + +} // namespace + +std::unique_ptr> mlir::createLinalgFuseTensorOpsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -663,5 +663,28 @@ return tiledShapes; } +void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp, + ArrayRef ivs) { + if (tiledOp.hasIndexSemantics()) { + assert(tiledOp->getNumRegions() == 1 && + tiledOp->getRegion(0).getBlocks().size() == 1 && + "expect producer to have one block."); + // Shift all index op results by the tile offset. + Block &block = tiledOp->getRegion(0).front(); + for (IndexOp indexOp : block.getOps()) { + if (ivs[indexOp.dim()] == nullptr) + continue; + OpBuilder::InsertionGuard g(b); + b.setInsertionPointAfter(indexOp); + AffineExpr index, offset; + bindDims(b.getContext(), index, offset); + AffineApplyOp applyOp = b.create( + indexOp.getLoc(), index + offset, + ValueRange{indexOp.getResult(), ivs[indexOp.dim()]}); + indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); + } + } +} + } // namespace linalg } // namespace mlir diff --git a/mlir/test/Dialect/Linalg/fusion-on-tensors.mlir b/mlir/test/Dialect/Linalg/fusion-on-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-on-tensors.mlir @@ -0,0 +1,471 @@ +// RUN: mlir-opt %s -linalg-fuse-tensor-ops="tile-loops=0,1,2,2" -split-input-file --cse | FileCheck %s + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 24)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 12)> +#map = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_input +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +builtin.func @fuse_input(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %1 = scf.for %arg3 = %c0 to %c24 step %c4 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) { + + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %2 = scf.for %arg5 = %c0 to %c25 step %c4 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + %3 = affine.min #map(%arg5) + + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %4 = scf.for %arg7 = %c0 to %c12 step %c4 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) { + + // Tile both fill output operand dimensions. + // CHECK: %[[UB0:.*]] = affine.min #[[MAP0]](%[[IV0]]) + // CHECK: %[[UB1:.*]] = affine.min #[[MAP1]](%[[IV2]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV0]], %[[IV2]] + // CHECK-SAME: %[[UB0]], %[[UB1]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: %[[T2:.*]] = tensor.cast %[[T1]] : tensor + %5 = tensor.extract_slice %0[%arg3, %arg7] [4, 4] [1, 1] : tensor<24x12xf32> to tensor<4x4xf32> + %6 = tensor.extract_slice %arg1[%arg7, %arg5] [4, %3] [1, 1] : tensor<12x25xf32> to tensor<4x?xf32> + %7 = tensor.extract_slice %arg8[%arg3, %arg5] [4, %3] [1, 1] : tensor<24x25xf32> to tensor<4x?xf32> + + // CHECK: %{{.*}} = linalg.matmul ins(%[[T2]] + %8 = linalg.matmul ins(%5, %6 : tensor<4x4xf32>, tensor<4x?xf32>) outs(%7 : tensor<4x?xf32>) -> tensor<4x?xf32> + %9 = tensor.insert_slice %8 into %arg8[%arg3, %arg5] [4, %3] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> + scf.yield %9 : tensor<24x25xf32> + } + scf.yield %4 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1, -d0 + 25)> +#map = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_input_2d_tiling +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +builtin.func @fuse_input_2d_tiling(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg1) : f32, tensor<12x25xf32> -> tensor<12x25xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %1 = scf.for %arg3 = %c0 to %c24 step %c4 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) { + %2 = tensor.extract_slice %arg0[%arg3, 0] [4, 12] [1, 1] : tensor<24x12xf32> to tensor<4x12xf32> + + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %3 = scf.for %arg5 = %c0 to %c25 step %c4 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + %4 = affine.min #map(%arg5) + + // Tile the second fill output operand dimension taking into account the + // domain size is not an integer mutiple of the step. + // CHECK: %[[UB1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[TS1]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: 0, %[[IV1]] + // CHECK-SAME: 12, %[[UB1]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + %5 = tensor.extract_slice %0[0, %arg5] [12, %4] [1, 1] : tensor<12x25xf32> to tensor<12x?xf32> + %6 = tensor.extract_slice %arg6[%arg3, %arg5] [4, %4] [1, 1] : tensor<24x25xf32> to tensor<4x?xf32> + + // CHECK: %{{.*}} = linalg.matmul ins({{.*}}, %[[T1]] + %7 = linalg.matmul ins(%2, %5 : tensor<4x12xf32>, tensor<12x?xf32>) outs(%6 : tensor<4x?xf32>) -> tensor<4x?xf32> + %8 = tensor.insert_slice %7 into %arg6[%arg3, %arg5] [4, %4] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> + scf.yield %8 : tensor<24x25xf32> + } + scf.yield %3 : tensor<24x25xf32> + } + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)> +#map = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_output +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +builtin.func @fuse_output(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + + // Verify the iteration argument is updated and the extract slice reused. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%{{.*}} = %[[ARG2]] + %1 = scf.for %arg3 = %c0 to %c24 step %c4 iter_args(%arg4 = %0) -> (tensor<24x25xf32>) { + + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = + %2 = scf.for %arg5 = %c0 to %c25 step %c4 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + %3 = affine.min #map(%arg5) + + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK-SAME: 4, %[[TS1]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + %4 = tensor.extract_slice %arg6[%arg3, %arg5] [4, %3] [1, 1] : tensor<24x25xf32> to tensor<4x?xf32> + + // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[T1]] + %5 = scf.for %arg7 = %c0 to %c12 step %c4 iter_args(%arg8 = %4) -> (tensor<4x?xf32>) { + %7 = tensor.extract_slice %arg0[%arg3, %arg7] [4, 4] [1, 1] : tensor<24x12xf32> to tensor<4x4xf32> + %8 = tensor.extract_slice %arg1[%arg7, %arg5] [4, %3] [1, 1] : tensor<12x25xf32> to tensor<4x?xf32> + + // CHECK: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG4]] + %9 = linalg.matmul ins(%7, %8 : tensor<4x4xf32>, tensor<4x?xf32>) outs(%arg8 : tensor<4x?xf32>) -> tensor<4x?xf32> + scf.yield %9 : tensor<4x?xf32> + } + %6 = tensor.insert_slice %5 into %arg6[%arg3, %arg5] [4, %3] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> + scf.yield %6 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 12)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, -d0 + 25)> +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_reduction +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32> +builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %arg3: tensor<12x7x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg3 : tensor<12x7x25xf32>) outs(%arg1 : tensor<12x25xf32>) { + ^bb0(%arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor<12x25xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %1 = scf.for %arg4 = %c0 to %c24 step %c4 iter_args(%arg5 = %arg2) -> (tensor<24x25xf32>) { + + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %2 = scf.for %arg6 = %c0 to %c25 step %c4 iter_args(%arg7 = %arg5) -> (tensor<24x25xf32>) { + + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + %3 = affine.min #map2(%arg6) + + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %4 = scf.for %arg8 = %c0 to %c12 step %c4 iter_args(%arg9 = %arg7) -> (tensor<24x25xf32>) { + %5 = tensor.extract_slice %arg0[%arg4, %arg8] [4, 4] [1, 1] : tensor<24x12xf32> to tensor<4x4xf32> + + // CHECK: %[[UB0:.*]] = affine.min #[[MAP1]](%[[IV2]]) + // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[IV1]], %[[TS1]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV2]], 0, %[[IV1]] + // CHECK-SAME: %[[UB0]], 7, %[[UB1]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[IV2]], %[[IV1]] + // CHECK-SAME: %[[UB0]], %[[UB1]] + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to + %6 = tensor.extract_slice %0[%arg8, %arg6] [4, %3] [1, 1] : tensor<12x25xf32> to tensor<4x?xf32> + %7 = tensor.extract_slice %arg9[%arg4, %arg6] [4, %3] [1, 1] : tensor<24x25xf32> to tensor<4x?xf32> + + // CHECK: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T3]] + %8 = linalg.matmul ins(%5, %6 : tensor<4x4xf32>, tensor<4x?xf32>) outs(%7 : tensor<4x?xf32>) -> tensor<4x?xf32> + %9 = tensor.insert_slice %8 into %arg9[%arg4, %arg6] [4, %3] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> + scf.yield %9 : tensor<24x25xf32> + } + scf.yield %4 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + return %1 : tensor<24x25xf32> +} + +// ----- + +#map = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_output_not_fusable +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +builtin.func @fuse_output_not_fusable(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.0 : f32 + + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[ARG2]]) + %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + + // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[T1]] + %1 = scf.for %arg3 = %c0 to %c24 step %c4 iter_args(%arg4 = %0) -> (tensor<24x25xf32>) { + + // CHECK: scf.for + %2 = scf.for %arg5 = %c0 to %c25 step %c4 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + %3 = affine.min #map(%arg5) + + // Cannot fuse producer in place of the slice op inside the reduction loop. + // CHECK: scf.for + // CHECK-NOT: linalg.fill + %4 = scf.for %arg7 = %c0 to %c12 step %c4 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) { + %5 = tensor.extract_slice %arg0[%arg3, %arg7] [4, 4] [1, 1] : tensor<24x12xf32> to tensor<4x4xf32> + %6 = tensor.extract_slice %arg1[%arg7, %arg5] [4, %3] [1, 1] : tensor<12x25xf32> to tensor<4x?xf32> + %7 = tensor.extract_slice %arg8[%arg3, %arg5] [4, %3] [1, 1] : tensor<24x25xf32> to tensor<4x?xf32> + %8 = linalg.matmul ins(%5, %6 : tensor<4x4xf32>, tensor<4x?xf32>) outs(%7 : tensor<4x?xf32>) -> tensor<4x?xf32> + %9 = tensor.insert_slice %8 into %arg8[%arg3, %arg5] [4, %3] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> + scf.yield %9 : tensor<24x25xf32> + } + scf.yield %4 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +#map0 = affine_map<(d0)[s0] -> (4, -d0 + s0)> +#map1 = affine_map<(d0, d1) -> (4, d0 - d1)> + +// CHECK: fuse_twisted_dynamic +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor +builtin.func @fuse_twisted_dynamic(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> tensor { + + // CHECK-DAG: %[[C0:.*]] = constant 0 + // CHECK-DAG: %[[C1:.*]] = constant 1 + %c1 = constant 1 : index + %c0 = constant 0 : index + %c4 = constant 4 : index + + // CHECK: %[[T0:.*]] = linalg.generic + // CHECK-DAG: %[[DIM0_T0:.*]] = tensor.dim %[[T0]], %[[C0]] + // CHECK-DAG: %[[DIM1_T0:.*]] = tensor.dim %[[T0]], %[[C1]] + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg1 : tensor) { + ^bb0(%arg3: f32): // no predecessors + linalg.yield %arg3 : f32 + } -> tensor + %1 = tensor.dim %arg0, %c0 : tensor + %2 = tensor.dim %arg0, %c1 : tensor + %3 = tensor.dim %0, %c1 : tensor + %4 = tensor.dim %0, %c0 : tensor + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %5 = scf.for %arg3 = %c0 to %1 step %c4 iter_args(%arg4 = %arg2) -> (tensor) { + %6 = affine.min #map0(%arg3)[%1] + + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %7 = scf.for %arg5 = %c0 to %3 step %c4 iter_args(%arg6 = %arg4) -> (tensor) { + + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])[%[[DIM1_T0]]] + %8 = affine.min #map0(%arg5)[%3] + + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %9 = scf.for %arg7 = %c0 to %2 step %c4 iter_args(%arg8 = %arg6) -> (tensor) { + %10 = affine.min #map0(%arg7)[%2] + %11 = tensor.extract_slice %arg0[%arg3, %arg7] [%6, %10] [1, 1] : tensor to tensor + + // CHECK: %[[TS2:.*]] = affine.min #[[MAP0]](%[[IV2]])[%[[DIM0_T0]]] + %12 = affine.min #map0(%arg7)[%4] + + // Tile the producer output operand and take into account the domain + // size is dynamic and may not be an integer multiple of the step. + // CHECK: %[[DIM0_ARG1:.*]] = tensor.dim %[[ARG1]], %[[C0]] + // CHECK: %[[UB0:.*]] = affine.min #[[MAP1]](%[[IV2]], %[[TS2]])[%[[DIM0_ARG1]]] + // CHECK: %[[DIM1_ARG1:.*]] = tensor.dim %[[ARG1]], %[[C1]] + // CHECK: %[[UB1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[TS1]])[%[[DIM1_ARG1]]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[IV2]], %[[IV1]] + // CHECK-SAME: %[[UB0]], %[[UB1]] + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} outs(%[[T1]] + %13 = tensor.extract_slice %0[%arg7, %arg5] [%12, %8] [1, 1] : tensor to tensor + %14 = tensor.dim %arg8, %c0 : tensor + %15 = affine.min #map1(%14, %arg3) + %16 = tensor.dim %arg8, %c1 : tensor + %17 = affine.min #map1(%16, %arg5) + %18 = tensor.extract_slice %arg8[%arg3, %arg5] [%15, %17] [1, 1] : tensor to tensor + + // CHECK: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]] + %19 = linalg.matmul ins(%11, %13 : tensor, tensor) outs(%18 : tensor) -> tensor + %20 = tensor.insert_slice %19 into %arg8[%arg3, %arg5] [%15, %17] [1, 1] : tensor into tensor + scf.yield %20 : tensor + } + scf.yield %9 : tensor + } + scf.yield %7 : tensor + } + return %5 : tensor +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)> +#map2 = affine_map<(d0, d1) -> (4, d0 - d1)> + +// CHECK: fuse_index_dynamic +builtin.func @fuse_index_dynamic(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> tensor { + %c1 = constant 1 : index + %c0 = constant 0 : index + %c4 = constant 4 : index + %0 = linalg.generic { + indexing_maps = [#map0], + iterator_types = ["parallel", "parallel"]} + outs(%arg1 : tensor) { + ^bb0(%arg3: i32): // no predecessors + %6 = linalg.index 0 : index + %7 = linalg.index 1 : index + %8 = addi %6, %7 : index + %9 = index_cast %8 : index to i32 + linalg.yield %9 : i32 + } -> tensor + %1 = tensor.dim %arg0, %c0 : tensor + %2 = tensor.dim %0, %c1 : tensor + %3 = tensor.dim %arg0, %c1 : tensor + %4 = tensor.dim %0, %c0 : tensor + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %5 = scf.for %arg3 = %c0 to %1 step %c4 iter_args(%arg4 = %arg2) -> (tensor) { + %6 = affine.min #map1(%arg3)[%1] + %7 = tensor.extract_slice %arg0[%arg3, 0] [%6, %3] [1, 1] : tensor to tensor + + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %8 = scf.for %arg5 = %c0 to %2 step %c4 iter_args(%arg6 = %arg4) -> (tensor) { + %9 = affine.min #map1(%arg5)[%2] + %10 = tensor.extract_slice %0[0, %arg5] [%4, %9] [1, 1] : tensor to tensor + %11 = tensor.dim %arg6, %c0 : tensor + %12 = affine.min #map2(%11, %arg3) + %13 = tensor.dim %arg6, %c1 : tensor + %14 = affine.min #map2(%13, %arg5) + + // Shift only the first dimension since only the second dimension of the + // producer output is tiled and since the index map is twisted. + // CHECK: linalg.generic + // CHECK: %[[IDX0:.*]] = linalg.index 0 + // CHECK: %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV1]]) + // CHECK: %[[IDX1:.*]] = linalg.index 1 + // CHECK: %{{.*}} = addi %[[IDX0_SHIFTED]], %[[IDX1]] + %15 = tensor.extract_slice %arg6[%arg3, %arg5] [%12, %14] [1, 1] : tensor to tensor + %16 = linalg.matmul ins(%7, %10 : tensor, tensor) outs(%15 : tensor) -> tensor + %17 = tensor.insert_slice %16 into %arg6[%arg3, %arg5] [%12, %14] [1, 1] : tensor into tensor + scf.yield %17 : tensor + } + scf.yield %8 : tensor + } + return %5 : tensor +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 24)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 12)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (2, d0 - d1)> +#map = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_input_4d_tiling +builtin.func @fuse_input_4d_tiling(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c4 = constant 4 : index + %c2 = constant 2 : index + %cst = constant 0.000000e+00 : f32 + %c24 = constant 24 : index + %c25 = constant 25 : index + %c12 = constant 12 : index + %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %1 = scf.for %arg3 = %c0 to %c24 step %c4 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) { + + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %2 = scf.for %arg5 = %c0 to %c25 step %c4 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + %3 = affine.min #map(%arg5) + + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %4 = scf.for %arg7 = %c0 to %c12 step %c4 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) { + + // Tile both fill output operand dimensions. + // CHECK: %[[UB0:.*]] = affine.min #[[MAP0]](%[[IV0]]) + // CHECK: %[[UB1:.*]] = affine.min #[[MAP1]](%[[IV2]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV0]], %[[IV2]] + // CHECK-SAME: %[[UB0]], %[[UB1]] + %5 = tensor.extract_slice %0[%arg3, %arg7] [4, 4] [1, 1] : tensor<24x12xf32> to tensor<4x4xf32> + %6 = tensor.extract_slice %arg1[%arg7, %arg5] [4, %3] [1, 1] : tensor<12x25xf32> to tensor<4x?xf32> + %7 = tensor.extract_slice %arg8[%arg3, %arg5] [4, %3] [1, 1] : tensor<24x25xf32> to tensor<4x?xf32> + + // CHECK: scf.for %[[IV3:[0-9a-zA-Z]*]] = + %8 = scf.for %arg9 = %c0 to %c4 step %c2 iter_args(%arg10 = %7) -> (tensor<4x?xf32>) { + + // Tile only the second fill output operand dimension + // CHECK: %[[UB3:.*]] = affine.min #[[MAP2]](%[[UB1]], %[[IV3]]) + + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[T0]] + // CHECK-SAME: 0, %[[IV3]] + // CHECK-SAME: %[[UB0]], %[[UB3]] + // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to + %10 = tensor.extract_slice %5[0, %arg9] [4, 2] [1, 1] : tensor<4x4xf32> to tensor<4x2xf32> + %11 = tensor.extract_slice %6[%arg9, 0] [2, %3] [1, 1] : tensor<4x?xf32> to tensor<2x?xf32> + + // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]] + %12 = linalg.matmul ins(%10, %11 : tensor<4x2xf32>, tensor<2x?xf32>) outs(%arg10 : tensor<4x?xf32>) -> tensor<4x?xf32> + scf.yield %12 : tensor<4x?xf32> + } + %9 = tensor.insert_slice %8 into %arg8[%arg3, %arg5] [4, %3] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> + scf.yield %9 : tensor<24x25xf32> + } + scf.yield %4 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + return %1 : tensor<24x25xf32> +}