Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4221,6 +4221,9 @@ def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>; def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; +def SPV_OC_OpPtrCastToGeneric : I32EnumAttrCase<"OpPtrCastToGeneric", 121>; +def SPV_OC_OpGenericCastToPtr : I32EnumAttrCase<"OpGenericCastToPtr", 122>; +def SPV_OC_OpGenericCastToPtrExplicit : I32EnumAttrCase<"OpGenericCastToPtrExplicit", 123>; def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; def SPV_OC_OpSNegate : I32EnumAttrCase<"OpSNegate", 126>; def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; @@ -4372,7 +4375,8 @@ SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather, SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpPtrCastToGeneric, + SPV_OC_OpGenericCastToPtr, SPV_OC_OpGenericCastToPtrExplicit, SPV_OC_OpBitcast, SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -331,4 +331,144 @@ }]; } +// ----- +def SPV_PtrCastToGenericOp : SPV_Op<"PtrCastToGeneric", [NoSideEffect]> { + let summary = "Convert a pointer’s Storage Class to Generic."; + + let description = [{ + Result Type must be an OpTypePointer. Its Storage Class must be Generic. + + Pointer must point to the Workgroup, CrossWorkgroup, or Function Storage + Class. + + Result Type and Pointer must point to the same type. + + + + #### Example: + + ```mlir + %1 = spv.PtrCastToGenericOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; + + let arguments = (ins + SPV_AnyPtr:$operand + ); + + let results = (outs + SPV_AnyPtr:$result + ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; +} + +// ----- + +def SPV_GenericCastToPtrOp : SPV_Op<"GenericCastToPtr", [NoSideEffect]> { + let summary = "Convert a pointer’s Storage Class to a non-Generic class."; + + let description = [{ + Result Type must be an OpTypePointer. Its Storage Class must be + Workgroup, CrossWorkgroup, or Function. + + Pointer must point to the Generic Storage Class. + + Result Type and Pointer must point to the same type. + + + + #### Example: + + ```mlir + %1 = spv.GenericCastToPtrOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; + + let arguments = (ins + SPV_AnyPtr:$operand + ); + + let results = (outs + SPV_AnyPtr:$result + ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; +} + +// ----- + +def SPV_GenericCastToPtrExplicitOp : SPV_Op<"GenericCastToPtrExplicit", [NoSideEffect]> { + let summary = [{ + Attempts to explicitly convert Pointer to Storage storage-class pointer + value. + }]; + + let description = [{ + Result Type must be an OpTypePointer. Its Storage Class must be Storage. + + Pointer must have a type of OpTypePointer whose Type is the same as the + Type of Result Type.Pointer must point to the Generic Storage Class. If + the cast fails, the instruction result is an OpConstantNull pointer in + the Storage Storage Class. + + Storage must be one of the following literal values from Storage Class: + Workgroup, CrossWorkgroup, or Function. + + + + ``` + [TODO] + ```mlir + + #### Example: + + ```mlir + %1 = spv.GenericCastToPtrExplicitOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; + + let arguments = (ins + SPV_AnyPtr:$operand + ); + + let results = (outs + SPV_AnyPtr:$result + ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; + + let autogenSerialization = 0; +} + #endif // MLIR_DIALECT_SPIRV_IR_CAST_OPS Index: mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1525,6 +1525,90 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.PtrCastToGenericOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::PtrCastToGenericOp::verify() { + auto operandType = operand().getType().cast(); + auto resultType = result().getType().cast(); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Workgroup && + operandStorage != spirv::StorageClass::CrossWorkgroup && + operandStorage != spirv::StorageClass::Function) + return emitError("pointer must point to the Workgroup, CrossWorkgroup" + ", or Function Storage Class"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Generic) + return emitError("result type must be of storage class Generic"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GenericCastToPtrOp::verify() { + auto operandType = operand().getType().cast(); + auto resultType = result().getType().cast(); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Generic) + return emitError("pointer type must be of storage class Generic"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Workgroup && + resultStorage != spirv::StorageClass::CrossWorkgroup && + resultStorage != spirv::StorageClass::Function) + return emitError("result must point to the Workgroup, CrossWorkgroup, " + "or Function Storage Class"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrExplicitOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { + auto operandType = operand().getType().cast(); + auto resultType = result().getType().cast(); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Generic) + return emitError("pointer type must be of storage class Generic"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Workgroup && + resultStorage != spirv::StorageClass::CrossWorkgroup && + resultStorage != spirv::StorageClass::Function) + return emitError("result must point to the Workgroup, CrossWorkgroup, " + "or Function Storage Class"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + //===----------------------------------------------------------------------===// // spv.BranchOp //===----------------------------------------------------------------------===// Index: mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp =================================================================== --- mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -523,6 +523,40 @@ return success(); } +template <> +LogicalResult Deserializer::processOp( + ArrayRef words) { + if (words.size() != 4) { + return emitError(unknownLoc, + "expected 4 words in GenericCastToPtrExplicitOp" + " but got : ") + << words.size(); + } + SmallVector resultTypes; + SmallVector operands; + uint32_t valueID = 0; + auto type = getType(words[0]); + + if (!type) + return emitError(unknownLoc, "unknown type result : ") << words[0]; + resultTypes.push_back(type); + + valueID = words[1]; + + auto arg = getValue(words[2]); + if (!arg) + return emitError(unknownLoc, "unknown result : ") << words[2]; + operands.push_back(arg); + + Location loc = createFileLineColLoc(opBuilder); + OperationState opState(loc, "spv.GenericCastToPtrExplicit"); + opState.addOperands(operands); + opState.addTypes(resultTypes); + Operation *op = opBuilder.create(opState); + valueMap[valueID] = op->getResult(0); + return success(); +} + // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS Index: mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp =================================================================== --- mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -667,6 +667,32 @@ return success(); } +template <> +LogicalResult Serializer::processOp( + spirv::GenericCastToPtrExplicitOp op) { + SmallVector operands; + Type resultTy; + Location loc = op->getLoc(); + uint32_t resultTypeID = 0; + uint32_t resultID = 0; + resultTy = op->getResult(0).getType(); + if (failed(processType(loc, resultTy, resultTypeID))) + return failure(); + operands.push_back(resultTypeID); + + resultID = getNextID(); + operands.push_back(resultID); + valueIDMap[op->getResult(0)] = resultID; + + for (Value operand : op->getOperands()) + operands.push_back(getValueID(operand)); + spirv::StorageClass resultStorage = + resultTy.cast().getStorageClass(); + operands.push_back(static_cast(resultStorage)); + encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, + operands); + return success(); +} // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. Index: mlir/test/Dialect/SPIRV/IR/cast-ops.mlir =================================================================== --- mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -260,3 +260,106 @@ spv.ReturnValue %0 : i64 } +// ----- + +//===----------------------------------------------------------------------===// +// spv.PtrCastToGeneric +//===----------------------------------------------------------------------===// + +func.func @ptrcasttogeneric1(%arg0 : !spv.ptr) { + // CHECK: {{%.*}} = spv.PtrCastToGeneric {{%.*}} : !spv.ptr to !spv.ptr + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} +// ----- + +func.func @ptrcasttogeneric2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @ptrcasttogeneric3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result type must be of storage class Generic}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @ptrcasttogeneric4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr, Generic> + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtr +//===----------------------------------------------------------------------===// + +func.func @genericcasttoptr1(%arg0 : !spv.ptr, Generic>) { + // CHECK: {{%.*}} = spv.GenericCastToPtr {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + return +} +// ----- + +func.func @genericcasttoptr2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer type must be of storage class Generic}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptr3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptr4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr, Workgroup> + return +} +// ----- + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrExplicit +//===----------------------------------------------------------------------===// + +func.func @genericcasttoptrexplicit1(%arg0 : !spv.ptr, Generic>) { + // CHECK: {{%.*}} = spv.GenericCastToPtrExplicit {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + return +} +// ----- + +func.func @genericcasttoptrexplicit2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer type must be of storage class Generic}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptrexplicit3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptrexplicit4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr, Workgroup> + return +} Index: mlir/test/Target/SPIRV/cast-ops.mlir =================================================================== --- mlir/test/Target/SPIRV/cast-ops.mlir +++ mlir/test/Target/SPIRV/cast-ops.mlir @@ -71,3 +71,23 @@ spv.ReturnValue %0 : i64 } } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @ptr_cast_to_generic(%arg0 : !spv.ptr) "None" { + // CHECK: {{%.*}} = spv.PtrCastToGeneric {{%.*}} : !spv.ptr to !spv.ptr + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + spv.Return + } + spv.func @generic_cast_to_ptr(%arg0 : !spv.ptr, Generic>) "None" { + // CHECK: {{%.*}} = spv.GenericCastToPtr {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + spv.Return + } + spv.func @generic_cast_to_ptr_explicit(%arg0 : !spv.ptr, Generic>) "None" { + // CHECK: {{%.*}} = spv.GenericCastToPtrExplicit {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + spv.Return + } +}