diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2907,7 +2907,7 @@ // for the definition of the following types and type categories. def SPV_Void : TypeAlias; -def SPV_Bool : IntOfWidths<[1]>; +def SPV_Bool : I<1>; def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>; def SPV_Float : FloatOfWidths<[16, 32, 64]>; def SPV_Float16or32 : FloatOfWidths<[16, 32]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -245,6 +245,11 @@ ); let autogenSerialization = 0; + + let assemblyFormat = [{ + $callee `(` $arguments `)` attr-dict `:` + functional-type($arguments, results) + }]; } // ----- @@ -412,6 +417,8 @@ ); let results = (outs); + + let assemblyFormat = "$value attr-dict `:` type($value)"; } def SPV_SelectionOp : SPV_Op<"selection", [InFunctionScope]> { diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -65,6 +65,8 @@ ); let verifier = [{ return success(); }]; + + let assemblyFormat = "$predicate attr-dict `:` type($result)"; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -404,6 +404,8 @@ let hasOpcode = 0; let autogenSerialization = 0; + + let assemblyFormat = "attr-dict `:` type($result)"; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -409,6 +409,8 @@ let hasOpcode = 0; let autogenSerialization = 0; + + let assemblyFormat = "$spec_const attr-dict `:` type($reference)"; } def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1635,46 +1635,6 @@ // spv.FunctionCall //===----------------------------------------------------------------------===// -static ParseResult parseFunctionCallOp(OpAsmParser &parser, - OperationState &state) { - FlatSymbolRefAttr calleeAttr; - FunctionType type; - SmallVector operands; - auto loc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, kCallee, state.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseColonType(type)) { - return failure(); - } - - auto funcType = type.dyn_cast(); - if (!funcType) { - return parser.emitError(loc, "expected function type, but provided ") - << type; - } - - if (funcType.getNumResults() > 1) { - return parser.emitError(loc, "expected callee function to have 0 or 1 " - "result, but provided ") - << funcType.getNumResults(); - } - - return failure(parser.addTypesToList(funcType.getResults(), state.types) || - parser.resolveOperands(operands, funcType.getInputs(), loc, - state.operands)); -} - -static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) { - SmallVector argTypes(functionCallOp.getOperandTypes()); - SmallVector resultTypes(functionCallOp.getResultTypes()); - Type functionType = - FunctionType::get(argTypes, resultTypes, functionCallOp.getContext()); - - printer << spirv::FunctionCallOp::getOperationName() << ' ' - << functionCallOp.getAttr(kCallee) << '(' - << functionCallOp.arguments() << ") : " << functionType; -} - static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { auto fnName = functionCallOp.callee(); @@ -2398,24 +2358,6 @@ // spv._reference_of //===----------------------------------------------------------------------===// -static ParseResult parseReferenceOfOp(OpAsmParser &parser, - OperationState &state) { - FlatSymbolRefAttr constRefAttr; - Type type; - if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName, - state.attributes) || - parser.parseColonType(type)) { - return failure(); - } - return parser.addTypeToList(type, state.types); -} - -static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) { - printer << spirv::ReferenceOfOp::getOperationName() << ' '; - printer.printSymbolName(referenceOfOp.spec_const()); - printer << " : " << referenceOfOp.reference().getType(); -} - static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { auto moduleOp = referenceOfOp.getParentOfType(); auto specConstOp = @@ -2449,20 +2391,6 @@ // spv.ReturnValue //===----------------------------------------------------------------------===// -static ParseResult parseReturnValueOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType retValInfo; - Type retValType; - return failure(parser.parseOperand(retValInfo) || - parser.parseColonType(retValType) || - parser.resolveOperand(retValInfo, retValType, state.operands)); -} - -static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) { - printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value() - << " : " << retValOp.value().getType(); -} - static LogicalResult verify(spirv::ReturnValueOp retValOp) { auto funcOp = retValOp.getParentOfType(); auto numFnResults = funcOp.getType().getNumResults(); @@ -2897,44 +2825,6 @@ return verifyMemoryAccessAttribute(storeOp); } -//===----------------------------------------------------------------------===// -// spv.SubgroupBallotKHROp -//===----------------------------------------------------------------------===// - -static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operandInfo; - Type resultType; - IntegerType i1Type = parser.getBuilder().getI1Type(); - if (parser.parseOperand(operandInfo) || parser.parseColonType(resultType) || - parser.resolveOperand(operandInfo, i1Type, state.operands)) - return failure(); - - return parser.addTypeToList(resultType, state.types); -} - -static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) { - printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ' - << ballotOp.predicate() << " : " << ballotOp.getType(); -} - -//===----------------------------------------------------------------------===// -// spv.Undef -//===----------------------------------------------------------------------===// - -static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &state) { - Type type; - if (parser.parseColonType(type)) { - return failure(); - } - state.addTypes(type); - return success(); -} - -static void print(spirv::UndefOp undefOp, OpAsmPrinter &printer) { - printer << spirv::UndefOp::getOperationName() << " : " << undefOp.getType(); -} - //===----------------------------------------------------------------------===// // spv.Unreachable //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -189,7 +189,7 @@ spv.module "Logical" "GLSL450" { func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () { // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}} - %0 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32) + %0:2 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32) spv.Return } }