diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -23,6 +23,7 @@ namespace linalg { +struct LinalgElementwiseFusionOptions; struct LinalgFusionOptions; struct LinalgTilingOptions; @@ -69,9 +70,40 @@ /// tensors. void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); +using ControlElementwiseOpsFusionFn = + std::function; + +/// Options that control fusion of elementwise operations. +struct LinalgElementwiseFusionOptions { + /// Enable fusion of reshapes that are introducing unit-dimensions into the + /// shape with elementwise operations. By default this is disabled. + bool allowFoldingUnitDimReshapes = false; + + LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) { + allowFoldingUnitDimReshapes = val; + return *this; + } + + /// Function that allows the caller to control when to stop fusion. Once a + /// producer is deemed fusable with the consumer (structurally), this callback + /// can be used to abort the fusion based on non-structural constraints. This + /// is the hook for cost models to control the amount of fusion done. + ControlElementwiseOpsFusionFn controlElementwiseOpsFusionFn = + [](const OpResult & /*producer */, const OpOperand & /*consumer */) { + return true; + }; + + LinalgElementwiseFusionOptions & + setControlElementwiseOpsFusionFn(ControlElementwiseOpsFusionFn fun) { + controlElementwiseOpsFusionFn = std::move(fun); + return *this; + } +}; + /// Patterns for fusing linalg operation on tensors. void populateElementwiseOpsFusionPatterns( - RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false); + RewritePatternSet &patterns, + LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions()); /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -48,6 +48,10 @@ if (consumerIndexMap.getNumResults() != producer.getNumLoops()) return false; + // Currently support only operations with single result. + if (producer.getNumOutputs() != 1) + return false; + // Finally the index_map for the result must be invertible. For now just // verify it is a permutation. AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); @@ -209,10 +213,12 @@ static Optional> fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, + const ControlElementwiseOpsFusionFn &controlFn, PatternRewriter &rewriter) { LinalgOp consumer = cast(consumerOpOperand.getOwner()); unsigned consumerIdx = consumerOpOperand.getOperandNumber(); - if (!areElementwiseOpsFusable(producer, consumer, consumerIdx)) + if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) || + !controlFn(producer->getResult(0), consumerOpOperand)) return llvm::None; unsigned numFusedOperands = @@ -1041,18 +1047,22 @@ /// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant. template -struct FoldSplatConstants : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +class FoldSplatConstants : public OpRewritePattern { +public: + FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(fun) {} LogicalResult matchAndRewrite(LinalgOpTy op, PatternRewriter &rewriter) const override { if (!op.hasTensorSemantics()) return failure(); LinalgOp linalgOp = cast(op.getOperation()); - for (auto operand : llvm::enumerate(linalgOp.getInputs())) { - ConstantOp constantOp = operand.value().getDefiningOp(); + for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) { + ConstantOp constantOp = operand.value().get().getDefiningOp(); if (!constantOp || - !constantOp.value().cast().isSplat()) + !constantOp.value().cast().isSplat() || + !controlFn(constantOp->getResult(0), operand.value())) continue; // The indexing_maps for the operands of the fused operation are same as @@ -1099,11 +1109,15 @@ } return failure(); } + +private: + ControlElementwiseOpsFusionFn controlFn; }; } // namespace static Optional> -fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) { +fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand, + const ControlElementwiseOpsFusionFn &controlFn) { Operation *producer = consumerOpOperand.get().getDefiningOp(); if (!producer || producer->getNumResults() != 1) return llvm::None; @@ -1114,14 +1128,17 @@ return llvm::None; return fuseElementwiseOpsImpl(cast(producer), consumerOpOperand, - rewriter); + controlFn, rewriter); } namespace { /// Patterns to fuse a generic op, with the producer of its operands. template -struct FuseElementwiseOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +class FuseElementwiseOps : public OpRewritePattern { +public: + FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(fun) {} LogicalResult matchAndRewrite(LinalgOpTy op, PatternRewriter &rewriter) const override { @@ -1132,7 +1149,7 @@ if (!producerOp || !producerOp.hasTensorSemantics()) continue; Optional> fusedOpResults = - fuseElementwiseOps(rewriter, opOperand); + fuseElementwiseOps(rewriter, opOperand, controlFn); if (fusedOpResults) { rewriter.replaceOp(op, *fusedOpResults); return success(); @@ -1140,6 +1157,9 @@ } return failure(); } + +private: + ControlElementwiseOpsFusionFn controlFn; }; /// Pass that fuses generic ops on tensors. Used only for testing. @@ -1148,7 +1168,10 @@ void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes); + populateElementwiseOpsFusionPatterns( + patterns, + LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes( + allowFoldingUnitDimReshapes)); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -1193,14 +1216,14 @@ } void mlir::linalg::populateElementwiseOpsFusionPatterns( - RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { + RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { auto *context = patterns.getContext(); patterns .add, FuseElementwiseOps, FoldSplatConstants, FoldSplatConstants>( - context); - populateFoldReshapeOpsByExpansionPatterns(patterns, - allowFoldingUnitDimReshapes); + context, options.controlElementwiseOpsFusionFn); + populateFoldReshapeOpsByExpansionPatterns( + patterns, options.allowFoldingUnitDimReshapes); GenericOp::getCanonicalizationPatterns(patterns, context); IndexedGenericOp::getCanonicalizationPatterns(patterns, context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#binary2Dpointwise = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] +} +#ternary2Dpointwise = { + indexing_maps = [#map0, #map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] +} +func @test_fusion_limit( + %arg0 : tensor, %arg1 : tensor, %arg2 : tensor, + %arg3 : tensor, %arg4 : tensor, %arg5 : tensor) + -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = memref.dim %arg0, %c0 : tensor + %d1 = memref.dim %arg0, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %0 = linalg.generic #binary2Dpointwise + ins(%arg0, %arg1 : tensor, tensor) + outs(%init : tensor) { + ^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32): + %1 = mulf %arg6, %arg7 : f32 + linalg.yield %1 : f32 + } -> tensor + %2 = linalg.generic #binary2Dpointwise + ins(%arg2, %arg3 : tensor, tensor) + outs(%init : tensor) { + ^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32): + %3 = mulf %arg6, %arg7 : f32 + linalg.yield %3 : f32 + } -> tensor + %4 = linalg.generic #binary2Dpointwise + ins(%arg4, %arg5 : tensor, tensor) + outs(%init : tensor) { + ^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32): + %5 = mulf %arg6, %arg7 : f32 + linalg.yield %5 : f32 + } -> tensor + %6 = linalg.generic #ternary2Dpointwise + ins(%0, %2, %4 : tensor, tensor, tensor) + outs(%init : tensor) { + ^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32, %arg9 : f32): + %7 = addf %arg6, %arg7 : f32 + %8 = addf %7, %arg8 : f32 + linalg.yield %8 : f32 + } -> tensor + return %6 : tensor +} +// CHECK-LABEL: func @test_fusion_limit +// CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG4:[a-zA-z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor +// CHECK: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]] +// CHECK: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]] +// CHECK: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] +// CHECK: return %[[OP3]] diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ TestGpuRewrite.cpp TestInlining.cpp TestLinalgCodegenStrategy.cpp + TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp TestLinalgTransforms.cpp diff --git a/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp @@ -0,0 +1,79 @@ +//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing fusion of elementwise operations in +// Linalg, mainly linalg options. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { + +static void addOperands(Operation *op, llvm::SetVector &operandSet) { + if (!op) + return; + TypeSwitch(op) + .Case([&](linalg::LinalgOp linalgOp) { + operandSet.insert(linalgOp.getInputs().begin(), + linalgOp.getInputs().end()); + }) + .Default([&](Operation *operation) { + operandSet.insert(operation->operand_begin(), operation->operand_end()); + }); +} + +template +static bool setFusedOpOperandLimit(const OpResult &producer, + const OpOperand &consumer) { + llvm::SetVector fusedOpOperands; + if (producer.getOwner()->getNumResults() != 1) + return false; + addOperands(consumer.getOwner(), fusedOpOperands); + fusedOpOperands.remove(producer); + addOperands(producer.getOwner(), fusedOpOperands); + return fusedOpOperands.size() <= limit; +} + +namespace { +struct TestLinalgElementwiseFusion + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + MLIRContext *context = &this->getContext(); + FuncOp funcOp = this->getFunction(); + RewritePatternSet fusionPatterns(context); + + linalg::populateElementwiseOpsFusionPatterns( + fusionPatterns, + linalg::LinalgElementwiseFusionOptions() + .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); + + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)); + } +}; +} // namespace + +namespace test { +void registerTestLinalgElementwiseFusion() { + PassRegistration testElementwiseFusionPass( + "test-linalg-elementwise-fusion-patterns", + "Test Linalg element wise operation fusion patterns"); +} +} // namespace test + +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -77,6 +77,7 @@ void registerTestIRVisitorsPass(); void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); +void registerTestLinalgElementwiseFusion(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); void registerTestLinalgGreedyFusion(); @@ -154,6 +155,7 @@ test::registerTestIRVisitorsPass(); test::registerTestInterfaces(); test::registerTestLinalgCodegenStrategy(); + test::registerTestLinalgElementwiseFusion(); test::registerTestLinalgFusionTransforms(); test::registerTestLinalgTensorFusionTransforms(); test::registerTestLinalgGreedyFusion();