diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1832,7 +1832,7 @@ // Atomic operations. // -def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, AnyInteger]>; +def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnyInteger]>; def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [ TypesMatchWith<"result #0 and operand #1 have the same type", diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2450,7 +2450,7 @@ if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) return emitOpError("expected LLVM IR floating point type"); } else if (getBinOp() == AtomicBinOp::xchg) { - if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/false)) + if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/true)) return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); } else { auto intType = llvm::dyn_cast(valType); diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -633,7 +633,7 @@ // ----- func.func @atomicrmw_expected_ptr(%f32 : f32) { - // expected-error@+1 {{operand #0 must be LLVM pointer to floating point LLVM type or integer}} + // expected-error@+1 {{operand #0 must be LLVM pointer to floating point LLVM type or LLVM pointer type or integer}} %0 = "llvm.atomicrmw"(%f32, %f32) {bin_op=11, ordering=1} : (f32, f32) -> f32 llvm.return }