diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -85,7 +85,7 @@ "ArrayRef attrs = {}", [{ auto reassociationMaps = convertReassociationIndicesToMaps($_builder, reassociation); - build($_builder, $_state, src, reassociationMaps, attrs); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); }]> ]; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -20,6 +20,7 @@ std::unique_ptr> createLinalgFusionPass(); std::unique_ptr createLinalgFusionOfTensorOpsPass(); +std::unique_ptr createFoldReshapeOpsByLinearizationPass(); std::unique_ptr> createLinalgTilingPass(ArrayRef tileSizes = {}); @@ -48,6 +49,19 @@ /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); +/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its +/// producer (consumer) generic operation by expanding the dimensionality of the +/// loop in the generic op. +void populateFoldReshapeOpsByExpansionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + +/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its +/// producer (consumer) generic/indexed_generic operation by linearizing the +/// indexing map used to access the source (target) of the reshape operation in +/// the generic/indexed_generic operation. +void populateFoldReshapeOpsByLinearizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + /// Patterns for fusing linalg operation on tensors. void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -35,6 +35,14 @@ let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } +def LinalgFoldReshapeOpsByLinearization : + Pass<"linalg-fold-reshape-ops-by-linearization"> { + let summary = "Fold TensorReshapeOps with generic/indexed generic ops by " + "linearization"; + let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()"; + let dependentDialects = ["AffineDialect"]; +} + def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -91,9 +91,9 @@ /// Fuse linalg operation on tensors, with the producer of the operand at /// position `consumerIdx` of the consumer. -Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, - unsigned consumerIdx, - OperationFolder *folder = nullptr); +Optional> +fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, + unsigned consumerIdx, OperationFolder *folder = nullptr); /// Returns the linearized list of all shape dimensions in a `linalgOp`. /// Applying the inverse, concatenated loopToOperandRangeMaps to this list diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -514,9 +514,8 @@ return success(); } // Check if producer and consumer are both collapsing dims. - else if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), - reshapeOp.getSrcType(), - reshapeOp.getResultType())) { + if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), reshapeOp.getSrcType(), + reshapeOp.getResultType())) { rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(), collapseReassociationMaps(srcReshapeOp.getReassociationMaps(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -24,247 +24,240 @@ using namespace mlir; using namespace mlir::linalg; -namespace { - /// Implementation of fusion of generic ops and indexed_generic ops. -struct FuseGenericOpsOnTensors { - static bool isFusible(LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx) { - // Producer and consumer must have tensor semantics. - if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) - return false; - - // Verify that - // - the producer has all "parallel" iterator type. - if (producer.getNumParallelLoops() != producer.getNumLoops()) - return false; - - // Get the consumer index map. The number of results of the consumer index - // map must match the number of loops of the producer. - AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); - if (consumerIndexMap.getNumResults() != producer.getNumLoops()) - return false; - - // Finally the index_map for the result must be invertible. For now just - // verify it is a permutation. - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - return producerResultIndexMap.isPermutation(); - } +// struct FuseGenericOpsOnTensors { +static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer, + unsigned consumerIdx) { + // Producer and consumer must have tensor semantics. + if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) + return false; - static LinalgOp fuse(LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { - if (!isFusible(producer, consumer, consumerIdx)) - return nullptr; - - unsigned numFusedOperands = producer.getOperation()->getNumOperands() + - consumer.getOperation()->getNumOperands() - 1; - - // Compute the fused operands list, - SmallVector fusedOperands; - fusedOperands.reserve(numFusedOperands); - auto consumerOperands = consumer.getOperation()->getOperands(); - auto producerOperands = producer.getOperation()->getOperands(); - fusedOperands.assign(consumerOperands.begin(), - std::next(consumerOperands.begin(), consumerIdx)); - fusedOperands.append(producerOperands.begin(), producerOperands.end()); - fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), - consumerOperands.end()); - - // Compute indexing_maps for the fused operation. The indexing_maps for the - // operands of the consumers that arent fused are the same. The - // indexing_maps for the producers need to be computed based on the - // indexing_map of the operand at consumerIdx in the consumer. - SmallVector fusedIndexMaps; - auto consumerIndexMaps = consumer.indexing_maps(); - fusedIndexMaps.reserve(fusedOperands.size() + - consumer.getOperation()->getNumResults()); - fusedIndexMaps.assign(consumerIndexMaps.begin(), - std::next(consumerIndexMaps.begin(), consumerIdx)); - // Compute indexing maps for the producer args in the fused operation. - computeProducerOperandIndex( - producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); - - // Append the indexing maps for the remaining consumer operands. - fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), - consumerIndexMaps.end()); - - // Generate the fused op. - // Tensor-level fusion is only on ops without initTensors and outputBuffers. - LinalgOp fusedOp; - if (isa(producer.getOperation()) && - isa(consumer.getOperation())) { - fusedOp = - rewriter - .create(consumer.getLoc(), - consumer.getOperation()->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, - rewriter.getArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr) - .getOperation(); - } else { - fusedOp = - rewriter - .create( - consumer.getLoc(), consumer.getOperation()->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, - rewriter.getArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr) - .getOperation(); - } + // Verify that + // - the producer has all "parallel" iterator type. + if (producer.getNumParallelLoops() != producer.getNumLoops()) + return false; - // Construct an AffineMap from consumer loops to producer loops. - // consumer loop -> tensor index - AffineMap consumerResultIndexMap = - consumer.getInputIndexingMap(consumerIdx); - // producer loop -> tensor index - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - // tensor index -> producer loop - AffineMap invProducerResultIndexMap = - inversePermutation(producerResultIndexMap); - assert(invProducerResultIndexMap && - "expected producer result indexig map to be invertible"); - // consumer loop -> producer loop - AffineMap consumerToProducerLoopsMap = - invProducerResultIndexMap.compose(consumerResultIndexMap); - - generateFusedRegion(rewriter, fusedOp, producer, consumer, - consumerToProducerLoopsMap, consumerIdx, - consumer.getNumLoops()); - return fusedOp; - } + // Get the consumer index map. The number of results of the consumer index + // map must match the number of loops of the producer. + AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); + if (consumerIndexMap.getNumResults() != producer.getNumLoops()) + return false; -private: - /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of - /// the `producer` to use in the fused operation given the indexing map of the - /// result of the producer in the consumer. - static void computeProducerOperandIndex( - LinalgOp producer, AffineMap fusedConsumerArgIndexMap, - SmallVectorImpl &fusedOpIndexingMapAttrs) { - // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map - // from consumer loop -> consumer arg tensor index/producer result tensor - // index. The fused loop is same as the consumer loop. For each producer arg - // the indexing map to be computed is a map from consumer loop -> producer - // arg tensor index. - - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - // producerResultIndexMap is a map from producer loop -> tensor index. - // Compute the inverse to get map from tensor index -> producer loop. - // The inverse is a map from producer result tensor index -> producer loop. - AffineMap invProducerResultIndexMap = - inversePermutation(producerResultIndexMap); - assert(invProducerResultIndexMap && - "expected producer result indexig map to be invertible"); - for (unsigned argNum : llvm::seq(0, producer.getNumInputs())) { - // argMap is a map from producer loop -> producer arg tensor index. - AffineMap argMap = producer.getInputIndexingMap(argNum); - - // Compose argMap with invProducerResultIndexMap to get a map from - // producer result tensor index -> producer arg tensor index. - AffineMap t1 = argMap.compose(invProducerResultIndexMap); - - // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from - // consumer loop/ fused loop -> producer arg tensor index. - AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); - fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); - } + // Finally the index_map for the result must be invertible. For now just + // verify it is a permutation. + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + return producerResultIndexMap.isPermutation(); +} + +/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of +/// the `producer` to use in the fused operation given the indexing map of the +/// result of the producer in the consumer. +static void getIndexingMapOfProducerOperandsInFusedOp( + LinalgOp producer, AffineMap fusedConsumerArgIndexMap, + SmallVectorImpl &fusedOpIndexingMapAttrs) { + // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map + // from consumer loop -> consumer arg tensor index/producer result tensor + // index. The fused loop is same as the consumer loop. For each producer arg + // the indexing map to be computed is a map from consumer loop -> producer + // arg tensor index. + + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + // producerResultIndexMap is a map from producer loop -> tensor index. + // Compute the inverse to get map from tensor index -> producer loop. + // The inverse is a map from producer result tensor index -> producer loop. + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + for (unsigned argNum : llvm::seq(0, producer.getNumInputs())) { + // argMap is a map from producer loop -> producer arg tensor index. + AffineMap argMap = producer.getInputIndexingMap(argNum); + + // Compose argMap with invProducerResultIndexMap to get a map from + // producer result tensor index -> producer arg tensor index. + AffineMap t1 = argMap.compose(invProducerResultIndexMap); + + // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from + // consumer loop/ fused loop -> producer arg tensor index. + AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); + fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); } +} - /// Generate the region of the fused operation. The region of the fused op - /// must be empty. - static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp, - LinalgOp producer, LinalgOp consumer, - AffineMap consumerToProducerLoopsMap, - unsigned consumerIdx, unsigned nloops) { - // Build the region of the fused op. - Block &producerBlock = producer.getOperation()->getRegion(0).front(); - Block &consumerBlock = consumer.getOperation()->getRegion(0).front(); - Block *fusedBlock = new Block(); - fusedOp->getRegion(0).push_back(fusedBlock); - BlockAndValueMapping mapper; - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(fusedBlock); - - // The block arguments are - // [index_0, index_1, ... , - // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), - // producer_operand_0, ... , producer_operand_(n-1)], - // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] - // , where n is the number of producer's operand and m is the number - // consumer's operand. - // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a - // generic op. In this case, there are no indices in block arguments. - unsigned numProducerIndices = - isa(producer.getOperation()) ? nloops : 0; - unsigned numConsumerIndices = - isa(consumer.getOperation()) ? nloops : 0; - // Firstly, add all the indices to the block arguments. - for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices); - i < e; ++i) - fusedBlock->addArgument(rewriter.getIndexType()); - // Map the arguments for the unmodified args from the consumer. - for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { - if (consumerArg.index() == consumerIdx + numConsumerIndices) { - // Map the arguments for the args from the producer. - for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) { - // If producer is an indexed_generic op, map the indices from consumer - // loop to producer loop (because the fusedOp is built based on - // consumer's perspective). - if (producerArg.index() < numProducerIndices) { - auto newIndex = rewriter.create( - producer.getLoc(), - consumerToProducerLoopsMap.getSubMap(producerArg.index()), - fusedBlock->getArguments().take_front(nloops)); - mapper.map(producerArg.value(), newIndex); - } else { - mapper.map(producerArg.value(), - fusedBlock->addArgument(producerArg.value().getType())); - } +/// Generate the region of the fused tensor operation. The region of the fused +/// op must be empty. +static void generateFusedTensorOpRegion(PatternRewriter &rewriter, + Operation *fusedOp, LinalgOp producer, + LinalgOp consumer, + AffineMap consumerToProducerLoopsMap, + unsigned consumerIdx, unsigned nloops) { + // Build the region of the fused op. + Block &producerBlock = producer.getOperation()->getRegion(0).front(); + Block &consumerBlock = consumer.getOperation()->getRegion(0).front(); + Block *fusedBlock = new Block(); + fusedOp->getRegion(0).push_back(fusedBlock); + BlockAndValueMapping mapper; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(fusedBlock); + + // The block arguments are + // [index_0, index_1, ... , + // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), + // producer_operand_0, ... , producer_operand_(n-1)], + // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] + // , where n is the number of producer's operand and m is the number + // consumer's operand. + // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a + // generic op. In this case, there are no indices in block arguments. + unsigned numProducerIndices = + isa(producer.getOperation()) ? nloops : 0; + unsigned numConsumerIndices = + isa(consumer.getOperation()) ? nloops : 0; + // Firstly, add all the indices to the block arguments. + for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices); + i < e; ++i) + fusedBlock->addArgument(rewriter.getIndexType()); + // Map the arguments for the unmodified args from the consumer. + for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { + if (consumerArg.index() == consumerIdx + numConsumerIndices) { + // Map the arguments for the args from the producer. + for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) { + // If producer is an indexed_generic op, map the indices from consumer + // loop to producer loop (because the fusedOp is built based on + // consumer's perspective). + if (producerArg.index() < numProducerIndices) { + auto newIndex = rewriter.create( + producer.getLoc(), + consumerToProducerLoopsMap.getSubMap(producerArg.index()), + fusedBlock->getArguments().take_front(nloops)); + mapper.map(producerArg.value(), newIndex); + } else { + mapper.map(producerArg.value(), + fusedBlock->addArgument(producerArg.value().getType())); } - continue; } + continue; + } - // If consumer is an indexed_generic op, map the indices to the block - // arguments directly. Otherwise, add the same type of arugment and map to - // it. - if (consumerArg.index() < numConsumerIndices) { - mapper.map(consumerArg.value(), - fusedBlock->getArgument(consumerArg.index())); - } else { - mapper.map(consumerArg.value(), - fusedBlock->addArgument(consumerArg.value().getType())); - } + // If consumer is an indexed_generic op, map the indices to the block + // arguments directly. Otherwise, add the same type of arugment and map to + // it. + if (consumerArg.index() < numConsumerIndices) { + mapper.map(consumerArg.value(), + fusedBlock->getArgument(consumerArg.index())); + } else { + mapper.map(consumerArg.value(), + fusedBlock->addArgument(consumerArg.value().getType())); } + } - // Add operations from producer (except the yield operation) to the fused - // op. - for (auto &op : producerBlock.getOperations()) { - if (auto yieldOp = dyn_cast(op)) { - // Lookup the value the yield operation is mapped to. - Value yieldVal = yieldOp.getOperand(0); - if (Value clonedVal = mapper.lookupOrNull(yieldVal)) - mapper.map( - consumerBlock.getArgument(consumerIdx + numConsumerIndices), - clonedVal); - continue; - } - rewriter.clone(op, mapper); + // Add operations from producer (except the yield operation) to the fused + // op. + for (auto &op : producerBlock.getOperations()) { + if (auto yieldOp = dyn_cast(op)) { + // Lookup the value the yield operation is mapped to. + Value yieldVal = yieldOp.getOperand(0); + if (Value clonedVal = mapper.lookupOrNull(yieldVal)) + mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices), + clonedVal); + continue; } - for (auto &op : consumerBlock.getOperations()) - rewriter.clone(op, mapper); + rewriter.clone(op, mapper); } -}; -} // namespace + for (auto &op : consumerBlock.getOperations()) + rewriter.clone(op, mapper); +} + +static Optional> +fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, + PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { + if (!areTensorOpsFusable(producer, consumer, consumerIdx)) + return llvm::None; + + unsigned numFusedOperands = + producer.getNumInputs() + consumer.getNumInputs() - 1; + + // Compute the fused operands list, + SmallVector fusedOperands; + fusedOperands.reserve(numFusedOperands); + auto consumerOperands = consumer.getInputs(); + auto producerOperands = producer.getInputs(); + fusedOperands.assign(consumerOperands.begin(), + std::next(consumerOperands.begin(), consumerIdx)); + fusedOperands.append(producerOperands.begin(), producerOperands.end()); + fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), + consumerOperands.end()); + + // Compute indexing_maps for the fused operation. The indexing_maps for the + // operands of the consumers that arent fused are the same. The + // indexing_maps for the producers need to be computed based on the + // indexing_map of the operand at consumerIdx in the consumer. + SmallVector fusedIndexMaps; + auto consumerIndexMaps = consumer.indexing_maps(); + fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs()); + fusedIndexMaps.assign(consumerIndexMaps.begin(), + std::next(consumerIndexMaps.begin(), consumerIdx)); + // Compute indexing maps for the producer args in the fused operation. + getIndexingMapOfProducerOperandsInFusedOp( + producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); + + // Append the indexing maps for the remaining consumer operands. + fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), + consumerIndexMaps.end()); + + // Generate the fused op. + // Tensor-level fusion is only on ops without initTensors and outputBuffers. + LinalgOp fusedOp; + if (isa(producer.getOperation()) && + isa(consumer.getOperation())) { + fusedOp = rewriter + .create(consumer.getLoc(), + consumer.getOperation()->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, + rewriter.getArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) + .getOperation(); + } else { + fusedOp = + rewriter + .create(consumer.getLoc(), + consumer.getOperation()->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, + rewriter.getArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) + .getOperation(); + } + + // Construct an AffineMap from consumer loops to producer loops. + // consumer loop -> tensor index + AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx); + // producer loop -> tensor index + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + // tensor index -> producer loop + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + // consumer loop -> producer loop + AffineMap consumerToProducerLoopsMap = + invProducerResultIndexMap.compose(consumerResultIndexMap); + + generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer, + consumer, consumerToProducerLoopsMap, consumerIdx, + consumer.getNumLoops()); + return SmallVector(fusedOp.getOperation()->getResults()); +} /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` /// provided, given the shape of the source tensor that corresponds to the @@ -313,18 +306,21 @@ /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is /// true) or its producer (if `asProducer` is false) given the indexing map at /// its use. -static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp, - AffineMap useIndexMap, bool asProducer) { +static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp, + AffineMap useIndexMap, + bool asProducer) { RankedTensorType returnType = reshapeOp.getResultType(); RankedTensorType operandType = reshapeOp.getSrcType(); - // Reshape is fusible with its consumer (i.e. reshape as a producer) when its + // Reshape is fusable with its consumer (i.e. reshape as a producer) when its // operand is of lesser rank than the result. Fusing when operand has higher // rank will require use of mods and divs in the indexing maps of the fused op // which would make it non-invertible. Similarly reshape is fused with its // producer (i.e. reshape as consumer) only if the return type has lesser // rank. - if ((asProducer && returnType.getRank() < operandType.getRank()) || - (!asProducer && operandType.getRank() < returnType.getRank())) + if ((asProducer && reshapeOp.getSrcType().hasStaticShape() && + returnType.getRank() < operandType.getRank()) || + (!asProducer && reshapeOp.getResultType().hasStaticShape() && + operandType.getRank() < returnType.getRank())) return false; return useIndexMap.isPermutation(); } @@ -346,314 +342,533 @@ return nullptr; } -namespace { +/// Conditions for folding a generic/indexed-generic operation with a reshape op +/// by expanding the iteration space dimensionality for tensor operations. These +/// are preconditions assumed by `foldReshapeByDimExpansion` which implements +/// the following fusion pattern. +/// +/// Consider +/// +/// %c = linalg.generic ins(%a, %b : memref, memref) +/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, +/// affine_map<(d0, d1, d2) -> (d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] +/// %d = linalg.tensor_reshape %c +/// [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, +/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, +/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] +/// : tensor into tensor +/// +/// The reshape can be folded into the `linalgOp` if the +/// generic/indexed-generic op loop dimensionality is increased to match the +/// result (operand) of the tensor_reshape when the reshape is expanding +/// (folding). The indexing_map of the fused tensor in the `linalgOp` and the +/// reassociation map helps compute the indexing maps of the modified op. For +/// the above example, based on the reassociation map it can be concluded that +/// +/// - The loop used to access the first dimension of the fused tensor is split +/// into two. +/// - The loop used to access the second dimension of the fused tensor is kept +/// as is. +/// - The loop used to access the third dimension of the fused tensor is split +/// into three. +/// +/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified +/// op, then +/// +/// d0 -> e0, e1 +/// d1 -> e2, e3, e4 +/// d2 -> e5 +/// +/// substituting this, the generic op can be rewritten as +/// +/// %d = linalg.generic ins(%0, %1 : ) +/// indexing_maps = +/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, +/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, +/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] +/// +/// Since operands to the linalg generic are now 5D, reshapes can be introduced +/// to make it consistent +/// +/// %0 = linalg.tensor_reshape %a +/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2), +/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4), +/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)] +/// : tensor into tensor +/// %1 = linalg.tensor_reshape %b +/// [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2), +/// affine_map<(e0, e1, e2, e3) -> (e3)] +/// : tensor into tensor +/// +/// The added reshapes are again expanding patterns, so they will get fused +/// with its producers if possible. +static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, + unsigned fusedTensorIndex) { + // Is fusable only if: + // - The linalgOp is a generic op. + // - All the indexing maps for operands in linalgOp are projected + // permutations. + // - The indexing map at the position representing the fused tensor is a + // permutation. + // - All the loops in linalgOp are parallel loops. + return isa(linalgOp.getOperation()) && + linalgOp.hasTensorSemantics() && + llvm::all_of(linalgOp.indexing_maps().getValue().take_front( + linalgOp.getNumInputs()), + [](Attribute attr) { + return attr.cast() + .getValue() + .isProjectedPermutation(); + }) && + linalgOp.getIndexingMap(fusedTensorIndex).isPermutation() && + llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) { + return attr.cast().getValue() == + getParallelIteratorTypeName(); + }); +} -/// Implementation of fusion on tensor ops when producer is a TensorReshapeOp. -struct FuseTensorReshapeOpAsProducer { - static bool isFusible(TensorReshapeOp producer, LinalgOp consumer, - unsigned consumerIdx) { - return isa(consumer.getOperation()) && - consumer.hasTensorSemantics() && - isTensorReshapeOpFusible(producer, - consumer.getInputIndexingMap(consumerIdx), - /*asProducer=*/true); +/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic +/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those +/// conditions have been satisfied. +static Optional> +fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, + unsigned fusedTensorIndex, PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { + assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) && + "preconditions for fuse operation failed"); + // Check if reshape is expanding or collapsing. + bool isExpanding = + reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); + RankedTensorType expandedType = + isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); + RankedTensorType foldedType = + isExpanding ? reshapeOp.getSrcType() : reshapeOp.getResultType(); + AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); + + // The reshape is folding/expanding consecutive dimensions. Given the indexing + // map of the fused tensor find the number of dimensions each of the loops of + // the original op is expanded into. Also record the shape of the expanded + // dimensions. + ArrayRef expandedShape = expandedType.getShape(); + SmallVector numFoldedDims(foldedType.getRank(), 0); + SmallVector, 4> expandedDimsShape( + expandedType.getRank()); + auto reassociationMaps = reshapeOp.getReassociationMaps(); + for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { + unsigned pos = resultExpr.value().cast().getPosition(); + AffineMap foldedDims = reassociationMaps[resultExpr.index()]; + numFoldedDims[pos] = foldedDims.getNumResults(); + ArrayRef shape = expandedShape.slice( + foldedDims.getResult(0).cast().getPosition(), + numFoldedDims[pos]); + expandedDimsShape[pos].assign(shape.begin(), shape.end()); } - static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { - if (producer.src().getDefiningOp()) - return nullptr; + // The remapping of the indices is then the prefix sum (inclusive) of the + // numFoldedDims. + SmallVector remapping(numFoldedDims.size() + 1, 0); + unsigned sum = 0; + for (auto numFoldedDim : llvm::enumerate(numFoldedDims)) { + sum += numFoldedDim.value(); + remapping[numFoldedDim.index() + 1] = sum; + } - if (!isFusible(producer, consumer, consumerIdx)) - return nullptr; + SmallVector expandedOpIndexingMaps; + // Compute the modified indexing maps by replacing every loop (AffineDimExpr) + // in the original indexing map with the sequence of loops that it is expanded + // to. + for (AffineMap indexingMap : linalgOp.getIndexingMaps()) { + SmallVector newExprs; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned pos = expr.cast().getPosition(); + for (unsigned newPos : + llvm::seq(remapping[pos], remapping[pos + 1])) { + newExprs.push_back(rewriter.getAffineDimExpr(newPos)); + } + } + expandedOpIndexingMaps.push_back( + AffineMap::get(remapping.back(), indexingMap.getNumSymbols(), newExprs, + rewriter.getContext())); + } - // Compute the fused operands list, - Operation *consumerOp = consumer.getOperation(); - SmallVector fusedOperands(consumerOp->getOperands()); - fusedOperands[consumerIdx] = producer.src(); + // The operands of the expanded op are computed by reshaping the original + // operands. The reshape depends on the ordering of the loop used to access + // the tensor in the original operation, and are expanded into as many + // dimensions as the loop is expanded into (as computed by `remapping`). + auto getReshapeInfo = + [&](AffineMap operandIndexingMap, + SmallVectorImpl &reassociation, + SmallVectorImpl &expandedOpOperandShape) { + unsigned reshapeDims = 0; + for (AffineExpr expr : operandIndexingMap.getResults()) { + unsigned origDim = expr.cast().getPosition(); + auto foldedDims = llvm::seq( + reshapeDims, reshapeDims + numFoldedDims[origDim]); + reassociation.emplace_back(foldedDims.begin(), foldedDims.end()); + expandedOpOperandShape.append(expandedDimsShape[origDim].begin(), + expandedDimsShape[origDim].end()); + reshapeDims += numFoldedDims[origDim]; + } + }; + SmallVector expandedOpOperands; + for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + if (operand.index() == fusedTensorIndex) { + expandedOpOperands.push_back(reshapeOp.src()); + continue; + } + AffineMap indexingMap = linalgOp.getIndexingMap(operand.index()); + SmallVector reassociation; + SmallVector expandedOperandShape; + getReshapeInfo(indexingMap, reassociation, expandedOperandShape); + Type expandedOperandType = RankedTensorType::get( + expandedOperandShape, + operand.value().getType().cast().getElementType()); + if (expandedOperandType != operand.value().getType()) { + expandedOpOperands.push_back(rewriter.create( + linalgOp.getLoc(), expandedOperandType, operand.value(), + reassociation)); + } else { + expandedOpOperands.push_back(operand.value()); + } + } + SmallVector resultTypes; + SmallVector, 1> resultReassociation; + for (auto result : llvm::enumerate(linalgOp.getOperation()->getResults())) { + AffineMap indexingMap = + linalgOp.getIndexingMap(linalgOp.getNumInputs() + result.index()); + SmallVector reassociation; + SmallVector expandedResultShape; + getReshapeInfo(indexingMap, reassociation, expandedResultShape); + resultTypes.push_back(RankedTensorType::get( + expandedResultShape, + result.value().getType().cast().getElementType())); + resultReassociation.emplace_back(std::move(reassociation)); + } - // Compute indexing_maps for the fused operation. The indexing_maps for the - // operands of the consumers that arent fused are the same. - SmallVector fusedIndexMaps = - llvm::to_vector<4>(llvm::map_range( - consumer.indexing_maps(), [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); + // The iterator types of the expanded op are all parallel. + SmallVector iteratorTypes(remapping.back(), + getParallelIteratorTypeName()); + + LinalgOp fusedOp = createLinalgOpOfSameType( + linalgOp, rewriter, linalgOp.getLoc(), resultTypes, + /*inputs=*/expandedOpOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, expandedOpIndexingMaps, iteratorTypes); + Region &fusedRegion = fusedOp.getOperation()->getRegion(0); + // TODO: Add support for indexed generic op, which would need mapping the + // expanded dimensions to the original dimension arguments. + rewriter.cloneRegionBefore(linalgOp.getOperation()->getRegion(0), fusedRegion, + fusedRegion.begin()); + + // Reshape the result values to their original shape if this is a collapsing + // reshape folded into its consumer. + SmallVector resultVals; + for (auto result : llvm::enumerate(linalgOp.getOperation()->getResults())) { + if (!isExpanding && + resultTypes[result.index()] != result.value().getType()) { + resultVals.push_back(rewriter.create( + linalgOp.getLoc(), result.value().getType(), + fusedOp.getOperation()->getResult(result.index()), + resultReassociation[result.index()])); + } else { + resultVals.push_back(fusedOp.getOperation()->getResult(result.index())); + } + } + // Assuming a single result. + return resultVals; +} - // Accepted consumer maps are either identity or permutation. - auto invMap = inversePermutation(fusedIndexMaps[consumerIdx]); +namespace { - // Compute the indexing map to use for the operand of the producer. - AffineMap modifiedMap = - linearizeCollapsedDims(invMap, producer.getResultType().getShape(), - producer.getReassociationMaps()); - for (AffineExpr expr : modifiedMap.getResults()) { - if (!expr.isPureAffine()) - return nullptr; - } - fusedIndexMaps[consumerIdx] = modifiedMap; +/// Pattern to fold tensor_reshape op with its consumer by using the source of +/// the reshape op as the operand in the consumer (instead of the result of the +/// tensor_reshapeop) when the tensor_reshape op is collapsing. The +/// corresponding index map in the consumer needs to be modified to linearize +/// the folded dimension. +/// +/// For example, +/// +/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// %0 = linalg.tensor_reshape %arg0 +/// [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>, +/// affine_map<(i, j, k, l) -> (l)>] +/// tensor into tensor +/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } +/// ins(%0, %arg1 : tensor, tensor) ... +/// -> tensor +/// +/// can be folded into +/// +/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } +/// ins(%arg0, %arg1 : tensor, tensor) ... +/// -> tensor +template +struct FoldProducerReshapeOpByLinearization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - // Further check that the resulting index maps can be fused and - // inverted. Without this the resultant op is not legal. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) - return nullptr; + LogicalResult matchAndRewrite(LinalgOpTy op, + PatternRewriter &rewriter) const override { + if (!op.hasTensorSemantics()) + return failure(); + LinalgOp linalgOp = cast(op.getOperation()); + for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + TensorReshapeOp reshapeOp = + operand.value().getDefiningOp(); + if (!reshapeOp || + !isTensorReshapeOpFoldableByLinearization( + reshapeOp, linalgOp.getInputIndexingMap(operand.index()), + /*asProducer =*/true)) + continue; - SmallVector indexMapAttrs = llvm::to_vector<4>( - llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { - return AffineMapAttr::get(map); - })); - LinalgOp fusedOp = createLinalgOpOfSameType( - consumer, rewriter, rewriter.getUnknownLoc(), - consumerOp->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, // no init tensors for now. - rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr); - auto &fusedRegion = fusedOp.getOperation()->getRegion(0); - rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion, - fusedRegion.begin()); - return fusedOp; + // Compute the fused operands list, + SmallVector fusedOperands(linalgOp.getInputs()); + fusedOperands[operand.index()] = reshapeOp.src(); + + // Compute indexing_maps for the fused operation. The indexing_maps for + // the operands of the consumers that arent fused are the same. + SmallVector fusedIndexMaps = llvm::to_vector<4>( + op.indexing_maps().template getAsValueRange()); + + // Accepted consumer maps are either identity or permutation. + auto invMap = inversePermutation(fusedIndexMaps[operand.index()]); + + // Compute the indexing map to use for the result of the producer. + AffineMap modifiedMap = + linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps()); + for (AffineExpr expr : modifiedMap.getResults()) { + if (!expr.isPureAffine()) + return failure(); + } + fusedIndexMaps[operand.index()] = modifiedMap; + + // Further check that the resulting index maps can be fused and + // inverted. Without this the resultant op is not legal. + if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) + return op.emitRemark("fused op loop bound computation failed"); + + rewriter.startRootUpdate(op); + op.getOperation()->setOperands(fusedOperands); + op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps)); + rewriter.finalizeRootUpdate(op); + if (reshapeOp.use_empty()) + rewriter.eraseOp(reshapeOp); + return success(); + } + return op.emitRemark("no fusion candidates found"); } }; -/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. -struct FuseTensorReshapeOpAsConsumer { - static bool isCollapsingAndFusible(LinalgOp producer, - TensorReshapeOp consumer, - unsigned consumerIdx) { - return isa(producer.getOperation()) && - producer.hasTensorSemantics() && - isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), - /*asProducer=*/false); +/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the +/// reshape op is collapsing dimensions. The dimensionality of the loop in the +/// consumer generic op is expanded. +struct FoldWithProducerReshapeOpByExpansion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = cast(genericOp.getOperation()); + for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + TensorReshapeOp reshapeOp = + operand.value().getDefiningOp(); + if (!reshapeOp) + continue; + + // Fold only if + // - The tensor reshape op is folding. + // - All constraints of fusing with reshape by expansion are met. + if (reshapeOp.getSrcType().getRank() < + reshapeOp.getResultType().getRank() || + !isFusableWithReshapeByDimExpansion(linalgOp, operand.index())) + continue; + + Optional> replacementValues = + fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(), + rewriter); + if (!replacementValues) + return failure(); + rewriter.replaceOp(genericOp, replacementValues.getValue()); + if (reshapeOp.use_empty()) + rewriter.eraseOp(reshapeOp); + return success(); + } + return failure(); } +}; - static LinalgOp fuseCollapsingCase(LinalgOp producer, - TensorReshapeOp consumer, - unsigned consumerIdx, - PatternRewriter &rewriter) { +/// Pattern to fold tensor_reshape op with its producer. The corresponding index +/// map in the consumer needs to be modified to linearize the folded dimension. +struct FoldConsumerReshapeOpByLinearization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + LinalgOp producer = reshapeOp.src().getDefiningOp(); + if (!producer || + !isa(producer.getOperation()) || + !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || + !isTensorReshapeOpFoldableByLinearization( + reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false)) + return failure(); // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. - SmallVector fusedIndexMaps = - llvm::to_vector<4>(llvm::map_range( - producer.indexing_maps(), [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); + SmallVector fusedIndexMaps = llvm::to_vector<4>( + producer.indexing_maps().getAsValueRange()); auto invMap = inversePermutation(producer.getOutputIndexingMap(0)); // Compute the indexing map to use for the operand of the producer. AffineMap modifiedMap = - linearizeCollapsedDims(invMap, consumer.getSrcType().getShape(), - consumer.getReassociationMaps()); + linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) - return nullptr; + return reshapeOp.emitRemark("fused op indexing map is not affine"); } fusedIndexMaps.back() = modifiedMap; // Further check that the resulting index maps can be fused and // inverted. Without this the resultant op is not legal. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) - return nullptr; + return reshapeOp.emitRemark("fused op loop bound computation failed"); - SmallVector indexMapAttrs = llvm::to_vector<4>( - llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { - return AffineMapAttr::get(map); - })); - - Operation *producerOp = producer.getOperation(); LinalgOp fusedOp = createLinalgOpOfSameType( - producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(), - /*inputs=*/producerOp->getOperands(), + producer, rewriter, rewriter.getUnknownLoc(), reshapeOp.getResultType(), + /*inputs=*/producer.getInputs(), /*outputBuffers=*/ValueRange{}, /*initTensors=*/ValueRange{}, // no init tensors for now. - rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(), + rewriter.getAffineMapArrayAttr(fusedIndexMaps), + producer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr, /*symbol_source=*/nullptr); auto &fusedRegion = fusedOp.getOperation()->getRegion(0); - rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion, - fusedRegion.begin()); - return fusedOp; - } - - static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx) { - // Is fusible only if: - // 1) The producer is a generic op. - // 2) The producer has tensor semantics. - // 3) The tensor reshape op is a expanding case. - // 4) All the shapes are the same for the generic op. - // 5) All the indexing maps in producer are identity. - // 6) All the loops in producer are parallel loops. - // 7) The producer has a single user. - auto types = producer.getInputOutputShapedTypes(); - assert(!types.empty()); - return isa(producer.getOperation()) && - producer.hasTensorSemantics() && - consumer.getSrcType().getRank() < - consumer.getResultType().getRank() && - std::equal(types.begin() + 1, types.end(), types.begin()) && - llvm::all_of(producer.getIndexingMaps(), - [](AffineMap map) { return map.isIdentity(); }) && - llvm::all_of(producer.iterator_types(), - [](Attribute attr) { - return attr.cast().getValue() == - getParallelIteratorTypeName(); - }) && - producer.getOperation()->hasOneUse(); - } - - static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx, - PatternRewriter &rewriter) { - Location loc = producer.getLoc(); - auto dstShape = consumer.getResultType().cast().getShape(); - SmallVector args; - for (auto arg : producer.getOperation()->getOperands()) { - auto type = RankedTensorType::get( - dstShape, arg.getType().cast().getElementType()); - args.push_back(rewriter.createOrFold( - loc, type, arg, consumer.reassociation())); - } - - SmallVector resultTypes; - for (auto t : producer.getOutputTensorTypes()) { - Type type = RankedTensorType::get(dstShape, - t.cast().getElementType()); - resultTypes.push_back(type); - } - - int rank = dstShape.size(); - auto genericOp = rewriter.create( - loc, resultTypes, /*inputs=*/args, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, - SmallVector(args.size() + resultTypes.size(), - rewriter.getMultiDimIdentityMap(rank)), - SmallVector(rank, getParallelIteratorTypeName())); - Region ®ion = genericOp.getRegion(); - rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region, - region.begin()); - return cast(genericOp.getOperation()); - } - - static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { - if (isCollapsingAndFusible(producer, consumer, consumerIdx)) - return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter); - if (isExpandingAndFusible(producer, consumer, consumerIdx)) - return fuseExpandingCase(producer, consumer, consumerIdx, rewriter); - return nullptr; + rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), + fusedRegion, fusedRegion.begin()); + rewriter.replaceOp(reshapeOp, fusedOp.getOperation()->getResults()); + if (producer.use_empty()) + rewriter.eraseOp(producer); + return success(); } }; -/// Implementation of fusion on tensor ops when producer is a splat constant. -struct FuseConstantOpAsProducer { - static bool isFusible(ConstantOp producer, LinalgOp consumer, - unsigned consumerIdx) { - return isa(consumer.getOperation()) && - consumer.hasTensorSemantics() && - producer.getResult().getType().isa() && - producer.value().cast().isSplat(); +/// Pattern to fold a tensor_reshape op with its producer generic op if the +/// tensor_reshape op is expanding, by expanding the dimensionality of the loop +/// in the producer op. +struct FoldReshapeWithGenericOpByExpansion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + // Fold only if + // - The tensor reshape op is a expanding case. + // - All constraints of fusing with reshape by expansion are met. + if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank()) + return failure(); + LinalgOp producer = reshapeOp.src().getDefiningOp(); + if (!producer || producer.getNumOutputs() != 1 || + !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs())) + return failure(); + Optional> replacementValues = + fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(), + rewriter); + if (!replacementValues) + return failure(); + rewriter.replaceOp(reshapeOp, replacementValues.getValue()); + if (producer.use_empty()) + rewriter.eraseOp(producer); + return success(); } +}; - static LinalgOp fuse(ConstantOp producer, LinalgOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { - if (!isFusible(producer, consumer, consumerIdx)) - return nullptr; - - // The indexing_maps for the operands of the fused operation are same as - // those for the operands of the consumer without the indexing map at - // consumerIdx - SmallVector fusedIndexMaps = - llvm::to_vector<4>(llvm::map_range( - consumer.indexing_maps(), [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); - fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx)); - - // The operands list is same as the consumer with the argument for constant - // index dropped. - Operation *consumerOp = consumer.getOperation(); - SmallVector fusedOperands(consumerOp->getOperands()); - fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx)); - - // Create a constant scalar value from the splat constant. - Value scalarConstant = rewriter.create( - producer.getLoc(), - producer.value().cast().getSplatValue()); +/// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant. +template +struct FoldSplatConstants : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LinalgOp fusedOp = createLinalgOpOfSameType( - consumer, rewriter, rewriter.getUnknownLoc(), - consumerOp->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, // no init tensors for now. - rewriter.getAffineMapArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr); + LogicalResult matchAndRewrite(LinalgOpTy op, + PatternRewriter &rewriter) const override { + if (!op.hasTensorSemantics()) + return failure(); + LinalgOp linalgOp = cast(op.getOperation()); + for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + ConstantOp constantOp = operand.value().getDefiningOp(); + if (!constantOp || + !constantOp.value().cast().isSplat()) + continue; - // Map the block argument corresponding to the replaced argument with the - // scalar constant. - Region &consumerRegion = consumerOp->getRegion(0); - Block &entryBlock = *consumerRegion.begin(); - unsigned argIndex = entryBlock.getNumArguments() - - consumerOp->getNumOperands() + consumerIdx; - BlockAndValueMapping mapping; - mapping.map(entryBlock.getArgument(argIndex), scalarConstant); - Region &fusedRegion = fusedOp.getOperation()->getRegion(0); - rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(), - mapping); - return fusedOp; + // The indexing_maps for the operands of the fused operation are same as + // those for the operands of the linalgOp without the indexing map at + // operand.index() + SmallVector fusedIndexMaps = llvm::to_vector<4>( + linalgOp.indexing_maps().getAsValueRange()); + fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index())); + + // The operands list is same as the linalgOp with the argument for + // constant index dropped. + SmallVector fusedOperands(linalgOp.getInputs()); + fusedOperands.erase(std::next(fusedOperands.begin(), operand.index())); + + // Create a constant scalar value from the splat constant. + Value scalarConstant = rewriter.create( + constantOp.getLoc(), + constantOp.value().cast().getSplatValue()); + + LinalgOp fusedOp = createLinalgOpOfSameType( + linalgOp, rewriter, rewriter.getUnknownLoc(), + linalgOp.getOperation()->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, // no init tensors for now. + rewriter.getAffineMapArrayAttr(fusedIndexMaps), + linalgOp.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*symbol_source=*/nullptr); + + // Map the block argument corresponding to the replaced argument with the + // scalar constant. + Region &linalgOpRegion = linalgOp.getOperation()->getRegion(0); + Block &entryBlock = *linalgOpRegion.begin(); + unsigned argIndex = entryBlock.getNumArguments() - + linalgOp.getNumInputs() + operand.index(); + BlockAndValueMapping mapping; + mapping.map(entryBlock.getArgument(argIndex), scalarConstant); + Region &fusedRegion = fusedOp.getOperation()->getRegion(0); + rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion, + fusedRegion.begin(), mapping); + rewriter.replaceOp(linalgOp, fusedOp.getOperation()->getResults()); + if (constantOp.use_empty()) + rewriter.eraseOp(constantOp); + return success(); + } + return failure(); } }; } // namespace -Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, - Operation *consumer, - unsigned consumerIdx, - OperationFolder *folder) { +Optional> +mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, + unsigned consumerIdx, OperationFolder *folder) { if (consumerIdx >= consumer->getNumOperands()) - return nullptr; + return llvm::None; Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); if (!producer || producer->getNumResults() != 1) - return nullptr; + return llvm::None; // Fuse when consumer is GenericOp or IndexedGenericOp. - if (isa(consumer)) { - if (isa(producer)) - return FuseGenericOpsOnTensors::fuse(cast(producer), - cast(consumer), - consumerIdx, rewriter, folder); - if (auto reshapeOpProducer = dyn_cast(producer)) - return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer, - cast(consumer), - consumerIdx, rewriter, folder); - if (auto constantOpProducer = dyn_cast(producer)) - return FuseConstantOpAsProducer::fuse(constantOpProducer, - cast(consumer), - consumerIdx, rewriter, folder); - return nullptr; - } + if (!isa(consumer) || + !isa(producer)) + return llvm::None; - if (isa(producer)) { - // Fuse when consumer is a TensorReshapeOp. - if (TensorReshapeOp reshapeOp = dyn_cast(consumer)) { - return FuseTensorReshapeOpAsConsumer::fuse( - cast(producer), reshapeOp, consumerIdx, rewriter, folder); - } - } - - return nullptr; + return fuseTensorOpsImpl(cast(producer), cast(consumer), + consumerIdx, rewriter, folder); } namespace { @@ -669,10 +884,13 @@ llvm::seq(0, op.getOperation()->getNumOperands())) { Operation *producer = op.getOperation()->getOperand(operandNum).getDefiningOp(); - if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) { - rewriter.replaceOp(op, fusedOp->getResults()); - if (producer && llvm::all_of(producer->getResults(), - [](Value val) { return val.use_empty(); })) + if (!producer) + continue; + Optional> fusedOpResults = + fuseTensorOps(rewriter, op, operandNum); + if (fusedOpResults) { + rewriter.replaceOp(op, *fusedOpResults); + if (producer->use_empty()) rewriter.eraseOp(producer); return success(); } @@ -689,16 +907,52 @@ Operation *op = getOperation(); populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); applyPatternsAndFoldGreedily(op->getRegions(), patterns); - }; + } +}; + +/// Pass to test folding of reshape op with generic/indexed_generic ops by +/// linearization. +struct FoldReshapeOpsByLinearizationPass + : public LinalgFoldReshapeOpsByLinearizationBase< + FoldReshapeOpsByLinearizationPass> { + void runOnOperation() override { + OwningRewritePatternList patterns; + Operation *op = getOperation(); + populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), patterns); + } }; + } // namespace +void mlir::populateFoldReshapeOpsByLinearizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>(context); +} + +void mlir::populateFoldReshapeOpsByExpansionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert(context); +} + void mlir::populateLinalgTensorOpsFusionPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert, FuseTensorOps, - FuseTensorOps>(context); + FoldSplatConstants, + FoldSplatConstants>(context); + populateFoldReshapeOpsByExpansionPatterns(context, patterns); + GenericOp::getCanonicalizationPatterns(patterns, context); + IndexedGenericOp::getCanonicalizationPatterns(patterns, context); + TensorReshapeOp::getCanonicalizationPatterns(patterns, context); } std::unique_ptr mlir::createLinalgFusionOfTensorOpsPass() { return std::make_unique(); } + +std::unique_ptr mlir::createFoldReshapeOpsByLinearizationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -142,124 +142,6 @@ // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_producer_fusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k)>, - affine_map<(i, j, k, l) -> (l)>] : - tensor into tensor - %1 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%0, %arg1 : tensor, tensor) { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - return %1 : tensor -} - -// CHECK-LABEL: func @generic_op_reshape_producer_fusion -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]] -// CHECK-NOT: linalg.generic - - -// ----- - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_consumer_fusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k, l)>] : - tensor into tensor - return %1 : tensor -} - -// CHECK-LABEL: func @generic_op_reshape_consumer_fusion -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.generic - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_consumer_nofusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k, l)>] : - tensor into tensor - return %1 : tensor -} - -// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion -// CHECK: linalg.tensor_reshape - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d2)> - -func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>) - -> tensor<8x33x4xf32> { - %cst = constant dense<2.000000e+00> : tensor<264x4xf32> - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) { - ^bb0(%arg1: f32, %arg2: f32): // no predecessors - %2 = mulf %arg1, %arg2 : f32 - linalg.yield %2 : f32 - } -> tensor<264x4xf32> - %1 = linalg.tensor_reshape %0 [#map1, #map2] : - tensor<264x4xf32> into tensor<8x33x4xf32> - return %1 : tensor<8x33x4xf32> -} - -// The reshape op in `%arg0` is folded into the indexing map of generic op. -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0 * 33 + d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @generic_op_reshape_consumer_expanding -// CHECK-NOT: linalg.tensor_reshape -// CHECK: %[[CST:.*]] = constant {{.*}} : f32 -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: tensor<264x4xf32> -// CHECK: -> tensor<8x33x4xf32> -// CHECK-NOT: linalg.tensor_reshape - -// ----- - #map0 = affine_map<(d0, d1, d2) -> (d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> @@ -499,159 +381,3 @@ // CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 // CHECK: linalg.yield %[[VAL4]] : i32 // CHECK-NOT: linalg.indexed_generic - -// ----- - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor) - -> tensor { - %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k)>, - affine_map<(i, j, k, l) -> (l)>] : - tensor into tensor - %1 = linalg.indexed_generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] } - ins(%0 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors - %2 = index_cast %arg2 : index to i32 - %3 = addi %arg6, %2 : i32 - linalg.yield %3 : i32 - } -> tensor - return %1 : tensor -} - -// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.indexed_generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor) - -> tensor { - %0 = linalg.indexed_generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] } - ins(%arg0 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors - %2 = index_cast %arg2 : index to i32 - %3 = addi %arg6, %2 : i32 - linalg.yield %3 : i32 - } -> tensor - %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k, l)>] : - tensor into tensor - return %1 : tensor -} - -// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.indexed_generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -#map0 = affine_map<(d0, d1, d2) -> (d0)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> { - %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<3x7x5xf32> - return %1 : tensor<3x7x5xf32> -} - -// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -#map0 = affine_map<(d0, d1, d2) -> (d0)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { - %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<5x7x3xf32> - return %1 : tensor<5x7x3xf32> -} - -// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -#map0 = affine_map<(d0, d1, d2) -> (d0)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> { - %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<5x3x7xf32> - return %1 : tensor<5x3x7xf32> -} - -// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> - - -#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0)> -#map3 = affine_map<(d0, d1, d2) -> (d1, d2)> -func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> { - %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<5x3x7xf32> - %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32> - return %1 : tensor<5x21xf32> -} - -// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion -// CHECK-NOT: linalg.tensor_reshape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-NOT: linalg.tensor_reshape diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -0,0 +1,192 @@ +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +func @generic_op_reshape_producer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + return %1 : tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)> +// CHECK: func @generic_op_reshape_producer_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T1:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]] +// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[T0]] : tensor, tensor) +// CHECK: %[[T2:.+]] = linalg.tensor_reshape +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor into tensor +// CHECK: return %[[T2]] + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @generic_op_reshape_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func @generic_op_reshape_consumer_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T2:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] +// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) +// CHECK: return %[[T2]] : tensor + + +// ----- + +func @reshape_as_consumer_permutation + (%a : tensor, %b : tensor) + -> tensor { + %c = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%a, %b : tensor, tensor) { + ^bb0(%arg0 : f32, %arg1: f32): + %1 = addf %arg0, %arg1 : f32 + linalg.yield %1 : f32 + } -> tensor + %d = linalg.tensor_reshape %c + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] + : tensor into tensor + return %d : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK: func @reshape_as_consumer_permutation +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP3]], #[[MAP4]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T2:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]] +// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) +// CHECK: return %[[T2]] : tensor + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d2)> + +func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) + -> tensor<8x33x4xf32> { + %cst = constant dense<2.000000e+00> : tensor<264x4xf32> + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) { + ^bb0(%arg1: f32, %arg2: f32): // no predecessors + %2 = mulf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + } -> tensor<264x4xf32> + %1 = linalg.tensor_reshape %0 [#map1, #map2] : + tensor<264x4xf32> into tensor<8x33x4xf32> + return %1 : tensor<8x33x4xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: func @generic_op_reshape_consumer_static +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32> +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32> +// CHECK: %[[T1:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>) +// CHECK: return %[[T1]] : tensor<8x33x4xf32> + +// ----- + +func @scalar_reshape(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) + -> tensor<1x10xf32> { + %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor + %1 = linalg.generic + {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} ins(%0 : tensor) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<10xf32> + %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1) -> (d0, d1)>] + : tensor<10xf32> into tensor<1x10xf32> + return %2 : tensor<1x10xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @scalar_reshape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32> +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] [] +// CHECK-SAME: tensor<1xf32> into tensor +// CHECK: %[[T1:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[T0]] : tensor) +// CHECK: return %[[T1]] : tensor<1x10xf32> diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir @@ -0,0 +1,241 @@ +// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s + + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_producer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_producer_fusion +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]] +// CHECK-NOT: linalg.generic + + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_consumer_fusion +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.generic + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_consumer_nofusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion +// CHECK: linalg.tensor_reshape + +// ----- + + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor) + -> tensor { + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.indexed_generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } + ins(%0 : tensor) { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = addi %arg6, %2 : i32 + linalg.yield %3 : i32 + } -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor) + -> tensor { + %0 = linalg.indexed_generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } + ins(%arg0 : tensor) { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = addi %arg6, %2 : i32 + linalg.yield %3 : i32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<3x7x5xf32> + return %1 : tensor<3x7x5xf32> +} + +// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x7x3xf32> + return %1 : tensor<5x7x3xf32> +} + +// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x3x7xf32> + return %1 : tensor<5x3x7xf32> +} + +// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> + + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d2)> +func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> { + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x3x7xf32> + %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32> + return %1 : tensor<5x21xf32> +} + +// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape