diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FoldInterfaces.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallString.h" @@ -43,6 +44,15 @@ return pointer.cast().getElementType(); } }; + +struct OpenMPDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; + + bool shouldMaterializeInto(Region *region) const final { + // Avoid folding constants across target regions + return isa(region->getParentOp()); + } +}; } // namespace void OpenMPDialect::initialize() { @@ -55,6 +65,7 @@ #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" >(); + addInterface(); LLVM::LLVMPointerType::attachInterface< PointerLikeModel>(*getContext()); MemRefType::attachInterface>(*getContext()); diff --git a/mlir/test/Dialect/OpenMP/canonicalize.mlir b/mlir/test/Dialect/OpenMP/canonicalize.mlir --- a/mlir/test/Dialect/OpenMP/canonicalize.mlir +++ b/mlir/test/Dialect/OpenMP/canonicalize.mlir @@ -126,3 +126,19 @@ // CHECK: omp.parallel // CHECK: func.call @foo() : () -> () // CHECK: omp.terminator + +// ----- + +func.func @constant_hoisting_target(%x : !llvm.ptr) { + omp.target { + %c1 = arith.constant 10 : i32 + llvm.store %c1, %x : i32, !llvm.ptr + omp.terminator + } + return +} + +// CHECK-LABEL: func.func @constant_hoisting_target +// CHECK-NOT: arith.constant +// CHECK: omp.target +// CHECK-NEXT: arith.constant