diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3040,9 +3040,6 @@ def SPV_IOrUIVec4 : SPV_Vec4; def SPV_Int32Vec4 : SPV_Vec4; -// TODO(antiagainst): Use a more appropriate way to model optional operands -class SPV_Optional : Variadic; - // TODO(ravishankarm): From 1.4, this should also include Composite type. def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -240,7 +240,7 @@ ); let results = (outs - SPV_Optional:$result + Optional:$result ); let autogenSerialization = 0; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -30,7 +30,7 @@ SPV_ScopeAttr:$execution_scope, SPV_GroupOperationAttr:$group_operation, SPV_ScalarOrVectorOf:$value, - SPV_Optional:$cluster_size + Optional:$cluster_size ); let results = (outs diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -469,7 +469,7 @@ let arguments = (ins SPV_StorageClassAttr:$storage_class, - SPV_Optional:$initializer + Optional:$initializer ); let results = (outs diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -155,7 +155,7 @@ groupOperation = rewriter.create( \ loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ spirv::GroupOperation::Reduce, inputElement, \ - /*cluster_size=*/ArrayRef()); \ + /*cluster_size=*/nullptr); \ } break switch (*binaryOpKind) { CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -2291,6 +2291,10 @@ << operands[0]; } + // Use null type to mean no result type. + if (isVoidType(resultType)) + resultType = nullptr; + auto resultID = operands[1]; auto functionID = operands[2]; @@ -2306,18 +2310,12 @@ arguments.push_back(value); } - SmallVector resultTypes; - if (!isVoidType(resultType)) { - resultTypes.push_back(resultType); - } - auto opFunctionCall = opBuilder.create( - unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName), + unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName), arguments); - if (!resultTypes.empty()) { + if (resultType) valueMap[resultID] = opFunctionCall.getResult(0); - } return success(); } diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -202,7 +202,7 @@ spv.module Logical GLSL450 { spv.func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () "None" { - // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}} + // expected-error @+1 {{result group starting at #0 requires 0 or 1 element, but found 2}} %0:2 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32) spv.Return } diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -548,7 +548,7 @@ if quantifier == '': arg_type = 'SPV_Type' elif quantifier == '?': - arg_type = 'SPV_Optional' + arg_type = 'Optional' else: arg_type = 'Variadic' elif kind == 'IdMemorySemantics' or kind == 'IdScope':