diff --git a/clang/docs/tools/clang-formatted-files.txt b/clang/docs/tools/clang-formatted-files.txt --- a/clang/docs/tools/clang-formatted-files.txt +++ b/clang/docs/tools/clang-formatted-files.txt @@ -7766,6 +7766,7 @@ mlir/include/mlir/Interfaces/CastInterfaces.h mlir/include/mlir/Interfaces/ControlFlowInterfaces.h mlir/include/mlir/Interfaces/CopyOpInterface.h +mlir/include/mlir/Interfaces/CSEInterfaces.h mlir/include/mlir/Interfaces/DataLayoutInterfaces.h mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.h diff --git a/mlir/include/mlir/Interfaces/CSEInterfaces.h b/mlir/include/mlir/Interfaces/CSEInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/CSEInterfaces.h @@ -0,0 +1,32 @@ +//===- CSEInterfaces.h ------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INTERFACES_CSEINTERFACES_H_ +#define MLIR_INTERFACES_CSEINTERFACES_H_ + +#include "mlir/IR/DialectInterface.h" + +namespace mlir { +class Operation; + +/// Define an interface to allow for dialects to control specific aspects of +/// common subexpression elimination behavior for operations they define. +class DialectCSEInterface : public DialectInterface::Base { +public: + DialectCSEInterface(Dialect *dialect) : Base(dialect) {} + + /// Registered hook to check if an operation that is *not* isolated from + /// above, should allow common subexpressions to be extracted out of its + /// regions. + virtual bool subexpressionExtractionAllowed(Operation *op) const { + return true; + } +}; + +} // namespace mlir + +#endif // MLIR_INTERFACES_CSEINTERFACES_H_ diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/CSEInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" @@ -61,7 +62,8 @@ class CSEDriver { public: CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) - : rewriter(rewriter), domInfo(domInfo) {} + : rewriter(rewriter), domInfo(domInfo), + interfaces(rewriter.getContext()) {} /// Simplify all operations within the given op. void simplify(Operation *op, bool *changed = nullptr); @@ -122,6 +124,9 @@ DominanceInfo *domInfo = nullptr; MemEffectsCache memEffectsCache; + /// CSE interfaces in the present context that can modify CSE behavior. + DialectInterfaceCollection interfaces; + // Various statistics. int64_t numCSE = 0; int64_t numDCE = 0; @@ -289,7 +294,12 @@ // If this operation is isolated above, we can't process nested regions // with the given 'knownValues' map. This would cause the insertion of // implicit captures in explicit capture only regions. - if (op.mightHaveTrait()) { + // Also, avoid capturing known values from parent regions if this behavior + // is explicitly disabled for the given operation. + const DialectCSEInterface *cseInterface = interfaces.getInterfaceFor(&op); + if (op.mightHaveTrait() || + LLVM_UNLIKELY(cseInterface && + !cseInterface->subexpressionExtractionAllowed(&op))) { ScopedMapTy nestedKnownValues; for (auto ®ion : op.getRegions()) simplifyRegion(nestedKnownValues, region); diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -520,3 +520,23 @@ %2 = "test.op_with_memread"() : () -> (i32) return %0, %2, %1 : i32, i32, i32 } + +// CHECK-LABEL: @no_cse_across_disabled_op +func.func @no_cse_across_disabled_op() -> (i32) { + // CHECK-NEXT: %[[CONST1:.+]] = arith.constant 1 : i32 + %0 = arith.constant 1 : i32 + + // CHECK-NEXT: test.no_cse_one_region_op + "test.no_cse_one_region_op"() ({ + %1 = arith.constant 1 : i32 + %2 = arith.addi %1, %1 : i32 + "foo.yield"(%2) : (i32) -> () + + // CHECK-NEXT: %[[CONST2:.+]] = arith.constant 1 : i32 + // CHECK-NEXT: %[[SUM:.+]] = arith.addi %[[CONST2]], %[[CONST2]] : i32 + // CHECK-NEXT: "foo.yield"(%[[SUM]]) : (i32) -> () + }) : () -> () + + // CHECK: return %[[CONST1]] : i32 + return %0 : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "mlir/Interfaces/CSEInterfaces.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/InliningUtils.h" @@ -273,6 +274,16 @@ } }; +struct TestDialectCSEInterface : public DialectCSEInterface { + using DialectCSEInterface::DialectCSEInterface; + + bool subexpressionExtractionAllowed(Operation *op) const final { + // Don't allow extracting common subexpressions from the region of these + // operations. + return !isa(op); + } +}; + /// This class defines the interface for handling inlining with standard /// operations. struct TestInlinerInterface : public DialectInlinerInterface { @@ -385,6 +396,7 @@ auto &blobInterface = addInterface(); addInterface(blobInterface); - addInterfaces(); + addInterfaces(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2703,6 +2703,10 @@ }]; } +def NoCSEOneRegionOp : TEST_Op<"no_cse_one_region_op", []> { + let regions = (region AnyRegion); +} + //===----------------------------------------------------------------------===// // Test Ops to upgrade base on the dialect versions //===----------------------------------------------------------------------===//