diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -24,246 +24,240 @@ using namespace mlir; using namespace mlir::linalg; -namespace { - /// Implementation of fusion of generic ops and indexed_generic ops. -struct FuseGenericOpsOnTensors { - static bool isFusable(LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx) { - // Producer and consumer must have tensor semantics. - if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) - return false; - - // Verify that - // - the producer has all "parallel" iterator type. - if (producer.getNumParallelLoops() != producer.getNumLoops()) - return false; - - // Get the consumer index map. The number of results of the consumer index - // map must match the number of loops of the producer. - AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); - if (consumerIndexMap.getNumResults() != producer.getNumLoops()) - return false; - - // Finally the index_map for the result must be invertible. For now just - // verify it is a permutation. - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - return producerResultIndexMap.isPermutation(); - } +// struct FuseGenericOpsOnTensors { +static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer, + unsigned consumerIdx) { + // Producer and consumer must have tensor semantics. + if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) + return false; - static Optional> - fuse(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, - PatternRewriter &rewriter, OperationFolder *folder = nullptr) { - if (!isFusable(producer, consumer, consumerIdx)) - return llvm::None; - - unsigned numFusedOperands = - producer.getNumInputs() + consumer.getNumInputs() - 1; - - // Compute the fused operands list, - SmallVector fusedOperands; - fusedOperands.reserve(numFusedOperands); - auto consumerOperands = consumer.getInputs(); - auto producerOperands = producer.getInputs(); - fusedOperands.assign(consumerOperands.begin(), - std::next(consumerOperands.begin(), consumerIdx)); - fusedOperands.append(producerOperands.begin(), producerOperands.end()); - fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), - consumerOperands.end()); - - // Compute indexing_maps for the fused operation. The indexing_maps for the - // operands of the consumers that arent fused are the same. The - // indexing_maps for the producers need to be computed based on the - // indexing_map of the operand at consumerIdx in the consumer. - SmallVector fusedIndexMaps; - auto consumerIndexMaps = consumer.indexing_maps(); - fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs()); - fusedIndexMaps.assign(consumerIndexMaps.begin(), - std::next(consumerIndexMaps.begin(), consumerIdx)); - // Compute indexing maps for the producer args in the fused operation. - computeProducerOperandIndex( - producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); - - // Append the indexing maps for the remaining consumer operands. - fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), - consumerIndexMaps.end()); - - // Generate the fused op. - // Tensor-level fusion is only on ops without initTensors and outputBuffers. - LinalgOp fusedOp; - if (isa(producer.getOperation()) && - isa(consumer.getOperation())) { - fusedOp = - rewriter - .create(consumer.getLoc(), - consumer.getOperation()->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, - rewriter.getArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr) - .getOperation(); - } else { - fusedOp = - rewriter - .create( - consumer.getLoc(), consumer.getOperation()->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputBuffers=*/ValueRange{}, - /*initTensors=*/ValueRange{}, - rewriter.getArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr) - .getOperation(); - } + // Verify that + // - the producer has all "parallel" iterator type. + if (producer.getNumParallelLoops() != producer.getNumLoops()) + return false; - // Construct an AffineMap from consumer loops to producer loops. - // consumer loop -> tensor index - AffineMap consumerResultIndexMap = - consumer.getInputIndexingMap(consumerIdx); - // producer loop -> tensor index - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - // tensor index -> producer loop - AffineMap invProducerResultIndexMap = - inversePermutation(producerResultIndexMap); - assert(invProducerResultIndexMap && - "expected producer result indexig map to be invertible"); - // consumer loop -> producer loop - AffineMap consumerToProducerLoopsMap = - invProducerResultIndexMap.compose(consumerResultIndexMap); - - generateFusedRegion(rewriter, fusedOp, producer, consumer, - consumerToProducerLoopsMap, consumerIdx, - consumer.getNumLoops()); - return SmallVector(fusedOp.getOperation()->getResults()); - } + // Get the consumer index map. The number of results of the consumer index + // map must match the number of loops of the producer. + AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); + if (consumerIndexMap.getNumResults() != producer.getNumLoops()) + return false; -private: - /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of - /// the `producer` to use in the fused operation given the indexing map of the - /// result of the producer in the consumer. - static void computeProducerOperandIndex( - LinalgOp producer, AffineMap fusedConsumerArgIndexMap, - SmallVectorImpl &fusedOpIndexingMapAttrs) { - // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map - // from consumer loop -> consumer arg tensor index/producer result tensor - // index. The fused loop is same as the consumer loop. For each producer arg - // the indexing map to be computed is a map from consumer loop -> producer - // arg tensor index. - - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - // producerResultIndexMap is a map from producer loop -> tensor index. - // Compute the inverse to get map from tensor index -> producer loop. - // The inverse is a map from producer result tensor index -> producer loop. - AffineMap invProducerResultIndexMap = - inversePermutation(producerResultIndexMap); - assert(invProducerResultIndexMap && - "expected producer result indexig map to be invertible"); - for (unsigned argNum : llvm::seq(0, producer.getNumInputs())) { - // argMap is a map from producer loop -> producer arg tensor index. - AffineMap argMap = producer.getInputIndexingMap(argNum); - - // Compose argMap with invProducerResultIndexMap to get a map from - // producer result tensor index -> producer arg tensor index. - AffineMap t1 = argMap.compose(invProducerResultIndexMap); - - // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from - // consumer loop/ fused loop -> producer arg tensor index. - AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); - fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); - } + // Finally the index_map for the result must be invertible. For now just + // verify it is a permutation. + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + return producerResultIndexMap.isPermutation(); +} + +/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of +/// the `producer` to use in the fused operation given the indexing map of the +/// result of the producer in the consumer. +static void getIndexingMapOfProducerOperandsInFusedOp( + LinalgOp producer, AffineMap fusedConsumerArgIndexMap, + SmallVectorImpl &fusedOpIndexingMapAttrs) { + // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map + // from consumer loop -> consumer arg tensor index/producer result tensor + // index. The fused loop is same as the consumer loop. For each producer arg + // the indexing map to be computed is a map from consumer loop -> producer + // arg tensor index. + + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + // producerResultIndexMap is a map from producer loop -> tensor index. + // Compute the inverse to get map from tensor index -> producer loop. + // The inverse is a map from producer result tensor index -> producer loop. + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + for (unsigned argNum : llvm::seq(0, producer.getNumInputs())) { + // argMap is a map from producer loop -> producer arg tensor index. + AffineMap argMap = producer.getInputIndexingMap(argNum); + + // Compose argMap with invProducerResultIndexMap to get a map from + // producer result tensor index -> producer arg tensor index. + AffineMap t1 = argMap.compose(invProducerResultIndexMap); + + // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from + // consumer loop/ fused loop -> producer arg tensor index. + AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); + fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); } +} - /// Generate the region of the fused operation. The region of the fused op - /// must be empty. - static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp, - LinalgOp producer, LinalgOp consumer, - AffineMap consumerToProducerLoopsMap, - unsigned consumerIdx, unsigned nloops) { - // Build the region of the fused op. - Block &producerBlock = producer.getOperation()->getRegion(0).front(); - Block &consumerBlock = consumer.getOperation()->getRegion(0).front(); - Block *fusedBlock = new Block(); - fusedOp->getRegion(0).push_back(fusedBlock); - BlockAndValueMapping mapper; - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(fusedBlock); - - // The block arguments are - // [index_0, index_1, ... , - // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), - // producer_operand_0, ... , producer_operand_(n-1)], - // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] - // , where n is the number of producer's operand and m is the number - // consumer's operand. - // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a - // generic op. In this case, there are no indices in block arguments. - unsigned numProducerIndices = - isa(producer.getOperation()) ? nloops : 0; - unsigned numConsumerIndices = - isa(consumer.getOperation()) ? nloops : 0; - // Firstly, add all the indices to the block arguments. - for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices); - i < e; ++i) - fusedBlock->addArgument(rewriter.getIndexType()); - // Map the arguments for the unmodified args from the consumer. - for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { - if (consumerArg.index() == consumerIdx + numConsumerIndices) { - // Map the arguments for the args from the producer. - for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) { - // If producer is an indexed_generic op, map the indices from consumer - // loop to producer loop (because the fusedOp is built based on - // consumer's perspective). - if (producerArg.index() < numProducerIndices) { - auto newIndex = rewriter.create( - producer.getLoc(), - consumerToProducerLoopsMap.getSubMap(producerArg.index()), - fusedBlock->getArguments().take_front(nloops)); - mapper.map(producerArg.value(), newIndex); - } else { - mapper.map(producerArg.value(), - fusedBlock->addArgument(producerArg.value().getType())); - } +/// Generate the region of the fused tensor operation. The region of the fused +/// op must be empty. +static void generateFusedTensorOpRegion(PatternRewriter &rewriter, + Operation *fusedOp, LinalgOp producer, + LinalgOp consumer, + AffineMap consumerToProducerLoopsMap, + unsigned consumerIdx, unsigned nloops) { + // Build the region of the fused op. + Block &producerBlock = producer.getOperation()->getRegion(0).front(); + Block &consumerBlock = consumer.getOperation()->getRegion(0).front(); + Block *fusedBlock = new Block(); + fusedOp->getRegion(0).push_back(fusedBlock); + BlockAndValueMapping mapper; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(fusedBlock); + + // The block arguments are + // [index_0, index_1, ... , + // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), + // producer_operand_0, ... , producer_operand_(n-1)], + // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] + // , where n is the number of producer's operand and m is the number + // consumer's operand. + // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a + // generic op. In this case, there are no indices in block arguments. + unsigned numProducerIndices = + isa(producer.getOperation()) ? nloops : 0; + unsigned numConsumerIndices = + isa(consumer.getOperation()) ? nloops : 0; + // Firstly, add all the indices to the block arguments. + for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices); + i < e; ++i) + fusedBlock->addArgument(rewriter.getIndexType()); + // Map the arguments for the unmodified args from the consumer. + for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { + if (consumerArg.index() == consumerIdx + numConsumerIndices) { + // Map the arguments for the args from the producer. + for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) { + // If producer is an indexed_generic op, map the indices from consumer + // loop to producer loop (because the fusedOp is built based on + // consumer's perspective). + if (producerArg.index() < numProducerIndices) { + auto newIndex = rewriter.create( + producer.getLoc(), + consumerToProducerLoopsMap.getSubMap(producerArg.index()), + fusedBlock->getArguments().take_front(nloops)); + mapper.map(producerArg.value(), newIndex); + } else { + mapper.map(producerArg.value(), + fusedBlock->addArgument(producerArg.value().getType())); } - continue; } + continue; + } - // If consumer is an indexed_generic op, map the indices to the block - // arguments directly. Otherwise, add the same type of arugment and map to - // it. - if (consumerArg.index() < numConsumerIndices) { - mapper.map(consumerArg.value(), - fusedBlock->getArgument(consumerArg.index())); - } else { - mapper.map(consumerArg.value(), - fusedBlock->addArgument(consumerArg.value().getType())); - } + // If consumer is an indexed_generic op, map the indices to the block + // arguments directly. Otherwise, add the same type of arugment and map to + // it. + if (consumerArg.index() < numConsumerIndices) { + mapper.map(consumerArg.value(), + fusedBlock->getArgument(consumerArg.index())); + } else { + mapper.map(consumerArg.value(), + fusedBlock->addArgument(consumerArg.value().getType())); } + } - // Add operations from producer (except the yield operation) to the fused - // op. - for (auto &op : producerBlock.getOperations()) { - if (auto yieldOp = dyn_cast(op)) { - // Lookup the value the yield operation is mapped to. - Value yieldVal = yieldOp.getOperand(0); - if (Value clonedVal = mapper.lookupOrNull(yieldVal)) - mapper.map( - consumerBlock.getArgument(consumerIdx + numConsumerIndices), - clonedVal); - continue; - } - rewriter.clone(op, mapper); + // Add operations from producer (except the yield operation) to the fused + // op. + for (auto &op : producerBlock.getOperations()) { + if (auto yieldOp = dyn_cast(op)) { + // Lookup the value the yield operation is mapped to. + Value yieldVal = yieldOp.getOperand(0); + if (Value clonedVal = mapper.lookupOrNull(yieldVal)) + mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices), + clonedVal); + continue; } - for (auto &op : consumerBlock.getOperations()) - rewriter.clone(op, mapper); + rewriter.clone(op, mapper); } -}; -} // namespace + for (auto &op : consumerBlock.getOperations()) + rewriter.clone(op, mapper); +} + +static Optional> +fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, + PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { + if (!areTensorOpsFusable(producer, consumer, consumerIdx)) + return llvm::None; + + unsigned numFusedOperands = + producer.getNumInputs() + consumer.getNumInputs() - 1; + + // Compute the fused operands list, + SmallVector fusedOperands; + fusedOperands.reserve(numFusedOperands); + auto consumerOperands = consumer.getInputs(); + auto producerOperands = producer.getInputs(); + fusedOperands.assign(consumerOperands.begin(), + std::next(consumerOperands.begin(), consumerIdx)); + fusedOperands.append(producerOperands.begin(), producerOperands.end()); + fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), + consumerOperands.end()); + + // Compute indexing_maps for the fused operation. The indexing_maps for the + // operands of the consumers that arent fused are the same. The + // indexing_maps for the producers need to be computed based on the + // indexing_map of the operand at consumerIdx in the consumer. + SmallVector fusedIndexMaps; + auto consumerIndexMaps = consumer.indexing_maps(); + fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs()); + fusedIndexMaps.assign(consumerIndexMaps.begin(), + std::next(consumerIndexMaps.begin(), consumerIdx)); + // Compute indexing maps for the producer args in the fused operation. + getIndexingMapOfProducerOperandsInFusedOp( + producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); + + // Append the indexing maps for the remaining consumer operands. + fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), + consumerIndexMaps.end()); + + // Generate the fused op. + // Tensor-level fusion is only on ops without initTensors and outputBuffers. + LinalgOp fusedOp; + if (isa(producer.getOperation()) && + isa(consumer.getOperation())) { + fusedOp = rewriter + .create(consumer.getLoc(), + consumer.getOperation()->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, + rewriter.getArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) + .getOperation(); + } else { + fusedOp = + rewriter + .create(consumer.getLoc(), + consumer.getOperation()->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, + rewriter.getArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) + .getOperation(); + } + + // Construct an AffineMap from consumer loops to producer loops. + // consumer loop -> tensor index + AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx); + // producer loop -> tensor index + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + // tensor index -> producer loop + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + // consumer loop -> producer loop + AffineMap consumerToProducerLoopsMap = + invProducerResultIndexMap.compose(consumerResultIndexMap); + + generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer, + consumer, consumerToProducerLoopsMap, consumerIdx, + consumer.getNumLoops()); + return SmallVector(fusedOp.getOperation()->getResults()); +} /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` /// provided, given the shape of the source tensor that corresponds to the @@ -878,9 +872,8 @@ !isa(producer)) return llvm::None; - return FuseGenericOpsOnTensors::fuse(cast(producer), - cast(consumer), consumerIdx, - rewriter, folder); + return fuseTensorOpsImpl(cast(producer), cast(consumer), + consumerIdx, rewriter, folder); } namespace {