diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3205,6 +3205,7 @@ def SPV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>; def SPV_OC_OpImageDrefGather : I32EnumAttrCase<"OpImageDrefGather", 97>; def SPV_OC_OpImage : I32EnumAttrCase<"OpImage", 100>; +def SPV_OC_OpImageQuerySize : I32EnumAttrCase<"OpImageQuerySize", 104>; def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>; def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>; @@ -3344,37 +3345,37 @@ SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather, - SPV_OC_OpImage, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, - SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, - SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, SPV_OC_OpSNegate, - SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, - SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, - SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, - SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, - SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, - SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, - SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, - SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, - SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, - SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, - SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, - SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, - SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, - SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, - SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, - SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, - SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, - SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, - SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak, - SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, - SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, - SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, - SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, - SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, - SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, - SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine, - SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, + SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU, + SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, + SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, + SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, + SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, + SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, + SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, + SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, + SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, + SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, + SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, + SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, + SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, + SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, + SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, + SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, + SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, + SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, + SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, + SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, + SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, + SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td @@ -70,9 +70,66 @@ let assemblyFormat = "attr-dict $sampledimage `:` type($sampledimage) `,` $coordinate `:` type($coordinate) `,` $dref `:` type($dref) `->` type($result)"; let verifier = [{ return ::verify(*this); }]; - } +// ----- + +def SPV_ImageQuerySizeOp : SPV_Op<"ImageQuerySize", [NoSideEffect]> { + let summary = "Query the dimensions of Image, with no level of detail."; + + let description = [{ + Result Type must be an integer type scalar or vector. The number of + components must be: + + 1 for the 1D and Buffer dimensionalities, + + 2 for the 2D, Cube, and Rect dimensionalities, + + 3 for the 3D dimensionality, + + plus 1 more if the image type is arrayed. This vector is filled in with + (width [, height] [, elements]) where elements is the number of layers + in an image array or the number of cubes in a cube-map array. + + Image must be an object whose type is OpTypeImage. Its Dim operand must + be one of those listed under Result Type, above. Additionally, if its + Dim is 1D, 2D, 3D, or Cube, it must also have either an MS of 1 or a + Sampled of 0 or 2. There is no implicit level-of-detail consumed by this + instruction. See OpImageQuerySizeLod for querying images having level of + detail. This operation is allowed on an image decorated as NonReadable. + See the client API specification for additional image type restrictions. + + + + #### Example: + + ```mlir + %3 = spv.ImageQuerySize %0 : !spv.image -> i32 + %4 = spv.ImageQuerySize %1 : !spv.image -> vector<2xi32> + %5 = spv.ImageQuerySize %2 : !spv.image -> vector<3xi32> + ``` + + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_ImageQuery, SPV_C_Kernel]> + ]; + + let arguments = (ins + SPV_AnyImage:$image + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = "attr-dict $image `:` type($image) `->` type($result)"; + + let verifier = [{return ::verify(*this);}]; +} // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3651,6 +3651,71 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.ImageQuerySize +//===----------------------------------------------------------------------===// + +static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) { + spirv::ImageType imageType = + imageQuerySizeOp.image().getType().cast(); + Type resultType = imageQuerySizeOp.result().getType(); + + spirv::Dim dim = imageType.getDim(); + spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); + spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo(); + switch (dim) { + case spirv::Dim::Dim1D: + case spirv::Dim::Dim2D: + case spirv::Dim::Dim3D: + case spirv::Dim::Cube: + if (!(samplingInfo == spirv::ImageSamplingInfo::MultiSampled || + samplerInfo == spirv::ImageSamplerUseInfo::SamplerUnknown || + samplerInfo == spirv::ImageSamplerUseInfo::NoSampler)) + return imageQuerySizeOp.emitError( + "if Dim is 1D, 2D, 3D, or Cube, " + "it must also have either an MS of 1 or a Sampled of 0 or 2"); + break; + case spirv::Dim::Buffer: + case spirv::Dim::Rect: + break; + default: + return imageQuerySizeOp.emitError("the Dim operand of the image type must " + "be 1D, 2D, 3D, Buffer, Cube, or Rect"); + } + + unsigned componentNumber = 0; + switch (dim) { + case spirv::Dim::Dim1D: + case spirv::Dim::Buffer: + componentNumber = 1; + break; + case spirv::Dim::Dim2D: + case spirv::Dim::Cube: + case spirv::Dim::Rect: + componentNumber = 2; + break; + case spirv::Dim::Dim3D: + componentNumber = 3; + break; + default: + break; + } + + if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) + componentNumber += 1; + + unsigned resultComponentNumber = 1; + if (auto resultVectorType = resultType.dyn_cast()) + resultComponentNumber = resultVectorType.getNumElements(); + + if (componentNumber != resultComponentNumber) + return imageQuerySizeOp.emitError("expected the result to have ") + << componentNumber << " component(s), but found " + << resultComponentNumber << " component(s)"; + + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir @@ -53,4 +53,50 @@ // CHECK: spv.Image {{.*}} : !spv.sampled_image> %0 = spv.Image %arg0 : !spv.sampled_image> return -} \ No newline at end of file +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.ImageQuerySize +//===----------------------------------------------------------------------===// + +func @image_query_size(%arg0 : !spv.image) -> () { + // CHECK: {{%.*}} = spv.ImageQuerySize %arg0 : !spv.image -> i32 + %0 = spv.ImageQuerySize %arg0 : !spv.image -> i32 + spv.Return +} + +// ----- + +func @image_query_size_error_dim(%arg0 : !spv.image) -> () { + // expected-error @+1 {{the Dim operand of the image type must be 1D, 2D, 3D, Buffer, Cube, or Rect}} + %0 = spv.ImageQuerySize %arg0 : !spv.image -> i32 + spv.Return +} + +// ----- + +func @image_query_size_error_dim_sample(%arg0 : !spv.image) -> () { + // expected-error @+1 {{if Dim is 1D, 2D, 3D, or Cube, it must also have either an MS of 1 or a Sampled of 0 or 2}} + %0 = spv.ImageQuerySize %arg0 : !spv.image -> i32 + spv.Return +} + +// ----- + +func @image_query_size_error_result1(%arg0 : !spv.image) -> () { + // expected-error @+1 {{expected the result to have 4 component(s), but found 3 component(s)}} + %0 = spv.ImageQuerySize %arg0 : !spv.image -> vector<3xi32> + spv.Return +} + +// ----- + +func @image_query_size_error_result2(%arg0 : !spv.image) -> () { + // expected-error @+1 {{expected the result to have 1 component(s), but found 2 component(s)}} + %0 = spv.ImageQuerySize %arg0 : !spv.image -> vector<2xi32> + spv.Return +} + +// ----- diff --git a/mlir/test/Target/SPIRV/image-ops.mlir b/mlir/test/Target/SPIRV/image-ops.mlir --- a/mlir/test/Target/SPIRV/image-ops.mlir +++ b/mlir/test/Target/SPIRV/image-ops.mlir @@ -8,4 +8,9 @@ %1 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xf32> spv.Return } + spv.func @image_query_size(%arg0 : !spv.image) "None" { + // CHECK: {{%.*}} = spv.ImageQuerySize %arg0 : !spv.image -> vector<2xi32> + %0 = spv.ImageQuerySize %arg0 : !spv.image -> vector<2xi32> + spv.Return + } }