diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -731,4 +731,56 @@ }]; } +def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0>; +def AtomicBinOpAdd : I64EnumAttrCase<"add", 1>; +def AtomicBinOpSub : I64EnumAttrCase<"sub", 2>; +def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3>; +def AtomicBinOpNand : I64EnumAttrCase<"nand", 4>; +def AtomicBinOpOr : I64EnumAttrCase<"_or", 5>; +def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6>; +def AtomicBinOpMax : I64EnumAttrCase<"max", 7>; +def AtomicBinOpMin : I64EnumAttrCase<"min", 8>; +def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9>; +def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10>; +def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11>; +def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12>; +def AtomicBinOp : I64EnumAttr< + "AtomicBinOp", + "llvm.atomicrmw binary operations", + [AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd, + AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax, + AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd, + AtomicBinOpFSub]> { + let cppNamespace = "::mlir::LLVM"; +} + +def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0>; +def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1>; +def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2>; +def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 4>; +def AtomicOrderingRelease : I64EnumAttrCase<"release", 5>; +def AtomicOrderingAcquireRelease : I64EnumAttrCase<"acq_rel", 6>; +def AtomicOrderingSequentiallyConsistent : I64EnumAttrCase<"seq_cst", 7>; +def AtomicOrdering : I64EnumAttr< + "AtomicOrdering", + "Atomic ordering for LLVM's memory model", + [AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic, + AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcquireRelease, + AtomicOrderingSequentiallyConsistent]> { + let cppNamespace = "::mlir::LLVM"; +} + +def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">, + Arguments<(ins AtomicBinOp:$bin_op, LLVM_Type:$ptr, LLVM_Type:$val, + AtomicOrdering:$ordering)>, + Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + $res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val, + getLLVMAtomicOrdering($ordering)); + }]; + let parser = [{ return parseAtomicRMWOp(parser, result); }]; + let printer = [{ printAtomicRMWOp(p, *this); }]; + let verifier = "return ::verify(*this);"; +} + #endif // LLVMIR_OPS diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1368,6 +1368,107 @@ return success(); } +//===----------------------------------------------------------------------===// +// Printer, parser and verifier for LLVM::AtomicRMWOp. +//===----------------------------------------------------------------------===// + +static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { + p << op.getOperationName() << " "; + p << '"' << stringifyAtomicBinOp(op.bin_op()) << "\" "; + p << '"' << stringifyAtomicOrdering(op.ordering()) << "\" "; + p << op.ptr() << ", " << op.val(); + p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"}); + p << " : (" << op.ptr().getType() << ", " << op.val().getType() << ") -> " + << op.res().getType(); +} + +// ::= `llvm.atomicrmw` string-literal string-literal +// ssa-use `,` ssa-use attribute-dict? `:` type +static ParseResult parseAtomicRMWOp(OpAsmParser &parser, + OperationState &result) { + Type type; + StringAttr binOp, ordering; + llvm::SMLoc binOpLoc, orderingLoc, trailingTypeLoc; + OpAsmParser::OperandType ptr, val; + if (parser.getCurrentLocation(&binOpLoc) || + parser.parseAttribute(binOp, "bin_op", result.attributes) || + parser.getCurrentLocation(&orderingLoc) || + parser.parseAttribute(ordering, "ordering", result.attributes) || + parser.parseOperand(ptr) || parser.parseComma() || + parser.parseOperand(val) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) + return failure(); + + // Extract the result type from the trailing function type. + auto funcType = type.dyn_cast(); + if (!funcType || funcType.getNumInputs() != 2 || + funcType.getNumResults() != 1) + return parser.emitError( + trailingTypeLoc, + "expected trailing function type with two arguments and one result"); + + if (parser.resolveOperand(ptr, funcType.getInput(0), result.operands) || + parser.resolveOperand(val, funcType.getInput(1), result.operands)) + return failure(); + + // Replace the string attribute `bin_op` with an integer attribute. + auto binOpKind = symbolizeAtomicBinOp(binOp.getValue()); + if (!binOpKind) { + return parser.emitError(binOpLoc) + << "'" << binOp.getValue() + << "' is an incorrect value of the 'bin_op' attribute"; + } + + auto binOpValue = static_cast(binOpKind.getValue()); + auto binOpAttr = parser.getBuilder().getI64IntegerAttr(binOpValue); + result.attributes[0].second = binOpAttr; + + // Replace the string attribute `ordering` with an integer attribute. + auto orderingKind = symbolizeAtomicOrdering(ordering.getValue()); + if (!orderingKind) { + return parser.emitError(orderingLoc) + << "'" << ordering.getValue() + << "' is an incorrect value of the 'ordering' attribute"; + } + + auto orderingValue = static_cast(orderingKind.getValue()); + auto orderingAttr = parser.getBuilder().getI64IntegerAttr(orderingValue); + result.attributes[1].second = orderingAttr; + + result.addTypes(funcType.getResults()); + return success(); +} + +static LogicalResult verify(AtomicRMWOp op) { + auto ptrType = op.ptr().getType().cast(); + if (!ptrType.isPointerTy()) + return op.emitOpError("expected LLVM IR pointer type for operand #0"); + auto valType = op.val().getType().cast(); + if (valType != ptrType.getPointerElementTy()) + return op.emitOpError("expected LLVM IR element type for operand #0 to " + "match type for operand #1"); + auto resType = op.res().getType().cast(); + if (resType != valType) + return op.emitOpError( + "expected LLVM IR result type to match type for operand #1"); + if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { + if (!valType.getUnderlyingType()->isFloatingPointTy()) + return op.emitOpError("expected LLVM IR floating point type"); + } else if (op.bin_op() == AtomicBinOp::xchg) { + if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && + !valType.isIntegerTy(32) && !valType.isIntegerTy(64) && + !valType.getUnderlyingType()->isHalfTy() && !valType.isFloatTy() && + !valType.isDoubleTy()) + return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); + } else { + if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && + !valType.isIntegerTy(32) && !valType.isIntegerTy(64)) + return op.emitOpError("expected LLVM IR integer type"); + } + return success(); +} + //===----------------------------------------------------------------------===// // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -150,6 +150,58 @@ llvm_unreachable("incorrect comparison predicate"); } +static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) { + switch (op) { + case LLVM::AtomicBinOp::xchg: + return llvm::AtomicRMWInst::BinOp::Xchg; + case LLVM::AtomicBinOp::add: + return llvm::AtomicRMWInst::BinOp::Add; + case LLVM::AtomicBinOp::sub: + return llvm::AtomicRMWInst::BinOp::Sub; + case LLVM::AtomicBinOp::_and: + return llvm::AtomicRMWInst::BinOp::And; + case LLVM::AtomicBinOp::nand: + return llvm::AtomicRMWInst::BinOp::Nand; + case LLVM::AtomicBinOp::_or: + return llvm::AtomicRMWInst::BinOp::Or; + case LLVM::AtomicBinOp::_xor: + return llvm::AtomicRMWInst::BinOp::Xor; + case LLVM::AtomicBinOp::max: + return llvm::AtomicRMWInst::BinOp::Max; + case LLVM::AtomicBinOp::min: + return llvm::AtomicRMWInst::BinOp::Min; + case LLVM::AtomicBinOp::umax: + return llvm::AtomicRMWInst::BinOp::UMax; + case LLVM::AtomicBinOp::umin: + return llvm::AtomicRMWInst::BinOp::UMin; + case LLVM::AtomicBinOp::fadd: + return llvm::AtomicRMWInst::BinOp::FAdd; + case LLVM::AtomicBinOp::fsub: + return llvm::AtomicRMWInst::BinOp::FSub; + } + llvm_unreachable("incorrect atomic binary operator"); +} + +static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) { + switch (ordering) { + case LLVM::AtomicOrdering::not_atomic: + return llvm::AtomicOrdering::NotAtomic; + case LLVM::AtomicOrdering::unordered: + return llvm::AtomicOrdering::Unordered; + case LLVM::AtomicOrdering::monotonic: + return llvm::AtomicOrdering::Monotonic; + case LLVM::AtomicOrdering::acquire: + return llvm::AtomicOrdering::Acquire; + case LLVM::AtomicOrdering::release: + return llvm::AtomicOrdering::Release; + case LLVM::AtomicOrdering::acq_rel: + return llvm::AtomicOrdering::AcquireRelease; + case LLVM::AtomicOrdering::seq_cst: + return llvm::AtomicOrdering::SequentiallyConsistent; + } + llvm_unreachable("incorrect atomic ordering"); +} + /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -382,3 +382,51 @@ %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) } + +// ----- +// CHECK-LABEL: @atomicrmw_expected_ptr +func @atomicrmw_expected_ptr(%f32 : !llvm.float) { + // expected-error@+1 {{expected LLVM IR pointer type for operand #0}} + %0 = llvm.atomicrmw "fadd" "unordered" %f32, %f32 : (!llvm.float, !llvm.float) -> !llvm.float + llvm.return +} + +// ----- +// CHECK-LABEL: @atomicrmw_mismatched_operands +func @atomicrmw_mismatched_operands(%f32_ptr : !llvm<"float*">, %i32 : !llvm.i32) { + // expected-error@+1 {{expected LLVM IR element type for operand #0 to match type for operand #1}} + %0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %i32 : (!llvm<"float*">, !llvm.i32) -> !llvm.float + llvm.return +} + +// ----- +// CHECK-LABEL: @atomicrmw_mismatched_result +func @atomicrmw_mismatched_operands(%f32_ptr : !llvm<"float*">, %f32 : !llvm.float) { + // expected-error@+1 {{expected LLVM IR result type to match type for operand #1}} + %0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.i32 + llvm.return +} + +// ----- +// CHECK-LABEL: @atomicrmw_expected_float +func @atomicrmw_expected_float(%i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) { + // expected-error@+1 {{expected LLVM IR floating point type}} + %0 = llvm.atomicrmw "fadd" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + llvm.return +} + +// ----- +// CHECK-LABEL: @atomicrmw_unexpected_xchg_type +func @atomicrmw_xchg_type(%i1_ptr : !llvm<"i1*">, %i1 : !llvm.i1) { + // expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}} + %0 = llvm.atomicrmw "xchg" "unordered" %i1_ptr, %i1 : (!llvm<"i1*">, !llvm.i1) -> !llvm.i1 + llvm.return +} + +// ----- +// CHECK-LABEL: @atomicrmw_expected_int +func @atomicrmw_expected_int(%f32_ptr : !llvm<"float*">, %f32 : !llvm.float) { + // expected-error@+1 {{expected LLVM IR integer type}} + %0 = llvm.atomicrmw "max" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float + llvm.return +} diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -218,3 +218,10 @@ %1 = llvm.mlir.null : !llvm<"{void(i32, void()*)*, i64}*"> llvm.return } + +// CHECK-LABEL: @atomics +func @atomics(%arg0 : !llvm<"float*">, %arg1 : !llvm.float) { + // CHECK: llvm.atomicrmw "fadd" "unordered" %{{.*}}, %{{.*}} : (!llvm<"float*">, !llvm.float) -> !llvm.float + %0 = llvm.atomicrmw "fadd" "unordered" %arg0, %arg1 : (!llvm<"float*">, !llvm.float) -> !llvm.float + llvm.return +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1039,3 +1039,36 @@ // CHECK: ret i32* null llvm.return %0 : !llvm<"i32*"> } + +// CHECK-LABEL: @atomics +llvm.func @atomics( + %f32_ptr : !llvm<"float*">, %f32 : !llvm.float, + %i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) -> !llvm.float { + // CHECK: atomicrmw fadd float* %{{.*}}, float %{{.*}} unordered + %0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float + // CHECK: atomicrmw fsub float* %{{.*}}, float %{{.*}} unordered + %1 = llvm.atomicrmw "fsub" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float + // CHECK: atomicrmw xchg float* %{{.*}}, float %{{.*}} monotonic + %2 = llvm.atomicrmw "xchg" "monotonic" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float + // CHECK: atomicrmw add i32* %{{.*}}, i32 %{{.*}} acquire + %3 = llvm.atomicrmw "add" "acquire" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw sub i32* %{{.*}}, i32 %{{.*}} release + %4 = llvm.atomicrmw "sub" "release" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw and i32* %{{.*}}, i32 %{{.*}} acq_rel + %5 = llvm.atomicrmw "_and" "acq_rel" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw nand i32* %{{.*}}, i32 %{{.*}} seq_cst + %6 = llvm.atomicrmw "nand" "seq_cst" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw or i32* %{{.*}}, i32 %{{.*}} unordered + %7 = llvm.atomicrmw "_or" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw xor i32* %{{.*}}, i32 %{{.*}} unordered + %8 = llvm.atomicrmw "_xor" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw max i32* %{{.*}}, i32 %{{.*}} unordered + %9 = llvm.atomicrmw "max" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw min i32* %{{.*}}, i32 %{{.*}} unordered + %10 = llvm.atomicrmw "min" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw umax i32* %{{.*}}, i32 %{{.*}} unordered + %11 = llvm.atomicrmw "umax" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + // CHECK: atomicrmw umin i32* %{{.*}}, i32 %{{.*}} unordered + %12 = llvm.atomicrmw "umin" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32 + llvm.return %0 : !llvm.float +}