diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -407,31 +408,68 @@ namespace { -/// Implementation of fusion of generic ops. -struct FuseGenericOpsOnTensors { - static bool isFusible(GenericOp producer, GenericOp consumer, - unsigned consumerIdx) { - // Verify that - // - the producer has all "parallel" iterator type. - if (producer.getNumParallelLoops() != producer.getNumLoops()) - return false; +static bool isGenericOpFusible(LinalgOp producer, LinalgOp consumer, + unsigned consumerIdx) { + // Verify that + // - the producer has all "parallel" iterator type. + if (producer.getNumParallelLoops() != producer.getNumLoops()) + return false; - // Get the consumer index map. The number of results of the consumer index - // map must match the number of loops of the producer. - AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); - if (consumerIndexMap.getNumResults() != producer.getNumLoops()) - return false; + // Get the consumer index map. The number of results of the consumer index + // map must match the number of loops of the producer. + AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); + if (consumerIndexMap.getNumResults() != producer.getNumLoops()) + return false; - // Finally the index_map for the result must be invertible. For now just - // verify it is a permutation. - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - return producerResultIndexMap.isPermutation(); + // Finally the index_map for the result must be invertible. For now just + // verify it is a permutation. + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + return producerResultIndexMap.isPermutation(); +} + +/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of +/// the `producer` to use in the fused operation given the indexing map of the +/// result of the producer in the consumer. +static void computeProducerOperandIndex( + LinalgOp producer, AffineMap fusedConsumerArgIndexMap, + SmallVectorImpl &fusedOpIndexingMapAttrs) { + // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map + // from consumer loop -> consumer arg tensor index/producer result tensor + // index. The fused loop is same as the consumer loop. For each producer arg + // the indexing map to be computed is a map from consumer loop -> producer + // arg tensor index. + + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + // producerResultIndexMap is a map from producer loop -> tensor index. + // Compute the inverse to get map from tensor index -> producer loop. + // The inverse is a map from producer result tensor index -> producer loop. + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + for (unsigned argNum : llvm::seq(0, producer.getNumInputs())) { + // argMap is a map from producer loop -> producer arg tensor index. + AffineMap argMap = producer.getInputIndexingMap(argNum); + + // Compose argMap with invProducerResultIndexMap to get a map from + // producer result tensor index -> producer arg tensor index. + AffineMap t1 = argMap.compose(invProducerResultIndexMap); + + // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from + // consumer loop/ fused loop -> producer arg tensor index. + AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); + fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); } +} +/// Implementation of fusion of generic ops. +struct FuseGenericOpsOnTensors { static Operation *fuse(GenericOp producer, GenericOp consumer, unsigned consumerIdx, PatternRewriter &rewriter, OperationFolder *folder = nullptr) { - if (!isFusible(producer, consumer, consumerIdx)) + if (!isGenericOpFusible(cast(producer.getOperation()), + cast(consumer.getOperation()), + consumerIdx)) return nullptr; unsigned numFusedOperands = producer.getOperation()->getNumOperands() + @@ -458,8 +496,9 @@ fusedIndexMaps.assign(consumerIndexMaps.begin(), std::next(consumerIndexMaps.begin(), consumerIdx)); // Compute indexing maps for the producer args in the fused operation. - computeProducerOperandIndex( - producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); + computeProducerOperandIndex(cast(producer.getOperation()), + consumer.getInputIndexingMap(consumerIdx), + fusedIndexMaps); // Append the indexing maps for the remaining consumer operands. fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), @@ -479,41 +518,6 @@ } private: - /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of - /// the `producer` to use in the fused operation given the indexing map of the - /// result of the producer in the consumer. - static void computeProducerOperandIndex( - GenericOp producer, AffineMap fusedConsumerArgIndexMap, - SmallVectorImpl &fusedOpIndexingMapAttrs) { - // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map - // from consumer loop -> consumer arg tensor index/producer result tensor - // index. The fused loop is same as the consumer loop. For each producer arg - // the indexing map to be computed is a map from consumer loop -> producer - // arg tensor index. - - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - // producerResultIndexMap is a map from producer loop -> tensor index. - // Compute the inverse to get map from tensor index -> producer loop. - // The inverse is a map from producer result tensor index -> producer loop. - AffineMap invProducerResultIndexMap = - inversePermutation(producerResultIndexMap); - assert(invProducerResultIndexMap && - "expected producer result indexig map to be invertible"); - for (unsigned argNum : llvm::seq(0, producer.getNumInputs())) { - // argMap is a map from producer loop -> producer arg tensor index. - AffineMap argMap = producer.getInputIndexingMap(argNum); - - // Compose argMap with invProducerResultIndexMap to get a map from - // producer result tensor index -> producer arg tensor index. - AffineMap t1 = argMap.compose(invProducerResultIndexMap); - - // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from - // consumer loop/ fused loop -> producer arg tensor index. - AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); - fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); - } - } - /// Generate the region of the fused operation. The region of the fused op /// must be empty. static void generateFusedRegion(PatternRewriter &rewriter, @@ -557,6 +561,156 @@ rewriter.clone(op, mapper); } }; + +/// Implementation of fusion of generic ops. +struct FuseIndexedGenericOpsOnTensors { + template + static Operation *fuse(T1 producer, T2 consumer, unsigned consumerIdx, + PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { + if (!isGenericOpFusible(cast(producer.getOperation()), + cast(consumer.getOperation()), + consumerIdx)) + return nullptr; + + unsigned numFusedOperands = producer.getOperation()->getNumOperands() + + consumer.getOperation()->getNumOperands() - 1; + + // Compute the fused operands list, + SmallVector fusedOperands; + fusedOperands.reserve(numFusedOperands); + auto consumerOperands = consumer.getOperation()->getOperands(); + auto producerOperands = producer.getOperation()->getOperands(); + fusedOperands.assign(consumerOperands.begin(), + std::next(consumerOperands.begin(), consumerIdx)); + fusedOperands.append(producerOperands.begin(), producerOperands.end()); + fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), + consumerOperands.end()); + + // Compute indexing_maps for the fused operation. The indexing_maps for the + // operands of the consumers that arent fused are the same. The + // indexing_maps for the producers need to be computed based on the + // indexing_map of the operand at consumerIdx in the consumer. + SmallVector fusedIndexMaps; + auto consumerIndexMaps = consumer.indexing_maps(); + fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults()); + fusedIndexMaps.assign(consumerIndexMaps.begin(), + std::next(consumerIndexMaps.begin(), consumerIdx)); + // Compute indexing maps for the producer args in the fused operation. + computeProducerOperandIndex(cast(producer.getOperation()), + consumer.getInputIndexingMap(consumerIdx), + fusedIndexMaps); + + // Append the indexing maps for the remaining consumer operands. + fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), + consumerIndexMaps.end()); + + // Generate the fused op. + auto fusedOp = rewriter.create( + rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, + rewriter.getI64IntegerAttr(fusedOperands.size()), + rewriter.getI64IntegerAttr(consumer.getNumResults()), + rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + AffineMap consumerResultIndexMap = + consumer.getInputIndexingMap(consumerIdx); + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + generateFusedIndexedRegion( + rewriter, fusedOp, cast(producer.getOperation()), + cast(consumer.getOperation()), + consumerResultIndexMap.compose(invProducerResultIndexMap), consumerIdx, + consumer.getNumLoops()); + return fusedOp; + } + +private: + /// Generate the region of the fused operation. The region of the fused op + /// must be empty. + static void generateFusedIndexedRegion(PatternRewriter &rewriter, + IndexedGenericOp &fusedOp, + LinalgOp producer, LinalgOp consumer, + AffineMap consumerToProducerLoopsMap, + unsigned consumerIdx, + unsigned nloops) { + // Build the region of the fused op. + Block &producerBlock = producer.getOperation()->getRegion(0).front(); + Block &consumerBlock = consumer.getOperation()->getRegion(0).front(); + Block *fusedBlock = new Block(); + fusedOp.region().push_back(fusedBlock); + BlockAndValueMapping mapper; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(fusedBlock); + + // The block arguments are + // [index_0, index_1, ... , + // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), + // producer_operand_0, ... , producer_operand_(n-1)], + // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] + // , where n is the number of producer's operand and m is the number + // consumer's operand. + unsigned numProducerIndices = + isa(producer.getOperation()) ? nloops : 0; + unsigned numConsumerIndices = + isa(consumer.getOperation()) ? nloops : 0; + // Firstly, add all the indices to the block arguments. + for (unsigned i = 0; i < nloops; ++i) + fusedBlock->addArgument(rewriter.getIndexType()); + // Map the arguments for the unmodified args from the consumer. + for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { + if (consumerArg.index() == consumerIdx + numConsumerIndices) { + // Map the arguments for the args from the producer. + for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) { + // If producer is an indexed_generic op, map the indices from consumer + // loop to producer loop (because the fusedOp is built based on + // consumer's perspective). + if (producerArg.index() < numProducerIndices) { + auto newIndex = rewriter.create( + producer.getLoc(), + consumerToProducerLoopsMap.getSubMap(producerArg.index()), + fusedBlock->getArguments().take_front(nloops)); + mapper.map(producerArg.value(), newIndex); + } else { + mapper.map(producerArg.value(), + fusedBlock->addArgument(producerArg.value().getType())); + } + } + continue; + } + + // If consumer is an indexed_generic op, map the indices to the block + // arguments directly. Otherwise, add the same type of arugment and map to + // it. + if (consumerArg.index() < numConsumerIndices) { + mapper.map(consumerArg.value(), + fusedBlock->getArgument(consumerArg.index())); + } else { + mapper.map(consumerArg.value(), + fusedBlock->addArgument(consumerArg.value().getType())); + } + } + + // Add operations from producer (except the yield operation) to the fused + // op. + for (auto &op : producerBlock.getOperations()) { + if (auto yieldOp = dyn_cast(op)) { + // Lookup the value the yield operation is mapped to. + Value yieldVal = yieldOp.getOperand(0); + auto clonedVal = mapper.lookup(yieldVal); + mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices), + clonedVal); + continue; + } + rewriter.clone(op, mapper); + } + for (auto &op : consumerBlock.getOperations()) + rewriter.clone(op, mapper); + } +}; } // namespace /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` @@ -761,6 +915,10 @@ } else if (auto reshapeOpProducer = dyn_cast(producer)) { return FuseTensorReshapeOpAsProducer::fuse( reshapeOpProducer, genericOp, consumerIdx, rewriter, folder); + } else if (auto indexedGenericOpProducer = + dyn_cast(producer)) { + return FuseIndexedGenericOpsOnTensors::fuse( + indexedGenericOpProducer, genericOp, consumerIdx, rewriter, folder); } return nullptr; } @@ -774,6 +932,26 @@ } return nullptr; } + + // Fuse when consumer is IndexedGenericOp. + if (IndexedGenericOp indexedGenericOp = + dyn_cast(consumer)) { + if (!indexedGenericOp.hasTensorSemantics()) + return nullptr; + if (auto genericOpProducer = dyn_cast(producer)) { + if (genericOpProducer.hasTensorSemantics()) + return FuseIndexedGenericOpsOnTensors::fuse( + genericOpProducer, indexedGenericOp, consumerIdx, rewriter, folder); + } + if (auto indexedGenericOpProducer = dyn_cast(producer)) { + if (indexedGenericOpProducer.hasTensorSemantics()) + return FuseIndexedGenericOpsOnTensors::fuse( + indexedGenericOpProducer, indexedGenericOp, consumerIdx, rewriter, + folder); + } + return nullptr; + } + return nullptr; } @@ -820,8 +998,8 @@ void mlir::populateLinalgTensorOpsFusionPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert, FuseTensorOps>( - context); + patterns.insert, FuseTensorOps, + FuseTensorOps>(context); } std::unique_ptr> mlir::createLinalgFusionPass() { diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -219,3 +219,150 @@ // CHECK-LABEL: func @generic_op_reshape_consumer_nofusion // CHECK: linalg.tensor_reshape + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @generic_op_indexed_generic_op_fusion(%arg0: tensor<2x2xi32>, + %arg1: tensor<2x2xi32>) { + %0 = linalg.generic { + args_in = 2 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] } %arg0, %arg1 { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + %10 = addi %arg2, %arg3 : i32 + linalg.yield %10 : i32 + } : tensor<2x2xi32>, tensor<2x2xi32> -> tensor<2x2xi32> + %1 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } %0 { + ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = index_cast %arg3 : index to i32 + %4 = addi %arg4, %2 : i32 + %5 = subi %4, %3 : i32 + linalg.yield %5 : i32 + }: tensor<2x2xi32> -> tensor<2x2xi32> + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion +// CHECK-NOT: linalg.generic +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 2 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ARG3]] : i32 +// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32 +// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32 +// CHECK: %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32 +// CHECK: %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32 +// CHECK: linalg.yield %[[VAL3]] : i32 + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @indexed_generic_op_generic_op_fusion(%arg0: tensor<2x2xi32>, + %arg1: tensor<2x2xi32>) { + %0 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } %arg0 { + ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = index_cast %arg3 : index to i32 + %4 = addi %arg4, %2 : i32 + %5 = subi %4, %3 : i32 + linalg.yield %5 : i32 + }: tensor<2x2xi32> -> tensor<2x2xi32> + %1 = linalg.generic { + args_in = 2 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] } %0, %arg1 { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + %10 = addi %arg2, %arg3 : i32 + linalg.yield %10 : i32 + } : tensor<2x2xi32>, tensor<2x2xi32> -> tensor<2x2xi32> + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 2 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32 +// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND]] : i32 +// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32 +// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG3]] : i32 +// CHECK: linalg.yield %[[VAL3]] : i32 +// CHECK-NOT: linalg.generic + +// ----- + +// The indices of the first indexed_generic op are swapped after fusion. +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func @indexed_generic_op_fusion(%arg0: tensor<2x2xi32>) { + %0 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } %arg0 { + ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = index_cast %arg3 : index to i32 + %4 = addi %arg4, %2 : i32 + %5 = subi %4, %3 : i32 + linalg.yield %5 : i32 + }: tensor<2x2xi32> -> tensor<2x2xi32> + %1 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map1, #map1], + iterator_types = ["parallel", "parallel"] } %0 { + ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = index_cast %arg3 : index to i32 + %4 = addi %arg4, %2 : i32 + %5 = subi %4, %3 : i32 + linalg.yield %5 : i32 + }: tensor<2x2xi32> -> tensor<2x2xi32> + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @indexed_generic_op_fusion +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 1 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32 +// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32 +// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND1]] : i32 +// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[ARG0]] : index to i32 +// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[ARG1]] : index to i32 +// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32 +// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 +// CHECK: linalg.yield %[[VAL4]] : i32 +// CHECK-NOT: linalg.indexed_generic