Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4221,6 +4221,9 @@ def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>; def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; +def SPV_OC_OpPtrCastToGeneric : I32EnumAttrCase<"OpPtrCastToGeneric", 121>; +def SPV_OC_OpGenericCastToPtr : I32EnumAttrCase<"OpGenericCastToPtr", 122>; +def SPV_OC_OpGenericCastToPtrExplicit : I32EnumAttrCase<"OpGenericCastToPtrExplicit", 123>; def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; def SPV_OC_OpSNegate : I32EnumAttrCase<"OpSNegate", 126>; def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; @@ -4372,7 +4375,8 @@ SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather, SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpPtrCastToGeneric, + SPV_OC_OpGenericCastToPtr, SPV_OC_OpGenericCastToPtrExplicit, SPV_OC_OpBitcast, SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -331,4 +331,95 @@ }]; } +// ----- + +def SPV_PtrCastToGenericOp : SPV_CastOp<"PtrCastToGeneric", + SPV_AnyPtr, + SPV_AnyPtr, + [NoSideEffect]> { + let summary = [{ + Convert pointer to Generic Storage. + }]; + + let description = [{ + Convert a pointer’s Storage Class to Generic. + + Result Type must be an OpTypePointer. Its Storage + Class must be Generic. + + Pointer must point to the Workgroup, CrossWorkgroup, + or Function Storage Class. + + Result Type and Pointer must point to the same type. + + #### Example: + + ```mlir + %1 = spv.PtrCastToGenericOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_GenericCastToPtrOp : SPV_CastOp<"GenericCastToPtr", + SPV_AnyPtr, + SPV_AnyPtr, + [NoSideEffect]> { + let summary = [{ + Convert pointer from Generic Storage to Workgroup, CrossWorkgroup, + or Function. + }]; + + let description = [{ + Convert a pointer’s Storage Class to a non-Generic class. + + Result Type must be an OpTypePointer. Its Storage Class + must be Workgroup, CrossWorkgroup, or Function. + + Pointer must point to the Generic Storage Class. + + Result Type and Pointer must point to the same type. + + #### Example: + + ```mlir + %1 = spv.GenericCastToPtrOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_GenericCastToPtrExplicitOp : SPV_CastOp<"GenericCastToPtrExplicit", + SPV_AnyPtr, + SPV_AnyPtr, + [NoSideEffect]> { + let summary = [{ + Convert pointer from Generic Storage to Workgroup, CrossWorkgroup, + or Function explicitly. + }]; + + let description = [{ + Attempts to explicitly convert Pointer to non-Generic + storage-class pointer value. + + Result Type must be an OpTypePointer. Its Storage Class must be + Workgroup, CrossWorkgroup, or Function. + + Pointer must have a type of OpTypePointer whose Type is the same as + the Type of Result Type. Pointer must point to the Generic Storage Class. + If the cast fails, the instruction result is an OpConstantNull pointer + in the Storage Class of result Type. + + #### Example: + + ```mlir + %1 = spv.GenericCastToPtrExplicitOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; +} #endif // MLIR_DIALECT_SPIRV_IR_CAST_OPS Index: mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1525,6 +1525,108 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.PtrCastToGenericOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::PtrCastToGenericOp::verify() { + auto operandType = operand().getType(); + auto resultType = result().getType(); + if (!operandType.isa() || + !resultType.isa()) { + return emitError("operand and result expected to be pointer type"); + } + spirv::StorageClass operandStorage = + operandType.cast().getStorageClass(); + if (operandStorage != spirv::StorageClass::Workgroup && + operandStorage != spirv::StorageClass::CrossWorkgroup && + operandStorage != spirv::StorageClass::Function) + return emitError("pointer must point to the Workgroup, CrossWorkgroup" + ", or Function Storage Class"); + spirv::StorageClass resultStorage = + resultType.cast().getStorageClass(); + if (resultStorage != spirv::StorageClass::Generic) + return emitError("result type must be of storage class Generic"); + + Type operandPointeeType = + operandType.cast().getPointeeType(); + Type resultPointeeType = + resultType.cast().getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GenericCastToPtrOp::verify() { + auto operandType = operand().getType(); + auto resultType = result().getType(); + if (!operandType.isa() || + !resultType.isa()) { + return emitError("operand and result expected to be pointer type"); + } + spirv::StorageClass operandStorage = + operandType.cast().getStorageClass(); + if (operandStorage != spirv::StorageClass::Generic) + return emitError("pointer type must be of storage class Generic"); + spirv::StorageClass resultStorage = + resultType.cast().getStorageClass(); + if (resultStorage != spirv::StorageClass::Workgroup && + resultStorage != spirv::StorageClass::CrossWorkgroup && + resultStorage != spirv::StorageClass::Function) + return emitError("result must point to the Workgroup, CrossWorkgroup, " + "or Function Storage Class"); + + Type operandPointeeType = + operandType.cast().getPointeeType(); + Type resultPointeeType = + resultType.cast().getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrExplicitOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { + auto operandType = operand().getType(); + auto resultType = result().getType(); + if (!operandType.isa() || + !resultType.isa()) { + return emitError("operand and result expected to be pointer type"); + } + spirv::StorageClass operandStorage = + operandType.cast().getStorageClass(); + if (operandStorage != spirv::StorageClass::Generic) + return emitError("pointer type must be of storage class Generic"); + spirv::StorageClass resultStorage = + resultType.cast().getStorageClass(); + if (resultStorage != spirv::StorageClass::Workgroup && + resultStorage != spirv::StorageClass::CrossWorkgroup && + resultStorage != spirv::StorageClass::Function) + return emitError("result must point to the Workgroup, CrossWorkgroup, " + "or Function Storage Class"); + + Type operandPointeeType = + operandType.cast().getPointeeType(); + Type resultPointeeType = + resultType.cast().getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + //===----------------------------------------------------------------------===// // spv.BranchOp //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/SPIRV/IR/cast-ops.mlir =================================================================== --- mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -260,3 +260,106 @@ spv.ReturnValue %0 : i64 } +// ----- + +//===----------------------------------------------------------------------===// +// spv.PtrCastToGeneric +//===----------------------------------------------------------------------===// + +func.func @ptrcasttogeneric1(%arg0 : !spv.ptr) { + // CHECK: {{%.*}} = spv.PtrCastToGeneric {{%.*}} : !spv.ptr to !spv.ptr + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} +// ----- + +func.func @ptrcasttogeneric2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @ptrcasttogeneric3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result type must be of storage class Generic}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @ptrcasttogeneric4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr, Generic> + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtr +//===----------------------------------------------------------------------===// + +func.func @genericcasttoptr1(%arg0 : !spv.ptr, Generic>) { + // CHECK: {{%.*}} = spv.GenericCastToPtr {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + return +} +// ----- + +func.func @genericcasttoptr2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer type must be of storage class Generic}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptr3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptr4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr, Workgroup> + return +} +// ----- + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrExplicit +//===----------------------------------------------------------------------===// + +func.func @genericcasttoptrexplicit1(%arg0 : !spv.ptr, Generic>) { + // CHECK: {{%.*}} = spv.GenericCastToPtrExplicit {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + return +} +// ----- + +func.func @genericcasttoptrexplicit2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer type must be of storage class Generic}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptrexplicit3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptrexplicit4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr, Workgroup> + return +} Index: mlir/test/Target/SPIRV/cast-ops.mlir =================================================================== --- mlir/test/Target/SPIRV/cast-ops.mlir +++ mlir/test/Target/SPIRV/cast-ops.mlir @@ -71,3 +71,23 @@ spv.ReturnValue %0 : i64 } } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @ptr_cast_to_generic(%arg0 : !spv.ptr) "None" { + // CHECK: {{%.*}} = spv.PtrCastToGeneric {{%.*}} : !spv.ptr to !spv.ptr + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + spv.Return + } + spv.func @generic_cast_to_ptr(%arg0 : !spv.ptr, Generic>) "None" { + // CHECK: {{%.*}} = spv.GenericCastToPtr {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + spv.Return + } + spv.func @generic_cast_to_ptr_explicit(%arg0 : !spv.ptr, Generic>) "None" { + // CHECK: {{%.*}} = spv.GenericCastToPtrExplicit {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + spv.Return + } +}