diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -737,6 +737,67 @@ return fusedOp; } }; + +/// Implementation of fusion on tensor ops when producer is a splat constant. +template +struct FuseConstantOpAsProducer { + static bool isFusible(ConstantOp producer, LinalgOpTy consumer, + unsigned consumerIdx) { + return producer.getResult().getType().isa() && + producer.value().template cast().isSplat(); + } + + static Operation *fuse(ConstantOp producer, LinalgOpTy consumer, + unsigned consumerIdx, PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { + if (!isFusible(producer, consumer, consumerIdx)) + return nullptr; + + // The indexing_maps for the operands of the fused operation are same as + // those for the operands of the consumer without the indexing map at + // consumerIdx + SmallVector fusedIndexMaps = + llvm::to_vector<4>(llvm::map_range( + consumer.indexing_maps(), [](Attribute attr) -> AffineMap { + return attr.cast().getValue(); + })); + fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx)); + + // The operands list is same as the consumer with the argument for constant + // index dropped. + SmallVector fusedOperands(consumer.operand_begin(), + consumer.operand_end()); + fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx)); + + // Create a constant scalar value from the splat constant. + Value scalarConstant = rewriter.create( + producer.getLoc(), + producer.value().template cast().getSplatValue()); + + auto fusedOp = rewriter.create( + rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, + rewriter.getI64IntegerAttr(consumer.getNumOperands() - 1), + rewriter.getI64IntegerAttr(consumer.getNumResults()), + rewriter.getAffineMapArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + + // Map the block argument corresponding to the replaced argument with the + // scalar constant. + Region &consumerRegion = consumer.region(); + Block &entryBlock = *consumerRegion.begin(); + unsigned argIndex = + entryBlock.getNumArguments() - consumer.getNumOperands() + consumerIdx; + BlockAndValueMapping mapping; + mapping.map(entryBlock.getArgument(argIndex), scalarConstant); + Region &fusedRegion = fusedOp.region(); + rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(), + mapping); + return fusedOp; + } +}; + } // namespace Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, @@ -760,6 +821,9 @@ } else if (auto reshapeOpProducer = dyn_cast(producer)) { return FuseTensorReshapeOpAsProducer::fuse( reshapeOpProducer, genericOp, consumerIdx, rewriter, folder); + } else if (auto constantOpProducer = dyn_cast(producer)) { + return FuseConstantOpAsProducer::fuse( + constantOpProducer, genericOp, consumerIdx, rewriter, folder); } return nullptr; } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -219,3 +219,58 @@ // CHECK-LABEL: func @generic_op_reshape_consumer_nofusion // CHECK: linalg.tensor_reshape + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> +{ + %0 = constant dense<42.0> : tensor<5xf32> + %1 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + %0, %arg0 { + ^bb0(%arg1: f32, %arg2: f32): + %2 = mulf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32> + return %1 : tensor<5x?x?xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @generic_op_constant_fusion +// CHECK: %[[CST:.*]] = constant {{.*}} : f32 +// CHECK: linalg.generic +// CHECK-SAME: args_in = 1 : i64 +// CHECK-SAME: args_out = 1 : i64 +// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32) +// CHECK: mulf %[[CST]], %[[ARG1]] + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> ()> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>) + -> tensor<5x?x?xf32> +{ + %0 = constant dense<42.0> : tensor + %1 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + %0, %arg0 { + ^bb0(%arg1: f32, %arg2: f32): + %2 = mulf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + }: tensor, tensor<5x?x?xf32> -> tensor<5x?x?xf32> + return %1 : tensor<5x?x?xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @generic_op_zero_dim_constant_fusion +// CHECK: %[[CST:.*]] = constant {{.*}} : f32 +// CHECK: linalg.generic +// CHECK-SAME: args_in = 1 : i64 +// CHECK-SAME: args_out = 1 : i64 +// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32) +// CHECK: mulf %[[CST]], %[[ARG1]]