diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1133,7 +1133,13 @@ /// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + + FoldReshapeWithGenericOpByExpansion( + MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(foldReshapes) {} + LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { // Fold only if all constraints of fusing with reshape by expansion are met. @@ -1141,7 +1147,8 @@ if (!producer || producer.getNumOutputs() != 1 || !isFusableWithReshapeByDimExpansion(producer, producer.getOutputOperand(0)) || - isUnitDimExpansionOnly(reshapeOp)) + !controlFoldingReshapes(producer->getResult(0), + reshapeOp->getOpOperand(0))) return failure(); Optional> replacementValues = fuseWithReshapeByExpansion( producer, reshapeOp, producer.getOutputOperand(0), rewriter); @@ -1150,6 +1157,9 @@ rewriter.replaceOp(reshapeOp, replacementValues.getValue()); return success(); } + +private: + ControlElementwiseOpsFusionFn controlFoldingReshapes; }; /// Pattern to fold a generic op with a splat constant. @@ -1242,12 +1252,15 @@ bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, OpOperand &consumer) { - auto expandShapeOp = producer.getDefiningOp(); - if (expandShapeOp) - return !isUnitDimExpansionOnly(expandShapeOp); - auto collapseShapeOp = - producer.getDefiningOp(); - return !isUnitDimExpansionOnly(collapseShapeOp); + if (auto producerCollapseOp = + dyn_cast(producer.getOwner())) { + return !isUnitDimExpansionOnly(producerCollapseOp); + } + if (auto consumerExpandOp = + dyn_cast(consumer.getOwner())) { + return !isUnitDimExpansionOnly(consumerExpandOp); + } + return true; } namespace { @@ -1389,7 +1402,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( RewritePatternSet &patterns, ControlElementwiseOpsFusionFn controlFoldingReshapes) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), + controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } diff --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s + +func @control_producer_reshape_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor + %d0 = tensor.dim %0, %c0 : tensor + %d1 = tensor.dim %0, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) + outs(%init : tensor) { + ^bb0(%arg2 : f32, %arg3:f32, %arg4 : f32): + %2 = addf %arg2, %arg3 : f32 + linalg.yield %2 : f32 + } -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: builtin.func @control_producer_reshape_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] +// CHECK-SAME: {{\[}}[0, 1], [2]{{\]}} : tensor into tensor +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]] +// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor, tensor) +// CHECK: return %[[RESULT]] + +// ----- + +func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> { + %c1 = constant 1 : index + %c2 = constant 2 : index + %cst = constant 0.0 : f32 + %d0 = tensor.dim %arg0, %c1 : tensor<1x?x?xf32> + %d1 = tensor.dim %arg1, %c2 : tensor<1x?x?xf32> + %init = linalg.init_tensor [%d0, %d1] : tensor + %fill = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + outs(%init : tensor) { + ^bb0(%arg2: f32): + linalg.yield %cst : f32 + } -> tensor + %0 = linalg.tensor_expand_shape %fill [[0, 1], [2]] : tensor into tensor<1x?x?xf32> + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>) + outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> + return %1 : tensor<1x?x?xf32> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2) +// CHECK: builtin.func @control_consumer_reshape_fusion +// CHECK: %[[FILL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]]] +// CHECK-SAME: outs(%{{.+}} : tensor<1x?x?xf32>) +// CHECK: linalg.batch_matmul +// CHECK-SAME: outs(%[[FILL]] : tensor<1x?x?xf32>) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -73,6 +73,52 @@ } }; +struct TestLinalgControlFuseByExpansion + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + StringRef getArgument() const final { + return "test-linalg-control-fusion-by-expansion"; + } + StringRef getDescription() const final { + return "Test controlling of fusion of elementwise ops with reshape by " + "expansion"; + } + + void runOnFunction() override { + MLIRContext *context = &this->getContext(); + FuncOp funcOp = this->getFunction(); + RewritePatternSet fusionPatterns(context); + + linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = + [](const OpResult &producer, OpOperand &consumer) { + if (auto collapseOp = + producer.getDefiningOp()) { + if (!collapseOp.src().getDefiningOp()) { + return false; + } + } + if (auto expandOp = + dyn_cast(consumer.getOwner())) { + if (expandOp->hasOneUse()) { + OpOperand &use = *expandOp->getUses().begin(); + auto linalgOp = dyn_cast(use.getOwner()); + if (linalgOp && linalgOp.isOutputTensor(&use)) + return true; + } + } + return linalg::skipUnitDimReshape(producer, consumer); + }; + + linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, + controlReshapeFusionFn); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)); + } +}; + struct TestPushExpandingReshape : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { @@ -99,6 +145,10 @@ PassRegistration(); } +void registerTestLinalgControlFuseByExpansion() { + PassRegistration(); +} + void registerTestPushExpandingReshape() { PassRegistration(); } diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -78,6 +78,7 @@ void registerTestIRVisitorsPass(); void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); +void registerTestLinalgControlFuseByExpansion(); void registerTestLinalgDistribution(); void registerTestLinalgElementwiseFusion(); void registerTestPushExpandingReshape(); @@ -165,6 +166,7 @@ mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestInterfaces(); mlir::test::registerTestLinalgCodegenStrategy(); + mlir::test::registerTestLinalgControlFuseByExpansion(); mlir::test::registerTestLinalgDistribution(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestPushExpandingReshape();