diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td @@ -102,6 +102,73 @@ // ----- +def SPV_AtomicCompareExchangeOp : SPV_Op<"AtomicCompareExchange", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value from Value only if Original Value equals Comparator, + and + + 3) store the New Value back through Pointer'only if 'Original Value + equaled Comparator. + + The instruction's result is the Original Value. + + Result Type must be an integer type scalar. + + Use Equal for the memory semantics of this instruction when Value and + Original Value compare equal. + + Use Unequal for the memory semantics of this instruction when Value and + Original Value compare unequal. Unequal must not be set to Release or + Acquire and Release. In addition, Unequal cannot be set to a stronger + memory-order then Equal. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. This type + must also match the type of Comparator. + + Memory is a memory Scope. + + + + ``` + atomic-compare-exchange-op ::= + `spv.AtomicCompareExchange` scope memory-semantics memory-semantics + ssa-use `,` ssa-use `,` ssa-use + `:` spv-pointer-type + ```mlir + + #### Example: + + ``` + %0 = spv.AtomicCompareExchange "Workgroup" "Acquire" "None" + %pointer, %value, %comparator + : !spv.ptr + ``` + }]; + + let arguments = (ins + SPV_Type:$pointer, + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$equal_semantics, + SPV_MemorySemanticsAttr:$unequal_semantics, + SPV_Type:$value, + SPV_Type:$comparator + ); + + let results = (outs + SPV_Type:$result + ); +} + +// ----- + def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> { let summary = "Deprecated (use OpAtomicCompareExchange)."; @@ -151,6 +218,58 @@ // ----- +def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value from copying Value, and + + 3) store the New Value back through Pointer. + + The instruction's result is the Original Value. + + Result Type must be a scalar of integer type or floating-point type. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory is a memory Scope. + + + + ``` + atomic-exchange-op ::= + `spv.AtomicCompareExchange` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ```mlir + + #### Example: + + ``` + %0 = spv.AtomicExchange "Workgroup" "Acquire" %pointer, %value, + : !spv.ptr + ``` + }]; + + let arguments = (ins + SPV_Type:$pointer, + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$semantics, + SPV_Type:$value + ); + + let results = (outs + SPV_Type:$result + ); +} + +// ----- + def SPV_AtomicIAddOp : SPV_AtomicUpdateWithValueOp<"AtomicIAdd", []> { let summary = [{ Perform the following steps atomically with respect to any other atomic diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3352,6 +3352,8 @@ def SPV_OC_OpBitCount : I32EnumAttrCase<"OpBitCount", 205>; def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>; def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>; +def SPV_OC_OpAtomicExchange : I32EnumAttrCase<"OpAtomicExchange", 229>; +def SPV_OC_OpAtomicCompareExchange : I32EnumAttrCase<"OpAtomicCompareExchange", 230>; def SPV_OC_OpAtomicCompareExchangeWeak : I32EnumAttrCase<"OpAtomicCompareExchangeWeak", 231>; def SPV_OC_OpAtomicIIncrement : I32EnumAttrCase<"OpAtomicIIncrement", 232>; def SPV_OC_OpAtomicIDecrement : I32EnumAttrCase<"OpAtomicIDecrement", 233>; @@ -3442,6 +3444,7 @@ SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, + SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange, SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1138,12 +1138,16 @@ return success(); } -//===----------------------------------------------------------------------===// -// spv.AtomicCompareExchangeWeak -//===----------------------------------------------------------------------===// +template +static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) { + printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" + << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" + << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" " + << atomOp.getOperands() << " : " << atomOp.pointer().getType(); +} -static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser, - OperationState &state) { +static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, + OperationState &state) { spirv::Scope memoryScope; spirv::MemorySemantics equalSemantics, unequalSemantics; SmallVector operandInfo; @@ -1173,15 +1177,8 @@ return parser.addTypeToList(ptrType.getPointeeType(), state.types); } -static void print(spirv::AtomicCompareExchangeWeakOp atomOp, - OpAsmPrinter &printer) { - printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" - << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" - << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" " - << atomOp.getOperands() << " : " << atomOp.pointer().getType(); -} - -static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) { +template +static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) { // According to the spec: // "The type of Value must be the same as Result Type. The type of the value // pointed to by Pointer must be the same as Result Type. This type must also @@ -1197,8 +1194,10 @@ "result, but found ") << atomOp.comparator().getType() << " vs " << atomOp.getType(); - Type pointeeType = - atomOp.pointer().getType().cast().getPointeeType(); + Type pointeeType = atomOp.pointer() + .getType() + .template cast() + .getPointeeType(); if (atomOp.getType() != pointeeType) return atomOp.emitOpError( "pointer operand's pointee type must have the same " @@ -1212,6 +1211,95 @@ } //===----------------------------------------------------------------------===// +// spv.AtomicCompareExchange +//===----------------------------------------------------------------------===// + +static ParseResult parseAtomicCompareExchangeOp(OpAsmParser &parser, + OperationState &state) { + return parseAtomicCompareExchangeImpl(parser, state); +} + +static void print(spirv::AtomicCompareExchangeOp atomOp, + OpAsmPrinter &printer) { + printAtomicCompareExchangeImpl(atomOp, printer); +} + +static LogicalResult verify(spirv::AtomicCompareExchangeOp atomOp) { + return verifyAtomicCompareExchangeImpl(atomOp); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicCompareExchangeWeak +//===----------------------------------------------------------------------===// + +static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser, + OperationState &state) { + return parseAtomicCompareExchangeImpl(parser, state); +} + +static void print(spirv::AtomicCompareExchangeWeakOp atomOp, + OpAsmPrinter &printer) { + printAtomicCompareExchangeImpl(atomOp, printer); +} + +static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) { + return verifyAtomicCompareExchangeImpl(atomOp); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicExchange +//===----------------------------------------------------------------------===// + +static void print(spirv::AtomicExchangeOp atomOp, OpAsmPrinter &printer) { + printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" + << stringifyMemorySemantics(atomOp.semantics()) << "\" " + << atomOp.getOperands() << " : " << atomOp.pointer().getType(); +} + +static ParseResult parseAtomicExchangeOp(OpAsmParser &parser, + OperationState &state) { + spirv::Scope memoryScope; + spirv::MemorySemantics semantics; + SmallVector operandInfo; + Type type; + if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) || + parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) || + parser.parseOperandList(operandInfo, 2)) + return failure(); + + auto loc = parser.getCurrentLocation(); + if (parser.parseColonType(type)) + return failure(); + + auto ptrType = type.dyn_cast(); + if (!ptrType) + return parser.emitError(loc, "expected pointer type"); + + if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()}, + parser.getNameLoc(), state.operands)) + return failure(); + + return parser.addTypeToList(ptrType.getPointeeType(), state.types); +} + +static LogicalResult verify(spirv::AtomicExchangeOp atomOp) { + if (atomOp.getType() != atomOp.value().getType()) + return atomOp.emitOpError("value operand must have the same type as the op " + "result, but found ") + << atomOp.value().getType() << " vs " << atomOp.getType(); + + Type pointeeType = + atomOp.pointer().getType().cast().getPointeeType(); + if (atomOp.getType() != pointeeType) + return atomOp.emitOpError( + "pointer operand's pointee type must have the same " + "as the op result type, but found ") + << pointeeType << " vs " << atomOp.getType(); + + return success(); +} + +//===----------------------------------------------------------------------===// // spv.BitcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/atomic-ops.mlir b/mlir/test/Target/SPIRV/atomic-ops.mlir --- a/mlir/test/Target/SPIRV/atomic-ops.mlir +++ b/mlir/test/Target/SPIRV/atomic-ops.mlir @@ -27,6 +27,10 @@ %10 = spv.AtomicUMin "Device" "Release" %ptr, %value : !spv.ptr // CHECK: spv.AtomicXor "Workgroup" "AcquireRelease" %{{.*}}, %{{.*}} : !spv.ptr %11 = spv.AtomicXor "Workgroup" "AcquireRelease" %ptr, %value : !spv.ptr + // CHECK: spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %{{.*}}, %{{.*}}, %{{.*}} : !spv.ptr + %12 = spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr + // CHECK: spv.AtomicExchange "Workgroup" "Release" %{{.*}}, %{{.*}} : !spv.ptr + %13 = spv.AtomicExchange "Workgroup" "Release" %ptr, %value: !spv.ptr spv.ReturnValue %0: i32 } }