diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2954,7 +2954,7 @@ // for the definition of the following types and type categories. def SPV_Void : TypeAlias; -def SPV_Bool : IntOfWidths<[1]>; +def SPV_Bool : I<1>; def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>; def SPV_Float : FloatOfWidths<[16, 32, 64]>; def SPV_Float16or32 : FloatOfWidths<[16, 32]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -245,6 +245,11 @@ ); let autogenSerialization = 0; + + let assemblyFormat = [{ + $callee `(` $arguments `)` attr-dict `:` + functional-type($arguments, results) + }]; } // ----- @@ -412,6 +417,8 @@ ); let results = (outs); + + let assemblyFormat = "$value attr-dict `:` type($value)"; } def SPV_SelectionOp : SPV_Op<"selection", [InFunctionScope]> { diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -65,6 +65,8 @@ ); let verifier = [{ return success(); }]; + + let assemblyFormat = "$predicate attr-dict `:` type($result)"; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -408,6 +408,8 @@ let hasOpcode = 0; let autogenSerialization = 0; + + let assemblyFormat = "attr-dict `:` type($result)"; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -57,6 +57,8 @@ let builders = [OpBuilder<[{Builder *builder, OperationState &state, spirv::GlobalVariableOp var}]>]; + + let assemblyFormat = "$variable attr-dict `:` type($pointer)"; } def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { @@ -409,6 +411,8 @@ let hasOpcode = 0; let autogenSerialization = 0; + + let assemblyFormat = "$spec_const attr-dict `:` type($reference)"; } def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -941,36 +941,6 @@ build(builder, state, var.type(), builder->getSymbolRefAttr(var)); } -static ParseResult parseAddressOfOp(OpAsmParser &parser, - OperationState &state) { - FlatSymbolRefAttr varRefAttr; - Type type; - if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName, - state.attributes) || - parser.parseColonType(type)) { - return failure(); - } - auto ptrType = type.dyn_cast(); - if (!ptrType) { - return parser.emitError(parser.getCurrentLocation(), - "expected spv.ptr type"); - } - state.addTypes(ptrType); - return success(); -} - -static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) { - SmallVector elidedAttrs; - printer << spirv::AddressOfOp::getOperationName(); - - // Print symbol name. - printer << ' '; - printer.printSymbolName(addressOfOp.variable()); - - // Print the type. - printer << " : " << addressOfOp.pointer().getType(); -} - static LogicalResult verify(spirv::AddressOfOp addressOfOp) { auto varOp = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom(addressOfOp.getParentOp(), @@ -1736,45 +1706,6 @@ // spv.FunctionCall //===----------------------------------------------------------------------===// -static ParseResult parseFunctionCallOp(OpAsmParser &parser, - OperationState &state) { - FlatSymbolRefAttr calleeAttr; - FunctionType type; - SmallVector operands; - auto loc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, kCallee, state.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseColonType(type)) { - return failure(); - } - - auto funcType = type.dyn_cast(); - if (!funcType) { - return parser.emitError(loc, "expected function type, but provided ") - << type; - } - - if (funcType.getNumResults() > 1) { - return parser.emitError(loc, "expected callee function to have 0 or 1 " - "result, but provided ") - << funcType.getNumResults(); - } - - return failure(parser.addTypesToList(funcType.getResults(), state.types) || - parser.resolveOperands(operands, funcType.getInputs(), loc, - state.operands)); -} - -static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) { - SmallVector argTypes(functionCallOp.getOperandTypes()); - Type functionType = FunctionType::get( - argTypes, functionCallOp.getResultTypes(), functionCallOp.getContext()); - - printer << spirv::FunctionCallOp::getOperationName() << ' ' - << functionCallOp.getAttr(kCallee) << '(' - << functionCallOp.arguments() << ") : " << functionType; -} - static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { auto fnName = functionCallOp.callee(); @@ -2533,24 +2464,6 @@ // spv._reference_of //===----------------------------------------------------------------------===// -static ParseResult parseReferenceOfOp(OpAsmParser &parser, - OperationState &state) { - FlatSymbolRefAttr constRefAttr; - Type type; - if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName, - state.attributes) || - parser.parseColonType(type)) { - return failure(); - } - return parser.addTypeToList(type, state.types); -} - -static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) { - printer << spirv::ReferenceOfOp::getOperationName() << ' '; - printer.printSymbolName(referenceOfOp.spec_const()); - printer << " : " << referenceOfOp.reference().getType(); -} - static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { auto specConstOp = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(), @@ -2584,20 +2497,6 @@ // spv.ReturnValue //===----------------------------------------------------------------------===// -static ParseResult parseReturnValueOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType retValInfo; - Type retValType; - return failure(parser.parseOperand(retValInfo) || - parser.parseColonType(retValType) || - parser.resolveOperand(retValInfo, retValType, state.operands)); -} - -static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) { - printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value() - << " : " << retValOp.value().getType(); -} - static LogicalResult verify(spirv::ReturnValueOp retValOp) { auto funcOp = retValOp.getParentOfType(); auto numFnResults = funcOp.getType().getNumResults(); @@ -3033,44 +2932,6 @@ } //===----------------------------------------------------------------------===// -// spv.SubgroupBallotKHROp -//===----------------------------------------------------------------------===// - -static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operandInfo; - Type resultType; - IntegerType i1Type = parser.getBuilder().getI1Type(); - if (parser.parseOperand(operandInfo) || parser.parseColonType(resultType) || - parser.resolveOperand(operandInfo, i1Type, state.operands)) - return failure(); - - return parser.addTypeToList(resultType, state.types); -} - -static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) { - printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ' - << ballotOp.predicate() << " : " << ballotOp.getType(); -} - -//===----------------------------------------------------------------------===// -// spv.Undef -//===----------------------------------------------------------------------===// - -static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &state) { - Type type; - if (parser.parseColonType(type)) { - return failure(); - } - state.addTypes(type); - return success(); -} - -static void print(spirv::UndefOp undefOp, OpAsmPrinter &printer) { - printer << spirv::UndefOp::getOperationName() << " : " << undefOp.getType(); -} - -//===----------------------------------------------------------------------===// // spv.Unreachable //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -202,7 +202,7 @@ spv.module "Logical" "GLSL450" { func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () { // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}} - %0 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32) + %0:2 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32) spv.Return } }