diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -407,9 +408,9 @@ namespace { -/// Implementation of fusion of generic ops. +/// Implementation of fusion of generic ops and indexed_generic ops. struct FuseGenericOpsOnTensors { - static bool isFusible(GenericOp producer, GenericOp consumer, + static bool isFusible(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx) { // Verify that // - the producer has all "parallel" iterator type. @@ -428,7 +429,7 @@ return producerResultIndexMap.isPermutation(); } - static Operation *fuse(GenericOp producer, GenericOp consumer, + static Operation *fuse(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, PatternRewriter &rewriter, OperationFolder *folder = nullptr) { if (!isFusible(producer, consumer, consumerIdx)) @@ -454,7 +455,8 @@ // indexing_map of the operand at consumerIdx in the consumer. SmallVector fusedIndexMaps; auto consumerIndexMaps = consumer.indexing_maps(); - fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults()); + fusedIndexMaps.reserve(fusedOperands.size() + + consumer.getOperation()->getNumResults()); fusedIndexMaps.assign(consumerIndexMaps.begin(), std::next(consumerIndexMaps.begin(), consumerIdx)); // Compute indexing maps for the producer args in the fused operation. @@ -466,15 +468,56 @@ consumerIndexMaps.end()); // Generate the fused op. - auto fusedOp = rewriter.create( - rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, - rewriter.getI64IntegerAttr(fusedOperands.size()), - rewriter.getI64IntegerAttr(consumer.getNumResults()), - rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr); - generateFusedRegion(rewriter, fusedOp.region(), producer.region(), - consumer.region(), consumerIdx); + LinalgOp fusedOp; + if (isa(producer.getOperation()) && + isa(consumer.getOperation())) { + fusedOp = + rewriter + .create( + rewriter.getUnknownLoc(), + consumer.getOperation()->getResultTypes(), fusedOperands, + rewriter.getI64IntegerAttr(fusedOperands.size()), + rewriter.getI64IntegerAttr( + consumer.getOperation()->getNumResults()), + rewriter.getArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr) + .getOperation(); + } else { + fusedOp = + rewriter + .create( + rewriter.getUnknownLoc(), + consumer.getOperation()->getResultTypes(), fusedOperands, + rewriter.getI64IntegerAttr(fusedOperands.size()), + rewriter.getI64IntegerAttr( + consumer.getOperation()->getNumResults()), + rewriter.getArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr) + .getOperation(); + } + + // Construct an AffineMap from consumer loops to producer loops. + // consumer loop -> tensor index + AffineMap consumerResultIndexMap = + consumer.getInputIndexingMap(consumerIdx); + // producer loop -> tensor index + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + // tensor index -> producer loop + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + // consumer loop -> producer loop + AffineMap consumerToProducerLoopsMap = + invProducerResultIndexMap.compose(consumerResultIndexMap); + + generateFusedRegion(rewriter, fusedOp, producer, consumer, + consumerToProducerLoopsMap, consumerIdx, + consumer.getNumLoops()); return fusedOp; } @@ -483,7 +526,7 @@ /// the `producer` to use in the fused operation given the indexing map of the /// result of the producer in the consumer. static void computeProducerOperandIndex( - GenericOp producer, AffineMap fusedConsumerArgIndexMap, + LinalgOp producer, AffineMap fusedConsumerArgIndexMap, SmallVectorImpl &fusedOpIndexingMapAttrs) { // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map // from consumer loop -> consumer arg tensor index/producer result tensor @@ -516,29 +559,68 @@ /// Generate the region of the fused operation. The region of the fused op /// must be empty. - static void generateFusedRegion(PatternRewriter &rewriter, - Region &fusedRegion, Region &producerRegion, - Region &consumerRegion, - unsigned consumerIdx) { + static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp, + LinalgOp producer, LinalgOp consumer, + AffineMap consumerToProducerLoopsMap, + unsigned consumerIdx, unsigned nloops) { // Build the region of the fused op. - Block &producerBlock = producerRegion.front(); - Block &consumerBlock = consumerRegion.front(); + Block &producerBlock = producer.getOperation()->getRegion(0).front(); + Block &consumerBlock = consumer.getOperation()->getRegion(0).front(); Block *fusedBlock = new Block(); - fusedRegion.push_back(fusedBlock); + fusedOp->getRegion(0).push_back(fusedBlock); BlockAndValueMapping mapper; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(fusedBlock); + + // The block arguments are + // [index_0, index_1, ... , + // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), + // producer_operand_0, ... , producer_operand_(n-1)], + // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] + // , where n is the number of producer's operand and m is the number + // consumer's operand. + // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a + // generic op. In this case, there are no indices in block arguments. + unsigned numProducerIndices = + isa(producer.getOperation()) ? nloops : 0; + unsigned numConsumerIndices = + isa(consumer.getOperation()) ? nloops : 0; + // Firstly, add all the indices to the block arguments. + for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices); + i < e; ++i) + fusedBlock->addArgument(rewriter.getIndexType()); // Map the arguments for the unmodified args from the consumer. for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { - if (consumerArg.index() == consumerIdx) { + if (consumerArg.index() == consumerIdx + numConsumerIndices) { // Map the arguments for the args from the producer. - for (auto producerArg : producerBlock.getArguments()) - mapper.map(producerArg, - fusedBlock->addArgument(producerArg.getType())); + for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) { + // If producer is an indexed_generic op, map the indices from consumer + // loop to producer loop (because the fusedOp is built based on + // consumer's perspective). + if (producerArg.index() < numProducerIndices) { + auto newIndex = rewriter.create( + producer.getLoc(), + consumerToProducerLoopsMap.getSubMap(producerArg.index()), + fusedBlock->getArguments().take_front(nloops)); + mapper.map(producerArg.value(), newIndex); + } else { + mapper.map(producerArg.value(), + fusedBlock->addArgument(producerArg.value().getType())); + } + } continue; } - mapper.map(consumerArg.value(), - fusedBlock->addArgument(consumerArg.value().getType())); + + // If consumer is an indexed_generic op, map the indices to the block + // arguments directly. Otherwise, add the same type of arugment and map to + // it. + if (consumerArg.index() < numConsumerIndices) { + mapper.map(consumerArg.value(), + fusedBlock->getArgument(consumerArg.index())); + } else { + mapper.map(consumerArg.value(), + fusedBlock->addArgument(consumerArg.value().getType())); + } } // Add operations from producer (except the yield operation) to the fused @@ -548,7 +630,8 @@ // Lookup the value the yield operation is mapped to. Value yieldVal = yieldOp.getOperand(0); auto clonedVal = mapper.lookup(yieldVal); - mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal); + mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices), + clonedVal); continue; } rewriter.clone(op, mapper); @@ -750,17 +833,22 @@ if (!producer || producer->getNumResults() != 1) return nullptr; - // Fuse when consumer is GenericOp. - if (GenericOp genericOp = dyn_cast(consumer)) { - if (!genericOp.hasTensorSemantics()) + // Fuse when consumer is GenericOp or IndexedGenericOp. + if (isa(consumer) || isa(consumer)) { + auto linalgOpConsumer = cast(consumer); + if (!linalgOpConsumer.hasTensorSemantics()) return nullptr; - if (auto genericOpProducer = dyn_cast(producer)) { - if (genericOpProducer.hasTensorSemantics()) - return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp, + if (isa(producer) || isa(producer)) { + auto linalgOpProducer = cast(producer); + if (linalgOpProducer.hasTensorSemantics()) + return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer, consumerIdx, rewriter, folder); } else if (auto reshapeOpProducer = dyn_cast(producer)) { - return FuseTensorReshapeOpAsProducer::fuse( - reshapeOpProducer, genericOp, consumerIdx, rewriter, folder); + if (auto genericOpConsumer = dyn_cast(consumer)) { + return FuseTensorReshapeOpAsProducer::fuse( + reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter, + folder); + } } return nullptr; } @@ -774,6 +862,7 @@ } return nullptr; } + return nullptr; } @@ -820,8 +909,8 @@ void mlir::populateLinalgTensorOpsFusionPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert, FuseTensorOps>( - context); + patterns.insert, FuseTensorOps, + FuseTensorOps>(context); } std::unique_ptr> mlir::createLinalgFusionPass() { diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -219,3 +219,150 @@ // CHECK-LABEL: func @generic_op_reshape_consumer_nofusion // CHECK: linalg.tensor_reshape + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @generic_op_indexed_generic_op_fusion(%arg0: tensor, + %arg1: tensor) { + %0 = linalg.generic { + args_in = 2 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] } %arg0, %arg1 { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + %10 = addi %arg2, %arg3 : i32 + linalg.yield %10 : i32 + } : tensor, tensor -> tensor + %1 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } %0 { + ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = index_cast %arg3 : index to i32 + %4 = addi %arg4, %2 : i32 + %5 = subi %4, %3 : i32 + linalg.yield %5 : i32 + }: tensor -> tensor + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion +// CHECK-NOT: linalg.generic +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 2 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ARG3]] : i32 +// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32 +// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32 +// CHECK: %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32 +// CHECK: %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32 +// CHECK: linalg.yield %[[VAL3]] : i32 + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @indexed_generic_op_generic_op_fusion(%arg0: tensor, + %arg1: tensor) { + %0 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } %arg0 { + ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = index_cast %arg3 : index to i32 + %4 = addi %arg4, %2 : i32 + %5 = subi %4, %3 : i32 + linalg.yield %5 : i32 + }: tensor -> tensor + %1 = linalg.generic { + args_in = 2 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] } %0, %arg1 { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + %10 = addi %arg2, %arg3 : i32 + linalg.yield %10 : i32 + } : tensor, tensor -> tensor + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 2 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32 +// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND]] : i32 +// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32 +// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG3]] : i32 +// CHECK: linalg.yield %[[VAL3]] : i32 +// CHECK-NOT: linalg.generic + +// ----- + +// The indices of the first indexed_generic op are swapped after fusion. +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func @indexed_generic_op_fusion(%arg0: tensor) { + %0 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } %arg0 { + ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = index_cast %arg3 : index to i32 + %4 = addi %arg4, %2 : i32 + %5 = subi %4, %3 : i32 + linalg.yield %5 : i32 + }: tensor -> tensor + %1 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map1, #map1], + iterator_types = ["parallel", "parallel"] } %0 { + ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = index_cast %arg3 : index to i32 + %4 = addi %arg4, %2 : i32 + %5 = subi %4, %3 : i32 + linalg.yield %5 : i32 + }: tensor -> tensor + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @indexed_generic_op_fusion +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 1 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32 +// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32 +// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND1]] : i32 +// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[ARG0]] : index to i32 +// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[ARG1]] : index to i32 +// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32 +// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 +// CHECK: linalg.yield %[[VAL4]] : i32 +// CHECK-NOT: linalg.indexed_generic