diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -846,8 +846,7 @@ auto resultID = getNextID(); APInt value = intAttr.getValue(); unsigned bitwidth = value.getBitWidth(); - bool isSigned = value.isSignedIntN(bitwidth); - + bool isSigned = intAttr.getType().isSignedInteger(); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -264,4 +264,17 @@ %0 = spirv.Constant dense<1> : tensor<2x2x3xi32> : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32, stride=4>, stride=12>, stride=24> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32, stride=4>, stride=12>, stride=24> } + + // CHECK-LABEL: @signless_int_const_bit_extension + spirv.func @signless_int_const_bit_extension() -> (i16) "None" { + // CHECK: spirv.Constant -1 : i16 + %signless_minus_one = spirv.Constant -1 : i16 + spirv.ReturnValue %signless_minus_one : i16 + } + // CHECK-LABEL: @signed_int_const_bit_extension + spirv.func @signed_int_const_bit_extension() -> (si16) "None" { + // CHECK: spirv.Constant -1 : si16 + %signed_minus_one = spirv.Constant -1 : si16 + spirv.ReturnValue %signed_minus_one : si16 + } } diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -76,6 +76,27 @@ builder.getStringAttr(name), nullptr); } + // Inserts an Integer or a Vector of Integers constant of value 'val'. + spirv::ConstantOp AddConstInt(Type type, APInt val) { + OpBuilder builder(module->getRegion()); + auto loc = UnknownLoc::get(&context); + + if (auto intType = dyn_cast(type)) { + return builder.create( + loc, type, builder.getIntegerAttr(type, val)); + } + if (auto vectorType = dyn_cast(type)) { + Type elemType = vectorType.getElementType(); + if (auto intType = dyn_cast(elemType)) { + return builder.create( + loc, type, + DenseElementsAttr::get(vectorType, + IntegerAttr::get(elemType, val).getValue())); + } + } + llvm_unreachable("unimplemented types for AddConstInt()"); + } + /// Handles a SPIR-V instruction with the given `opcode` and `operand`. /// Returns true to interrupt. using HandleFn = llvm::function_ref operands) { + return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && + operands[2] == 65535; + }; + EXPECT_TRUE(scanInstruction(hasSignlessVal)); + + auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef operands) { + return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && + operands[2] == 4294967295; + }; + EXPECT_TRUE(scanInstruction(hasSignedVal)); +} + TEST_F(SerializationTest, ContainsSymbolName) { auto structType = getFloatStructType(); addGlobalVar(structType, "var0");