diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -440,6 +440,10 @@ struct FuseGenericOpsOnTensors { static bool isFusible(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx) { + // Producer and consumer must have tensor semantics. + if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) + return false; + // Verify that // - the producer has all "parallel" iterator type. if (producer.getNumParallelLoops() != producer.getNumLoops()) @@ -457,9 +461,9 @@ return producerResultIndexMap.isPermutation(); } - static Operation *fuse(LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { + static LinalgOp fuse(LinalgOp producer, LinalgOp consumer, + unsigned consumerIdx, PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { if (!isFusible(producer, consumer, consumerIdx)) return nullptr; @@ -736,24 +740,45 @@ return useIndexMap.isIdentity(); } +/// Based on the type of `op` create a linalg op of the same type, i.e. if `op` +/// is a linalg.generic operation, the create a `linalg.generic` operation with +/// the given `args`. Expects `op` to be `linalg.generic` or +/// `linalg.indexed_generic`. +template +static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter, + Args... args) { + if (isa(op.getOperation())) + return cast(rewriter.create(args...).getOperation()); + if (isa(op.getOperation())) + return cast( + rewriter.create(args...).getOperation()); + llvm_unreachable( + "expected only linalg.generic or linalg.indexed_generic ops"); + return nullptr; +} + namespace { + /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp. -template struct FuseTensorReshapeOpAsProducer { - static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer, +struct FuseTensorReshapeOpAsProducer { + static bool isFusible(TensorReshapeOp producer, LinalgOp consumer, unsigned consumerIdx) { - return isTensorReshapeOpFusible( - producer, consumer.getInputIndexingMap(consumerIdx), true); + return isa(consumer.getOperation()) && + consumer.hasTensorSemantics() && + isTensorReshapeOpFusible(producer, + consumer.getInputIndexingMap(consumerIdx), + /*asProducer=*/true); } - static Operation *fuse(TensorReshapeOp producer, LinalgOpTy consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { + static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer, + unsigned consumerIdx, PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { if (!isFusible(producer, consumer, consumerIdx)) return nullptr; // Compute the fused operands list, - SmallVector fusedOperands(consumer.operand_begin(), - consumer.operand_end()); + Operation *consumerOp = consumer.getOperation(); + SmallVector fusedOperands(consumerOp->getOperands()); fusedOperands[consumerIdx] = producer.src(); // Compute indexing_maps for the fused operation. The indexing_maps for the @@ -783,32 +808,35 @@ llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })); - auto fusedOp = rewriter.create( - rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, + LinalgOp fusedOp = createLinalgOpOfSameType( + consumer, rewriter, rewriter.getUnknownLoc(), + consumerOp->getResultTypes(), fusedOperands, rewriter.getI64IntegerAttr(fusedOperands.size()), - rewriter.getI64IntegerAttr(consumer.getNumResults()), + rewriter.getI64IntegerAttr(consumerOp->getNumResults()), rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr, /*symbol_source=*/nullptr); - auto &fusedRegion = fusedOp.region(); - rewriter.cloneRegionBefore(consumer.region(), fusedRegion, + auto &fusedRegion = fusedOp.getOperation()->getRegion(0); + rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion, fusedRegion.begin()); return fusedOp; } }; /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. -template struct FuseTensorReshapeOpAsConsumer { - static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer, +struct FuseTensorReshapeOpAsConsumer { + static bool isFusible(LinalgOp producer, TensorReshapeOp consumer, unsigned consumerIdx) { - return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), - false); + return isa(producer.getOperation()) && + producer.hasTensorSemantics() && + isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), + /*asProducer=*/false); } - static Operation *fuse(LinalgOpTy producer, TensorReshapeOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { + static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer, + unsigned consumerIdx, PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { if (!isFusible(producer, consumer, consumerIdx)) return nullptr; @@ -839,33 +867,36 @@ return AffineMapAttr::get(map); })); - auto fusedOp = rewriter.create( - rewriter.getUnknownLoc(), consumer.getResultType(), - producer.getOperands(), - rewriter.getI64IntegerAttr(producer.getNumOperands()), + Operation *producerOp = producer.getOperation(); + LinalgOp fusedOp = createLinalgOpOfSameType( + producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(), + producerOp->getOperands(), + rewriter.getI64IntegerAttr(producerOp->getNumOperands()), rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr, /*symbol_source=*/nullptr); - auto &fusedRegion = fusedOp.region(); - rewriter.cloneRegionBefore(producer.region(), fusedRegion, + auto &fusedRegion = fusedOp.getOperation()->getRegion(0); + rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion, fusedRegion.begin()); return fusedOp; } }; /// Implementation of fusion on tensor ops when producer is a splat constant. -template struct FuseConstantOpAsProducer { - static bool isFusible(ConstantOp producer, LinalgOpTy consumer, +struct FuseConstantOpAsProducer { + static bool isFusible(ConstantOp producer, LinalgOp consumer, unsigned consumerIdx) { - return producer.getResult().getType().isa() && - producer.value().template cast().isSplat(); + return isa(consumer.getOperation()) && + consumer.hasTensorSemantics() && + producer.getResult().getType().isa() && + producer.value().cast().isSplat(); } - static Operation *fuse(ConstantOp producer, LinalgOpTy consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { + static LinalgOp fuse(ConstantOp producer, LinalgOp consumer, + unsigned consumerIdx, PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { if (!isFusible(producer, consumer, consumerIdx)) return nullptr; @@ -881,19 +912,20 @@ // The operands list is same as the consumer with the argument for constant // index dropped. - SmallVector fusedOperands(consumer.operand_begin(), - consumer.operand_end()); + Operation *consumerOp = consumer.getOperation(); + SmallVector fusedOperands(consumerOp->getOperands()); fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx)); // Create a constant scalar value from the splat constant. Value scalarConstant = rewriter.create( producer.getLoc(), - producer.value().template cast().getSplatValue()); + producer.value().cast().getSplatValue()); - auto fusedOp = rewriter.create( - rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, - rewriter.getI64IntegerAttr(consumer.getNumOperands() - 1), - rewriter.getI64IntegerAttr(consumer.getNumResults()), + LinalgOp fusedOp = createLinalgOpOfSameType( + consumer, rewriter, rewriter.getUnknownLoc(), + consumerOp->getResultTypes(), fusedOperands, + rewriter.getI64IntegerAttr(consumerOp->getNumOperands() - 1), + rewriter.getI64IntegerAttr(consumerOp->getNumResults()), rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, @@ -902,19 +934,18 @@ // Map the block argument corresponding to the replaced argument with the // scalar constant. - Region &consumerRegion = consumer.region(); + Region &consumerRegion = consumerOp->getRegion(0); Block &entryBlock = *consumerRegion.begin(); - unsigned argIndex = - entryBlock.getNumArguments() - consumer.getNumOperands() + consumerIdx; + unsigned argIndex = entryBlock.getNumArguments() - + consumerOp->getNumOperands() + consumerIdx; BlockAndValueMapping mapping; mapping.map(entryBlock.getArgument(argIndex), scalarConstant); - Region &fusedRegion = fusedOp.region(); + Region &fusedRegion = fusedOp.getOperation()->getRegion(0); rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(), mapping); return fusedOp; } }; - } // namespace Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, @@ -929,48 +960,27 @@ // Fuse when consumer is GenericOp or IndexedGenericOp. if (isa(consumer)) { - auto linalgOpConsumer = cast(consumer); - if (!linalgOpConsumer.hasTensorSemantics()) - return nullptr; - if (isa(producer)) { - auto linalgOpProducer = cast(producer); - if (linalgOpProducer.hasTensorSemantics()) - return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer, - consumerIdx, rewriter, folder); - } else if (auto reshapeOpProducer = dyn_cast(producer)) { - if (auto genericOpConsumer = dyn_cast(consumer)) { - return FuseTensorReshapeOpAsProducer::fuse( - reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter, - folder); - } else if (auto indexedGenericOpConsumer = - dyn_cast(consumer)) { - return FuseTensorReshapeOpAsProducer::fuse( - reshapeOpProducer, indexedGenericOpConsumer, consumerIdx, rewriter, - folder); - } - } else if (auto constantOpProducer = dyn_cast(producer)) { - if (auto genericOpConsumer = dyn_cast(consumer)) { - return FuseConstantOpAsProducer::fuse( - constantOpProducer, genericOpConsumer, consumerIdx, rewriter, - folder); - } - } + if (isa(producer)) + return FuseGenericOpsOnTensors::fuse(cast(producer), + cast(consumer), + consumerIdx, rewriter, folder); + if (auto reshapeOpProducer = dyn_cast(producer)) + return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer, + cast(consumer), + consumerIdx, rewriter, folder); + if (auto constantOpProducer = dyn_cast(producer)) + return FuseConstantOpAsProducer::fuse(constantOpProducer, + cast(consumer), + consumerIdx, rewriter, folder); return nullptr; } - // Fuse when consumer is a TensorReshapeOp. - if (TensorReshapeOp reshapeOp = dyn_cast(consumer)) { - if (auto genericOpProducer = dyn_cast(producer)) { - if (genericOpProducer.hasTensorSemantics()) - return FuseTensorReshapeOpAsConsumer::fuse( - genericOpProducer, reshapeOp, consumerIdx, rewriter, folder); - } else if (auto indexedGenericOpProducer = - dyn_cast(producer)) { - if (indexedGenericOpProducer.hasTensorSemantics()) - return FuseTensorReshapeOpAsConsumer::fuse( - indexedGenericOpProducer, reshapeOp, consumerIdx, rewriter, folder); + if (isa(producer)) { + // Fuse when consumer is a TensorReshapeOp. + if (TensorReshapeOp reshapeOp = dyn_cast(consumer)) { + return FuseTensorReshapeOpAsConsumer::fuse( + cast(producer), reshapeOp, consumerIdx, rewriter, folder); } - return nullptr; } return nullptr; diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -249,6 +249,38 @@ // ----- +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) + -> tensor<5x?x?xf32> +{ + %0 = constant dense<42.0> : tensor<5xf32> + %1 = linalg.indexed_generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + %0, %arg0 { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32): + %2 = mulf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32> + return %1 : tensor<5x?x?xf32> +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @indexed_generic_op_constant_fusion +// CHECK: %[[CST:.*]] = constant {{.*}} : f32 +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 1 : i64 +// CHECK-SAME: args_out = 1 : i64 +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[ARG4:.*]]: f32) +// CHECK: mulf %[[CST]], %[[ARG4]] + +// ----- + #map0 = affine_map<(d0, d1, d2) -> ()> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>) @@ -277,6 +309,38 @@ // ----- +#map0 = affine_map<(d0, d1, d2) -> ()> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @indexed_generic_op_zero_dim_constant_fusion + (%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> +{ + %0 = constant dense<42.0> : tensor + %1 = linalg.indexed_generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + %0, %arg0 { + ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32): + %2 = mulf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + }: tensor, tensor<5x?x?xf32> -> tensor<5x?x?xf32> + return %1 : tensor<5x?x?xf32> +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion +// CHECK: %[[CST:.*]] = constant {{.*}} : f32 +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 1 : i64 +// CHECK-SAME: args_out = 1 : i64 +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[ARG4:.*]]: f32) +// CHECK: mulf %[[CST]], %[[ARG4]] + +// ----- + #map0 = affine_map<(d0, d1) -> (d0, d1)> func @generic_op_indexed_generic_op_fusion(%arg0: tensor, %arg1: tensor) {