diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1251,15 +1251,15 @@ << operands[1]; } - auto scope = spirv::symbolizeScope(operands[2]); + auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); if (!scope) { return emitError(unknownLoc, "OpTypeCooperativeMatrix references undefined scope ") << operands[2]; } - unsigned rows = operands[3]; - unsigned columns = operands[4]; + unsigned rows = getConstantInt(operands[3]).getInt(); + unsigned columns = getConstantInt(operands[4]).getInt(); typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( elementTy, scope.getValue(), rows, columns); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1104,10 +1104,15 @@ return failure(); } typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; + auto getConstantOp = [&](uint32_t id) { + auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id); + return prepareConstantInt(loc, attr); + }; operands.push_back(elementTypeID); - operands.push_back(static_cast(cooperativeMatrixType.getScope())); - operands.push_back(cooperativeMatrixType.getRows()); - operands.push_back(cooperativeMatrixType.getColumns()); + operands.push_back( + getConstantOp(static_cast(cooperativeMatrixType.getScope()))); + operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); + operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); return success(); }