diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -116,6 +116,9 @@ /// \note This needs to be kept in sync with interop.h enum kmp_interop_type_t.: enum class OMPInteropType { Unknown, Target, TargetSync }; +/// Atomic compare operations. Currently OpenMP only supports ==, >, and <. +enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX }; + } // end namespace omp } // end namespace llvm diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1198,7 +1198,7 @@ const function_ref &IRB)>; private: - enum AtomicKind { Read, Write, Update, Capture }; + enum AtomicKind { Read, Write, Update, Capture, Compare }; /// Determine whether to emit flush or not /// @@ -1344,6 +1344,39 @@ AtomicUpdateCallbackTy &UpdateOp, bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr); + /// Emit atomic compare for constructs: --- Only scalar data types + /// cond-update-atomic: + /// x = x ordop expr ? expr : x; + /// x = expr ordop x ? expr : x; + /// x = x == e ? d : x; + /// x = e == x ? d : x; (this one is not in the spec) + /// cond-update-stmt: + /// if (x ordop expr) { x = expr; } + /// if (expr ordop x) { x = expr; } + /// if (x == e) { x = d; } + /// if (e == x) { x = d; } (this one is not in the spec) + /// + /// \param Loc The insert and source location description. + /// \param X The target atomic pointer to be updated. + /// \param E The expected value ('e') for forms that use an + /// equality comparison or an expression ('expr') for + /// forms that use 'ordop' (logically an atomic maximum or + /// minimum). + /// \param D The desired value for forms that use an equality + /// comparison. If forms that use 'ordop', it should be + /// \p nullptr. + /// \param AO Atomic ordering of the generated atomic instructions. + /// \param OP Atomic compare operation. It can only be ==, <, or >. + /// \param IsXBinopExpr True if the conditional statement is in the form where + /// x is on LHS. It only matters for < or >. + /// + /// \return Insertion point after generated atomic capture IR. + InsertPointTy createAtomicCompare(const LocationDescription &Loc, + AtomicOpValue &X, Value *E, Value *D, + AtomicOrdering AO, + omp::OMPAtomicCompareOp Op, + bool IsXBinopExpr); + /// Create the control flow structure of a canonical OpenMP loop. /// /// The emitted loop will be disconnected, i.e. no edge to the loop's diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -3171,6 +3171,7 @@ } break; case Write: + case Compare: case Update: if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease || AO == AtomicOrdering::SequentiallyConsistent) { @@ -3472,6 +3473,68 @@ return Builder.saveIP(); } +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare( + const LocationDescription &Loc, AtomicOpValue &X, Value *E, Value *D, + AtomicOrdering AO, OMPAtomicCompareOp Op, bool IsXBinopExpr) { + if (!updateToLocation(Loc)) + return Loc.IP; + + assert(X.Var->getType()->isPointerTy() && + "OMP atomic expects a pointer to target memory"); + assert((X.ElemTy->isFloatingPointTy() || X.ElemTy->isIntegerTy() || + X.ElemTy->isPointerTy()) && + "OMP atomic compare expected a scalar type"); + + if (Op == OMPAtomicCompareOp::EQ) { + unsigned Addrspace = cast(X.Var->getType())->getAddressSpace(); + IntegerType *IntCastTy = + IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits()); + Value *XAddr = + X.ElemTy->isIntegerTy() + ? X.Var + : Builder.CreateBitCast(X.Var, IntCastTy->getPointerTo(Addrspace)); + AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO); + // We don't need the result for now. + (void)Builder.CreateAtomicCmpXchg(XAddr, E, D, MaybeAlign(), AO, Failure); + } else { + assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) && + "Op should be either max or min at this point"); + assert(X.ElemTy->isIntegerTy() && + "max and min operators only support integer type"); + + // Reverse the ordop as the OpenMP forms are different from LLVM forms. + // Let's take max as example. + // OpenMP form: + // x = x > expr ? expr : x; + // LLVM form: + // *ptr = *ptr > val ? *ptr : val; + // We need to transform to LLVM form. + // x = x <= expr ? x : expr; + AtomicRMWInst::BinOp NewOp; + if (IsXBinopExpr) { + if (X.IsSigned) + NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min + : AtomicRMWInst::Max; + else + NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin + : AtomicRMWInst::UMax; + } else { + if (X.IsSigned) + NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max + : AtomicRMWInst::Min; + else + NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax + : AtomicRMWInst::UMin; + } + // We dont' need the result for now. + (void)Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO); + } + + checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare); + + return Builder.saveIP(); +} + GlobalVariable * OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl &Names, std::string VarName) { diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -3031,6 +3031,63 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, OMPAtomicCompare) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + LLVMContext &Ctx = M->getContext(); + IntegerType *Int32 = Type::getInt32Ty(Ctx); + AllocaInst *XVal = Builder.CreateAlloca(Int32); + XVal->setName("x"); + StoreInst *Init = + Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), XVal); + + OpenMPIRBuilder::AtomicOpValue XSigned = {XVal, Int32, true, false}; + OpenMPIRBuilder::AtomicOpValue XUnsigned = {XVal, Int32, false, false}; + AtomicOrdering AO = AtomicOrdering::Monotonic; + ConstantInt *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1U); + ConstantInt *D = ConstantInt::get(Type::getInt32Ty(Ctx), 1U); + OMPAtomicCompareOp OpMax = OMPAtomicCompareOp::MAX; + OMPAtomicCompareOp OpEQ = OMPAtomicCompareOp::EQ; + + Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XSigned, Expr, + nullptr, AO, OpMax, true)); + Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XUnsigned, Expr, + nullptr, AO, OpMax, false)); + Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XSigned, Expr, D, + AO, OpEQ, true)); + + BasicBlock *EntryBB = BB; + EXPECT_EQ(EntryBB->getParent()->size(), 1U); + EXPECT_EQ(EntryBB->size(), 5U); + + AtomicRMWInst *ARWM1 = dyn_cast(Init->getNextNode()); + EXPECT_NE(ARWM1, nullptr); + EXPECT_EQ(ARWM1->getPointerOperand(), XVal); + EXPECT_EQ(ARWM1->getValOperand(), Expr); + EXPECT_EQ(ARWM1->getOperation(), AtomicRMWInst::Min); + + AtomicRMWInst *ARWM2 = dyn_cast(ARWM1->getNextNode()); + EXPECT_NE(ARWM2, nullptr); + EXPECT_EQ(ARWM2->getPointerOperand(), XVal); + EXPECT_EQ(ARWM2->getValOperand(), Expr); + EXPECT_EQ(ARWM2->getOperation(), AtomicRMWInst::UMax); + + AtomicCmpXchgInst *AXCHG = dyn_cast(ARWM2->getNextNode()); + EXPECT_NE(AXCHG, nullptr); + EXPECT_EQ(AXCHG->getPointerOperand(), XVal); + EXPECT_EQ(AXCHG->getCompareOperand(), Expr); + EXPECT_EQ(AXCHG->getNewValOperand(), D); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + EXPECT_FALSE(verifyModule(*M, &errs())); +} + /// Returns the single instruction of InstTy type in BB that uses the value V. /// If there is more than one such instruction, returns null. template