diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4061,6 +4061,7 @@ def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>; def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; +def SPV_OC_OpPtrCastToGeneric : I32EnumAttrCase<"OpPtrCastToGeneric", 121>; def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; def SPV_OC_OpSNegate : I32EnumAttrCase<"OpSNegate", 126>; def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; @@ -4198,36 +4199,37 @@ SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather, SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, - SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, - SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, - SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, - SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, - SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, - SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, - SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, - SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, - SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, - SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, - SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, - SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, - SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, - SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, - SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange, - SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, - SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, - SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, - SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, - SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, - SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, - SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, + SPV_OC_OpPtrCastToGeneric, SPV_OC_OpBitcast, SPV_OC_OpSNegate, + SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, + SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, + SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, + SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, + SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, + SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, + SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, + SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, + SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, + SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, + SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, + SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, + SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, + SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, + SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, + SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, + SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, + SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, + SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicExchange, + SPV_OC_OpAtomicCompareExchange, SPV_OC_OpAtomicCompareExchangeWeak, + SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, + SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, + SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, + SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, + SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, + SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, + SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine, + SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -110,37 +110,6 @@ } -// ----- - -def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> { - let summary = [{ - Compute the correctly rounded floating-point representation of the sum - of c with the infinitely precise product of a and b. Rounding of - intermediate products shall not occur. Edge case results are per the - IEEE 754-2008 standard. - }]; - - let description = [{ - Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of - floating-point values. - - All of the operands, including the Result Type operand, must be of the - same type. - - - - ``` - fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:` - float-scalar-vector-type - ```mlir - - ``` - %0 = spv.OCL.fma %a, %b, %c : f32 - %1 = spv.OCL.fma %a, %b, %c : vector<3xf16> - ``` - }]; -} - // ----- def SPV_OCLCeilOp : SPV_OCLUnaryArithmeticOp<"ceil", 12, SPV_Float> { @@ -331,6 +300,37 @@ // ----- +def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> { + let summary = [{ + Compute the correctly rounded floating-point representation of the sum + of c with the infinitely precise product of a and b. Rounding of + intermediate products shall not occur. Edge case results are per the + IEEE 754-2008 standard. + }]; + + let description = [{ + Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:` + float-scalar-vector-type + ```mlir + + ``` + %0 = spv.OCL.fma %a, %b, %c : f32 + %1 = spv.OCL.fma %a, %b, %c : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> { let summary = "Compute the natural logarithm of x."; @@ -512,6 +512,44 @@ // ----- +def SPV_PtrCastToGenericOp : SPV_Op<"PtrCastToGeneric", [NoSideEffect]> { + let summary = "Convert a pointer’s Storage Class to Generic."; + + let description = [{ + Result Type must be an OpTypePointer. Its Storage Class must be Generic. + + Pointer must point to the Workgroup, CrossWorkgroup, or Function Storage + Class. + + Result Type and Pointer must point to the same type. + + + + #### Example: + + ```mlir + %1 = spv.PtrCastToGeneric %0 : !spv.ptr, CrossWorkgroup> to !spv.ptr, Generic> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; + + let arguments = (ins + SPV_Type:$pointer + ); + + let results = (outs + SPV_Type:$result + ); +} + +// ----- + def SPV_OCLSAbsOp : SPV_OCLUnaryArithmeticOp<"s_abs", 141, SPV_Integer> { let summary = "Absolute value of operand"; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -4424,6 +4424,61 @@ return verifyAccessChain(*this, indices()); } +//===----------------------------------------------------------------------===// +// spv.PtrCastToGenericOp +//===----------------------------------------------------------------------===// + +ParseResult spirv::PtrCastToGenericOp::parse(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType ptr; + + if (parser.parseOperand(ptr)) + return failure(); + + if (parser.parseColon()) + return failure(); + + Type ptrType; + if (parser.parseType(ptrType)) + return failure(); + + if (parser.resolveOperand(ptr, ptrType, state.operands)) + return failure(); + + Type resultType; + if (parser.parseKeywordType("to", resultType)) + return failure(); + + state.addTypes(resultType); + + return success(); +} + +void spirv::PtrCastToGenericOp::print(OpAsmPrinter &p) { + p << " " << getOperand(); + p << " : " << getOperand().getType() << " to " << getResult().getType(); +} + +LogicalResult spirv::PtrCastToGenericOp::verify() { + auto opType = getOperand().getType(); + if (!opType.isa()) + return failure(); + + auto ptrType = opType.cast(); + if (ptrType.getStorageClass() == StorageClass::Generic) + return failure(); + + auto resType = getResult().getType(); + if (!resType.isa()) + return failure(); + + ptrType = resType.cast(); + if (ptrType.getStorageClass() != StorageClass::Generic) + return failure(); + + return success(); +} + // TableGen'erated operation interfaces for querying versions, extensions, and // capabilities. #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -185,3 +185,12 @@ %2 = spv.OCL.fma %a, %b, %c : vector<3xf32> return } + +//===----------------------------------------------------------------------===// +// spv.PtrCastToGenericOp +//===----------------------------------------------------------------------===// + +func @cast_to_generic(%a: !spv.ptr) { + // CHECKL spv.PtrCastToGeneric {{%[^,]*}} : !spv.ptr to %!spv.ptr + %0 = spv.PtrCastToGeneric %a : !spv.ptr to %!spv.ptr +}