diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -2863,6 +2863,8 @@ POS_V, POS_E, POS_UpdateExpr, + POS_D, + POS_Cond, }; /// Set 'x' part of the associated expression/statement. @@ -2877,6 +2879,10 @@ void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; } /// Set 'expr' part of the associated expression/statement. void setExpr(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; } + /// Set 'd' part of the associated expression/statement. + void setD(Expr *D) { Data->getChildren()[DataPositionTy::POS_D] = D; } + /// Set conditional expression in `atomic compare`. + void setCond(Expr *C) { Data->getChildren()[DataPositionTy::POS_Cond] = C; } public: /// Creates directive with a list of \a Clauses and 'x', 'v' and 'expr' @@ -2894,6 +2900,8 @@ /// \param UE Helper expression of the form /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' or /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'. + /// \param D 'd' part of the associated expression/statement. + /// \param Cond Conditional expression in `atomic compare` construct. /// \param IsXLHSInRHSPart true if \a UE has the first form and false if the /// second. /// \param IsPostfixUpdate true if original value of 'x' must be stored in @@ -2901,7 +2909,8 @@ static OMPAtomicDirective * Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V, - Expr *E, Expr *UE, bool IsXLHSInRHSPart, bool IsPostfixUpdate); + Expr *E, Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart, + bool IsPostfixUpdate); /// Creates an empty directive with the place for \a NumClauses /// clauses. @@ -2951,6 +2960,20 @@ const Expr *getExpr() const { return cast_or_null(Data->getChildren()[DataPositionTy::POS_E]); } + /// Get 'd' part of the associated expression/statement. + Expr *getD() { + return cast_or_null(Data->getChildren()[DataPositionTy::POS_D]); + } + Expr *getD() const { + return cast_or_null(Data->getChildren()[DataPositionTy::POS_D]); + } + /// Get + Expr *getCondExpr() { + return cast_or_null(Data->getChildren()[DataPositionTy::POS_Cond]); + } + Expr *getCondExpr() const { + return cast_or_null(Data->getChildren()[DataPositionTy::POS_Cond]); + } static bool classof(const Stmt *T) { return T->getStmtClass() == OMPAtomicDirectiveClass; diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -863,16 +863,20 @@ !IsStandalone); } -OMPAtomicDirective *OMPAtomicDirective::Create( - const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, - ArrayRef Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V, - Expr *E, Expr *UE, bool IsXLHSInRHSPart, bool IsPostfixUpdate) { +OMPAtomicDirective * +OMPAtomicDirective::Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc, ArrayRef Clauses, + Stmt *AssociatedStmt, Expr *X, Expr *V, Expr *E, + Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart, + bool IsPostfixUpdate) { auto *Dir = createDirective( - C, Clauses, AssociatedStmt, /*NumChildren=*/4, StartLoc, EndLoc); + C, Clauses, AssociatedStmt, /*NumChildren=*/6, StartLoc, EndLoc); Dir->setX(X); Dir->setV(V); Dir->setExpr(E); Dir->setUpdateExpr(UE); + Dir->setD(D); + Dir->setCond(Cond); Dir->IsXLHSInRHSPart = IsXLHSInRHSPart; Dir->IsPostfixUpdate = IsPostfixUpdate; return Dir; diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -6011,11 +6011,52 @@ } } +static void emitOMPAtomicCompareExpr(CodeGenFunction &CGF, + llvm::AtomicOrdering AO, const Expr *X, + const Expr *E, const Expr *D, + const Expr *CE, bool IsXBinopExpr, + SourceLocation Loc) { + + llvm::OpenMPIRBuilder &OMPBuilder = + CGF.CGM.getOpenMPRuntime().getOMPBuilder(); + // llvm::OpenMPIRBuilder::InsertPointTy AllocaIP( + // CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator()); + + OMPAtomicCompareOp Op; + assert(isa(CE) && "CE is not a BinaryOperator"); + switch (cast(CE)->getOpcode()) { + case BO_EQ: + Op = OMPAtomicCompareOp::EQ; + break; + case BO_LT: + Op = OMPAtomicCompareOp::MIN; + break; + case BO_GT: + Op = OMPAtomicCompareOp::MAX; + break; + default: + llvm_unreachable("unsupported atomic compare binary operator"); + } + + LValue XLVal = CGF.EmitLValue(X); + llvm::Value *XPtr = XLVal.getPointer(CGF); + llvm::Value *EVal = CGF.EmitScalarExpr(E); + llvm::Value *DVal = D ? CGF.EmitScalarExpr(D) : nullptr; + + llvm::OpenMPIRBuilder::AtomicOpValue XOpVal{ + XPtr, XPtr->getType()->getPointerElementType(), + X->getType().isVolatileQualified(), + X->getType()->hasSignedIntegerRepresentation()}; + + CGF.Builder.restoreIP(OMPBuilder.createAtomicCompare( + CGF.Builder, XOpVal, EVal, DVal, AO, Op, IsXBinopExpr)); +} + static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind, llvm::AtomicOrdering AO, bool IsPostfixUpdate, const Expr *X, const Expr *V, const Expr *E, - const Expr *UE, bool IsXLHSInRHSPart, - SourceLocation Loc) { + const Expr *UE, const Expr *D, const Expr *CE, + bool IsXLHSInRHSPart, SourceLocation Loc) { switch (Kind) { case OMPC_read: emitOMPAtomicReadExpr(CGF, AO, X, V, Loc); @@ -6032,7 +6073,7 @@ IsXLHSInRHSPart, Loc); break; case OMPC_compare: - // Do nothing here as we already emit an error. + emitOMPAtomicCompareExpr(CGF, AO, X, E, D, CE, IsXLHSInRHSPart, Loc); break; case OMPC_if: case OMPC_final: @@ -6178,8 +6219,8 @@ LexicalScope Scope(*this, S.getSourceRange()); EmitStopPoint(S.getAssociatedStmt()); emitOMPAtomicExpr(*this, Kind, AO, S.isPostfixUpdate(), S.getX(), S.getV(), - S.getExpr(), S.getUpdateExpr(), S.isXLHSInRHSPart(), - S.getBeginLoc()); + S.getExpr(), S.getUpdateExpr(), S.getD(), S.getCondExpr(), + S.isXLHSInRHSPart(), S.getBeginLoc()); } static void emitCommonOMPTargetDirective(CodeGenFunction &CGF, diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -11372,6 +11372,8 @@ Expr *V = nullptr; Expr *E = nullptr; Expr *UE = nullptr; + Expr *D = nullptr; + Expr *CE = nullptr; bool IsXLHSInRHSPart = false; bool IsPostfixUpdate = false; // OpenMP [2.12.6, atomic Construct] @@ -11768,17 +11770,18 @@ << ErrorInfo.Error << ErrorInfo.NoteRange; return StmtError(); } - // TODO: For now we emit an error here and in emitOMPAtomicExpr we ignore - // code gen. - unsigned DiagID = Diags.getCustomDiagID( - DiagnosticsEngine::Error, "atomic compare is not supported for now"); - Diag(AtomicKindLoc, DiagID); + X = Checker.getX(); + E = Checker.getE(); + D = Checker.getD(); + CE = Checker.getCond(); + // We reuse this bool variable to tell if it is in the form 'x ordop expr'. + IsXLHSInRHSPart = Checker.isXBinopExpr(); } setFunctionHasBranchProtectedScope(); return OMPAtomicDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - X, V, E, UE, IsXLHSInRHSPart, + X, V, E, UE, D, CE, IsXLHSInRHSPart, IsPostfixUpdate); } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1357,7 +1357,6 @@ /// if (e == x) { x = d; } (this one is not in the spec) /// /// \param Loc The insert and source location description. - /// \param AllocIP Instruction to create AllocaInst before. /// \param X The target atomic pointer to be updated. /// \param E The expected value ('e') for forms that use an /// equality comparison or an expression ('expr') for @@ -1370,15 +1369,13 @@ /// \param OP Atomic compare operation. It can only be ==, <, or >. /// \param IsXBinopExpr True if the conditional statement is in the form where /// x is on LHS. It only matters for < or >. - /// \param IsSignedOp If the operation is signed or unsigned. It only - /// matters for < or >. /// /// \return Insertion point after generated atomic capture IR. InsertPointTy createAtomicCompare(const LocationDescription &Loc, - Instruction *AllocIP, AtomicOpValue &X, - Value *E, Value *D, AtomicOrdering AO, + AtomicOpValue &X, Value *E, Value *D, + AtomicOrdering AO, omp::OMPAtomicCompareOp Op, - bool IsXBinopExpr, bool IsSignedOp); + bool IsXBinopExpr); /// Create the control flow structure of a canonical OpenMP loop. /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -3473,10 +3473,11 @@ return Builder.saveIP(); } -OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare( - const LocationDescription &Loc, Instruction *AllocIP, AtomicOpValue &X, - Value *E, Value *D, AtomicOrdering AO, OMPAtomicCompareOp Op, - bool IsXBinopExpr, bool IsSignedOp) { +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::createAtomicCompare(const LocationDescription &Loc, + AtomicOpValue &X, Value *E, Value *D, + AtomicOrdering AO, OMPAtomicCompareOp Op, + bool IsXBinopExpr) { if (!updateToLocation(Loc)) return Loc.IP; @@ -3514,14 +3515,14 @@ // x = x <= expr ? x : expr; AtomicRMWInst::BinOp NewOp; if (IsXBinopExpr) { - if (IsSignedOp) + if (X.IsSigned) NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min : AtomicRMWInst::Max; else NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin : AtomicRMWInst::UMax; } else { - if (IsSignedOp) + if (X.IsSigned) NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max : AtomicRMWInst::Min; else