diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -19,6 +19,7 @@ class AffineExpr; class AffineMap; class OperationFolder; +class PatternRewriter; namespace linalg { class LinalgDependenceGraph; @@ -71,11 +72,11 @@ const LinalgDependenceGraph &graph, OperationFolder *folder = nullptr); -/// Fuse linalg operation on tensors, where the result of the producer is used -/// as the operand of the consumer at position `consumerIdx`. -Optional fuseTensorOps(OpBuilder &b, LinalgOp producer, - LinalgOp consumer, unsigned consumerIdx, - OperationFolder *folder = nullptr); +/// Fuse linalg operation on tensors, with the producer of the operand at +/// position `consumerIdx` of the consumer. +Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, + unsigned consumerIdx, + OperationFolder *folder = nullptr); /// Returns the linearized list of all view dimensions in a linalgOp. Applying /// the inverse, concatenated loopToOperandRangeMaps to this list allows the diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -362,154 +362,6 @@ return llvm::None; } -/// Checks if two Generic ops are fusible, when one is a producer and another is -/// a consumer (with the result of the producer being the `consumerIdx` operand -/// of the consumer). -static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx) { - // Verify that the producer and consumer are ops on tensors. - if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) - return false; - - auto producerOp = dyn_cast(producer.getOperation()); - auto consumerOp = dyn_cast(consumer.getOperation()); - // Verify that - // - the producer and consumers are generic ops, - // - only handle cases where the producer has a single return value, - // - the producer return value should be the same as argument at `consumerIdx` - // of the consumer, - // - the producer has all "parallel" iterator type. - // - only handle ops that use regions for specifying the scalar operations. - if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || - producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || - producerOp.getNumParallelLoops() != producerOp.getNumLoops()) - return false; - - // Get the consumer index map. The number of results of the consumer index map - // must match the number of loops of the producer. - AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); - if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) - return false; - - // Finally the index_map for the result must be invertible. For now just - // verify it is a permutation. - AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0); - return producerResultIndexMap.isPermutation(); -} - -/// Computes the indexing maps for arguments of a producer generic op when the -/// result of the producer is fused with the consumer. -/// - consumerIndexMap is the indexing_map for the argument in the consumer op -/// that is the result of the producer op. -/// - invProducerResultIndexMap is the inverse of the indexing_map for the -/// result in the producer op. -/// - producerArgIndexMap is the indexing_map of the argument of the producer -/// op. -/// The result is the indexing_map to use for the producer argument when the -/// producer and consumer ops are fused. -static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, - AffineMap invProducerResultIndexMap, - AffineMap producerArgIndexMap) { - // t1 is map from producer result tensor index -> producer arg tensor index. - auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); - // The return is map from consumer loop -> producer arg tensor index, - // i.e. indexing_map for the producer argument in the fused operation. - return t1.compose(consumerIndexMap); -} - -Optional mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer, - LinalgOp consumer, - unsigned consumerIdx, - OperationFolder *folder) { - if (!areTensorOpsFusible(producer, consumer, consumerIdx)) - return {}; - - MLIRContext *context = b.getContext(); - auto producerOp = cast(producer.getOperation()); - auto consumerOp = cast(consumer.getOperation()); - AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); - AffineMap invProducerResultIndexMap = - inversePermutation(producerOp.getOutputIndexingMap(0)); - if (!invProducerResultIndexMap) - return {}; - - // Compute the fused op operandslist by replacing the operand corresponding to - // the result of the producer, with the operands of the producer. - unsigned fusedArgsIn = - producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; - auto fusedArgsOut = consumerOp.getNumOutputs(); - SmallVector fusedOperandsList(consumerOp.getOperands()); - fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); - fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); - fusedOperandsList.insert( - std::next(fusedOperandsList.begin(), consumerIdx), - producerOp.operand_begin(), - std::next(producerOp.operand_begin(), producerOp.getNumInputs())); - - // Compute the fused indexing_maps of the operands/results of the fused op. - SmallVector fusedIndexingMapAttrs; - fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); - fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), - consumerOp.indexing_maps().end()); - fusedIndexingMapAttrs.erase( - std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); - auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); - for (auto producerArgIndexAttr : - llvm::enumerate(producerOp.indexing_maps())) { - if (producerArgIndexAttr.index() == producerOp.getNumInputs()) - break; - auto composedIndexMap = computeProducerArgMap( - consumerIndexMap, invProducerResultIndexMap, - producerArgIndexAttr.value().cast().getValue()); - insertPos = std::next(fusedIndexingMapAttrs.insert( - insertPos, AffineMapAttr::get(composedIndexMap))); - } - - // Generate the fused op. - auto fusedLinalgOp = b.create( - UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList, - b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), - b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr); - - // Build the region of the fused op. - auto &fusedOpRegion = fusedLinalgOp.region(); - Block &producerOpBlock = producerOp.region().front(); - Block &consumerOpBlock = consumerOp.region().front(); - Block *fusedBlock = new Block(); - fusedOpRegion.push_back(fusedBlock); - BlockAndValueMapping mapper; - // Map the arguments for the unmodified args from the consumer. - for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { - if (consumerOpArg.index() == consumerIdx) { - // Map the arguments for the args from the producer. - for (auto producerOpArg : producerOpBlock.getArguments()) - mapper.map(producerOpArg, - fusedBlock->addArgument(producerOpArg.getType())); - continue; - } - mapper.map(consumerOpArg.value(), - fusedBlock->addArgument(consumerOpArg.value().getType())); - } - - // Add operations from producer (except the yield operation) to the fused op. - for (auto &op : producerOpBlock.getOperations()) { - if (auto yieldOp = dyn_cast(op)) { - // Lookup the value the yield operation is mapped to. - Value yieldVal = yieldOp.getOperand(0); - auto clonedVal = mapper.lookup(yieldVal); - mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); - continue; - } - fusedBlock->push_back(op.clone(mapper)); - } - for (auto &op : consumerOpBlock.getOperations()) - fusedBlock->push_back(op.clone(mapper)); - - return cast(fusedLinalgOp.getOperation()); -} - static void fuseLinalgOpsGreedily(FuncOp f) { LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); @@ -551,33 +403,222 @@ LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); } +//====---------------------------------------------------------------------===// +// Fusion on Tensor operation. +//====---------------------------------------------------------------------===// + namespace { -/// Patterns to fuse a generic op, with the producer of its operands. -struct FuseGenericTensorOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Implementation of fusion of different op types. +template +struct FuseTensorOpsImpl { + /// Generates a fused operation by fusing the `producer` and `consumer`, when + /// result of producer is used as operand at `consumerIdx` in the `consumer`. + static Operation *fuse(PatternRewriter &rewriter, ProducerOpTy producer, + ConsumerOpTy consumer, unsigned consumerIdx, + OperationFolder *folder = nullptr) { + if (!DerivedTy::isFusible(producer, consumer, consumerIdx)) + return nullptr; + + unsigned numFusedOperands = producer.getOperation()->getNumOperands() + + consumer.getOperation()->getNumOperands() - 1; + + // Compute the fused operands list, + SmallVector fusedOperands; + fusedOperands.reserve(numFusedOperands); + auto consumerOperands = consumer.getOperation()->getOperands(); + auto producerOperands = producer.getOperation()->getOperands(); + fusedOperands.assign(consumerOperands.begin(), + std::next(consumerOperands.begin(), consumerIdx)); + fusedOperands.append(producerOperands.begin(), producerOperands.end()); + fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), + consumerOperands.end()); + + return DerivedTy::createFusedOp(rewriter, fusedOperands, producer, consumer, + consumerIdx, folder); + } +}; - LogicalResult matchAndRewrite(GenericOp op, - PatternRewriter &rewriter) const override { - if (!op.hasTensorSemantics()) - return failure(); +/// Implementation of fusion of generic ops. +struct FuseGenericOpsOnTensors + : public FuseTensorOpsImpl { + static bool isFusible(GenericOp producer, GenericOp consumer, + unsigned consumerIdx) { + // Verify that + // - the producer has all "parallel" iterator type. + if (producer.getNumParallelLoops() != producer.getNumLoops()) + return false; + + // Get the consumer index map. The number of results of the consumer index + // map must match the number of loops of the producer. + AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); + if (consumerIndexMap.getNumResults() != producer.getNumLoops()) + return false; + + // Finally the index_map for the result must be invertible. For now just + // verify it is a permutation. + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + return producerResultIndexMap.isPermutation(); + } - // Find the first operand that is defined by another generic op on tensors. - for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { - auto definingOp = - dyn_cast_or_null(operand.value().getDefiningOp()); - if (!definingOp || !definingOp.hasTensorSemantics()) + static Operation *createFusedOp(PatternRewriter &rewriter, + ArrayRef fusedOperands, + GenericOp producer, GenericOp consumer, + unsigned consumerIdx, + OperationFolder *folder = nullptr) { + // Compute indexing_maps for the fused operation. The indexing_maps for the + // operands of the consumers that arent fused are the same. The + // indexing_maps for the producers need to be computed based on the + // indexing_map of the operand at consumerIdx in the consumer. + SmallVector fusedIndexMaps; + auto consumerIndexMaps = consumer.indexing_maps(); + fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults()); + fusedIndexMaps.assign(consumerIndexMaps.begin(), + std::next(consumerIndexMaps.begin(), consumerIdx)); + // Compute indexing maps for the producer args in the fused operation. + computeProducerOperandIndex( + producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); + + // Append the indexing maps for the remaining consumer operands. + fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), + consumerIndexMaps.end()); + + // Generate the fused op. + auto fusedOp = rewriter.create( + rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, + rewriter.getI64IntegerAttr(fusedOperands.size()), + rewriter.getI64IntegerAttr(consumer.getNumResults()), + rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + generateFusedRegion(rewriter, fusedOp.region(), producer.region(), + consumer.region(), consumerIdx); + return fusedOp; + } + +private: + /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of + /// the `producer` to use in the fused operation given the indexing map of the + /// result of the producer in the consumer. + static void computeProducerOperandIndex( + GenericOp producer, AffineMap fusedConsumerArgIndexMap, + SmallVectorImpl &fusedOpIndexingMapAttrs) { + // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map + // from consumer loop -> consumer arg tensor index/producer result tensor + // index. The fused loop is same as the consumer loop. For each producer arg + // the indexing map to be computed is a map from consumer loop -> producer + // arg tensor index. + + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + // producerResultIndexMap is a map from producer loop -> tensor index. + // Compute the inverse to get map from tensor index -> producer loop. + // The inverse is a map from producer result tensor index -> producer loop. + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + for (unsigned argNum : llvm::seq(0, producer.getNumInputs())) { + // argMap is a map from producer loop -> producer arg tensor index. + AffineMap argMap = producer.getInputIndexingMap(argNum); + + // Compose argMap with invProducerResultIndexMap to get a map from + // producer result tensor index -> producer arg tensor index. + AffineMap t1 = argMap.compose(invProducerResultIndexMap); + + // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from + // consumer loop/ fused loop -> producer arg tensor index. + AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); + fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); + } + } + + /// Generate the region of the fused operation. The region of the fused op + /// must be empty. + static void generateFusedRegion(PatternRewriter &rewriter, + Region &fusedRegion, Region &producerRegion, + Region &consumerRegion, + unsigned consumerIdx) { + // Build the region of the fused op. + Block &producerBlock = producerRegion.front(); + Block &consumerBlock = consumerRegion.front(); + Block *fusedBlock = new Block(); + fusedRegion.push_back(fusedBlock); + BlockAndValueMapping mapper; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(fusedBlock); + // Map the arguments for the unmodified args from the consumer. + for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { + if (consumerArg.index() == consumerIdx) { + // Map the arguments for the args from the producer. + for (auto producerArg : producerBlock.getArguments()) + mapper.map(producerArg, + fusedBlock->addArgument(producerArg.getType())); continue; - auto fusedOp = - fuseTensorOps(rewriter, cast(definingOp.getOperation()), - cast(op.getOperation()), operand.index()); - if (!fusedOp) + } + mapper.map(consumerArg.value(), + fusedBlock->addArgument(consumerArg.value().getType())); + } + + // Add operations from producer (except the yield operation) to the fused + // op. + for (auto &op : producerBlock.getOperations()) { + if (auto yieldOp = dyn_cast(op)) { + // Lookup the value the yield operation is mapped to. + Value yieldVal = yieldOp.getOperand(0); + auto clonedVal = mapper.lookup(yieldVal); + mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal); continue; - rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); - if (llvm::all_of(definingOp.getResults(), - [](Value val) -> bool { return val.use_empty(); })) - rewriter.eraseOp(definingOp); - return success(); + } + rewriter.clone(op, mapper); + } + for (auto &op : consumerBlock.getOperations()) + rewriter.clone(op, mapper); + } +}; +} // namespace + +Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, + Operation *consumer, + unsigned consumerIdx, + OperationFolder *folder) { + if (consumerIdx >= consumer->getNumOperands()) + return nullptr; + Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); + if (!producer || producer->getNumResults() != 1) + return nullptr; + + if (GenericOp genericOp = dyn_cast(consumer)) { + if (!genericOp.hasTensorSemantics()) + return nullptr; + if (auto genericOpProducer = dyn_cast(producer)) { + if (genericOpProducer.hasTensorSemantics()) + return FuseGenericOpsOnTensors::fuse(rewriter, genericOpProducer, + genericOp, consumerIdx); + } + } + return nullptr; +} + +namespace { +/// Patterns to fuse a generic op, with the producer of its operands. +template +struct FuseTensorOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LinalgOpTy op, + PatternRewriter &rewriter) const override { + // Find the first operand that is defined by another generic op on tensors. + for (auto operandNum : + llvm::seq(0, op.getOperation()->getNumOperands())) { + Operation *producer = + op.getOperation()->getOperand(operandNum).getDefiningOp(); + if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) { + rewriter.replaceOp(op, fusedOp->getResults()); + if (producer && llvm::all_of(producer->getResults(), + [](Value val) { return val.use_empty(); })) + rewriter.eraseOp(producer); + return success(); + } } return failure(); } @@ -589,7 +630,7 @@ void runOnOperation() override { OwningRewritePatternList patterns; Operation *op = getOperation(); - patterns.insert(op->getContext()); + patterns.insert>(op->getContext()); applyPatternsAndFoldGreedily(op->getRegions(), patterns); }; };