diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1345,7 +1345,7 @@ bool IsPostfixUpdate, bool IsXBinopExpr); /// Emit atomic compare for constructs: --- Only scalar data types - /// cond-update-atomic: + /// cond-expr-stmt: /// x = x ordop expr ? expr : x; /// x = expr ordop x ? expr : x; /// x = x == e ? d : x; @@ -1355,9 +1355,21 @@ /// if (expr ordop x) { x = expr; } /// if (x == e) { x = d; } /// if (e == x) { x = d; } (this one is not in the spec) + /// conditional-update-capture-atomic: + /// v = x; cond-update-stmt; (IsPostfixUpdate=true, IsFailOnly=false) + /// cond-update-stmt; v = x; (IsPostfixUpdate=false, IsFailOnly=false) + /// if (x == e) { x = d; } else { v = x; } (IsPostfixUpdate=false, + /// IsFailOnly=true) + /// r = x == e; if (r) { x = d; } (IsPostfixUpdate=false, IsFailOnly=false) + /// r = x == e; if (r) { x = d; } else { v = x; } (IsPostfixUpdate=false, + /// IsFailOnly=true) /// /// \param Loc The insert and source location description. /// \param X The target atomic pointer to be updated. + /// \param V Memory address where to store captured value (for + /// compare capture only). + /// \param R Memory address where to store comparison result + /// (for compare capture with '==' only). /// \param E The expected value ('e') for forms that use an /// equality comparison or an expression ('expr') for /// forms that use 'ordop' (logically an atomic maximum or @@ -1369,13 +1381,19 @@ /// \param OP Atomic compare operation. It can only be ==, <, or >. /// \param IsXBinopExpr True if the conditional statement is in the form where /// x is on LHS. It only matters for < or >. + /// \param IsPostfixUpdate True if original value of 'x' must be stored in + /// 'v', not an updated one (for compare capture + /// only). + /// \param IsFailOnly True if the original value of 'x' is stored to 'v' + /// only when the comparison fails. This is only valid for + /// the case the comparison is '=='. /// /// \return Insertion point after generated atomic capture IR. - InsertPointTy createAtomicCompare(const LocationDescription &Loc, - AtomicOpValue &X, Value *E, Value *D, - AtomicOrdering AO, - omp::OMPAtomicCompareOp Op, - bool IsXBinopExpr); + InsertPointTy + createAtomicCompare(const LocationDescription &Loc, AtomicOpValue &X, + AtomicOpValue &V, AtomicOpValue &R, Value *E, Value *D, + AtomicOrdering AO, omp::OMPAtomicCompareOp Op, + bool IsXBinopExpr, bool IsPostfixUpdate, bool IsFailOnly); /// Create the control flow structure of a canonical OpenMP loop. /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -3475,8 +3475,11 @@ } OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare( - const LocationDescription &Loc, AtomicOpValue &X, Value *E, Value *D, - AtomicOrdering AO, OMPAtomicCompareOp Op, bool IsXBinopExpr) { + const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V, + AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO, + omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate, + bool IsFailOnly) { + if (!updateToLocation(Loc)) return Loc.IP; @@ -3484,14 +3487,80 @@ "OMP atomic expects a pointer to target memory"); assert((X.ElemTy->isIntegerTy() || X.ElemTy->isPointerTy()) && "OMP atomic compare expected a integer scalar type"); + // compare capture + if (V.Var) { + assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type"); + assert(V.ElemTy == X.ElemTy && "x and v must be of same type"); + } if (Op == OMPAtomicCompareOp::EQ) { AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO); - // We don't need the result for now. - (void)Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure); + AtomicCmpXchgInst *Result = + Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure); + if (V.Var) { + Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0); + assert(OldValue->getType() == V.ElemTy && + "OldValue and V must be of same type"); + if (IsPostfixUpdate) { + Builder.CreateStore(OldValue, V.Var, V.IsVolatile); + } else { + Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1); + if (IsFailOnly) { + // CurBB---- + // | | + // v | + // ContBB | + // | | + // v | + // ExitBB <- + // + // where ContBB only contains the store of old value to 'v'. + BasicBlock *CurBB = Builder.GetInsertBlock(); + Instruction *CurBBTI = CurBB->getTerminator(); + CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable(); + BasicBlock *ExitBB = CurBB->splitBasicBlock( + CurBBTI, X.Var->getName() + ".atomic.exit"); + BasicBlock *ContBB = CurBB->splitBasicBlock( + CurBB->getTerminator(), X.Var->getName() + ".atomic.cont"); + ContBB->getTerminator()->eraseFromParent(); + + Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB); + + Builder.SetInsertPoint(ContBB); + Builder.CreateStore(OldValue, V.Var); + Builder.CreateBr(ExitBB); + + if (UnreachableInst *ExitTI = + dyn_cast(ExitBB->getTerminator())) { + CurBBTI->eraseFromParent(); + Builder.SetInsertPoint(ExitBB); + } else { + Builder.SetInsertPoint(ExitTI); + } + } else { + Value *CapturedValue = + Builder.CreateSelect(SuccessOrFail, E, OldValue); + Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile); + } + } + } + // The comparison result has to be stored. + if (R.Var) { + assert(R.Var->getType()->isPointerTy() && + "r.var must be of pointer type"); + assert(R.ElemTy->isIntegerTy() && "r must be of integral type"); + + Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1); + unsigned Addrspace = + cast(R.Var->getType())->getAddressSpace(); + Value *ResultCast = Builder.CreateBitCast( + SuccessFailureVal, R.ElemTy->getPointerTo(Addrspace)); + Builder.CreateStore(ResultCast, R.Var, R.IsVolatile); + } } else { assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) && "Op should be either max or min at this point"); + assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is =="); // Reverse the ordop as the OpenMP forms are different from LLVM forms. // Let's take max as example. @@ -3517,8 +3586,36 @@ NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax : AtomicRMWInst::UMin; } - // We dont' need the result for now. - (void)Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO); + + AtomicRMWInst *OldValue = + Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO); + if (V.Var) { + Value *CapturedValue = nullptr; + if (IsPostfixUpdate) { + CapturedValue = OldValue; + } else { + CmpInst::Predicate Pred; + switch (NewOp) { + case AtomicRMWInst::Max: + Pred = CmpInst::ICMP_SGT; + break; + case AtomicRMWInst::UMax: + Pred = CmpInst::ICMP_UGT; + break; + case AtomicRMWInst::Min: + Pred = CmpInst::ICMP_SLT; + break; + case AtomicRMWInst::UMin: + Pred = CmpInst::ICMP_ULT; + break; + default: + llvm_unreachable("unexpected comparison op"); + } + Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E); + CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue); + } + Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile); + } } checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);