diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -514,7 +514,7 @@ // ----- -def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, []> { +def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, [UnsignedOp]> { let summary = "Unsigned-integer division of Operand 1 divided by Operand 2."; let description = [{ @@ -546,7 +546,7 @@ // ----- -def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer> { +def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer, [UnsignedOp]> { let summary = "Unsigned modulo operation of Operand 1 modulo Operand 2."; let description = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td @@ -438,7 +438,7 @@ // ----- -def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> { +def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", [UnsignedOp]> { let summary = [{ Perform the following steps atomically with respect to any other atomic accesses within Scope to the same location: @@ -480,7 +480,7 @@ // ----- -def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", []> { +def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", [UnsignedOp]> { let summary = [{ Perform the following steps atomically with respect to any other atomic accesses within Scope to the same location: diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3115,6 +3115,8 @@ "op must appear in a module-like op's block", CPred<"isDirectInModuleLikeOp($_op.getParentOp())">>; +def UnsignedOp : NativeOpTrait<"spirv::UnsignedOp">; + //===----------------------------------------------------------------------===// // SPIR-V opcode specification //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -232,7 +232,8 @@ // ----- -def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract", []> { +def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract", + [UnsignedOp]> { let summary = "Extract a bit field from an object, without sign extension."; let description = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -196,7 +196,10 @@ // ----- -def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF", SPV_Float, SPV_Integer, []> { +def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF", + SPV_Float, + SPV_Integer, + [UnsignedOp]> { let summary = [{ Convert value numerically from unsigned integer to floating point. }]; @@ -298,7 +301,10 @@ // ----- -def SPV_UConvertOp : SPV_CastOp<"UConvert", SPV_Integer, SPV_Integer, []> { +def SPV_UConvertOp : SPV_CastOp<"UConvert", + SPV_Integer, + SPV_Integer, + [UnsignedOp]> { let summary = [{ Convert unsigned width. This is either a truncate or a zero extend. }]; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -869,7 +869,9 @@ // ----- -def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> { +def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", + SPV_Integer, + [UnsignedOp]> { let summary = [{ Unsigned-integer comparison if Operand 1 is greater than Operand 2. }]; @@ -902,7 +904,9 @@ // ----- -def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integer, []> { +def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", + SPV_Integer, + [UnsignedOp]> { let summary = [{ Unsigned-integer comparison if Operand 1 is greater than or equal to Operand 2. @@ -936,7 +940,9 @@ // ----- -def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> { +def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", + SPV_Integer, + [UnsignedOp]> { let summary = [{ Unsigned-integer comparison if Operand 1 is less than Operand 2. }]; @@ -970,7 +976,7 @@ // ----- def SPV_ULessThanEqualOp : - SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, []> { + SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, [UnsignedOp]> { let summary = [{ Unsigned-integer comparison if Operand 1 is less than or equal to Operand 2. diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -631,7 +631,9 @@ // ----- def SPV_GroupNonUniformUMaxOp : - SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMax", SPV_Integer, []> { + SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMax", + SPV_Integer, + [UnsignedOp]> { let summary = [{ An unsigned integer maximum group operation of all Value operands contributed by active invocations in the group. @@ -681,7 +683,9 @@ // ----- def SPV_GroupNonUniformUMinOp : - SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMin", SPV_Integer, []> { + SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMin", + SPV_Integer, + [UnsignedOp]> { let summary = [{ An unsigned integer minimum group operation of all Value operands contributed by active invocations in the group. diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h @@ -0,0 +1,30 @@ +//===- SPIRVOps.h - MLIR SPIR-V operation traits ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares C++ classes for some of operation traits in the SPIR-V +// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_ +#define MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace OpTrait { +namespace spirv { + +template +class UnsignedOp : public TraitBase {}; + +} // namespace spirv +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_ #define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_ +#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -187,36 +187,6 @@ offset); } -/// Returns true if the operator is operating on unsigned integers. -/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information -/// to the ops themselves. -template -bool isUnsignedOp() { - return false; -} - -#define CHECK_UNSIGNED_OP(SPIRVOp) \ - template <> \ - bool isUnsignedOp() { \ - return true; \ - } - -CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp) -CHECK_UNSIGNED_OP(spirv::AtomicUMinOp) -CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp) -CHECK_UNSIGNED_OP(spirv::ConvertUToFOp) -CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp) -CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp) -CHECK_UNSIGNED_OP(spirv::UConvertOp) -CHECK_UNSIGNED_OP(spirv::UDivOp) -CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp) -CHECK_UNSIGNED_OP(spirv::UGreaterThanOp) -CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp) -CHECK_UNSIGNED_OP(spirv::ULessThanOp) -CHECK_UNSIGNED_OP(spirv::UModOp) - -#undef CHECK_UNSIGNED_OP - /// Returns true if the allocations of type `t` can be lowered to SPIR-V. static bool isAllocationSupported(MemRefType t) { // Currently only support workgroup local memory allocations with static @@ -334,7 +304,8 @@ auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); - if (isUnsignedOp() && dstType != operation.getType()) { + if (SPIRVOp::template hasTrait() && + dstType != operation.getType()) { return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } @@ -799,7 +770,7 @@ switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ - if (isUnsignedOp() && \ + if (spirvOp::template hasTrait() && \ operandType != this->typeConverter.convertType(operandType)) { \ return cmpIOp.emitError( \ "bitwidth emulation is not implemented yet on unsigned op"); \