diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -2671,6 +2671,74 @@ SPV_IF_R8ui ]>; +def SPV_IO_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_IO_Bias : BitEnumAttrCase<"Bias", 0x0001> { + list availability = [ + Capability<[SPV_C_Shader]> + ]; +} +def SPV_IO_Lod : BitEnumAttrCase<"Lod", 0x0002>; +def SPV_IO_Grad : BitEnumAttrCase<"Grad", 0x0004>; +def SPV_IO_ConstOffset : BitEnumAttrCase<"ConstOffset", 0x0008>; +def SPV_IO_Offset : BitEnumAttrCase<"Offset", 0x0010> { + list availability = [ + Capability<[SPV_C_ImageGatherExtended]> + ]; +} +def SPV_IO_ConstOffsets : BitEnumAttrCase<"ConstOffsets", 0x0020> { + list availability = [ + Capability<[SPV_C_ImageGatherExtended]> + ]; +} +def SPV_IO_Sample : BitEnumAttrCase<"Sample", 0x0040>; +def SPV_IO_MinLod : BitEnumAttrCase<"MinLod", 0x0080> { + list availability = [ + Capability<[SPV_C_MinLod]> + ]; +} +def SPV_IO_MakeTexelAvailable : BitEnumAttrCase<"MakeTexelAvailable", 0x0100> { + list availability = [ + MinVersion, + Capability<[SPV_C_VulkanMemoryModel]> + ]; +} +def SPV_IO_MakeTexelVisible : BitEnumAttrCase<"MakeTexelVisible", 0x0200> { + list availability = [ + MinVersion, + Capability<[SPV_C_VulkanMemoryModel]> + ]; +} +def SPV_IO_NonPrivateTexel : BitEnumAttrCase<"NonPrivateTexel", 0x0400> { + list availability = [ + MinVersion, + Capability<[SPV_C_VulkanMemoryModel]> + ]; +} +def SPV_IO_VolatileTexel : BitEnumAttrCase<"VolatileTexel", 0x0800> { + list availability = [ + MinVersion, + Capability<[SPV_C_VulkanMemoryModel]> + ]; +} +def SPV_IO_SignExtend : BitEnumAttrCase<"SignExtend", 0x1000> { + list availability = [ + MinVersion, + ]; +} +def SPV_IO_ZeroExtend : BitEnumAttrCase<"ZeroExtend", 0x2000> { + list availability = [ + MinVersion, + ]; +} + +def SPV_ImageOperandAttr : + SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", [ + SPV_IO_None, SPV_IO_Bias, SPV_IO_Lod, SPV_IO_Grad, SPV_IO_ConstOffset, + SPV_IO_Offset, SPV_IO_ConstOffsets, SPV_IO_Sample, SPV_IO_MinLod, + SPV_IO_MakeTexelAvailable, SPV_IO_MakeTexelVisible, SPV_IO_NonPrivateTexel, + SPV_IO_VolatileTexel, SPV_IO_SignExtend, SPV_IO_ZeroExtend + ]>; + def SPV_LT_Export : I32EnumAttrCase<"Export", 0> { list availability = [ Capability<[SPV_C_Linkage]> diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td @@ -42,11 +42,18 @@ Image Operands encodes what operands follow, as per Image Operands. - + ``` + image-operands ::= `"None"` | `"Bias"` | `"Lod"` | `"Grad"` + | `"ConstOffset"` | `"Offser"` | `"ConstOffsets"` + | `"Sample"` | `"MinLod"` | `"MakeTexelAvailable"` + | `"MakeTexelVisible"` | `"NonPrivateTexel"` + | `"VolatileTexel"` | `"SignExtend"` | `"ZeroExtend"` #### Example: + ``` ```mlir %0 = spv.ImageDrefGather %1 : !spv.sampled_image>, %2 : vector<4xf32>, %3 : f32 -> vector<4xi32> + %0 = spv.ImageDrefGather %1 : !spv.sampled_image>, %2 : vector<4xf32>, %3 : f32 ["NonPrivateTexel"] : f32, f32 -> vector<4xi32> ``` }]; @@ -60,14 +67,21 @@ let arguments = (ins SPV_AnySampledImage:$sampledimage, SPV_ScalarOrVectorOf:$coordinate, - SPV_Float:$dref + SPV_Float:$dref, + OptionalAttr:$imageoperands, + Variadic:$operand_arguments ); let results = (outs SPV_Vector:$result ); - let assemblyFormat = "attr-dict $sampledimage `:` type($sampledimage) `,` $coordinate `:` type($coordinate) `,` $dref `:` type($dref) `->` type($result)"; + let assemblyFormat = [{$sampledimage `:` type($sampledimage) `,` + $coordinate `:` type($coordinate) `,` $dref `:` type($dref) + custom($imageoperands) + ( `(` $operand_arguments^ `:` type($operand_arguments) `)`)? + attr-dict + `->` type($result)}]; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -311,6 +311,58 @@ elidedAttrs.push_back(spirv::attributeName()); } +static ParseResult parseImageOperands(OpAsmParser &parser, + spirv::ImageOperandsAttr &attr) { + // Expect image operands + if (parser.parseOptionalLSquare()) + return success(); + + spirv::ImageOperands imageOperands; + if (parseEnumStrAttr(imageOperands, parser)) + return failure(); + + attr = spirv::ImageOperandsAttr::get(parser.getBuilder().getContext(), + imageOperands); + + return parser.parseRSquare(); +} + +static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, + spirv::ImageOperandsAttr attr) { + if (attr) { + auto strImageOperands = stringifyImageOperands(attr.getValue()); + printer << "[\"" << strImageOperands << "\"]"; + } +} + +template +static LogicalResult verifyImageOperands(Op imageOp, + spirv::ImageOperandsAttr attr, + Operation::operand_range operands) { + if (!attr) { + if (operands.empty()) + return success(); + + return imageOp.emitError("the Image Operands should encode what operands " + "follow, as per Image Operands"); + } + + // TODO: Add the validation rules for the following Image Operands. + spirv::ImageOperands noSupportOperands = + spirv::ImageOperands::Bias | spirv::ImageOperands::Lod | + spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset | + spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets | + spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod | + spirv::ImageOperands::MakeTexelAvailable | + spirv::ImageOperands::MakeTexelVisible | + spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; + + if (spirv::bitEnumContains(attr.getValue(), noSupportOperands)) + llvm_unreachable("unimplemented operands of Image Operands"); + + return success(); +} + static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth = true, bool skipBitWidthCheck = false) { @@ -3656,7 +3708,6 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(spirv::ImageDrefGatherOp imageDrefGatherOp) { - // TODO: Support optional operands. VectorType resultType = imageDrefGatherOp.result().getType().cast(); auto sampledImageType = imageDrefGatherOp.sampledimage() @@ -3688,7 +3739,10 @@ return imageDrefGatherOp.emitOpError( "the MS operand of the underlying image type must be 0"); - return success(); + spirv::ImageOperandsAttr attr = imageDrefGatherOp.imageoperandsAttr(); + auto operandArguments = imageDrefGatherOp.operand_arguments(); + + return verifyImageOperands(imageDrefGatherOp, attr, operandArguments); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir @@ -1,6 +1,5 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s - //===----------------------------------------------------------------------===// // spv.ImageDrefGather //===----------------------------------------------------------------------===// @@ -13,6 +12,22 @@ // ----- +func @image_dref_gather_with_single_imageoperands(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { + // CHECK: spv.ImageDrefGather {{.*}} ["NonPrivateTexel"] -> vector<4xi32> + %0 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 ["NonPrivateTexel"] -> vector<4xi32> + spv.Return +} + +// ----- + +func @image_dref_gather_with_mismatch_imageoperands(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { + // expected-error @+1 {{the Image Operands should encode what operands follow, as per Image Operands}} + %0 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 (%arg2, %arg2 : f32, f32) -> vector<4xi32> + spv.Return +} + +// ----- + func @image_dref_gather_error_result_type(%arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { // expected-error @+1 {{result type must be a vector of four components}} %0 = spv.ImageDrefGather %arg0 : !spv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<3xi32>