diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -731,4 +731,56 @@ }]; } +def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0>; +def AtomicBinOpAdd : I64EnumAttrCase<"add", 1>; +def AtomicBinOpSub : I64EnumAttrCase<"sub", 2>; +def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3>; +def AtomicBinOpNand : I64EnumAttrCase<"nand", 4>; +def AtomicBinOpOr : I64EnumAttrCase<"_or", 5>; +def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6>; +def AtomicBinOpMax : I64EnumAttrCase<"max", 7>; +def AtomicBinOpMin : I64EnumAttrCase<"min", 8>; +def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9>; +def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10>; +def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11>; +def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12>; +def AtomicBinOp : I64EnumAttr< + "AtomicBinOp", + "llvm.atomicrmw binary operations", + [AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd, + AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax, + AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd, + AtomicBinOpFSub]> { + let cppNamespace = "::mlir::LLVM"; +} + +def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0>; +def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1>; +def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2>; +def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 4>; +def AtomicOrderingRelease : I64EnumAttrCase<"release", 5>; +def AtomicOrderingAcquireRelease : I64EnumAttrCase<"acq_rel", 6>; +def AtomicOrderingSequentiallyConsistent : I64EnumAttrCase<"seq_cst", 7>; +def AtomicOrdering : I64EnumAttr< + "AtomicOrdering", + "Atomic ordering for LLVM's memory model", + [AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic, + AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcquireRelease, + AtomicOrderingSequentiallyConsistent]> { + let cppNamespace = "::mlir::LLVM"; +} + +def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">, + Arguments<(ins AtomicBinOp:$op, LLVM_Type:$ptr, LLVM_Type:$val, + AtomicOrdering:$ordering)>, + Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + $res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($op), $ptr, $val, + getLLVMAtomicOrdering($ordering)); + }]; + let parser = [{ return parseAtomicRMWOp(parser, result); }]; + let printer = [{ printAtomicRMWOp(p, *this); }]; + let verifier = "return ::verify(*this);"; +} + #endif // LLVMIR_OPS diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1368,6 +1368,117 @@ return success(); } +//===----------------------------------------------------------------------===// +// Printer, parser and verifier for LLVM::AtomicRMWOp. +//===----------------------------------------------------------------------===// + +static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { + p << op.getOperationName() << " "; + p << "\"" << stringifyAtomicBinOp(op.op()) << "\" "; + p << "\"" << stringifyAtomicOrdering(op.ordering()) << "\" "; + p << *op.ptr() << ", " << *op.val(); + p.printOptionalAttrDict(op.getAttrs(), {"op", "ordering"}); + p << " : (" << op.ptr().getType() << ", " << op.val().getType() << ") -> " + << op.res().getType(); +} + +// ::= `llvm.atomicrmw` string-literal string-literal +// ssa-use `,` ssa-use attribute-dict? `:` type +static ParseResult parseAtomicRMWOp(OpAsmParser &parser, + OperationState &result) { + Attribute op; + Attribute ordering; + SmallVector attrs; + Type type; + llvm::SMLoc opLoc, orderingLoc, trailingTypeLoc; + OpAsmParser::OperandType ptr, val; + if (parser.getCurrentLocation(&opLoc) || + parser.parseAttribute(op, "op", attrs) || + parser.getCurrentLocation(&orderingLoc) || + parser.parseAttribute(ordering, "ordering", attrs) || + parser.parseOperand(ptr) || parser.parseComma() || + parser.parseOperand(val) || parser.parseOptionalAttrDict(attrs) || + parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || + parser.parseType(type)) + return failure(); + + // Extract the result type from the trailing function type. + auto funcType = type.dyn_cast(); + if (!funcType || funcType.getNumInputs() != 2 || + funcType.getNumResults() != 1) + return parser.emitError( + trailingTypeLoc, + "expected trailing function type with two arguments and one result"); + + if (parser.resolveOperand(ptr, funcType.getInput(0), result.operands) || + parser.resolveOperand(val, funcType.getInput(1), result.operands)) + return failure(); + + // Replace the string attribute `op` with an integer attribute. + auto opStr = op.dyn_cast(); + if (!opStr) + return parser.emitError(opLoc, "expected 'op' attribute of string type"); + + auto opKind = symbolizeAtomicBinOp(opStr.getValue()); + if (!opKind) { + return parser.emitError(opLoc) + << "'" << opStr.getValue() + << "' is an incorrect value of the 'op' attribute"; + } + + auto opValue = static_cast(opKind.getValue()); + attrs[0].second = parser.getBuilder().getI64IntegerAttr(opValue); + + // Replace the string attribute `ordering` with an integer attribute. + auto orderingStr = ordering.dyn_cast(); + if (!orderingStr) + return parser.emitError(orderingLoc, + "expected 'ordering' attribute of string type"); + + auto orderingKind = symbolizeAtomicOrdering(orderingStr.getValue()); + if (!orderingKind) { + return parser.emitError(orderingLoc) + << "'" << orderingStr.getValue() + << "' is an incorrect value of the 'ordering' attribute"; + } + + auto orderingValue = static_cast(orderingKind.getValue()); + attrs[1].second = parser.getBuilder().getI64IntegerAttr(orderingValue); + + result.attributes = attrs; + result.addTypes(funcType.getResults()); + return success(); +} + +static bool isFPOperation(AtomicBinOp op) { + switch (op) { + case AtomicBinOp::fadd: + case AtomicBinOp::fsub: + return true; + default: + return false; + } +} + +static LogicalResult verify(AtomicRMWOp op) { + auto ptrType = op.ptr().getType().dyn_cast(); + if (!ptrType || !ptrType.isPointerTy()) + return op.emitOpError("expected LLVM IR pointer type for `ptr` operand"); + auto valType = op.val().getType().dyn_cast(); + if (!valType || valType != ptrType.getPointerElementTy()) + return op.emitOpError("expected LLVM IR element type for `ptr` operand to " + "match type for `val` operand"); + auto resType = op.res().getType().dyn_cast(); + if (!resType || resType != valType) + return op.emitOpError( + "expected LLVM IR result type to match type for `val` operand"); + if (isFPOperation(op.op())) { + if (!valType.getUnderlyingType()->isFloatingPointTy()) + return op.emitOpError("expected LLVM IR floating point type"); + } + return success(); +} + //===----------------------------------------------------------------------===// // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -150,6 +150,58 @@ llvm_unreachable("incorrect comparison predicate"); } +static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) { + switch (op) { + case LLVM::AtomicBinOp::xchg: + return llvm::AtomicRMWInst::BinOp::Xchg; + case LLVM::AtomicBinOp::add: + return llvm::AtomicRMWInst::BinOp::Add; + case LLVM::AtomicBinOp::sub: + return llvm::AtomicRMWInst::BinOp::Sub; + case LLVM::AtomicBinOp::_and: + return llvm::AtomicRMWInst::BinOp::And; + case LLVM::AtomicBinOp::nand: + return llvm::AtomicRMWInst::BinOp::Nand; + case LLVM::AtomicBinOp::_or: + return llvm::AtomicRMWInst::BinOp::Or; + case LLVM::AtomicBinOp::_xor: + return llvm::AtomicRMWInst::BinOp::Xor; + case LLVM::AtomicBinOp::max: + return llvm::AtomicRMWInst::BinOp::Max; + case LLVM::AtomicBinOp::min: + return llvm::AtomicRMWInst::BinOp::Min; + case LLVM::AtomicBinOp::umax: + return llvm::AtomicRMWInst::BinOp::UMax; + case LLVM::AtomicBinOp::umin: + return llvm::AtomicRMWInst::BinOp::UMin; + case LLVM::AtomicBinOp::fadd: + return llvm::AtomicRMWInst::BinOp::FAdd; + case LLVM::AtomicBinOp::fsub: + return llvm::AtomicRMWInst::BinOp::FSub; + } + llvm_unreachable("incorrect atomic binary operator"); +} + +static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) { + switch (ordering) { + case LLVM::AtomicOrdering::not_atomic: + return llvm::AtomicOrdering::NotAtomic; + case LLVM::AtomicOrdering::unordered: + return llvm::AtomicOrdering::Unordered; + case LLVM::AtomicOrdering::monotonic: + return llvm::AtomicOrdering::Monotonic; + case LLVM::AtomicOrdering::acquire: + return llvm::AtomicOrdering::Acquire; + case LLVM::AtomicOrdering::release: + return llvm::AtomicOrdering::Release; + case LLVM::AtomicOrdering::acq_rel: + return llvm::AtomicOrdering::AcquireRelease; + case LLVM::AtomicOrdering::seq_cst: + return llvm::AtomicOrdering::SequentiallyConsistent; + } + llvm_unreachable("incorrect atomic ordering"); +} + /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -218,3 +218,12 @@ %1 = llvm.mlir.null : !llvm<"{void(i32, void()*)*, i64}*"> llvm.return } + +// CHECK-LABEL: @atomics +func @atomics(%arg0 : !llvm.i32, %arg1 : !llvm.float) { + %0 = llvm.alloca %arg0 x !llvm.float : (!llvm.i32) -> !llvm<"float*"> + %1 = llvm.getelementptr %0[%arg0, %arg0] : (!llvm<"float*">, !llvm.i32, !llvm.i32) -> !llvm<"float*"> + // CHECK: llvm.atomicrmw "fadd" "unordered" %1, %arg1 {volatile} : (!llvm<"float*">, !llvm.float) -> !llvm.float + %2 = llvm.atomicrmw "fadd" "unordered" %1, %arg1 {volatile} : (!llvm<"float*">, !llvm.float) -> !llvm.float + llvm.return +}