diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -3584,13 +3584,25 @@ AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) { // TODO: handle the case where XElemTy is not byte-sized or not a power of 2 // or a complex datatype. - bool DoCmpExch = (RMWOp == AtomicRMWInst::BAD_BINOP) || - (RMWOp == AtomicRMWInst::FAdd) || - (RMWOp == AtomicRMWInst::FSub) || - (RMWOp == AtomicRMWInst::Sub && !IsXBinopExpr) || !XElemTy; + bool emitRMWOp = false; + switch (RMWOp) { + case AtomicRMWInst::Add: + case AtomicRMWInst::And: + case AtomicRMWInst::Nand: + case AtomicRMWInst::Or: + case AtomicRMWInst::Xor: + emitRMWOp = XElemTy; + break; + case AtomicRMWInst::Sub: + emitRMWOp = (IsXBinopExpr && XElemTy); + break; + default: + emitRMWOp = false; + } + emitRMWOp &= XElemTy->isIntegerTy(); std::pair Res; - if (XElemTy->isIntegerTy() && !DoCmpExch) { + if (emitRMWOp) { Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO); // not needed except in case of postfix captures. Generate anyway for // consistency with the else part. Will be removed with any DCE pass. diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -3167,6 +3167,73 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, OMPAtomicUpdateIntr) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + Type *IntTy = Type::getInt32Ty(M->getContext()); + AllocaInst *XVal = Builder.CreateAlloca(IntTy); + XVal->setName("AtomicVar"); + Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0), XVal); + OpenMPIRBuilder::AtomicOpValue X = {XVal, IntTy, false, false}; + AtomicOrdering AO = AtomicOrdering::Monotonic; + Constant *ConstVal = ConstantInt::get(Type::getInt32Ty(Ctx), 1); + Value *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1); + AtomicRMWInst::BinOp RMWOp = AtomicRMWInst::UMax; + bool IsXLHSInRHSPart = false; + + BasicBlock *EntryBB = BB; + OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB, + EntryBB->getFirstInsertionPt()); + Value *Sub = nullptr; + + auto UpdateOp = [&](Value *Atomic, IRBuilder<> &IRB) { + Sub = IRB.CreateSub(ConstVal, Atomic); + return Sub; + }; + Builder.restoreIP(OMPBuilder.createAtomicUpdate( + Builder, AllocaIP, X, Expr, AO, RMWOp, UpdateOp, IsXLHSInRHSPart)); + BasicBlock *ContBB = EntryBB->getSingleSuccessor(); + BranchInst *ContTI = dyn_cast(ContBB->getTerminator()); + EXPECT_NE(ContTI, nullptr); + BasicBlock *EndBB = ContTI->getSuccessor(0); + EXPECT_TRUE(ContTI->isConditional()); + EXPECT_EQ(ContTI->getSuccessor(1), ContBB); + EXPECT_NE(EndBB, nullptr); + + PHINode *Phi = dyn_cast(&ContBB->front()); + EXPECT_NE(Phi, nullptr); + EXPECT_EQ(Phi->getNumIncomingValues(), 2U); + EXPECT_EQ(Phi->getIncomingBlock(0), EntryBB); + EXPECT_EQ(Phi->getIncomingBlock(1), ContBB); + + EXPECT_EQ(Sub->getNumUses(), 1U); + StoreInst *St = dyn_cast(Sub->user_back()); + AllocaInst *UpdateTemp = dyn_cast(St->getPointerOperand()); + + ExtractValueInst *ExVI1 = + dyn_cast(Phi->getIncomingValueForBlock(ContBB)); + EXPECT_NE(ExVI1, nullptr); + AtomicCmpXchgInst *CmpExchg = + dyn_cast(ExVI1->getAggregateOperand()); + EXPECT_NE(CmpExchg, nullptr); + EXPECT_EQ(CmpExchg->getPointerOperand(), XVal); + EXPECT_EQ(CmpExchg->getCompareOperand(), Phi); + EXPECT_EQ(CmpExchg->getSuccessOrdering(), AtomicOrdering::Monotonic); + + LoadInst *Ld = dyn_cast(CmpExchg->getNewValOperand()); + EXPECT_NE(Ld, nullptr); + EXPECT_EQ(UpdateTemp, Ld->getPointerOperand()); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + EXPECT_FALSE(verifyModule(*M, &errs())); +} + TEST_F(OpenMPIRBuilderTest, OMPAtomicCapture) { OpenMPIRBuilder OMPBuilder(*M); OMPBuilder.initialize(); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -1049,6 +1049,15 @@ %newval = "llvm.intr.smax"(%xval, %expr) : (i32, i32) -> i32 omp.yield(%newval : i32) } + // CHECK: %[[t1:.*]] = call i32 @llvm.umax.i32(i32 %[[x_old:.*]], i32 %[[expr]]) + // CHECK: store i32 %[[t1]], i32* %[[x_new:.*]] + // CHECK: %[[t2:.*]] = load i32, i32* %[[x_new]] + // CHECK: cmpxchg i32* %[[x]], i32 %[[x_old]], i32 %[[t2]] + omp.atomic.update %x : !llvm.ptr { + ^bb0(%xval: i32): + %newval = "llvm.intr.umax"(%xval, %expr) : (i32, i32) -> i32 + omp.yield(%newval : i32) + } llvm.return }