diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -805,16 +805,10 @@ return emitError(unknownLoc, "duplicate function definition/declaration"); } - auto functionControl = spirv::symbolizeFunctionControl(operands[2]); - if (!functionControl) { + auto fnControl = spirv::symbolizeFunctionControl(operands[2]); + if (!fnControl) { return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; } - if (functionControl.getValue() != spirv::FunctionControl::None) { - /// TODO: Handle different function controls - return emitError(unknownLoc, "unhandled Function Control: '") - << spirv::stringifyFunctionControl(functionControl.getValue()) - << "'"; - } Type fnType = getType(operands[3]); if (!fnType || !fnType.isa()) { @@ -831,8 +825,8 @@ } std::string fnName = getFunctionSymbol(operands[1]); - auto funcOp = - opBuilder.create(unknownLoc, fnName, functionType); + auto funcOp = opBuilder.create( + unknownLoc, fnName, functionType, fnControl.getValue()); curFunction = funcMap[operands[1]] = funcOp; LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " << fnType << ", id = " << operands[1] << ") --\n"); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -775,8 +775,7 @@ operands.push_back(resTypeID); auto funcID = getOrCreateFunctionID(op.getName()); operands.push_back(funcID); - // TODO: Support other function control options. - operands.push_back(static_cast(spirv::FunctionControl::None)); + operands.push_back(static_cast(op.function_control())); operands.push_back(fnTypeID); encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); diff --git a/mlir/test/Dialect/SPIRV/Serialization/module.mlir b/mlir/test/Dialect/SPIRV/Serialization/module.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/module.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/module.mlir @@ -1,13 +1,13 @@ // RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s // CHECK: spv.module Logical GLSL450 requires #spv.vce { -// CHECK-NEXT: spv.func @foo() "None" { +// CHECK-NEXT: spv.func @foo() "Inline" { // CHECK-NEXT: spv.Return // CHECK-NEXT: } // CHECK-NEXT: } spv.module Logical GLSL450 requires #spv.vce { - spv.func @foo() -> () "None" { + spv.func @foo() -> () "Inline" { spv.Return } }