diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -19,8 +19,8 @@ std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); std::unique_ptr> createLinalgFusionPass(); -std::unique_ptr -createLinalgFusionOfTensorOpsPass(bool useReshapeFusionByExpansion = false); +std::unique_ptr createLinalgFusionOfTensorOpsPass(); +std::unique_ptr createFoldReshapeOpsByLinearizationPass(); std::unique_ptr> createLinalgTilingPass(ArrayRef tileSizes = {}); @@ -50,10 +50,22 @@ std::unique_ptr> createConvertLinalgOnTensorsToBuffersPass(); +/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its +/// producer (consumer) generic operation by expanding the dimensionality of the +/// loop in the generic op. +void populateFoldReshapeOpsByExpansionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + +/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its +/// producer (consumer) generic/indexed_generic operation by linearizing the +/// indexing map used to access the source (target) of the reshape operation in +/// the generic/indexed_generic operation. +void populateFoldReshapeOpsByLinearizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + /// Patterns for fusing linalg operation on tensors. -void populateLinalgTensorOpsFusionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns, - bool useReshapeFusionByExpansion = false); +void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, + OwningRewritePatternList &patterns); /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -35,6 +35,14 @@ let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } +def LinalgFoldReshapeOpsByLinearization : + Pass<"linalg-fold-reshape-ops-by-linearization"> { + let summary = "Fold TensorReshapeOps with generic/indexed generic ops by " + "linearization"; + let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()"; + let dependentDialects = ["AffineDialect"]; +} + def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -93,8 +93,7 @@ /// position `consumerIdx` of the consumer. Optional> fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, - unsigned consumerIdx, OperationFolder *folder = nullptr, - bool useReshapeFusionByExpansion = false); + unsigned consumerIdx, OperationFolder *folder = nullptr); /// Returns the linearized list of all shape dimensions in a `linalgOp`. /// Applying the inverse, concatenated loopToOperandRangeMaps to this list diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -514,9 +514,8 @@ return success(); } // Check if producer and consumer are both collapsing dims. - else if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), - reshapeOp.getSrcType(), - reshapeOp.getResultType())) { + if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), reshapeOp.getSrcType(), + reshapeOp.getResultType())) { rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(), collapseReassociationMaps(srcReshapeOp.getReassociationMaps(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -312,9 +312,9 @@ /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is /// true) or its producer (if `asProducer` is false) given the indexing map at /// its use. -static bool isTensorReshapeOpFusableByLinearization(TensorReshapeOp reshapeOp, - AffineMap useIndexMap, - bool asProducer) { +static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp, + AffineMap useIndexMap, + bool asProducer) { RankedTensorType returnType = reshapeOp.getResultType(); RankedTensorType operandType = reshapeOp.getSrcType(); // Reshape is fusable with its consumer (i.e. reshape as a producer) when its @@ -348,10 +348,10 @@ return nullptr; } -/// Conditions for fusing a generic/indexed-generic operation with a reshape op -/// by expanding the iteration space dimensionality. These are preconditions -/// assumed by `fusWithReshapeByDimExpansion` which implements the following -/// fusion pattern. +/// Conditions for folding a generic/indexed-generic operation with a reshape op +/// by expanding the iteration space dimensionality for tensor operations. These +/// are preconditions assumed by `foldReshapeByDimExpansion` which implements +/// the following fusion pattern. /// /// Consider /// @@ -413,14 +413,13 @@ unsigned fusedTensorIndex) { // Is fusable only if: // - The linalgOp is a generic op. - // - The linalgOp has tensor semantics with a single output. // - All the indexing maps for operands in linalgOp are projected // permutations. // - The indexing map at the position representing the fused tensor is a // permutation. // - All the loops in linalgOp are parallel loops. return isa(linalgOp.getOperation()) && - linalgOp.hasTensorSemantics() && linalgOp.getNumOutputs() == 1 && + linalgOp.hasTensorSemantics() && llvm::all_of(linalgOp.indexing_maps().getValue().take_front( linalgOp.getNumInputs()), [](Attribute attr) { @@ -435,10 +434,15 @@ }); } +/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic +/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those +/// conditions have been satisfied. static Optional> fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, unsigned fusedTensorIndex, PatternRewriter &rewriter, OperationFolder *folder = nullptr) { + assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) && + "preconditions for fuse operation failed"); // Check if reshape is expanding or collapsing. bool isExpanding = reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); @@ -583,250 +587,281 @@ namespace { -/// Implementation of fusion on tensor ops when producer is a TensorReshapeOp. -struct FuseTensorReshapeOpAsProducerByLinearization { - static bool isFusable(TensorReshapeOp producer, LinalgOp consumer, - unsigned consumerIdx) { - return isa(consumer.getOperation()) && - consumer.hasTensorSemantics() && - isTensorReshapeOpFusableByLinearization( - producer, consumer.getInputIndexingMap(consumerIdx), - /*asProducer=*/true); - } - - static Optional> - fuse(TensorReshapeOp producer, LinalgOp consumer, unsigned consumerIdx, - PatternRewriter &rewriter, OperationFolder *folder = nullptr) { - if (producer.src().getDefiningOp()) - return llvm::None; - - if (!isFusable(producer, consumer, consumerIdx)) - return llvm::None; - - // Compute the fused operands list, - SmallVector fusedOperands(consumer.getInputs()); - fusedOperands[consumerIdx] = producer.src(); - - // Compute indexing_maps for the fused operation. The indexing_maps for the - // operands of the consumers that arent fused are the same. - SmallVector fusedIndexMaps = - llvm::to_vector<4>(llvm::map_range( - consumer.indexing_maps(), [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); +/// Pattern to fold tensor_reshape op with its consumer by using the source of +/// the reshape op as the operand in the consumer (instead of the result of the +/// tensor_reshapeop) when the tensor_reshape op is collapsing. The +/// corresponding index map in the consumer needs to be modified to linearize +/// the folded dimension. +/// +/// For example, +/// +/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// %0 = linalg.tensor_reshape %arg0 +/// [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>, +/// affine_map<(i, j, k, l) -> (l)>] +/// tensor into tensor +/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } +/// ins(%0, %arg1 : tensor, tensor) ... +/// -> tensor +/// +/// can be folded into +/// +/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } +/// ins(%arg0, %arg1 : tensor, tensor) ... +/// -> tensor +template +struct FoldProducerReshapeOpByLinearization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - // Accepted consumer maps are either identity or permutation. - auto invMap = inversePermutation(fusedIndexMaps[consumerIdx]); + LogicalResult matchAndRewrite(LinalgOpTy op, + PatternRewriter &rewriter) const override { + if (!op.hasTensorSemantics()) + return failure(); + LinalgOp linalgOp = cast(op.getOperation()); + for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + TensorReshapeOp reshapeOp = + operand.value().getDefiningOp(); + if (!reshapeOp || + !isTensorReshapeOpFoldableByLinearization( + reshapeOp, linalgOp.getInputIndexingMap(operand.index()), + /*asProducer =*/true)) + continue; - // Compute the indexing map to use for the operand of the producer. - AffineMap modifiedMap = - linearizeCollapsedDims(invMap, producer.getResultType().getShape(), - producer.getReassociationMaps()); - for (AffineExpr expr : modifiedMap.getResults()) { - if (!expr.isPureAffine()) - return llvm::None; + // Compute the fused operands list, + SmallVector fusedOperands(linalgOp.getInputs()); + fusedOperands[operand.index()] = reshapeOp.src(); + + // Compute indexing_maps for the fused operation. The indexing_maps for + // the operands of the consumers that arent fused are the same. + SmallVector fusedIndexMaps = llvm::to_vector<4>( + op.indexing_maps().template getAsValueRange()); + + // Accepted consumer maps are either identity or permutation. + auto invMap = inversePermutation(fusedIndexMaps[operand.index()]); + + // Compute the indexing map to use for the result of the producer. + AffineMap modifiedMap = + linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps()); + for (AffineExpr expr : modifiedMap.getResults()) { + if (!expr.isPureAffine()) + return failure(); + } + fusedIndexMaps[operand.index()] = modifiedMap; + + // Further check that the resulting index maps can be fused and + // inverted. Without this the resultant op is not legal. + if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) + return op.emitRemark("fused op loop bound computation failed"); + + rewriter.startRootUpdate(op); + op.getOperation()->setOperands(fusedOperands); + op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps)); + rewriter.finalizeRootUpdate(op); + if (reshapeOp.use_empty()) + rewriter.eraseOp(reshapeOp); + return success(); } - fusedIndexMaps[consumerIdx] = modifiedMap; - - // Further check that the resulting index maps can be fused and - // inverted. Without this the resultant op is not legal. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) - return llvm::None; - - SmallVector indexMapAttrs = llvm::to_vector<4>( - llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { - return AffineMapAttr::get(map); - })); - LinalgOp fusedOp = createLinalgOpOfSameType( - consumer, rewriter, rewriter.getUnknownLoc(), - consumer.getOperation()->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, // no init tensors for now. - rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr); - auto &fusedRegion = fusedOp.getOperation()->getRegion(0); - rewriter.cloneRegionBefore(consumer.getOperation()->getRegion(0), - fusedRegion, fusedRegion.begin()); - return SmallVector(fusedOp.getOperation()->getResults()); + return op.emitRemark("no fusion candidates found"); } }; -struct FuseTensorReshapeOpAsProducerByExpansion { - static bool isFusable(TensorReshapeOp producer, LinalgOp consumer, - unsigned consumerIdx) { - // Fuse only if - // - The tensor reshape op is folding. - // - All constraints of fusing with reshape by expansion are met. - return producer.getSrcType().getRank() > - producer.getResultType().getRank() && - isFusableWithReshapeByDimExpansion(consumer, consumerIdx); - } +/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the +/// reshape op is collapsing dimensions. The dimensionality of the loop in the +/// consumer generic op is expanded. +struct FoldWithProducerReshapeOpByExpansion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - static Optional> - fuse(TensorReshapeOp producer, LinalgOp consumer, unsigned consumerIdx, - PatternRewriter &rewriter, OperationFolder *folder = nullptr) { - if (!isFusable(producer, consumer, consumerIdx)) - return llvm::None; - return fuseWithReshapeByExpansion(consumer, producer, consumerIdx, - rewriter); + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = cast(genericOp.getOperation()); + for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + TensorReshapeOp reshapeOp = + operand.value().getDefiningOp(); + if (!reshapeOp) + continue; + + // Fold only if + // - The tensor reshape op is folding. + // - All constraints of fusing with reshape by expansion are met. + if (reshapeOp.getSrcType().getRank() < + reshapeOp.getResultType().getRank() || + !isFusableWithReshapeByDimExpansion(linalgOp, operand.index())) + continue; + + Optional> replacementValues = + fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(), + rewriter); + if (!replacementValues) + return failure(); + rewriter.replaceOp(genericOp, replacementValues.getValue()); + if (reshapeOp.use_empty()) + rewriter.eraseOp(reshapeOp); + return success(); + } + return failure(); } }; -struct FuseTensorReshapeOpAsConsumerByLinearization { - static bool isFusable(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx) { - return isa(producer.getOperation()) && - producer.hasTensorSemantics() && - isTensorReshapeOpFusableByLinearization( - consumer, producer.getOutputIndexingMap(0), - /*asProducer=*/false); - } +/// Pattern to fold tensor_reshape op with its producer. The corresponding index +/// map in the consumer needs to be modified to linearize the folded dimension. +struct FoldConsumerReshapeOpByLinearization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - static Optional> - fuse(LinalgOp producer, TensorReshapeOp consumer, unsigned consumerIdx, - PatternRewriter &rewriter, OperationFolder *folder = nullptr) { - if (!isFusable(producer, consumer, consumerIdx)) - return llvm::None; + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + LinalgOp producer = reshapeOp.src().getDefiningOp(); + if (!producer || + !isa(producer.getOperation()) || + !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || + !isTensorReshapeOpFoldableByLinearization( + reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false)) + return failure(); // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. - SmallVector fusedIndexMaps = - llvm::to_vector<4>(llvm::map_range( - producer.indexing_maps(), [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); + SmallVector fusedIndexMaps = llvm::to_vector<4>( + producer.indexing_maps().getAsValueRange()); auto invMap = inversePermutation(producer.getOutputIndexingMap(0)); // Compute the indexing map to use for the operand of the producer. AffineMap modifiedMap = - linearizeCollapsedDims(invMap, consumer.getSrcType().getShape(), - consumer.getReassociationMaps()); + linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) - return llvm::None; + return reshapeOp.emitRemark("fused op indexing map is not affine"); } fusedIndexMaps.back() = modifiedMap; // Further check that the resulting index maps can be fused and // inverted. Without this the resultant op is not legal. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) - return llvm::None; - - SmallVector indexMapAttrs = llvm::to_vector<4>( - llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { - return AffineMapAttr::get(map); - })); + return reshapeOp.emitRemark("fused op loop bound computation failed"); LinalgOp fusedOp = createLinalgOpOfSameType( - producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(), + producer, rewriter, rewriter.getUnknownLoc(), reshapeOp.getResultType(), /*inputs=*/producer.getInputs(), /*outputBuffers=*/ValueRange{}, /*initTensors=*/ValueRange{}, // no init tensors for now. - rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(), + rewriter.getAffineMapArrayAttr(fusedIndexMaps), + producer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr, /*symbol_source=*/nullptr); auto &fusedRegion = fusedOp.getOperation()->getRegion(0); rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), fusedRegion, fusedRegion.begin()); - return SmallVector(fusedOp.getOperation()->getResults()); + rewriter.replaceOp(reshapeOp, fusedOp.getOperation()->getResults()); + if (producer.use_empty()) + rewriter.eraseOp(producer); + return success(); } }; -/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. -struct FuseTensorReshapeOpAsConsumerByExpansion { - static bool isFusable(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx) { - // Fuse only if +/// Pattern to fold a tensor_reshape op with its producer generic op if the +/// tensor_reshape op is expanding, by expanding the dimensionality of the loop +/// in the producer op. +struct FoldReshapeWithGenericOpByExpansion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + // Fold only if // - The tensor reshape op is a expanding case. // - All constraints of fusing with reshape by expansion are met. - return consumer.getSrcType().getRank() < - consumer.getResultType().getRank() && - isFusableWithReshapeByDimExpansion(producer, - producer.getNumInputs()); - } - - static Optional> - fuse(LinalgOp producer, TensorReshapeOp consumer, unsigned consumerIdx, - PatternRewriter &rewriter, OperationFolder *folder = nullptr) { - if (!isFusable(producer, consumer, consumerIdx)) - return llvm::None; - return fuseWithReshapeByExpansion( - producer, consumer, producer.getNumInputs(), rewriter, folder); + if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank()) + return failure(); + LinalgOp producer = reshapeOp.src().getDefiningOp(); + if (!producer || producer.getNumOutputs() != 1 || + !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs())) + return failure(); + Optional> replacementValues = + fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(), + rewriter); + if (!replacementValues) + return failure(); + rewriter.replaceOp(reshapeOp, replacementValues.getValue()); + if (producer.use_empty()) + rewriter.eraseOp(producer); + return success(); } }; -/// Implementation of fusion on tensor ops when producer is a splat constant. -struct FuseConstantOpAsProducer { - static bool isFusable(ConstantOp producer, LinalgOp consumer, - unsigned consumerIdx) { - return isa(consumer.getOperation()) && - consumer.hasTensorSemantics() && - producer.getResult().getType().isa() && - producer.value().cast().isSplat(); - } - - static Optional> - fuse(ConstantOp producer, LinalgOp consumer, unsigned consumerIdx, - PatternRewriter &rewriter, OperationFolder *folder = nullptr) { - if (!isFusable(producer, consumer, consumerIdx)) - return llvm::None; - - // The indexing_maps for the operands of the fused operation are same as - // those for the operands of the consumer without the indexing map at - // consumerIdx - SmallVector fusedIndexMaps = - llvm::to_vector<4>(llvm::map_range( - consumer.indexing_maps(), [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); - fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx)); - - // The operands list is same as the consumer with the argument for constant - // index dropped. - SmallVector fusedOperands(consumer.getInputs()); - fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx)); - - // Create a constant scalar value from the splat constant. - Value scalarConstant = rewriter.create( - producer.getLoc(), - producer.value().cast().getSplatValue()); +/// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant. +template +struct FoldSplatConstants : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LinalgOp fusedOp = createLinalgOpOfSameType( - consumer, rewriter, rewriter.getUnknownLoc(), - consumer.getOperation()->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, // no init tensors for now. - rewriter.getAffineMapArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr); + LogicalResult matchAndRewrite(LinalgOpTy op, + PatternRewriter &rewriter) const override { + if (!op.hasTensorSemantics()) + return failure(); + LinalgOp linalgOp = cast(op.getOperation()); + for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + ConstantOp constantOp = operand.value().getDefiningOp(); + if (!constantOp || + !constantOp.value().cast().isSplat()) + continue; - // Map the block argument corresponding to the replaced argument with the - // scalar constant. - Region &consumerRegion = consumer.getOperation()->getRegion(0); - Block &entryBlock = *consumerRegion.begin(); - unsigned argIndex = - entryBlock.getNumArguments() - consumer.getNumInputs() + consumerIdx; - BlockAndValueMapping mapping; - mapping.map(entryBlock.getArgument(argIndex), scalarConstant); - Region &fusedRegion = fusedOp.getOperation()->getRegion(0); - rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(), - mapping); - return SmallVector(fusedOp.getOperation()->getResults()); + // The indexing_maps for the operands of the fused operation are same as + // those for the operands of the linalgOp without the indexing map at + // operand.index() + SmallVector fusedIndexMaps = llvm::to_vector<4>( + linalgOp.indexing_maps().getAsValueRange()); + fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index())); + + // The operands list is same as the linalgOp with the argument for + // constant index dropped. + SmallVector fusedOperands(linalgOp.getInputs()); + fusedOperands.erase(std::next(fusedOperands.begin(), operand.index())); + + // Create a constant scalar value from the splat constant. + Value scalarConstant = rewriter.create( + constantOp.getLoc(), + constantOp.value().cast().getSplatValue()); + + LinalgOp fusedOp = createLinalgOpOfSameType( + linalgOp, rewriter, rewriter.getUnknownLoc(), + linalgOp.getOperation()->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, // no init tensors for now. + rewriter.getAffineMapArrayAttr(fusedIndexMaps), + linalgOp.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*symbol_source=*/nullptr); + + // Map the block argument corresponding to the replaced argument with the + // scalar constant. + Region &linalgOpRegion = linalgOp.getOperation()->getRegion(0); + Block &entryBlock = *linalgOpRegion.begin(); + unsigned argIndex = entryBlock.getNumArguments() - + linalgOp.getNumInputs() + operand.index(); + BlockAndValueMapping mapping; + mapping.map(entryBlock.getArgument(argIndex), scalarConstant); + Region &fusedRegion = fusedOp.getOperation()->getRegion(0); + rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion, + fusedRegion.begin(), mapping); + rewriter.replaceOp(linalgOp, fusedOp.getOperation()->getResults()); + if (constantOp.use_empty()) + rewriter.eraseOp(constantOp); + return success(); + } + return failure(); } }; } // namespace Optional> mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, - unsigned consumerIdx, OperationFolder *folder, - bool useReshapeFusionByExpansion) { + unsigned consumerIdx, OperationFolder *folder) { if (consumerIdx >= consumer->getNumOperands()) return llvm::None; Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); @@ -834,50 +869,20 @@ return llvm::None; // Fuse when consumer is GenericOp or IndexedGenericOp. - if (isa(consumer)) { - if (isa(producer)) - return FuseGenericOpsOnTensors::fuse(cast(producer), - cast(consumer), - consumerIdx, rewriter, folder); - if (auto reshapeOpProducer = dyn_cast(producer)) { - if (useReshapeFusionByExpansion) { - return FuseTensorReshapeOpAsProducerByExpansion::fuse( - reshapeOpProducer, cast(consumer), consumerIdx, rewriter, - folder); - } - return FuseTensorReshapeOpAsProducerByLinearization::fuse( - reshapeOpProducer, cast(consumer), consumerIdx, rewriter, - folder); - } - if (auto constantOpProducer = dyn_cast(producer)) - return FuseConstantOpAsProducer::fuse(constantOpProducer, - cast(consumer), - consumerIdx, rewriter, folder); + if (!isa(consumer) || + !isa(producer)) return llvm::None; - } - if (isa(producer)) { - // Fuse when consumer is a TensorReshapeOp. - if (TensorReshapeOp reshapeOp = dyn_cast(consumer)) { - if (useReshapeFusionByExpansion) { - return FuseTensorReshapeOpAsConsumerByExpansion::fuse( - cast(producer), reshapeOp, consumerIdx, rewriter, folder); - } - return FuseTensorReshapeOpAsConsumerByLinearization::fuse( - cast(producer), reshapeOp, consumerIdx, rewriter, folder); - } - } - - return llvm::None; + return FuseGenericOpsOnTensors::fuse(cast(producer), + cast(consumer), consumerIdx, + rewriter, folder); } namespace { /// Patterns to fuse a generic op, with the producer of its operands. template struct FuseTensorOps : public OpRewritePattern { - FuseTensorOps(MLIRContext *context, bool useReshapeFusionByExpansion = false) - : OpRewritePattern::OpRewritePattern(context), - useReshapeFusionByExpansion(useReshapeFusionByExpansion) {} + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(LinalgOpTy op, PatternRewriter &rewriter) const override { @@ -886,56 +891,75 @@ llvm::seq(0, op.getOperation()->getNumOperands())) { Operation *producer = op.getOperation()->getOperand(operandNum).getDefiningOp(); - Optional> fusedOpResults = fuseTensorOps( - rewriter, op, operandNum, nullptr, useReshapeFusionByExpansion); + if (!producer) + continue; + Optional> fusedOpResults = + fuseTensorOps(rewriter, op, operandNum); if (fusedOpResults) { rewriter.replaceOp(op, *fusedOpResults); - if (producer && llvm::all_of(producer->getResults(), - [](Value val) { return val.use_empty(); })) + if (producer->use_empty()) rewriter.eraseOp(producer); return success(); } } return failure(); } - -private: - bool useReshapeFusionByExpansion; }; /// Pass that fuses generic ops on tensors. Used only for testing. struct FusionOfTensorOpsPass : public LinalgFusionOfTensorOpsBase { - FusionOfTensorOpsPass() = default; - FusionOfTensorOpsPass(const FusionOfTensorOpsPass &) {} - FusionOfTensorOpsPass(bool vUseReshapeFusionByExpansion) { - this->useReshapeFusionByExpansion = vUseReshapeFusionByExpansion; + void runOnOperation() override { + OwningRewritePatternList patterns; + Operation *op = getOperation(); + populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), patterns); } +}; +/// Pass to test folding of reshape op with generic/indexed_generic ops by +/// linearization. +struct FoldReshapeOpsByLinearizationPass + : public LinalgFoldReshapeOpsByLinearizationBase< + FoldReshapeOpsByLinearizationPass> { void runOnOperation() override { OwningRewritePatternList patterns; Operation *op = getOperation(); - populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns, - useReshapeFusionByExpansion); + populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns); applyPatternsAndFoldGreedily(op->getRegions(), patterns); - }; - - Option useReshapeFusionByExpansion{ - *this, "use-reshape-fusion-by-expansion", llvm::cl::init(true), - llvm::cl::desc("Enable use of reshape fusion with producer/consumer by " - "expanding the dimensionality of the producer/consumer")}; + } }; + } // namespace +void mlir::populateFoldReshapeOpsByLinearizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>(context); +} + +void mlir::populateFoldReshapeOpsByExpansionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert(context); +} + void mlir::populateLinalgTensorOpsFusionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns, - bool useReshapeFusionByExpansion) { + MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert, FuseTensorOps, - FuseTensorOps>(context, - useReshapeFusionByExpansion); + FoldSplatConstants, + FoldSplatConstants>(context); + populateFoldReshapeOpsByExpansionPatterns(context, patterns); + GenericOp::getCanonicalizationPatterns(patterns, context); + IndexedGenericOp::getCanonicalizationPatterns(patterns, context); + TensorReshapeOp::getCanonicalizationPatterns(patterns, context); +} + +std::unique_ptr mlir::createLinalgFusionOfTensorOpsPass() { + return std::make_unique(); } -std::unique_ptr -mlir::createLinalgFusionOfTensorOpsPass(bool useReshapeFusionByExpansion) { - return std::make_unique(useReshapeFusionByExpansion); +std::unique_ptr mlir::createFoldReshapeOpsByLinearizationPass() { + return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -142,90 +142,6 @@ // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_producer_fusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k)>, - affine_map<(i, j, k, l) -> (l)>] : - tensor into tensor - %1 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%0, %arg1 : tensor, tensor) { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - return %1 : tensor -} - -// CHECK-LABEL: func @generic_op_reshape_producer_fusion -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]] -// CHECK-NOT: linalg.generic - - -// ----- - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_consumer_fusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k, l)>] : - tensor into tensor - return %1 : tensor -} - -// CHECK-LABEL: func @generic_op_reshape_consumer_fusion -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.generic - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_consumer_nofusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k, l)>] : - tensor into tensor - return %1 : tensor -} - -// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion -// CHECK: linalg.tensor_reshape - -// ----- - #map0 = affine_map<(d0, d1, d2) -> (d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> @@ -465,159 +381,3 @@ // CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 // CHECK: linalg.yield %[[VAL4]] : i32 // CHECK-NOT: linalg.indexed_generic - -// ----- - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor) - -> tensor { - %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k)>, - affine_map<(i, j, k, l) -> (l)>] : - tensor into tensor - %1 = linalg.indexed_generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] } - ins(%0 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors - %2 = index_cast %arg2 : index to i32 - %3 = addi %arg6, %2 : i32 - linalg.yield %3 : i32 - } -> tensor - return %1 : tensor -} - -// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.indexed_generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor) - -> tensor { - %0 = linalg.indexed_generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] } - ins(%arg0 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors - %2 = index_cast %arg2 : index to i32 - %3 = addi %arg6, %2 : i32 - linalg.yield %3 : i32 - } -> tensor - %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k, l)>] : - tensor into tensor - return %1 : tensor -} - -// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.indexed_generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -#map0 = affine_map<(d0, d1, d2) -> (d0)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> { - %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<3x7x5xf32> - return %1 : tensor<3x7x5xf32> -} - -// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -#map0 = affine_map<(d0, d1, d2) -> (d0)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { - %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<5x7x3xf32> - return %1 : tensor<5x7x3xf32> -} - -// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -#map0 = affine_map<(d0, d1, d2) -> (d0)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> { - %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<5x3x7xf32> - return %1 : tensor<5x3x7xf32> -} - -// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> - - -#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0)> -#map3 = affine_map<(d0, d1, d2) -> (d1, d2)> -func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> { - %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<5x3x7xf32> - %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32> - return %1 : tensor<5x21xf32> -} - -// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops=use-reshape-fusion-by-expansion -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file | FileCheck %s #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir @@ -0,0 +1,241 @@ +// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s + + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_producer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_producer_fusion +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]] +// CHECK-NOT: linalg.generic + + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_consumer_fusion +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.generic + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_consumer_nofusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion +// CHECK: linalg.tensor_reshape + +// ----- + + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor) + -> tensor { + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.indexed_generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } + ins(%0 : tensor) { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = addi %arg6, %2 : i32 + linalg.yield %3 : i32 + } -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor) + -> tensor { + %0 = linalg.indexed_generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } + ins(%arg0 : tensor) { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = addi %arg6, %2 : i32 + linalg.yield %3 : i32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<3x7x5xf32> + return %1 : tensor<3x7x5xf32> +} + +// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x7x3xf32> + return %1 : tensor<5x7x3xf32> +} + +// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x3x7xf32> + return %1 : tensor<5x3x7xf32> +} + +// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> + + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d2)> +func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> { + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x3x7xf32> + %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32> + return %1 : tensor<5x21xf32> +} + +// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape