diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1198,7 +1198,7 @@ const function_ref &IRB)>; private: - enum AtomicKind { Read, Write, Update, Capture }; + enum AtomicKind { Read, Write, Update, Capture, Compare }; /// Determine whether to emit flush or not /// @@ -1344,6 +1344,39 @@ AtomicUpdateCallbackTy &UpdateOp, bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr); + enum AtomicCompareOp : unsigned { EQ, MIN, MAX }; + + /// Emit atomic compare for constructs: --- Only scalar data types + /// cond-update-atomic: + /// x = x ordop expr ? expr : x; + /// x = expr ordop x ? expr : x; + /// x = x == e ? d : x; + /// x = e == x ? d : x; (this one is not in the spec) + /// cond-update-stmt: + /// if (x ordop expr) { x = expr; } + /// if (expr ordop x) { x = expr; } + /// if (x == e) { x = d; } + /// if (e == x) { x = d; } (this one is not in the spec) + /// + /// \param Loc The insert and source location description. + /// \param AllocIP Instruction to create AllocaInst before. + /// \param X The target atomic pointer to be updated. + /// \param E The expected value ('e') for forms that use an equality + /// comparison or an expression ('expr') for forms that use + /// 'ordop' (logically an atomic maximum or minimum). + /// \param D The desired value for forms that use an equality + /// comparison. If forms that use 'ordop', it should be + /// \p nullptr. + /// \param AO Atomic ordering of the generated atomic instructions + /// \param IsXBinopExpr True if the conditional statement is in the form where + /// x is on LHS. + /// + /// \return Insertion point after generated atomic capture IR. + InsertPointTy createAtomicCompare(const LocationDescription &Loc, + Instruction *AllocIP, AtomicOpValue &X, + Value *E, Value *D, AtomicOrdering AO, + AtomicCompareOp Op, bool IsXBinopExpr); + /// Create the control flow structure of a canonical OpenMP loop. /// /// The emitted loop will be disconnected, i.e. no edge to the loop's diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -3472,6 +3472,71 @@ return Builder.saveIP(); } +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::createAtomicCompare(const LocationDescription &Loc, + Instruction *AllocIP, AtomicOpValue &X, + Value *E, Value *D, AtomicOrdering AO, + AtomicCompareOp Op, bool IsXBinopExpr) { + if (!updateToLocation(Loc)) + return Loc.IP; + + LLVM_DEBUG({ + Type *XTy = X.Var->getType(); + assert(XTy->isPointerTy() && + "OMP atomic expects a pointer to target memory"); + Type *XElemTy = XTy->getPointerElementType(); + assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() || + XElemTy->isPointerTy()) && + "OMP atomic compare expected a scalar type"); + }); + + if (Op == AtomicCompareOp::EQ) { + unsigned Addrspace = cast(X.Var->getType())->getAddressSpace(); + IntegerType *IntCastTy = + IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits()); + Value *XAddr = + X.ElemTy->isIntegerTy() + ? X.Var + : Builder.CreateBitCast(X.Var, IntCastTy->getPointerTo(Addrspace)); + AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO); + // We don't need the result for now. + (void)Builder.CreateAtomicCmpXchg(XAddr, E, D, MaybeAlign(), AO, Failure); + } else { + assert(Op == AtomicCompareOp::MAX || Op == AtomicCompareOp::MIN); + assert(X.ElemTy->isIntegerTy() && + "max and min operators only support integer type"); + + // Reverse the ordop as the OpenMP forms are different from LLVM forms. + // Let's take max as example. + // OpenMP form: + // x = x > expr ? expr : x; + // LLVM form: + // *ptr = *ptr > val ? *ptr : val; + // We need to transform to LLVM form. + // x = x <= expr ? x : expr; + AtomicRMWInst::BinOp NewOp; + if (IsXBinopExpr) { + if (Op == AtomicCompareOp::MAX) { + // TODO: Check if signed or unsigned. + NewOp = AtomicRMWInst::Min; + } else { + // TODO: Check if signed or unsigned. + NewOp = AtomicRMWInst::Max; + } + } else { + // TODO: Check if signed or unsigned. + NewOp = + Op == AtomicCompareOp::MAX ? AtomicRMWInst::Max : AtomicRMWInst::Min; + } + // We dont' need the result for now. + (void)Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO); + } + + checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare); + + return Builder.saveIP(); +} + GlobalVariable * OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl &Names, std::string VarName) { diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -3031,6 +3031,59 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, OMPAtomicCompare) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + LLVMContext &Ctx = M->getContext(); + IntegerType *Int32 = Type::getInt32Ty(Ctx); + AllocaInst *XVal = Builder.CreateAlloca(Int32); + XVal->setName("x"); + StoreInst *Init = + Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), XVal); + + OpenMPIRBuilder::AtomicOpValue X = {XVal, Int32, false, false}; + AtomicOrdering AO = AtomicOrdering::Monotonic; + ConstantInt *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1U); + ConstantInt *D = ConstantInt::get(Type::getInt32Ty(Ctx), 1U); + OpenMPIRBuilder::AtomicCompareOp OpMax = + OpenMPIRBuilder::AtomicCompareOp::MAX; + OpenMPIRBuilder::AtomicCompareOp OpEQ = OpenMPIRBuilder::AtomicCompareOp::EQ; + + BasicBlock *EntryBB = BB; + Instruction *AllocIP = EntryBB->getFirstNonPHI(); + + Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, AllocIP, X, Expr, + nullptr, AO, OpMax, true)); + Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, AllocIP, X, Expr, D, + AO, OpEQ, true)); + + EXPECT_EQ(EntryBB->getParent()->size(), 1U); + EXPECT_EQ(EntryBB->size(), 4U); + + M->dump(); + + AtomicRMWInst *ARWM = dyn_cast(Init->getNextNode()); + EXPECT_NE(ARWM, nullptr); + EXPECT_EQ(ARWM->getPointerOperand(), XVal); + EXPECT_EQ(ARWM->getValOperand(), Expr); + EXPECT_EQ(ARWM->getOperation(), AtomicRMWInst::Min); + + AtomicCmpXchgInst *AXCHG = dyn_cast(ARWM->getNextNode()); + EXPECT_NE(AXCHG, nullptr); + EXPECT_EQ(AXCHG->getPointerOperand(), XVal); + EXPECT_EQ(AXCHG->getCompareOperand(), Expr); + EXPECT_EQ(AXCHG->getNewValOperand(), D); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + EXPECT_FALSE(verifyModule(*M, &errs())); +} + /// Returns the single instruction of InstTy type in BB that uses the value V. /// If there is more than one such instruction, returns null. template