diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3117,6 +3117,8 @@ def UnsignedOp : NativeOpTrait<"spirv::UnsignedOp">; +def SignedOp : NativeOpTrait<"spirv::SignedOp">; + def UsableInSpecConstantOp : NativeOpTrait<"spirv::UsableInSpecConstantOp">; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -178,7 +178,8 @@ // ----- -def SPV_BitFieldSExtractOp : SPV_BitFieldExtractOp<"BitFieldSExtract", []> { +def SPV_BitFieldSExtractOp : SPV_BitFieldExtractOp<"BitFieldSExtract", + [SignedOp]> { let summary = "Extract a bit field from an object, with sign extension."; let description = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -163,7 +163,10 @@ // ----- -def SPV_ConvertSToFOp : SPV_CastOp<"ConvertSToF", SPV_Float, SPV_Integer, []> { +def SPV_ConvertSToFOp : SPV_CastOp<"ConvertSToF", + SPV_Float, + SPV_Integer, + [SignedOp]> { let summary = [{ Convert value numerically from signed integer to floating point. }]; @@ -270,7 +273,10 @@ // ----- -def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, [UsableInSpecConstantOp]> { +def SPV_SConvertOp : SPV_CastOp<"SConvert", + SPV_Integer, + SPV_Integer, + [UsableInSpecConstantOp, SignedOp]> { let summary = [{ Convert signed width. This is either a truncate or a sign extend. }]; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -690,7 +690,7 @@ def SPV_SGreaterThanOp : SPV_LogicalBinaryOp<"SGreaterThan", SPV_Integer, - [UsableInSpecConstantOp]> { + [UsableInSpecConstantOp, SignedOp]> { let summary = [{ Signed-integer comparison if Operand 1 is greater than Operand 2. }]; @@ -725,7 +725,8 @@ def SPV_SGreaterThanEqualOp : SPV_LogicalBinaryOp<"SGreaterThanEqual", SPV_Integer, - [UsableInSpecConstantOp]> { + [UsableInSpecConstantOp, + SignedOp]> { let summary = [{ Signed-integer comparison if Operand 1 is greater than or equal to Operand 2. @@ -761,7 +762,7 @@ def SPV_SLessThanOp : SPV_LogicalBinaryOp<"SLessThan", SPV_Integer, - [UsableInSpecConstantOp]> { + [UsableInSpecConstantOp, SignedOp]> { let summary = [{ Signed-integer comparison if Operand 1 is less than Operand 2. }]; @@ -796,7 +797,8 @@ def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual", SPV_Integer, - [UsableInSpecConstantOp]> { + [UsableInSpecConstantOp, + SignedOp]> { let summary = [{ Signed-integer comparison if Operand 1 is less than or equal to Operand 2. diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -533,7 +533,9 @@ // ----- def SPV_GroupNonUniformSMaxOp : - SPV_GroupNonUniformArithmeticOp<"GroupNonUniformSMax", SPV_Integer, []> { + SPV_GroupNonUniformArithmeticOp<"GroupNonUniformSMax", + SPV_Integer, + [SignedOp]> { let summary = [{ A signed integer maximum group operation of all Value operands contributed by active invocations in the group. @@ -582,7 +584,9 @@ // ----- def SPV_GroupNonUniformSMinOp : - SPV_GroupNonUniformArithmeticOp<"GroupNonUniformSMin", SPV_Integer, []> { + SPV_GroupNonUniformArithmeticOp<"GroupNonUniformSMin", + SPV_Integer, + [SignedOp]> { let summary = [{ A signed integer minimum group operation of all Value operands contributed by active invocations in the group. diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h @@ -23,6 +23,9 @@ template class UnsignedOp : public TraitBase {}; +template +class SignedOp : public TraitBase {}; + /// A trait to mark ops that can be enclosed/wrapped in a /// `SpecConstantOperation` op. template