diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -86,7 +86,7 @@ "ArrayRef attrs = {}", [{ auto reassociationMaps = convertReassociationIndicesToMaps(b, reassociation); - build(b, result, src, reassociationMaps, attrs); + build(b, result, resultType, src, reassociationMaps, attrs); }]> ]; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -19,7 +19,8 @@ std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); std::unique_ptr> createLinalgFusionPass(); -std::unique_ptr createLinalgFusionOfTensorOpsPass(); +std::unique_ptr +createLinalgFusionOfTensorOpsPass(bool useReshapeFusionByExpansion = false); std::unique_ptr> createLinalgTilingPass(ArrayRef tileSizes = {}); @@ -50,8 +51,9 @@ createConvertLinalgOnTensorsToBuffersPass(); /// Patterns for fusing linalg operation on tensors. -void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, - OwningRewritePatternList &patterns); +void populateLinalgTensorOpsFusionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + bool useReshapeFusionByExpansion = false); /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors. diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -91,9 +91,10 @@ /// Fuse linalg operation on tensors, with the producer of the operand at /// position `consumerIdx` of the consumer. -Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, - unsigned consumerIdx, - OperationFolder *folder = nullptr); +Optional> +fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, + unsigned consumerIdx, OperationFolder *folder = nullptr, + bool useReshapeFusionByExpansion = false); /// Returns the linearized list of all shape dimensions in a `linalgOp`. /// Applying the inverse, concatenated loopToOperandRangeMaps to this list diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -28,7 +28,7 @@ /// Implementation of fusion of generic ops and indexed_generic ops. struct FuseGenericOpsOnTensors { - static bool isFusible(LinalgOp producer, LinalgOp consumer, + static bool isFusable(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx) { // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) @@ -51,20 +51,20 @@ return producerResultIndexMap.isPermutation(); } - static LinalgOp fuse(LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { - if (!isFusible(producer, consumer, consumerIdx)) - return nullptr; + static Optional> + fuse(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, + PatternRewriter &rewriter, OperationFolder *folder = nullptr) { + if (!isFusable(producer, consumer, consumerIdx)) + return llvm::None; - unsigned numFusedOperands = producer.getOperation()->getNumOperands() + - consumer.getOperation()->getNumOperands() - 1; + unsigned numFusedOperands = + producer.getNumInputs() + consumer.getNumInputs() - 1; // Compute the fused operands list, SmallVector fusedOperands; fusedOperands.reserve(numFusedOperands); - auto consumerOperands = consumer.getOperation()->getOperands(); - auto producerOperands = producer.getOperation()->getOperands(); + auto consumerOperands = consumer.getInputs(); + auto producerOperands = producer.getInputs(); fusedOperands.assign(consumerOperands.begin(), std::next(consumerOperands.begin(), consumerIdx)); fusedOperands.append(producerOperands.begin(), producerOperands.end()); @@ -77,8 +77,7 @@ // indexing_map of the operand at consumerIdx in the consumer. SmallVector fusedIndexMaps; auto consumerIndexMaps = consumer.indexing_maps(); - fusedIndexMaps.reserve(fusedOperands.size() + - consumer.getOperation()->getNumResults()); + fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs()); fusedIndexMaps.assign(consumerIndexMaps.begin(), std::next(consumerIndexMaps.begin(), consumerIdx)); // Compute indexing maps for the producer args in the fused operation. @@ -141,7 +140,7 @@ generateFusedRegion(rewriter, fusedOp, producer, consumer, consumerToProducerLoopsMap, consumerIdx, consumer.getNumLoops()); - return fusedOp; + return SmallVector(fusedOp.getOperation()->getResults()); } private: @@ -313,18 +312,21 @@ /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is /// true) or its producer (if `asProducer` is false) given the indexing map at /// its use. -static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp, - AffineMap useIndexMap, bool asProducer) { +static bool isTensorReshapeOpFusableByLinearization(TensorReshapeOp reshapeOp, + AffineMap useIndexMap, + bool asProducer) { RankedTensorType returnType = reshapeOp.getResultType(); RankedTensorType operandType = reshapeOp.getSrcType(); - // Reshape is fusible with its consumer (i.e. reshape as a producer) when its + // Reshape is fusable with its consumer (i.e. reshape as a producer) when its // operand is of lesser rank than the result. Fusing when operand has higher // rank will require use of mods and divs in the indexing maps of the fused op // which would make it non-invertible. Similarly reshape is fused with its // producer (i.e. reshape as consumer) only if the return type has lesser // rank. - if ((asProducer && returnType.getRank() < operandType.getRank()) || - (!asProducer && operandType.getRank() < returnType.getRank())) + if ((asProducer && reshapeOp.getSrcType().hasStaticShape() && + returnType.getRank() < operandType.getRank()) || + (!asProducer && reshapeOp.getResultType().hasStaticShape() && + operandType.getRank() < returnType.getRank())) return false; return useIndexMap.isPermutation(); } @@ -346,31 +348,261 @@ return nullptr; } +/// Conditions for fusing a generic/indexed-generic operation with a reshape op +/// by expanding the iteration space dimensionality. These are preconditions +/// assumed by `fusWithReshapeByDimExpansion` which implements the following +/// fusion pattern. +/// +/// Consider +/// +/// %c = linalg.generic ins(%a, %b : memref, memref) +/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, +/// affine_map<(d0, d1, d2) -> (d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] +/// %d = linalg.tensor_reshape %c +/// [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, +/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, +/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] +/// : tensor into tensor +/// +/// The reshape can be folded into the `linalgOp` if the +/// generic/indexed-generic op loop dimensionality is increased to match the +/// result (operand) of the tensor_reshape when the reshape is expanding +/// (folding). The indexing_map of the fused tensor in the `linalgOp` and the +/// reassociation map helps compute the indexing maps of the modified op. For +/// the above example, based on the reassociation map it can be concluded that +/// +/// - The loop used to access the first dimension of the fused tensor is split +/// into two. +/// - The loop used to access the second dimension of the fused tensor is kept +/// as is. +/// - The loop used to access the third dimension of the fused tensor is split +/// into three. +/// +/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified +/// op, then +/// +/// d0 -> e0, e1 +/// d1 -> e2, e3, e4 +/// d2 -> e5 +/// +/// substituting this, the generic op can be rewritten as +/// +/// %d = linalg.generic ins(%0, %1 : ) +/// indexing_maps = +/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, +/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, +/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] +/// +/// Since operands to the linalg generic are now 5D, reshapes can be introduced +/// to make it consistent +/// +/// %0 = linalg.tensor_reshape %a +/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2), +/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4), +/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)] +/// : tensor into tensor +/// %1 = linalg.tensor_reshape %b +/// [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2), +/// affine_map<(e0, e1, e2, e3) -> (e3)] +/// : tensor into tensor +/// +/// The added reshapes are again expanding patterns, so they will get fused +/// with its producers if possible. +static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, + unsigned fusedTensorIndex) { + // Is fusable only if: + // - The linalgOp is a generic op. + // - The linalgOp has tensor semantics with a single output. + // - All the indexing maps for operands in linalgOp are projected + // permutations. + // - The indexing map at the position representing the fused tensor is a + // permutation. + // - All the loops in linalgOp are parallel loops. + return isa(linalgOp.getOperation()) && + linalgOp.hasTensorSemantics() && linalgOp.getNumOutputs() == 1 && + llvm::all_of(linalgOp.indexing_maps().getValue().take_front( + linalgOp.getNumInputs()), + [](Attribute attr) { + return attr.cast() + .getValue() + .isProjectedPermutation(); + }) && + linalgOp.getIndexingMap(fusedTensorIndex).isPermutation() && + llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) { + return attr.cast().getValue() == + getParallelIteratorTypeName(); + }); +} + +static Optional> +fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, + unsigned fusedTensorIndex, PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { + // Check if reshape is expanding or collapsing. + bool isExpanding = + reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); + RankedTensorType expandedType = + isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); + RankedTensorType foldedType = + isExpanding ? reshapeOp.getSrcType() : reshapeOp.getResultType(); + AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); + + // The reshape is folding/expanding consecutive dimensions. Given the indexing + // map of the fused tensor find the number of dimensions each of the loops of + // the original op is expanded into. Also record the shape of the expanded + // dimensions. + ArrayRef expandedShape = expandedType.getShape(); + SmallVector numFoldedDims(foldedType.getRank(), 0); + SmallVector, 4> expandedDimsShape( + expandedType.getRank()); + auto reassociationMaps = reshapeOp.getReassociationMaps(); + for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { + unsigned pos = resultExpr.value().cast().getPosition(); + AffineMap foldedDims = reassociationMaps[resultExpr.index()]; + numFoldedDims[pos] = foldedDims.getNumResults(); + ArrayRef shape = expandedShape.slice( + foldedDims.getResult(0).cast().getPosition(), + numFoldedDims[pos]); + expandedDimsShape[pos].assign(shape.begin(), shape.end()); + } + + // The remapping of the indices is then the prefix sum (inclusive) of the + // numFoldedDims. + SmallVector remapping(numFoldedDims.size() + 1, 0); + unsigned sum = 0; + for (auto numFoldedDim : llvm::enumerate(numFoldedDims)) { + sum += numFoldedDim.value(); + remapping[numFoldedDim.index() + 1] = sum; + } + + SmallVector expandedOpIndexingMaps; + // Compute the modified indexing maps by replacing every loop (AffineDimExpr) + // in the original indexing map with the sequence of loops that it is expanded + // to. + for (AffineMap indexingMap : linalgOp.getIndexingMaps()) { + SmallVector newExprs; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned pos = expr.cast().getPosition(); + for (unsigned newPos : + llvm::seq(remapping[pos], remapping[pos + 1])) { + newExprs.push_back(rewriter.getAffineDimExpr(newPos)); + } + } + expandedOpIndexingMaps.push_back( + AffineMap::get(remapping.back(), indexingMap.getNumSymbols(), newExprs, + rewriter.getContext())); + } + + // The operands of the expanded op are computed by reshaping the original + // operands. The reshape depends on the ordering of the loop used to access + // the tensor in the original operation, and are expanded into as many + // dimensions as the loop is expanded into (as computed by `remapping`). + auto getReshapeInfo = + [&](AffineMap operandIndexingMap, + SmallVectorImpl &reassociation, + SmallVectorImpl &expandedOpOperandShape) { + unsigned reshapeDims = 0; + for (AffineExpr expr : operandIndexingMap.getResults()) { + unsigned origDim = expr.cast().getPosition(); + auto foldedDims = llvm::seq( + reshapeDims, reshapeDims + numFoldedDims[origDim]); + reassociation.emplace_back(foldedDims.begin(), foldedDims.end()); + expandedOpOperandShape.append(expandedDimsShape[origDim].begin(), + expandedDimsShape[origDim].end()); + reshapeDims += numFoldedDims[origDim]; + } + }; + SmallVector expandedOpOperands; + for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + if (operand.index() == fusedTensorIndex) { + expandedOpOperands.push_back(reshapeOp.src()); + continue; + } + AffineMap indexingMap = linalgOp.getIndexingMap(operand.index()); + SmallVector reassociation; + SmallVector expandedOperandShape; + getReshapeInfo(indexingMap, reassociation, expandedOperandShape); + Type expandedOperandType = RankedTensorType::get( + expandedOperandShape, + operand.value().getType().cast().getElementType()); + if (expandedOperandType != operand.value().getType()) { + expandedOpOperands.push_back(rewriter.create( + linalgOp.getLoc(), expandedOperandType, operand.value(), + reassociation)); + } + } + SmallVector resultTypes; + SmallVector, 1> resultReassociation; + for (auto result : llvm::enumerate(linalgOp.getOperation()->getResults())) { + AffineMap indexingMap = + linalgOp.getIndexingMap(linalgOp.getNumInputs() + result.index()); + SmallVector reassociation; + SmallVector expandedResultShape; + getReshapeInfo(indexingMap, reassociation, expandedResultShape); + resultTypes.push_back(RankedTensorType::get( + expandedResultShape, + result.value().getType().cast().getElementType())); + resultReassociation.emplace_back(std::move(reassociation)); + } + + // The iterator types of the expanded op are all parallel. + SmallVector iteratorTypes(remapping.back(), + getParallelIteratorTypeName()); + + LinalgOp fusedOp = createLinalgOpOfSameType( + linalgOp, rewriter, linalgOp.getLoc(), resultTypes, + /*inputs=*/expandedOpOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, expandedOpIndexingMaps, iteratorTypes); + Region &fusedRegion = fusedOp.getOperation()->getRegion(0); + // TODO: Add support for indexed generic op, which would need mapping the + // expanded dimensions to the original dimension arguments. + rewriter.cloneRegionBefore(linalgOp.getOperation()->getRegion(0), fusedRegion, + fusedRegion.begin()); + + // Reshape the result values to their original shape if this is a collapsing + // reshape folded into its consumer. + SmallVector resultVals; + for (auto result : llvm::enumerate(linalgOp.getOperation()->getResults())) { + if (!isExpanding && + resultTypes[result.index()] != result.value().getType()) { + resultVals.push_back(rewriter.create( + linalgOp.getLoc(), result.value().getType(), + fusedOp.getOperation()->getResult(result.index()), + resultReassociation[result.index()])); + } else { + resultVals.push_back(fusedOp.getOperation()->getResult(result.index())); + } + } + // Assuming a single result. + return resultVals; +} + namespace { /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp. -struct FuseTensorReshapeOpAsProducer { - static bool isFusible(TensorReshapeOp producer, LinalgOp consumer, +struct FuseTensorReshapeOpAsProducerByLinearization { + static bool isFusable(TensorReshapeOp producer, LinalgOp consumer, unsigned consumerIdx) { return isa(consumer.getOperation()) && consumer.hasTensorSemantics() && - isTensorReshapeOpFusible(producer, - consumer.getInputIndexingMap(consumerIdx), - /*asProducer=*/true); + isTensorReshapeOpFusableByLinearization( + producer, consumer.getInputIndexingMap(consumerIdx), + /*asProducer=*/true); } - static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { + static Optional> + fuse(TensorReshapeOp producer, LinalgOp consumer, unsigned consumerIdx, + PatternRewriter &rewriter, OperationFolder *folder = nullptr) { if (producer.src().getDefiningOp()) - return nullptr; + return llvm::None; - if (!isFusible(producer, consumer, consumerIdx)) - return nullptr; + if (!isFusable(producer, consumer, consumerIdx)) + return llvm::None; // Compute the fused operands list, - Operation *consumerOp = consumer.getOperation(); - SmallVector fusedOperands(consumerOp->getOperands()); + SmallVector fusedOperands(consumer.getInputs()); fusedOperands[consumerIdx] = producer.src(); // Compute indexing_maps for the fused operation. The indexing_maps for the @@ -390,14 +622,14 @@ producer.getReassociationMaps()); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) - return nullptr; + return llvm::None; } fusedIndexMaps[consumerIdx] = modifiedMap; // Further check that the resulting index maps can be fused and // inverted. Without this the resultant op is not legal. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) - return nullptr; + return llvm::None; SmallVector indexMapAttrs = llvm::to_vector<4>( llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { @@ -405,7 +637,7 @@ })); LinalgOp fusedOp = createLinalgOpOfSameType( consumer, rewriter, rewriter.getUnknownLoc(), - consumerOp->getResultTypes(), + consumer.getOperation()->getResultTypes(), /*inputs=*/fusedOperands, /*outputBuffers=*/ValueRange{}, /*initTensors=*/ValueRange{}, // no init tensors for now. @@ -414,27 +646,48 @@ /*library_call=*/nullptr, /*symbol_source=*/nullptr); auto &fusedRegion = fusedOp.getOperation()->getRegion(0); - rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion, - fusedRegion.begin()); - return fusedOp; + rewriter.cloneRegionBefore(consumer.getOperation()->getRegion(0), + fusedRegion, fusedRegion.begin()); + return SmallVector(fusedOp.getOperation()->getResults()); } }; -/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. -struct FuseTensorReshapeOpAsConsumer { - static bool isCollapsingAndFusible(LinalgOp producer, - TensorReshapeOp consumer, - unsigned consumerIdx) { +struct FuseTensorReshapeOpAsProducerByExpansion { + static bool isFusable(TensorReshapeOp producer, LinalgOp consumer, + unsigned consumerIdx) { + // Fuse only if + // - The tensor reshape op is folding. + // - All constraints of fusing with reshape by expansion are met. + return producer.getSrcType().getRank() > + producer.getResultType().getRank() && + isFusableWithReshapeByDimExpansion(consumer, consumerIdx); + } + + static Optional> + fuse(TensorReshapeOp producer, LinalgOp consumer, unsigned consumerIdx, + PatternRewriter &rewriter, OperationFolder *folder = nullptr) { + if (!isFusable(producer, consumer, consumerIdx)) + return llvm::None; + return fuseWithReshapeByExpansion(consumer, producer, consumerIdx, + rewriter); + } +}; + +struct FuseTensorReshapeOpAsConsumerByLinearization { + static bool isFusable(LinalgOp producer, TensorReshapeOp consumer, + unsigned consumerIdx) { return isa(producer.getOperation()) && producer.hasTensorSemantics() && - isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), - /*asProducer=*/false); + isTensorReshapeOpFusableByLinearization( + consumer, producer.getOutputIndexingMap(0), + /*asProducer=*/false); } - static LinalgOp fuseCollapsingCase(LinalgOp producer, - TensorReshapeOp consumer, - unsigned consumerIdx, - PatternRewriter &rewriter) { + static Optional> + fuse(LinalgOp producer, TensorReshapeOp consumer, unsigned consumerIdx, + PatternRewriter &rewriter, OperationFolder *folder = nullptr) { + if (!isFusable(producer, consumer, consumerIdx)) + return llvm::None; // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. SmallVector fusedIndexMaps = @@ -451,24 +704,23 @@ consumer.getReassociationMaps()); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) - return nullptr; + return llvm::None; } fusedIndexMaps.back() = modifiedMap; // Further check that the resulting index maps can be fused and // inverted. Without this the resultant op is not legal. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) - return nullptr; + return llvm::None; SmallVector indexMapAttrs = llvm::to_vector<4>( llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })); - Operation *producerOp = producer.getOperation(); LinalgOp fusedOp = createLinalgOpOfSameType( producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(), - /*inputs=*/producerOp->getOperands(), + /*inputs=*/producer.getInputs(), /*outputBuffers=*/ValueRange{}, /*initTensors=*/ValueRange{}, // no init tensors for now. rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(), @@ -476,86 +728,40 @@ /*library_call=*/nullptr, /*symbol_source=*/nullptr); auto &fusedRegion = fusedOp.getOperation()->getRegion(0); - rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion, - fusedRegion.begin()); - return fusedOp; + rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), + fusedRegion, fusedRegion.begin()); + return SmallVector(fusedOp.getOperation()->getResults()); } +}; - static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx) { - // Is fusible only if: - // 1) The producer is a generic op. - // 2) The producer has tensor semantics. - // 3) The tensor reshape op is a expanding case. - // 4) All the shapes are the same for the generic op. - // 5) All the indexing maps in producer are identity. - // 6) All the loops in producer are parallel loops. - // 7) The producer has a single user. - auto types = producer.getInputOutputShapedTypes(); - assert(!types.empty()); - return isa(producer.getOperation()) && - producer.hasTensorSemantics() && - consumer.getSrcType().getRank() < +/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. +struct FuseTensorReshapeOpAsConsumerByExpansion { + static bool isFusable(LinalgOp producer, TensorReshapeOp consumer, + unsigned consumerIdx) { + // Fuse only if + // - The tensor reshape op is a expanding case. + // - The producer has a single user. + // - All constraints of fusing with reshape by expansion are met. + return consumer.getSrcType().getRank() < consumer.getResultType().getRank() && - std::equal(types.begin() + 1, types.end(), types.begin()) && - llvm::all_of(producer.getIndexingMaps(), - [](AffineMap map) { return map.isIdentity(); }) && - llvm::all_of(producer.iterator_types(), - [](Attribute attr) { - return attr.cast().getValue() == - getParallelIteratorTypeName(); - }) && - producer.getOperation()->hasOneUse(); + producer.getOperation()->hasOneUse() && + isFusableWithReshapeByDimExpansion(producer, + producer.getNumInputs()); } - static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx, - PatternRewriter &rewriter) { - Location loc = producer.getLoc(); - auto dstShape = consumer.getResultType().cast().getShape(); - SmallVector args; - for (auto arg : producer.getOperation()->getOperands()) { - auto type = RankedTensorType::get( - dstShape, arg.getType().cast().getElementType()); - args.push_back(rewriter.createOrFold( - loc, type, arg, consumer.reassociation())); - } - - SmallVector resultTypes; - for (auto t : producer.getOutputTensorTypes()) { - Type type = RankedTensorType::get(dstShape, - t.cast().getElementType()); - resultTypes.push_back(type); - } - - int rank = dstShape.size(); - auto genericOp = rewriter.create( - loc, resultTypes, /*inputs=*/args, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, - SmallVector(args.size() + resultTypes.size(), - rewriter.getMultiDimIdentityMap(rank)), - SmallVector(rank, getParallelIteratorTypeName())); - Region ®ion = genericOp.getRegion(); - rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region, - region.begin()); - return cast(genericOp.getOperation()); - } - - static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { - if (isCollapsingAndFusible(producer, consumer, consumerIdx)) - return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter); - if (isExpandingAndFusible(producer, consumer, consumerIdx)) - return fuseExpandingCase(producer, consumer, consumerIdx, rewriter); - return nullptr; + static Optional> + fuse(LinalgOp producer, TensorReshapeOp consumer, unsigned consumerIdx, + PatternRewriter &rewriter, OperationFolder *folder = nullptr) { + if (!isFusable(producer, consumer, consumerIdx)) + return llvm::None; + return fuseWithReshapeByExpansion( + producer, consumer, producer.getNumInputs(), rewriter, folder); } }; /// Implementation of fusion on tensor ops when producer is a splat constant. struct FuseConstantOpAsProducer { - static bool isFusible(ConstantOp producer, LinalgOp consumer, + static bool isFusable(ConstantOp producer, LinalgOp consumer, unsigned consumerIdx) { return isa(consumer.getOperation()) && consumer.hasTensorSemantics() && @@ -563,11 +769,11 @@ producer.value().cast().isSplat(); } - static LinalgOp fuse(ConstantOp producer, LinalgOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { - if (!isFusible(producer, consumer, consumerIdx)) - return nullptr; + static Optional> + fuse(ConstantOp producer, LinalgOp consumer, unsigned consumerIdx, + PatternRewriter &rewriter, OperationFolder *folder = nullptr) { + if (!isFusable(producer, consumer, consumerIdx)) + return llvm::None; // The indexing_maps for the operands of the fused operation are same as // those for the operands of the consumer without the indexing map at @@ -581,8 +787,7 @@ // The operands list is same as the consumer with the argument for constant // index dropped. - Operation *consumerOp = consumer.getOperation(); - SmallVector fusedOperands(consumerOp->getOperands()); + SmallVector fusedOperands(consumer.getInputs()); fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx)); // Create a constant scalar value from the splat constant. @@ -592,7 +797,7 @@ LinalgOp fusedOp = createLinalgOpOfSameType( consumer, rewriter, rewriter.getUnknownLoc(), - consumerOp->getResultTypes(), + consumer.getOperation()->getResultTypes(), /*inputs=*/fusedOperands, /*outputBuffers=*/ValueRange{}, /*initTensors=*/ValueRange{}, // no init tensors for now. @@ -604,29 +809,29 @@ // Map the block argument corresponding to the replaced argument with the // scalar constant. - Region &consumerRegion = consumerOp->getRegion(0); + Region &consumerRegion = consumer.getOperation()->getRegion(0); Block &entryBlock = *consumerRegion.begin(); - unsigned argIndex = entryBlock.getNumArguments() - - consumerOp->getNumOperands() + consumerIdx; + unsigned argIndex = + entryBlock.getNumArguments() - consumer.getNumInputs() + consumerIdx; BlockAndValueMapping mapping; mapping.map(entryBlock.getArgument(argIndex), scalarConstant); Region &fusedRegion = fusedOp.getOperation()->getRegion(0); rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(), mapping); - return fusedOp; + return SmallVector(fusedOp.getOperation()->getResults()); } }; } // namespace -Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, - Operation *consumer, - unsigned consumerIdx, - OperationFolder *folder) { +Optional> +mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, + unsigned consumerIdx, OperationFolder *folder, + bool useReshapeFusionByExpansion) { if (consumerIdx >= consumer->getNumOperands()) - return nullptr; + return llvm::None; Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); if (!producer || producer->getNumResults() != 1) - return nullptr; + return llvm::None; // Fuse when consumer is GenericOp or IndexedGenericOp. if (isa(consumer)) { @@ -634,33 +839,47 @@ return FuseGenericOpsOnTensors::fuse(cast(producer), cast(consumer), consumerIdx, rewriter, folder); - if (auto reshapeOpProducer = dyn_cast(producer)) - return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer, - cast(consumer), - consumerIdx, rewriter, folder); + if (auto reshapeOpProducer = dyn_cast(producer)) { + if (useReshapeFusionByExpansion) { + return FuseTensorReshapeOpAsProducerByExpansion::fuse( + reshapeOpProducer, cast(consumer), consumerIdx, rewriter, + folder); + } else { + return FuseTensorReshapeOpAsProducerByLinearization::fuse( + reshapeOpProducer, cast(consumer), consumerIdx, rewriter, + folder); + } + } if (auto constantOpProducer = dyn_cast(producer)) return FuseConstantOpAsProducer::fuse(constantOpProducer, cast(consumer), consumerIdx, rewriter, folder); - return nullptr; + return llvm::None; } if (isa(producer)) { // Fuse when consumer is a TensorReshapeOp. if (TensorReshapeOp reshapeOp = dyn_cast(consumer)) { - return FuseTensorReshapeOpAsConsumer::fuse( - cast(producer), reshapeOp, consumerIdx, rewriter, folder); + if (useReshapeFusionByExpansion) { + return FuseTensorReshapeOpAsConsumerByExpansion::fuse( + cast(producer), reshapeOp, consumerIdx, rewriter, folder); + } else { + return FuseTensorReshapeOpAsConsumerByLinearization::fuse( + cast(producer), reshapeOp, consumerIdx, rewriter, folder); + } } } - return nullptr; + return llvm::None; } namespace { /// Patterns to fuse a generic op, with the producer of its operands. template struct FuseTensorOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FuseTensorOps(MLIRContext *context, bool useReshapeFusionByExpansion = false) + : OpRewritePattern::OpRewritePattern(context), + useReshapeFusionByExpansion(useReshapeFusionByExpansion) {} LogicalResult matchAndRewrite(LinalgOpTy op, PatternRewriter &rewriter) const override { @@ -669,8 +888,10 @@ llvm::seq(0, op.getOperation()->getNumOperands())) { Operation *producer = op.getOperation()->getOperand(operandNum).getDefiningOp(); - if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) { - rewriter.replaceOp(op, fusedOp->getResults()); + Optional> fusedOpResults = fuseTensorOps( + rewriter, op, operandNum, nullptr, useReshapeFusionByExpansion); + if (fusedOpResults) { + rewriter.replaceOp(op, *fusedOpResults); if (producer && llvm::all_of(producer->getResults(), [](Value val) { return val.use_empty(); })) rewriter.eraseOp(producer); @@ -679,26 +900,44 @@ } return failure(); } + +private: + bool useReshapeFusionByExpansion; }; /// Pass that fuses generic ops on tensors. Used only for testing. struct FusionOfTensorOpsPass : public LinalgFusionOfTensorOpsBase { + FusionOfTensorOpsPass() = default; + FusionOfTensorOpsPass(const FusionOfTensorOpsPass &) {} + FusionOfTensorOpsPass(bool vUseReshapeFusionByExpansion) { + this->useReshapeFusionByExpansion = vUseReshapeFusionByExpansion; + } + void runOnOperation() override { OwningRewritePatternList patterns; Operation *op = getOperation(); - populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); + populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns, + useReshapeFusionByExpansion); applyPatternsAndFoldGreedily(op->getRegions(), patterns); }; + + Option useReshapeFusionByExpansion{ + *this, "use-reshape-fusion-by-expansion", llvm::cl::init(true), + llvm::cl::desc("Enable use of reshape fusion with producer/consumer by " + "expanding the dimensionality of the producer/consumer")}; }; } // namespace void mlir::populateLinalgTensorOpsFusionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns) { + MLIRContext *context, OwningRewritePatternList &patterns, + bool useReshapeFusionByExpansion) { patterns.insert, FuseTensorOps, - FuseTensorOps>(context); + FuseTensorOps>(context, + useReshapeFusionByExpansion); } -std::unique_ptr mlir::createLinalgFusionOfTensorOpsPass() { - return std::make_unique(); +std::unique_ptr +mlir::createLinalgFusionOfTensorOpsPass(bool useReshapeFusionByExpansion) { + return std::make_unique(useReshapeFusionByExpansion); } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -226,40 +226,6 @@ // ----- -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d2)> - -func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>) - -> tensor<8x33x4xf32> { - %cst = constant dense<2.000000e+00> : tensor<264x4xf32> - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) { - ^bb0(%arg1: f32, %arg2: f32): // no predecessors - %2 = mulf %arg1, %arg2 : f32 - linalg.yield %2 : f32 - } -> tensor<264x4xf32> - %1 = linalg.tensor_reshape %0 [#map1, #map2] : - tensor<264x4xf32> into tensor<8x33x4xf32> - return %1 : tensor<8x33x4xf32> -} - -// The reshape op in `%arg0` is folded into the indexing map of generic op. -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0 * 33 + d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @generic_op_reshape_consumer_expanding -// CHECK-NOT: linalg.tensor_reshape -// CHECK: %[[CST:.*]] = constant {{.*}} : f32 -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: tensor<264x4xf32> -// CHECK: -> tensor<8x33x4xf32> -// CHECK-NOT: linalg.tensor_reshape - -// ----- - #map0 = affine_map<(d0, d1, d2) -> (d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -0,0 +1,163 @@ +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops=use-reshape-fusion-by-expansion -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +func @generic_op_reshape_producer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + return %1 : tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)> +// CHECK: func @generic_op_reshape_producer_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T1:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]] +// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[T0]] : tensor, tensor) +// CHECK: %[[T2:.+]] = linalg.tensor_reshape +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor into tensor +// CHECK: return %[[T2]] + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @generic_op_reshape_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func @generic_op_reshape_consumer_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T2:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] +// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) +// CHECK: return %[[T2]] : tensor + + +// ----- + +func @reshape_as_consumer_permutation + (%a : tensor, %b : tensor) + -> tensor { + %c = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%a, %b : tensor, tensor) { + ^bb0(%arg0 : f32, %arg1: f32): + %1 = addf %arg0, %arg1 : f32 + linalg.yield %1 : f32 + } -> tensor + %d = linalg.tensor_reshape %c + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] + : tensor into tensor + return %d : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK: func @reshape_as_consumer_permutation +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP3]], #[[MAP4]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T2:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]] +// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) +// CHECK: return %[[T2]] : tensor + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d2)> + +func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) + -> tensor<8x33x4xf32> { + %cst = constant dense<2.000000e+00> : tensor<264x4xf32> + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) { + ^bb0(%arg1: f32, %arg2: f32): // no predecessors + %2 = mulf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + } -> tensor<264x4xf32> + %1 = linalg.tensor_reshape %0 [#map1, #map2] : + tensor<264x4xf32> into tensor<8x33x4xf32> + return %1 : tensor<8x33x4xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: func @generic_op_reshape_consumer_static +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32> +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32> +// CHECK: %[[T1:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>) +// CHECK: return %[[T1]] : tensor<8x33x4xf32>