diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -56,6 +56,7 @@ LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op); using ReassociationIndices = SmallVector; +using ReassociationIndicesRef = ArrayRef; using ReassociationExprs = SmallVector; /// Returns the name mangled library call name to disambiguate between different diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -431,22 +431,160 @@ }); } -// Get the output tensor to use for the expanded operation. Creates an -// `linalg.init_tensor` operation to materialize the tensor that carries the -// shape information. -static Value getOutputValueForExpansion( - OpBuilder &builder, Location loc, AffineMap outputIndexingMap, Value result, - ArrayRef> origDimToExpandedShapeMap) { +namespace { +/// Information needed to expand a generic/indexed_generic operation to fold the +/// reshape with it. +class ExpansionInfo { +public: + // Computes the mapping from original dimensions of the op to the dimensions + // of the expanded op given the `indexingMap` of the fused operand/result of + // the generic/indexed_generic op, the `reassocationMaps` of the reshape op + // and the shape of the expanded op. + LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex, + ArrayRef reassociationMaps, + ArrayRef expandedShape); + unsigned getOrigOpNumDims() const { return reassociation.size(); } + unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } + ReassociationIndicesRef getExpandedDims(unsigned i) const { + return reassociation[i]; + } + ArrayRef getExpandedShapeOfDim(unsigned i) const { + return expandedShapeMap[i]; + } + +private: + /// Reassociation from the dimensions in the original operation to the + /// dimension of the expanded operation. + SmallVector reassociation; + /// Mapping from extent of loops in the original operation, to the extent of + /// loops in the expanded operation. + SmallVector, 4> expandedShapeMap; + unsigned expandedOpNumDims; +}; +} // namespace + +LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, + unsigned fusedTensorIndex, + ArrayRef reassociationMaps, + ArrayRef expandedShape) { + if (reassociationMaps.empty()) + return failure(); + AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); + + Optional> originalLoopRange = + getStaticLoopRanges(linalgOp); + if (!originalLoopRange) + return linalgOp.emitError("unable to find loop range for operation"); + + reassociation.clear(); + expandedShapeMap.clear(); + // Compute the number of dimension in the expanded op that correspond to each + // dimension of the original op. + SmallVector numExpandedDims(fusedIndexMap.getNumDims(), 1); + expandedShapeMap.resize(fusedIndexMap.getNumDims()); + for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { + unsigned pos = resultExpr.value().cast().getPosition(); + AffineMap foldedDims = reassociationMaps[resultExpr.index()]; + numExpandedDims[pos] = foldedDims.getNumResults(); + ArrayRef shape = + expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); + expandedShapeMap[pos].assign(shape.begin(), shape.end()); + } + // The remaining dimensions remain the same. + for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) + if (expandedShapeMap[i].empty()) + expandedShapeMap[i] = {(*originalLoopRange)[i]}; + + // Compute reassociation map from the original op to the expanded op. + unsigned sum = 0; + reassociation.reserve(fusedIndexMap.getNumDims()); + for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) { + auto seq = llvm::seq(sum, sum + numFoldedDim.value()); + reassociation.emplace_back(seq.begin(), seq.end()); + sum += numFoldedDim.value(); + } + expandedOpNumDims = sum; + return success(); +} + +/// To expand an indexed_generic operation, the body of the indexed generic op +/// need to be modified appropriately. Specifically, uses of arguments for +/// induction variables in the original operation need to be replaced with +/// linearization of the corresponding arguments in the expanded op. That +/// requires the shape of the expanded dimensions (at least all but the most +/// significant. For now check that these are all statically sized. Note that +/// this could be extended to handle dynamic case, but the implementation below +/// uses `affine.apply` which seems to have issues when the shapes are not +/// static. +LogicalResult isIndexedGenericOpExpandable(LinalgOp linalgOp, + const ExpansionInfo &expansionInfo) { + for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { + ArrayRef expandedShape = expansionInfo.getExpandedShapeOfDim(i); + if (expandedShape.size() == 1) + continue; + for (int64_t shape : expandedShape.drop_front()) { + if (ShapedType::isDynamic(shape)) { + return linalgOp.emitError( + "unable to fuse indexed generic op where the expanded dim is " + "dynamic"); + } + } + } + return success(); +} + +/// Return the indexing map to use in the expanded op for a given the +/// `indexingMap` of the original operation. +static AffineMap +getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, + const ExpansionInfo &expansionInfo) { + SmallVector newExprs; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned pos = expr.cast().getPosition(); + SmallVector expandedExprs = llvm::to_vector<4>( + llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { + return builder.getAffineDimExpr(static_cast(v)); + })); + newExprs.append(expandedExprs.begin(), expandedExprs.end()); + } + return AffineMap::get(expansionInfo.getExpandedOpNumDims(), + indexingMap.getNumSymbols(), newExprs, + builder.getContext()); +} + +/// Return the type of the operand/result to use in the expanded op given the +/// type in the original op. +static RankedTensorType getExpandedType(RankedTensorType originalType, + AffineMap indexingMap, + const ExpansionInfo &expansionInfo) { + SmallVector expandedShape; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned dim = expr.cast().getPosition(); + auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); + expandedShape.append(dimExpansion.begin(), dimExpansion.end()); + } + return RankedTensorType::get(expandedShape, originalType.getElementType()); +} + +/// Get the value to use for the output in the expanded operation given the +/// `indexingMap` for the output in the original op. Creates an +/// `linalg.init_tensor` operation to materialize the tensor that carries the +/// shape information. This is only used when the tensor_reshape is expanding +/// and is a consumer. In such cases, the tensor_reshape op semantics gaurantees +/// that the shape of the output is computable from the shape of the input since +/// at most one of the expanded dims can be dynamic. +static Value getOutputValueForExpandedOp(OpBuilder &builder, Location loc, + AffineMap indexingMap, Value result, + const ExpansionInfo &expansionInfo) { SmallVector dynamicDims; SmallVector staticDims; ShapedType resultType = result.getType().cast(); ArrayRef origShape = resultType.getShape(); - for (AffineExpr expr : outputIndexingMap.getResults()) { + for (AffineExpr expr : indexingMap.getResults()) { unsigned origDimPos = expr.cast().getPosition(); - ArrayRef expandedShape(origDimToExpandedShapeMap[origDimPos]); bool foundDynamic = false; int64_t linearizedShape = 1; - for (int64_t extent : expandedShape) { + for (int64_t extent : expansionInfo.getExpandedShapeOfDim(origDimPos)) { if (ShapedType::isDynamic(extent)) { assert(!foundDynamic && "Expanded dimensions of reshape can have only one dynamic dim"); @@ -467,6 +605,79 @@ resultType.getElementType()); } +/// Returns the reassociation maps to use in the `linalg.tensor_reshape` +/// operation to convert the operands of the origial operation to operands of +/// the expanded operation. The same method is used to compute the +/// `linalg.tensor_reshape` used to collapse the result of the expanded op to +/// get the value that can replace all uses of the results of the original op. +static SmallVector +getReassociationForExpansion(AffineMap indexingMap, + const ExpansionInfo &expansionInfo) { + SmallVector reassociation; + unsigned numReshapeDims = 0; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned dim = expr.cast().getPosition(); + auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); + auto indices = llvm::to_vector<2>( + llvm::seq(numReshapeDims, numReshapeDims + numExpandedDims)); + reassociation.emplace_back(std::move(indices)); + numReshapeDims += numExpandedDims; + } + return reassociation; +} + +/// Build the body of the expanded IndexedGenericOp. The arguments for the +/// induction variables of the original operation need to be recovered by +/// linearizing the arguments of the corresponding dimensions of the expanded +/// op. For now it is assumed that the shapes of the expanded op needed for +/// linearization are static. +static void buildExpandedIndexedGenericOpRegion( + PatternRewriter &rewriter, Location loc, Region &originalOpRegion, + Region &fusedOpRegion, const ExpansionInfo &expansionInfo) { + assert(fusedOpRegion.empty() && "expected fused op to have empty region"); + // Create an entry block in the fused region with same number of arguments + // as the fused op + Block *fusedEntryBlock = new Block; + fusedOpRegion.push_back(fusedEntryBlock); + rewriter.cloneRegionBefore(originalOpRegion, fusedOpRegion, + fusedOpRegion.end()); + + // Merge the entry block of the fused op with the cloned blocks. For this + // compute the value for arguments of the region in the original operation + // in terms of the arguments of the fused op. Since the original operation + // is expanded, the expanded dimensions need to be folded back to get the + // replacement value for the arguments corresponding to interation index. + // For now this expects that all the loop ranges are constants, which is + // true if the shapes are all static. This has already been checked in the + // precondition. + using namespace edsc::op; + using namespace edsc::intrinsics; + OpBuilder::InsertionGuard guard(rewriter); + SmallVector argReplacements(originalOpRegion.getNumArguments()); + rewriter.setInsertionPointToStart(fusedEntryBlock); + edsc::ScopedContext scopedContext(rewriter, loc); + IndexType indexType = rewriter.getIndexType(); + for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { + Value linearizedIndex = fusedEntryBlock->addArgument(indexType); + ArrayRef expandedDimsShape = + expansionInfo.getExpandedShapeOfDim(i).drop_front(); + for (unsigned shape : expandedDimsShape) { + assert(!ShapedType::isDynamic(shape)); + linearizedIndex = linearizedIndex * std_constant_index(shape); + linearizedIndex = + linearizedIndex + fusedEntryBlock->addArgument(indexType); + } + argReplacements[i] = linearizedIndex; + } + for (auto i : llvm::seq(expansionInfo.getOrigOpNumDims(), + argReplacements.size())) { + argReplacements[i] = + fusedEntryBlock->addArgument(originalOpRegion.getArgument(i).getType()); + } + rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock, + argReplacements); +} + /// Implements the fusion of a tensor_reshape op and a generic/indexed_generic /// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those /// conditions have been satisfied. @@ -481,104 +692,22 @@ reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); RankedTensorType expandedType = isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); - AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); - // The reshape is folding/expanding consecutive dimensions. Given the indexing - // map of the fused tensor find the number of dimensions each of the loops of - // the original op is expanded into. Also record the shape of the expanded - // dimensions. - ArrayRef expandedShape = expandedType.getShape(); - Optional> origOpLoopRange = - getStaticLoopRanges(linalgOp); - if (!origOpLoopRange) { - linalgOp.emitError("unable to find loop range for operation"); + ExpansionInfo expansionInfo; + if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex, + reshapeOp.getReassociationMaps(), + expandedType.getShape()))) return llvm::None; - } - SmallVector numFoldedDims(fusedIndexMap.getNumDims(), 1); - SmallVector, 4> expandedDimsShape( - fusedIndexMap.getNumDims()); - auto reassociationMaps = reshapeOp.getReassociationMaps(); - for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { - unsigned pos = resultExpr.value().cast().getPosition(); - AffineMap foldedDims = reassociationMaps[resultExpr.index()]; - numFoldedDims[pos] = foldedDims.getNumResults(); - ArrayRef shape = - expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]); - expandedDimsShape[pos].assign(shape.begin(), shape.end()); - } - // The remaining dimensions remain the same. - for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) - if (expandedDimsShape[i].empty()) - expandedDimsShape[i] = {(*origOpLoopRange)[i]}; - - if (isa(linalgOp.getOperation())) { - // For indexed generic op, the region contains arguments that represent the - // induction variable value of the loops. In the fused op these values are - // obtained by linearizing the expanded dimensions. For now just check that - // the extents used in the linearization (all the expanded dims except the - // front) are statically know. For dynamic case, we would need shape - // information on these dimensions to get these. - for (auto &expandedShape : expandedDimsShape) { - if (expandedShape.size() == 1) - continue; - for (int64_t expandedDimShape : llvm::make_range( - std::next(expandedShape.begin()), expandedShape.end())) { - if (ShapedType::isDynamic(expandedDimShape)) { - linalgOp.emitError( - "unable to fuse indexed generic op where the expanded dim is " - "dynamic"); - return llvm::None; - } - } - } - } - // The remapping of the indices is then the prefix sum (inclusive) of the - // numFoldedDims. - SmallVector remapping(numFoldedDims.size() + 1, 0); - unsigned sum = 0; - for (auto numFoldedDim : llvm::enumerate(numFoldedDims)) { - sum += numFoldedDim.value(); - remapping[numFoldedDim.index() + 1] = sum; - } + if (isa(linalgOp.getOperation()) && + failed(isIndexedGenericOpExpandable(linalgOp, expansionInfo))) + return llvm::None; - SmallVector expandedOpIndexingMaps; - // Compute the modified indexing maps by replacing every loop (AffineDimExpr) - // in the original indexing map with the sequence of loops that it is expanded - // to. - for (AffineMap indexingMap : linalgOp.getIndexingMaps()) { - SmallVector newExprs; - for (AffineExpr expr : indexingMap.getResults()) { - unsigned pos = expr.cast().getPosition(); - for (unsigned newPos : - llvm::seq(remapping[pos], remapping[pos + 1])) { - newExprs.push_back(rewriter.getAffineDimExpr(newPos)); - } - } - expandedOpIndexingMaps.push_back( - AffineMap::get(remapping.back(), indexingMap.getNumSymbols(), newExprs, - rewriter.getContext())); - } + SmallVector expandedOpIndexingMaps = llvm::to_vector<4>( + llvm::map_range(linalgOp.getIndexingMaps(), [&](AffineMap m) { + return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); + })); - // The operands of the expanded op are computed by reshaping the original - // operands. The reshape depends on the ordering of the loop used to access - // the tensor in the original operation, and are expanded into as many - // dimensions as the loop is expanded into (as computed by `remapping`). - auto getReshapeInfo = - [&](AffineMap operandIndexingMap, - SmallVectorImpl &reassociation, - SmallVectorImpl &expandedOpOperandShape) { - unsigned reshapeDims = 0; - for (AffineExpr expr : operandIndexingMap.getResults()) { - unsigned origDim = expr.cast().getPosition(); - auto foldedDims = llvm::seq( - reshapeDims, reshapeDims + numFoldedDims[origDim]); - reassociation.emplace_back(foldedDims.begin(), foldedDims.end()); - expandedOpOperandShape.append(expandedDimsShape[origDim].begin(), - expandedDimsShape[origDim].end()); - reshapeDims += numFoldedDims[origDim]; - } - }; SmallVector expandedOpOperands; for (auto operand : llvm::enumerate(linalgOp.getInputs())) { if (operand.index() == fusedTensorIndex) { @@ -586,36 +715,31 @@ continue; } AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index()); - SmallVector reassociation; - SmallVector expandedOperandShape; - getReshapeInfo(indexingMap, reassociation, expandedOperandShape); - Type expandedOperandType = RankedTensorType::get( - expandedOperandShape, - operand.value().getType().cast().getElementType()); + RankedTensorType expandedOperandType = + getExpandedType(operand.value().getType().cast(), + indexingMap, expansionInfo); if (expandedOperandType != operand.value().getType()) { + // Reshape the operand to get the right type. + SmallVector reassociation = + getReassociationForExpansion(indexingMap, expansionInfo); expandedOpOperands.push_back(rewriter.create( linalgOp.getLoc(), expandedOperandType, operand.value(), reassociation)); - } else { - expandedOpOperands.push_back(operand.value()); + continue; } + expandedOpOperands.push_back(operand.value()); } Location loc = linalgOp.getLoc(); SmallVector outputs; - SmallVector, 1> resultReassociation; for (auto result : llvm::enumerate(linalgOp.getOutputs())) { AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index()); - SmallVector reassociation; - SmallVector expandedResultShape; - getReshapeInfo(indexingMap, reassociation, expandedResultShape); - outputs.push_back(getOutputValueForExpansion( - rewriter, loc, indexingMap, result.value(), expandedDimsShape)); - resultReassociation.emplace_back(std::move(reassociation)); + outputs.push_back(getOutputValueForExpandedOp( + rewriter, loc, indexingMap, result.value(), expansionInfo)); } // The iterator types of the expanded op are all parallel. - SmallVector iteratorTypes(remapping.back(), + SmallVector iteratorTypes(expansionInfo.getExpandedOpNumDims(), getParallelIteratorTypeName()); TypeRange resultTypes = ValueRange(outputs).getTypes(); @@ -631,48 +755,8 @@ fusedRegion.begin()); } else { assert(isa(linalgOp.getOperation())); - // Create an entry block in the fused Region with same number of arguments - // as the fused op - Block *fusedEntryBlock = new Block; - fusedRegion.push_back(fusedEntryBlock); - rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.end()); - - // Merge the entry block of the fused op with the cloned blocks. For this - // compute the value for arguments of the region in the original operation - // in terms of the arguments of the fused op. Since the original operation - // is expanded, the expanded dimensions need to be folded back to get the - // replacement value for the arguments corresponding to interation index. - // For now this expects that all the loop ranges are constants, which is - // true if the shapes are all static. This has already been checked in the - // precondition. - using namespace edsc::op; - using namespace edsc::intrinsics; - OpBuilder::InsertionGuard guard(rewriter); - SmallVector argReplacements(originalRegion.getNumArguments()); - rewriter.setInsertionPointToStart(fusedEntryBlock); - edsc::ScopedContext scopedContext(rewriter, fusedOp.getLoc()); - IndexType indexType = rewriter.getIndexType(); - for (unsigned i : llvm::seq(0, numFoldedDims.size())) { - Value linearizedIndex = fusedEntryBlock->addArgument(indexType); - for (unsigned foldedDim = remapping[i] + 1; foldedDim != remapping[i + 1]; - foldedDim++) { - int64_t expandedDimExtent = - expandedDimsShape[i][foldedDim - remapping[i]]; - assert(!ShapedType::isDynamic(expandedDimExtent)); - linearizedIndex = - linearizedIndex * std_constant_index(expandedDimExtent); - linearizedIndex = - linearizedIndex + fusedEntryBlock->addArgument(indexType); - } - argReplacements[i] = linearizedIndex; - } - for (unsigned i : - llvm::seq(numFoldedDims.size(), argReplacements.size())) { - argReplacements[i] = - fusedEntryBlock->addArgument(originalRegion.getArgument(i).getType()); - } - rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock, - argReplacements); + buildExpandedIndexedGenericOpRegion(rewriter, loc, originalRegion, + fusedRegion, expansionInfo); } // Reshape the result values to their original shape if this is a collapsing @@ -681,10 +765,12 @@ for (auto result : llvm::enumerate(linalgOp->getResults())) { if (!isExpanding && resultTypes[result.index()] != result.value().getType()) { + SmallVector reassociation = + getReassociationForExpansion( + linalgOp.getOutputIndexingMap(result.index()), expansionInfo); resultVals.push_back(rewriter.create( linalgOp.getLoc(), result.value().getType(), - fusedOp->getResult(result.index()), - resultReassociation[result.index()])); + fusedOp->getResult(result.index()), reassociation)); } else { resultVals.push_back(fusedOp->getResult(result.index())); }