diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -805,11 +805,15 @@ auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; - // According to SPIR-V spec, "When the type's bit width is less than 32-bits, - // the literal's value appears in the low-order bits of the word, and the - // high-order bits must be 0 for a floating-point type, or 0 for an integer - // type with Signedness of 0, or sign extended when Signedness is 1." - if (bitwidth == 32 || bitwidth == 16) { + switch (bitwidth) { + // According to SPIR-V spec, "When the type's bit width is less than + // 32-bits, the literal's value appears in the low-order bits of the word, + // and the high-order bits must be 0 for a floating-point type, or 0 for an + // integer type with Signedness of 0, or sign extended when Signedness + // is 1." + case 32: + case 16: + case 8: { uint32_t word = 0; if (isSigned) { word = static_cast(value.getSExtValue()); @@ -818,10 +822,10 @@ } (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); - } - // According to SPIR-V spec: "When the type's bit width is larger than one - // word, the literal’s low-order words appear first." - else if (bitwidth == 64) { + } break; + // According to SPIR-V spec: "When the type's bit width is larger than one + // word, the literal’s low-order words appear first." + case 64: { struct DoubleWord { uint32_t word1; uint32_t word2; @@ -833,7 +837,8 @@ } (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); - } else { + } break; + default: { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss, /*isSigned=*/false); @@ -842,6 +847,7 @@ << bitwidth << "-bit integer literal: " << rss.str(); return 0; } + } if (!isSpec) { constIDMap[intAttr] = resultID; diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -85,6 +85,32 @@ spv.Return } + // CHECK-LABEL: @i8_const + spv.func @i8_const() -> () "None" { + // CHECK: spv.Constant 0 : i8 + %0 = spv.Constant 0 : i8 + // CHECK: spv.Constant -1 : i8 + %1 = spv.Constant 255 : i8 + + // CHECK: spv.Constant 0 : si8 + %2 = spv.Constant 0 : si8 + // CHECK: spv.Constant 127 : si8 + %3 = spv.Constant 127 : si8 + // CHECK: spv.Constant -128 : si8 + %4 = spv.Constant -128 : si8 + + // CHECK: spv.Constant 0 : i8 + %5 = spv.Constant 0 : ui8 + // CHECK: spv.Constant -1 : i8 + %6 = spv.Constant 255 : ui8 + + %10 = spv.IAdd %0, %1: i8 + %11 = spv.IAdd %2, %3: si8 + %12 = spv.IAdd %3, %4: si8 + %13 = spv.IAdd %5, %6: ui8 + spv.Return + } + // CHECK-LABEL: @float_const spv.func @float_const() -> () "None" { // CHECK: spv.Constant 0.000000e+00 : f32