diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3048,6 +3048,7 @@ def SPV_Bool : TypeAlias; def SPV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>; def SPV_Int32 : TypeAlias; +def SPV_Float32 : TypeAlias; def SPV_Float : FloatOfWidths<[16, 32, 64]>; def SPV_Float16or32 : FloatOfWidths<[16, 32]>; def SPV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], @@ -3202,6 +3203,7 @@ def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; def SPV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>; +def SPV_OC_OpImageDrefGather : I32EnumAttrCase<"OpImageDrefGather", 97>; def SPV_OC_OpImage : I32EnumAttrCase<"OpImage", 100>; def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>; @@ -3341,37 +3343,38 @@ SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, - SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImage, - SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, - SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, - SPV_OC_OpBitcast, SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, - SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, - SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, - SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, - SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, - SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, - SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, - SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, - SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, - SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, - SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, - SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, - SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, - SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, - SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, - SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, - SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, - SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, - SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, - SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, - SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, + SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather, + SPV_OC_OpImage, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, + SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, + SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, SPV_OC_OpSNegate, + SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, + SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, + SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, + SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, + SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, + SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, + SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, + SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, + SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, + SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, + SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, + SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, + SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, + SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, + SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, + SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, + SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, + SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, + SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak, + SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, + SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, + SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, + SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, + SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, + SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, + SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine, + SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td @@ -17,6 +17,63 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +// ----- + +def SPV_ImageDrefGatherOp : SPV_Op<"ImageDrefGather", [NoSideEffect]> { + let summary = "Gathers the requested depth-comparison from four texels."; + + let description = [{ + Result Type must be a vector of four components of floating-point type + or integer type. Its components must be the same as Sampled Type of the + underlying OpTypeImage (unless that underlying Sampled Type is + OpTypeVoid). It has one component per gathered texel. + + Sampled Image must be an object whose type is OpTypeSampledImage. Its + OpTypeImage must have a Dim of 2D, Cube, or Rect. The MS operand of the + underlying OpTypeImage must be 0. + + Coordinate must be a scalar or vector of floating-point type. It + contains (u[, v] … [, array layer]) as needed by the definition of + Sampled Image. + + Dref is the depth-comparison reference value. It must be a 32-bit + floating-point type scalar. + + Image Operands encodes what operands follow, as per Image Operands. + + + + #### Example: + + ```mlir + %0 = spv.ImageDrefGather %1 : !spv.sampled_image>, %2 : vector<4xf32>, %3 : f32 -> vector<4xi32> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Shader]> + ]; + + let arguments = (ins + SPV_AnySampledImage:$sampledimage, + SPV_ScalarOrVectorOf:$coordinate, + SPV_Float:$dref + ); + + let results = (outs + SPV_Vector:$result + ); + + let assemblyFormat = "attr-dict $sampledimage `:` type($sampledimage) `,` $coordinate `:` type($coordinate) `,` $dref `:` type($dref) `->` type($result)"; + + let verifier = [{ return ::verify(*this); }]; + +} + + // ----- def SPV_ImageOp : SPV_Op<"Image", diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3608,6 +3608,46 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.ImageDrefGather +//===----------------------------------------------------------------------===// + +static LogicalResult verify(spirv::ImageDrefGatherOp imageDrefGatherOp) { + // TODO: Support optional operands. + VectorType resultType = + imageDrefGatherOp.result().getType().cast(); + auto sampledImageType = imageDrefGatherOp.sampledimage() + .getType() + .cast(); + auto imageType = sampledImageType.getImageType().cast(); + + if (resultType.getNumElements() != 4) + return imageDrefGatherOp.emitOpError( + "result type must be a vector of four components"); + + Type elementType = resultType.getElementType(); + Type sampledElementType = imageType.getElementType(); + if (!sampledElementType.isa() && elementType != sampledElementType) + return imageDrefGatherOp.emitOpError( + "the component type of result must be the same as sampled type of the " + "underlying image type"); + + spirv::Dim imageDim = imageType.getDim(); + spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo(); + + if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube && + imageDim != spirv::Dim::Rect) + return imageDrefGatherOp.emitOpError( + "the Dim operand of the underlying image type must be 2D, Cube, or " + "Rect"); + + if (imageMS != spirv::ImageSamplingInfo::SingleSampled) + return imageDrefGatherOp.emitOpError( + "the MS operand of the underlying image type must be 0"); + + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir @@ -1,5 +1,50 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.ImageDrefGather +//===----------------------------------------------------------------------===// + +func @image_dref_gather(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { + // CHECK: spv.ImageDrefGather {{.*}} : !spv.sampled_image>, {{.*}} : vector<4xf32>, {{.*}} : f32 -> vector<4xi32> + %0 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xi32> + spv.Return +} + +// ----- + +func @image_dref_gather_error_result_type(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { + // expected-error @+1 {{result type must be a vector of four components}} + %0 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<3xi32> + spv.Return +} + +// ----- + +func @image_dref_gather_error_same_type(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { + // expected-error @+1 {{the component type of result must be the same as sampled type of the underlying image type}} + %0 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xf32> + spv.Return +} + +// ----- + +func @image_dref_gather_error_dim(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { + // expected-error @+1 {{the Dim operand of the underlying image type must be 2D, Cube, or Rect}} + %0 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xi32> + spv.Return +} + +// ----- + +func @image_dref_gather_error_ms(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { + // expected-error @+1 {{the MS operand of the underlying image type must be 0}} + %0 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xi32> + spv.Return +} + +// ----- + //===----------------------------------------------------------------------===// // spv.Image //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/image-ops.mlir b/mlir/test/Target/SPIRV/image-ops.mlir --- a/mlir/test/Target/SPIRV/image-ops.mlir +++ b/mlir/test/Target/SPIRV/image-ops.mlir @@ -1,9 +1,11 @@ // RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { - spv.func @image(%arg0 : !spv.sampled_image>) "None" { + spv.func @image(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) "None" { // CHECK: {{%.*}} = spv.Image {{%.*}} : !spv.sampled_image> %0 = spv.Image %arg0 : !spv.sampled_image> + // CHECK: {{%.*}} = spv.ImageDrefGather {{%.*}} : !spv.sampled_image>, {{%.*}} : vector<4xf32>, {{%.*}} : f32 -> vector<4xf32> + %1 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xf32> spv.Return } }