diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1166,9 +1166,9 @@ /// \param UpdateOp Code generator for complex expressions that cannot be /// expressed through atomicrmw instruction. /// \param VolatileX true if \a X volatile? - /// \param IsXLHSInRHSPart true if \a X is Left H.S. in Right H.S. part of - /// the update expression, false otherwise. - /// (e.g. true for X = X BinOp Expr) + /// \param IsXBinopExpr true if \a X is Left H.S. in Right H.S. part of the + /// update expression, false otherwise. + /// (e.g. true for X = X BinOp Expr) /// /// \returns A pair of the old value of X before the update, and the value /// used for the update. @@ -1177,7 +1177,7 @@ AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, - bool IsXLHSInRHSPart); + bool IsXBinopExpr); /// Emit the binary op. described by \p RMWOp, using \p Src1 and \p Src2 . /// @@ -1235,9 +1235,9 @@ /// atomic will be generated. /// \param UpdateOp Code generator for complex expressions that cannot be /// expressed through atomicrmw instruction. - /// \param IsXLHSInRHSPart true if \a X is Left H.S. in Right H.S. part of - /// the update expression, false otherwise. - /// (e.g. true for X = X BinOp Expr) + /// \param IsXBinopExpr true if \a X is Left H.S. in Right H.S. part of the + /// update expression, false otherwise. + /// (e.g. true for X = X BinOp Expr) /// /// \return Insertion point after generated atomic update IR. InsertPointTy createAtomicUpdate(const LocationDescription &Loc, @@ -1245,7 +1245,7 @@ Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp, - bool IsXLHSInRHSPart); + bool IsXBinopExpr); /// Emit atomic update for constructs: --- Only Scalar data types /// V = X; X = X BinOp Expr , @@ -1269,9 +1269,9 @@ /// expressed through atomicrmw instruction. /// \param UpdateExpr true if X is an in place update of the form /// X = X BinOp Expr or X = Expr BinOp X - /// \param IsXLHSInRHSPart true if X is Left H.S. in Right H.S. part of the - /// update expression, false otherwise. - /// (e.g. true for X = X BinOp Expr) + /// \param IsXBinopExpr true if X is Left H.S. in Right H.S. part of the + /// update expression, false otherwise. + /// (e.g. true for X = X BinOp Expr) /// \param IsPostfixUpdate true if original value of 'x' must be stored in /// 'v', not an updated one. /// @@ -1281,7 +1281,7 @@ AtomicOpValue &X, AtomicOpValue &V, Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp, bool UpdateExpr, - bool IsPostfixUpdate, bool IsXLHSInRHSPart); + bool IsPostfixUpdate, bool IsXBinopExpr); /// Create the control flow structure of a canonical OpenMP loop. /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -3077,7 +3077,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate( const LocationDescription &Loc, Instruction *AllocIP, AtomicOpValue &X, Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp, - AtomicUpdateCallbackTy &UpdateOp, bool IsXLHSInRHSPart) { + AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) { if (!updateToLocation(Loc)) return Loc.IP; @@ -3095,7 +3095,7 @@ }); emitAtomicUpdate(AllocIP, X.Var, Expr, AO, RMWOp, UpdateOp, X.IsVolatile, - IsXLHSInRHSPart); + IsXBinopExpr); checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update); return Builder.saveIP(); } @@ -3132,13 +3132,13 @@ OpenMPIRBuilder::emitAtomicUpdate(Instruction *AllocIP, Value *X, Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp, - bool VolatileX, bool IsXLHSInRHSPart) { + bool VolatileX, bool IsXBinopExpr) { Type *XElemTy = X->getType()->getPointerElementType(); bool DoCmpExch = ((RMWOp == AtomicRMWInst::BAD_BINOP) || (RMWOp == AtomicRMWInst::FAdd)) || (RMWOp == AtomicRMWInst::FSub) || - (RMWOp == AtomicRMWInst::Sub && !IsXLHSInRHSPart); + (RMWOp == AtomicRMWInst::Sub && !IsXBinopExpr); std::pair Res; if (XElemTy->isIntegerTy() && !DoCmpExch) { @@ -3230,7 +3230,7 @@ const LocationDescription &Loc, Instruction *AllocIP, AtomicOpValue &X, AtomicOpValue &V, Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp, - bool UpdateExpr, bool IsPostfixUpdate, bool IsXLHSInRHSPart) { + bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) { if (!updateToLocation(Loc)) return Loc.IP; @@ -3249,9 +3249,8 @@ // If UpdateExpr is 'x' updated with some `expr` not based on 'x', // 'x' is simply atomically rewritten with 'expr'. AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg); - std::pair Result = - emitAtomicUpdate(AllocIP, X.Var, Expr, AO, AtomicOp, UpdateOp, - X.IsVolatile, IsXLHSInRHSPart); + std::pair Result = emitAtomicUpdate( + AllocIP, X.Var, Expr, AO, AtomicOp, UpdateOp, X.IsVolatile, IsXBinopExpr); Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second); Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -592,6 +592,44 @@ let verifier = [{ return verifyAtomicWriteOp(*this); }]; } +def ATOMIC_BINOP_KIND_ADD : I64EnumAttrCase<"ADD", 0>; +def ATOMIC_BINOP_KIND_MUL : I64EnumAttrCase<"MUL", 1>; +def ATOMIC_BINOP_KIND_SUB : I64EnumAttrCase<"SUB", 2>; +def ATOMIC_BINOP_KIND_DIV : I64EnumAttrCase<"DIV", 3>; +def ATOMIC_BINOP_KIND_AND : I64EnumAttrCase<"AND", 4>; +def ATOMIC_BINOP_KIND_OR : I64EnumAttrCase<"OR", 5>; +def ATOMIC_BINOP_KIND_XOR : I64EnumAttrCase<"XOR", 6>; +def ATOMIC_BINOP_KIND_SHIFT_RIGHT : I64EnumAttrCase<"SHIFTR", 7>; +def ATOMIC_BINOP_KIND_SHIFT_LEFT : I64EnumAttrCase<"SHIFTL", 8>; +def ATOMIC_BINOP_KIND_MAX : I64EnumAttrCase<"MAX", 9>; +def ATOMIC_BINOP_KIND_MIN : I64EnumAttrCase<"MIN", 10>; +def ATOMIC_BINOP_KIND_EQV : I64EnumAttrCase<"EQV", 11>; +def ATOMIC_BINOP_KIND_NEQV : I64EnumAttrCase<"NEQV", 12>; + +def AtomicBinOpKindAttr : I64EnumAttr< + "AtomicBinOpKind", "BinOp for Atomic Updates", + [ATOMIC_BINOP_KIND_ADD, ATOMIC_BINOP_KIND_MUL, ATOMIC_BINOP_KIND_SUB, + ATOMIC_BINOP_KIND_DIV, ATOMIC_BINOP_KIND_AND, ATOMIC_BINOP_KIND_OR, + ATOMIC_BINOP_KIND_XOR, ATOMIC_BINOP_KIND_SHIFT_RIGHT, + ATOMIC_BINOP_KIND_SHIFT_LEFT, ATOMIC_BINOP_KIND_MAX, + ATOMIC_BINOP_KIND_MIN, ATOMIC_BINOP_KIND_EQV, ATOMIC_BINOP_KIND_NEQV]> { + let cppNamespace = "::mlir::omp"; + let stringToSymbolFnName = "AtomicBinOpKindToEnum"; + let symbolToStringFnName = "AtomicBinOpKindToString"; +} + +def AtomicUpdateOp : OpenMP_Op<"atomic.update"> { + let arguments = (ins OpenMP_PointerLikeType:$x, + AnyType:$expr, + UnitAttr:$isXBinopExpr, + AtomicBinOpKindAttr:$binop, + DefaultValuedAttr:$hint, + OptionalAttr:$memory_order); + let parser = [{ return parseAtomicUpdateOp(parser, result); }]; + let printer = [{ return printAtomicUpdateOp(p, *this); }]; + let verifier = [{ return verifyAtomicUpdateOp(*this); }]; +} + //===----------------------------------------------------------------------===// // 2.19.5.7 declare reduction Directive //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1372,5 +1372,84 @@ return verifySynchronizationHint(op, op.hint()); } +//===----------------------------------------------------------------------===// +// AtomicUpdateOp +//===----------------------------------------------------------------------===// + +/// Parser for AtomicUpdateOp +/// +/// operation ::= `omp.atomic.update` atomic-clause-list region +static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, + OperationState &result) { + SmallVector clauses = {memoryOrderClause, hintClause}; + SmallVector segments; + OpAsmParser::OperandType x, y, z; + Type xType, exprType; + StringRef binOp; + + // x = y `op` z : xtype, exprtype + if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) || + parser.parseKeyword(&binOp) || parser.parseOperand(z) || + parseClauses(parser, result, clauses, segments) || parser.parseColon() || + parser.parseType(xType) || parser.parseComma() || + parser.parseType(exprType) || + parser.resolveOperand(x, xType, result.operands)) { + return failure(); + } + + auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper()); + if (!binOpEnum) + return parser.emitError(parser.getNameLoc()) + << "invalid atomic bin op in atomic update\n"; + auto attr = + parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue()); + result.addAttribute("binop", attr); + + OpAsmParser::OperandType expr; + if (x.name == y.name && x.number == y.number) { + expr = z; + result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr()); + } else if (x.name == z.name && x.number == z.number) { + expr = y; + } else { + return parser.emitError(parser.getNameLoc()) + << "atomic update variable " << x.name + << " not found in the RHS of the assignment statement in an" + " atomic.update operation"; + } + return parser.resolveOperand(expr, exprType, result.operands); +} + +/// Printer for AtomicUpdateOp +static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) { + p << " " << op.x() << " = "; + Value y, z; + if (op.isXBinopExpr()) { + y = op.x(); + z = op.expr(); + } else { + y = op.expr(); + z = op.x(); + } + p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z + << " "; + if (op.memory_order()) + p << "memory_order(" << op.memory_order() << ") "; + if (op.hintAttr()) + printSynchronizationHint(p, op, op.hintAttr()); + p << ": " << op.x().getType() << ", " << op.expr().getType(); +} + +/// Verifier for AtomicUpdateOp +static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) { + if (op.memory_order()) { + StringRef memoryOrder = op.memory_order().getValue(); + if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire")) + return op.emitError( + "memory-order must not be acq_rel or acquire for atomic updates"); + } + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -601,6 +601,47 @@ // ----- +func @omp_atomic_update1(%x: memref, %expr: i32, %foo: memref) { + // expected-error @below {{atomic update variable %x not found in the RHS of the assignment statement in an atomic.update operation}} + omp.atomic.update %x = %foo add %expr : memref, i32 + return +} + +// ----- + +func @omp_atomic_update2(%x: memref, %expr: i32) { + // expected-error @below {{invalid atomic bin op in atomic update}} + omp.atomic.update %x = %x invalid %expr : memref, i32 + return +} + +// ----- + +func @omp_atomic_update3(%x: memref, %expr: i32) { + // expected-error @below {{memory-order must not be acq_rel or acquire for atomic updates}} + omp.atomic.update %x = %x add %expr memory_order(acq_rel) : memref, i32 + return +} + +// ----- + +func @omp_atomic_update4(%x: memref, %expr: i32) { + // expected-error @below {{memory-order must not be acq_rel or acquire for atomic updates}} + omp.atomic.update %x = %x add %expr memory_order(acquire) : memref, i32 + return +} + +// ----- + +// expected-note @below {{prior use here}} +func @omp_atomic_update5(%x: memref, %expr: i32) { + // expected-error @below {{use of value '%x' expects different type than prior uses: 'i32' vs 'memref'}} + omp.atomic.update %x = %x add %expr : i32, memref + return +} + +// ----- + func @omp_sections(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref) -> () { // expected-error @below {{operand used in both private and firstprivate clauses}} omp.sections private(%data_var1 : memref) firstprivate(%data_var1 : memref) { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -524,6 +524,67 @@ return } +// CHECK-LABEL: omp_atomic_update +// CHECK-SAME: (%[[X:.*]]: memref, %[[EXPR:.*]]: i32, %[[XBOOL:.*]]: memref, %[[EXPRBOOL:.*]]: i1) +func @omp_atomic_update(%x : memref, %expr : i32, %xBool : memref, %exprBool : i1) { + // CHECK: omp.atomic.update %[[X]] = %[[X]] add %[[EXPR]] : memref, i32 + omp.atomic.update %x = %x add %expr : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[X]] sub %[[EXPR]] : memref, i32 + omp.atomic.update %x = %x sub %expr : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[X]] mul %[[EXPR]] : memref, i32 + omp.atomic.update %x = %x mul %expr : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[X]] div %[[EXPR]] : memref, i32 + omp.atomic.update %x = %x div %expr : memref, i32 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] and %[[EXPRBOOL]] : memref, i1 + omp.atomic.update %xBool = %xBool and %exprBool : memref, i1 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] or %[[EXPRBOOL]] : memref, i1 + omp.atomic.update %xBool = %xBool or %exprBool : memref, i1 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] xor %[[EXPRBOOL]] : memref, i1 + omp.atomic.update %xBool = %xBool xor %exprBool : memref, i1 + // CHECK: omp.atomic.update %[[X]] = %[[X]] shiftr %[[EXPR]] : memref, i32 + omp.atomic.update %x = %x shiftr %expr : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[X]] shiftl %[[EXPR]] : memref, i32 + omp.atomic.update %x = %x shiftl %expr : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[X]] max %[[EXPR]] : memref, i32 + omp.atomic.update %x = %x max %expr : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[X]] min %[[EXPR]] : memref, i32 + omp.atomic.update %x = %x min %expr : memref, i32 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] eqv %[[EXPRBOOL]] : memref, i1 + omp.atomic.update %xBool = %xBool eqv %exprBool : memref, i1 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] neqv %[[EXPRBOOL]] : memref, i1 + omp.atomic.update %xBool = %xBool neqv %exprBool : memref, i1 + + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] add %[[X]] : memref, i32 + omp.atomic.update %x = %expr add %x : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] sub %[[X]] : memref, i32 + omp.atomic.update %x = %expr sub %x : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] mul %[[X]] : memref, i32 + omp.atomic.update %x = %expr mul %x : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] div %[[X]] : memref, i32 + omp.atomic.update %x = %expr div %x : memref, i32 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] and %[[XBOOL]] : memref, i1 + omp.atomic.update %xBool = %exprBool and %xBool : memref, i1 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] or %[[XBOOL]] : memref, i1 + omp.atomic.update %xBool = %exprBool or %xBool : memref, i1 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] xor %[[XBOOL]] : memref, i1 + omp.atomic.update %xBool = %exprBool xor %xBool : memref, i1 + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] shiftr %[[X]] : memref, i32 + omp.atomic.update %x = %expr shiftr %x : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] shiftl %[[X]] : memref, i32 + omp.atomic.update %x = %expr shiftl %x : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] max %[[X]] : memref, i32 + omp.atomic.update %x = %expr max %x : memref, i32 + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] min %[[X]] : memref, i32 + omp.atomic.update %x = %expr min %x : memref, i32 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] eqv %[[XBOOL]] : memref, i1 + omp.atomic.update %xBool = %exprBool eqv %xBool : memref, i1 + // CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] neqv %[[XBOOL]] : memref, i1 + omp.atomic.update %xBool = %exprBool neqv %xBool : memref, i1 + // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] add %[[X]] memory_order(seq_cst) hint(speculative) : memref, i32 + omp.atomic.update %x = %expr add %x hint(speculative) memory_order(seq_cst) : memref, i32 + return +} + // CHECK-LABEL: omp_sectionsop func @omp_sectionsop(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref, %redn_var : !llvm.ptr) {