diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -41,8 +41,8 @@ ArrayRef argAttrs, ArrayRef resultAttrs); void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef argAttrs, + ArrayRef resultAttrs); /// Callback type for `parseFunctionOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of @@ -52,26 +52,20 @@ using FuncTypeBuilder = function_ref, ArrayRef, VariadicFlag, std::string &)>; -/// Parses function arguments using `parser`. The `allowVariadic` argument -/// indicates whether functions with variadic arguments are supported. The -/// trailing arguments are populated by this function with names, types, -/// attributes and locations of the arguments. -ParseResult parseFunctionArgumentList( - OpAsmParser &parser, bool allowAttributes, bool allowVariadic, - SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic); - /// Parses a function signature using `parser`. The `allowVariadic` argument /// indicates whether functions with variadic arguments are supported. The /// trailing arguments are populated by this function with names, types, /// attributes and locations of the arguments and those of the results. -ParseResult parseFunctionSignature( - OpAsmParser &parser, bool allowVariadic, - SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs); +ParseResult +parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); + +/// Get a function type corresponding to an array of arguments (which have +/// types) and a set of result types. +Type getFunctionType(Builder &builder, ArrayRef argAttrs, + ArrayRef resultTypes); /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -633,14 +633,14 @@ /// unlike `OpBuilder::getType`, this method does not implicitly insert a /// context parameter. template - T getChecked(SMLoc loc, ParamsT &&... params) { + T getChecked(SMLoc loc, ParamsT &&...params) { return T::getChecked([&] { return emitError(loc); }, std::forward(params)...); } /// A variant of `getChecked` that uses the result of `getNameLoc` to emit /// errors. template - T getChecked(ParamsT &&... params) { + T getChecked(ParamsT &&...params) { return T::getChecked([&] { return emitError(getNameLoc()); }, std::forward(params)...); } @@ -1093,7 +1093,6 @@ SMLoc location; // Location of the token. StringRef name; // Value name, e.g. %42 or %abc unsigned number; // Number, e.g. 12 for an operand like %xyz#12 - Optional sourceLoc; // Source location specifier if present. }; /// Parse different components, viz., use-info of operand(s), successor(s), @@ -1219,34 +1218,64 @@ SmallVectorImpl &symbOperands, AffineExpr &expr) = 0; + //===--------------------------------------------------------------------===// + // Argument Parsing + //===--------------------------------------------------------------------===// + + struct Argument { + UnresolvedOperand ssaName; // SourceLoc, SSA name, result # + Type type; // Type. + DictionaryAttr attrs; // Attributes if present. + Optional sourceLoc; // Source location specifier if present. + }; + + /// Parse a single argument with the following syntax: + /// + /// `%ssaname : !type { optionalAttrDict} loc(optionalSourceLoc)` + /// + /// If `allowType` is false or `allowAttrs` are false then the respective + /// parts of the grammar are not parsed. + virtual ParseResult parseArgument(Argument &result, bool allowType = true, + bool allowAttrs = true) = 0; + + /// Parse a single argument if present. + virtual OptionalParseResult parseOptionalArgument(Argument &result, + bool allowType = true, + bool allowAttrs = true) = 0; + + /// Parse zero or more arguments with a specified surrounding delimiter. + virtual ParseResult parseArgumentList(SmallVectorImpl &result, + Delimiter delimiter = Delimiter::None, + bool allowType = true, + bool allowAttrs = true) = 0; + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// /// Parses a region. Any parsed blocks are appended to 'region' and must be /// moved to the op regions after the op is created. The first block of the - /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is - /// set to true, the argument names are allowed to shadow the names of other - /// existing SSA values defined above the region scope. 'enableNameShadowing' - /// can only be set to true for regions attached to operations that are - /// 'IsolatedFromAbove'. + /// region takes 'arguments'. + /// + /// If 'enableNameShadowing' is set to true, the argument names are allowed to + /// shadow the names of other existing SSA values defined above the region + /// scope. 'enableNameShadowing' can only be set to true for regions attached + /// to operations that are 'IsolatedFromAbove'. virtual ParseResult parseRegion(Region ®ion, - ArrayRef arguments = {}, - ArrayRef argTypes = {}, + ArrayRef arguments = {}, bool enableNameShadowing = false) = 0; /// Parses a region if present. - virtual OptionalParseResult parseOptionalRegion( - Region ®ion, ArrayRef arguments = {}, - ArrayRef argTypes = {}, bool enableNameShadowing = false) = 0; + virtual OptionalParseResult + parseOptionalRegion(Region ®ion, ArrayRef arguments = {}, + bool enableNameShadowing = false) = 0; /// Parses a region if present. If the region is present, a new region is /// allocated and placed in `region`. If no region is present or on failure, /// `region` remains untouched. virtual OptionalParseResult parseOptionalRegion(std::unique_ptr ®ion, - ArrayRef arguments = {}, - ArrayRef argTypes = {}, + ArrayRef arguments = {}, bool enableNameShadowing = false) = 0; //===--------------------------------------------------------------------===// @@ -1269,7 +1298,7 @@ /// Parse a list of assignments of the form /// (%x1 = %y1, %x2 = %y2, ...) - ParseResult parseAssignmentList(SmallVectorImpl &lhs, + ParseResult parseAssignmentList(SmallVectorImpl &lhs, SmallVectorImpl &rhs) { OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); if (!result.hasValue()) @@ -1278,26 +1307,8 @@ } virtual OptionalParseResult - parseOptionalAssignmentList(SmallVectorImpl &lhs, + parseOptionalAssignmentList(SmallVectorImpl &lhs, SmallVectorImpl &rhs) = 0; - - /// Parse a list of assignments of the form - /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...) - ParseResult - parseAssignmentListWithTypes(SmallVectorImpl &lhs, - SmallVectorImpl &rhs, - SmallVectorImpl &types) { - OptionalParseResult result = - parseOptionalAssignmentListWithTypes(lhs, rhs, types); - if (!result.hasValue()) - return emitError(getCurrentLocation(), "expected '('"); - return result.getValue(); - } - - virtual OptionalParseResult - parseOptionalAssignmentListWithTypes(SmallVectorImpl &lhs, - SmallVectorImpl &rhs, - SmallVectorImpl &types) = 0; }; //===--------------------------------------------------------------------===// @@ -1339,7 +1350,6 @@ virtual AliasResult getAlias(Type type, raw_ostream &os) const { return AliasResult::NoAlias; } - }; } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1431,9 +1431,11 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); - OpAsmParser::UnresolvedOperand inductionVariable; + OpAsmParser::Argument inductionVariable; + inductionVariable.type = builder.getIndexType(); // Parse the induction variable followed by '='. - if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) || + if (parser.parseArgument(inductionVariable, + /*allowType=*/false, /*allowAttrs=*/false) || parser.parseEqual()) return failure(); @@ -1463,8 +1465,10 @@ } // Parse the optional initial iteration arguments. - SmallVector regionArgs, operands; - SmallVector argTypes; + SmallVector regionArgs; + SmallVector operands; + + // Induction variable. regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { @@ -1473,23 +1477,24 @@ parser.parseArrowTypeList(result.types)) return failure(); // Resolve input operands. - for (auto operandType : llvm::zip(operands, result.types)) - if (parser.resolveOperand(std::get<0>(operandType), - std::get<1>(operandType), result.operands)) + for (auto argOperandType : llvm::zip( + MutableArrayRef(regionArgs).drop_front(), + operands, result.types)) { + Type type = std::get<2>(argOperandType); + std::get<0>(argOperandType).type = type; + if (parser.resolveOperand(std::get<1>(argOperandType), type, + result.operands)) return failure(); + } } - // Induction variable. - Type indexType = builder.getIndexType(); - argTypes.push_back(indexType); - // Loop carried variables. - argTypes.append(result.types.begin(), result.types.end()); + // Parse the body region. Region *body = result.addRegion(); - if (regionArgs.size() != argTypes.size()) + if (regionArgs.size() != result.types.size() + 1) return parser.emitError( parser.getNameLoc(), "mismatch between the number of loop-carried values and results"); - if (parser.parseRegion(*body, regionArgs, argTypes)) + if (parser.parseRegion(*body, regionArgs)) return failure(); AffineForOp::ensureTerminator(*body, builder, result.location); @@ -1548,7 +1553,8 @@ void AffineForOp::print(OpAsmPrinter &p) { p << ' '; - p.printOperand(getBody()->getArgument(0)); + p.printRegionArgument(getBody()->getArgument(0), /*argAtrs=*/{}, + /*omitType=*/true); p << " = "; printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p); p << " to "; @@ -3527,9 +3533,9 @@ OperationState &result) { auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); - SmallVector ivs; - if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, - /*allowResultNumber=*/false) || + SmallVector ivs; + if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren, + /*allowType=*/false, /*allowAttrs=*/false) || parser.parseEqual() || parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) || parser.parseKeyword("to") || @@ -3600,8 +3606,9 @@ // Now parse the body. Region *body = result.addRegion(); - SmallVector types(ivs.size(), indexType); - if (parser.parseRegion(*body, ivs, types) || + for (auto &iv : ivs) + iv.type = indexType; + if (parser.parseRegion(*body, ivs) || parser.parseOptionalAttrDict(result.attributes)) return failure(); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -178,21 +178,20 @@ // Parse async value operands (%value as %unwrapped : !async.value). SmallVector valueArgs; - SmallVector unwrappedArgs; + SmallVector unwrappedArgs; SmallVector valueTypes; - SmallVector unwrappedTypes; // Parse a single instance of `%value as %unwrapped : !async.value`. auto parseAsyncValueArg = [&]() -> ParseResult { if (parser.parseOperand(valueArgs.emplace_back()) || parser.parseKeyword("as") || - parser.parseOperand(unwrappedArgs.emplace_back()) || + parser.parseArgument(unwrappedArgs.emplace_back(), /*allowType*/ false, + /*allowAttrs*/ false) || parser.parseColonType(valueTypes.emplace_back())) return failure(); auto valueTy = valueTypes.back().dyn_cast(); - unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); - + unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type(); return success(); }; @@ -227,12 +226,7 @@ // Parse asynchronous region. Region *body = result.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, - /*argTypes=*/{unwrappedTypes}, - /*enableNameShadowing=*/false)) - return failure(); - - return success(); + return parser.parseRegion(*body, /*arguments=*/unwrappedArgs); } LogicalResult ExecuteOp::verifyRegions() { diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -622,8 +622,17 @@ Type index = parser.getBuilder().getIndexType(); SmallVector dataTypes( LaunchOp::kNumConfigRegionAttributes, index); + + SmallVector regionArguments; + for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) { + OpAsmParser::Argument arg; + arg.ssaName = std::get<0>(ssaValueAndType); + arg.type = std::get<1>(ssaValueAndType); + regionArguments.push_back(arg); + } + Region *body = result.addRegion(); - if (parser.parseRegion(*body, regionArgs, dataTypes) || + if (parser.parseRegion(*body, regionArguments) || parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -758,11 +767,17 @@ SmallVectorImpl &argTypes) { if (parser.parseOptionalKeyword("args")) return success(); - SmallVector argAttrs; - bool isVariadic = false; - return function_interface_impl::parseFunctionArgumentList( - parser, /*allowAttributes=*/false, - /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic); + + SmallVector args; + if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, + /*allowAttrs*/ false)) + return failure(); + for (auto &arg : args) { + argNames.push_back(arg.ssaName); + argTypes.push_back(arg.type); + }; + return success(); } static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, @@ -852,32 +867,14 @@ /// keyword provided as argument. static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, - SmallVectorImpl &args, - SmallVectorImpl &argTypes) { + SmallVectorImpl &args) { // If we could not parse the keyword, just assume empty list and succeed. if (failed(parser.parseOptionalKeyword(keyword))) return success(); - if (failed(parser.parseLParen())) - return failure(); - - // Early exit for an empty list. - if (succeeded(parser.parseOptionalRParen())) - return success(); - - do { - OpAsmParser::UnresolvedOperand arg; - Type type; - - if (parser.parseOperand(arg, /*allowResultNumber=*/false) || - parser.parseColonType(type)) - return failure(); - - args.push_back(arg); - argTypes.push_back(type); - } while (succeeded(parser.parseOptionalComma())); - - return parser.parseRParen(); + return parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, + /*allowAttrs=*/false); } /// Parses a GPU function. @@ -886,10 +883,8 @@ /// (`->` function-result-list)? memory-attribution `kernel`? /// function-attributes? region ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector entryArgs; - SmallVector argAttrs; - SmallVector resultAttrs; - SmallVector argTypes; + SmallVector entryArgs; + SmallVector resultAttrs; SmallVector resultTypes; bool isVariadic; @@ -901,34 +896,41 @@ auto signatureLocation = parser.getCurrentLocation(); if (failed(function_interface_impl::parseFunctionSignature( - parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs))) + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs))) return failure(); - if (entryArgs.empty() && !argTypes.empty()) + if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty()) return parser.emitError(signatureLocation) << "gpu.func requires named arguments"; // Construct the function type. More types will be added to the region, but // not to the function type. Builder &builder = parser.getBuilder(); + + SmallVector argTypes; + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); auto type = builder.getFunctionType(argTypes, resultTypes); result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); + function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, + resultAttrs); + // Parse workgroup memory attributions. if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), - entryArgs, argTypes))) + entryArgs))) return failure(); // Store the number of operands we just parsed as the number of workgroup // memory attributions. - unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); + unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs(); result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), builder.getI64IntegerAttr(numWorkgroupAttrs)); // Parse private memory attributions. - if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), - entryArgs, argTypes))) + if (failed( + parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), entryArgs))) return failure(); // Parse the kernel attribute if present. @@ -939,13 +941,11 @@ // Parse attributes. if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, - resultAttrs); // Parse the region. If no argument names were provided, take all names // (including those of attributions) from the entry block. auto *body = result.addRegion(); - return parser.parseRegion(*body, entryArgs, argTypes); + return parser.parseRegion(*body, entryArgs); } static void printAttributions(OpAsmPrinter &p, StringRef keyword, @@ -1078,16 +1078,14 @@ ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) { StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), - result.attributes)) - return failure(); - - // If module attributes are present, parse them. - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + result.attributes) || + // If module attributes are present, parse them. + parser.parseOptionalAttrDictWithKeyword(result.attributes)) return failure(); // Parse the module body. auto *body = result.addRegion(); - if (parser.parseRegion(*body, None, None)) + if (parser.parseRegion(*body, {})) return failure(); // Ensure that this module has a valid terminator. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2152,10 +2152,8 @@ parser, result, LLVM::Linkage::External))); StringAttr nameAttr; - SmallVector entryArgs; - SmallVector argAttrs; - SmallVector resultAttrs; - SmallVector argTypes; + SmallVector entryArgs; + SmallVector resultAttrs; SmallVector resultTypes; bool isVariadic; @@ -2163,10 +2161,13 @@ if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes) || function_interface_impl::parseFunctionSignature( - parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs)) + parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, + resultAttrs)) return failure(); + SmallVector argTypes; + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); auto type = buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, function_interface_impl::VariadicFlag(isVariadic)); @@ -2178,11 +2179,11 @@ if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, - argAttrs, resultAttrs); + entryArgs, resultAttrs); auto *body = result.addRegion(); - OptionalParseResult parseResult = parser.parseOptionalRegion( - *body, entryArgs, entryArgs.empty() ? ArrayRef() : argTypes); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs); return failure(parseResult.hasValue() && failed(*parseResult)); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -799,10 +799,8 @@ failed(parser.parseOptionalAttrDict(result.attributes))) return failure(); - SmallVector regionOperands; std::unique_ptr region = std::make_unique(); - SmallVector operandTypes, regionTypes; - if (parser.parseRegion(*region, regionOperands, regionTypes)) + if (parser.parseRegion(*region, {})) return failure(); result.addRegion(std::move(region)); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -275,7 +275,7 @@ return failure(); // Parse the body region. - if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) + if (parser.parseRegion(*bodyRegion, /*arguments=*/{})) return failure(); AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), result.location); @@ -1215,7 +1215,7 @@ return failure(); Region *body = result.addRegion(); - if (parser.parseRegion(*body, llvm::None, llvm::None) || + if (parser.parseRegion(*body, {}) || parser.parseOptionalAttrDict(result.attributes)) return failure(); result.types.push_back(memrefType.cast().getElementType()); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -523,20 +523,17 @@ SmallVectorImpl &steps, SmallVectorImpl &loopVarTypes, UnitAttr &inclusive) { // Parse an opening `(` followed by induction variables followed by `)` - SmallVector ivs; - if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, - /*allowResultNumber=*/false)) - return failure(); - - size_t numIVs = ivs.size(); + SmallVector ivs; Type loopVarType; - if (parser.parseColonType(loopVarType) || + if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren, + /*allowTypes=*/false, /*allowAttrs=*/false) || + parser.parseColonType(loopVarType) || // Parse loop bounds. parser.parseEqual() || - parser.parseOperandList(lowerBound, numIVs, + parser.parseOperandList(lowerBound, ivs.size(), OpAsmParser::Delimiter::Paren) || parser.parseKeyword("to") || - parser.parseOperandList(upperBound, numIVs, + parser.parseOperandList(upperBound, ivs.size(), OpAsmParser::Delimiter::Paren)) return failure(); @@ -545,15 +542,14 @@ // Parse step values. if (parser.parseKeyword("step") || - parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren)) + parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren)) return failure(); // Now parse the body. - loopVarTypes = SmallVector(numIVs, loopVarType); - SmallVector blockArgs(ivs); - if (parser.parseRegion(region, blockArgs, loopVarTypes)) - return failure(); - return success(); + loopVarTypes = SmallVector(ivs.size(), loopVarType); + for (auto &iv : ivs) + iv.type = loopVarType; + return parser.parseRegion(region, ivs); } void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, @@ -582,33 +578,29 @@ /// clause ::= TODO ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) { // Parse an opening `(` followed by induction variables followed by `)` - SmallVector ivs; - if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, - /*allowResultNumber=*/false)) - return failure(); - int numIVs = static_cast(ivs.size()); + SmallVector ivs; Type loopVarType; - if (parser.parseColonType(loopVarType)) - return failure(); - // Parse loop bounds. - SmallVector lower; - if (parser.parseEqual() || - parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(lower, loopVarType, result.operands)) - return failure(); - SmallVector upper; - if (parser.parseKeyword("to") || - parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(upper, loopVarType, result.operands)) - return failure(); - - // Parse step values. - SmallVector steps; - if (parser.parseKeyword("step") || - parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || + SmallVector lower, upper, steps; + if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren, + /*allowType=*/false, /*allowAttrs=*/false) || + parser.parseColonType(loopVarType) || + // Parse loop bounds. + parser.parseEqual() || + parser.parseOperandList(lower, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(lower, loopVarType, result.operands) || + parser.parseKeyword("to") || + parser.parseOperandList(upper, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(upper, loopVarType, result.operands) || + // Parse step values. + parser.parseKeyword("step") || + parser.parseOperandList(steps, ivs.size(), + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(steps, loopVarType, result.operands)) return failure(); + int numIVs = static_cast(ivs.size()); SmallVector segments{numIVs, numIVs, numIVs}; // TODO: Add parseClauses() when we support clauses result.addAttribute("operand_segment_sizes", @@ -616,11 +608,9 @@ // Now parse the body. Region *body = result.addRegion(); - SmallVector ivTypes(numIVs, loopVarType); - SmallVector blockArgs(ivs); - if (parser.parseRegion(*body, blockArgs, ivTypes)) - return failure(); - return success(); + for (auto &iv : ivs) + iv.type = loopVarType; + return parser.parseRegion(*body, ivs); } void SimdLoopOp::print(OpAsmPrinter &p) { diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -101,10 +101,9 @@ ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the loop variable followed by type. - OpAsmParser::UnresolvedOperand loopVariable; - Type loopVariableType; - if (parser.parseOperand(loopVariable, /*allowResultNumber=*/false) || - parser.parseColonType(loopVariableType)) + OpAsmParser::Argument loopVariable; + if (parser.parseArgument(loopVariable, /*allowType=*/true, + /*allowAttrs=*/false)) return failure(); // Parse the "in" keyword. @@ -117,13 +116,13 @@ return failure(); // Resolve the operand. - Type rangeType = pdl::RangeType::get(loopVariableType); + Type rangeType = pdl::RangeType::get(loopVariable.type); if (parser.resolveOperand(operandInfo, rangeType, result.operands)) return failure(); // Parse the body region. Region *body = result.addRegion(); - if (parser.parseRegion(*body, {loopVariable}, {loopVariableType})) + if (parser.parseRegion(*body, loopVariable)) return failure(); // Parse the attribute dictionary. diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -399,15 +399,17 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); - OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step; - // Parse the induction variable followed by '='. - if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) || - parser.parseEqual()) - return failure(); - - // Parse loop bounds. Type indexType = builder.getIndexType(); - if (parser.parseOperand(lb) || + + OpAsmParser::Argument inductionVariable; + inductionVariable.type = indexType; + OpAsmParser::UnresolvedOperand lb, ub, step; + // Parse the induction variable followed by '='. + if (parser.parseArgument(inductionVariable, /*allowType=*/false, + /*allowAttrs=*/false) || + parser.parseEqual() || + // Parse loop bounds. + parser.parseOperand(lb) || parser.resolveOperand(lb, indexType, result.operands) || parser.parseKeyword("to") || parser.parseOperand(ub) || parser.resolveOperand(ub, indexType, result.operands) || @@ -416,8 +418,8 @@ return failure(); // Parse the optional initial iteration arguments. - SmallVector regionArgs, operands; - SmallVector argTypes; + SmallVector regionArgs; + SmallVector operands; regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { @@ -425,24 +427,27 @@ if (parser.parseAssignmentList(regionArgs, operands) || parser.parseArrowTypeList(result.types)) return failure(); + // Resolve input operands. - for (auto operandType : llvm::zip(operands, result.types)) - if (parser.resolveOperand(std::get<0>(operandType), - std::get<1>(operandType), result.operands)) + for (auto argOperandType : llvm::zip( + MutableArrayRef(regionArgs).drop_front(), + operands, result.types)) { + Type type = std::get<2>(argOperandType); + std::get<0>(argOperandType).type = type; + if (parser.resolveOperand(std::get<1>(argOperandType), type, + result.operands)) return failure(); + } } - // Induction variable. - argTypes.push_back(indexType); - // Loop carried variables - argTypes.append(result.types.begin(), result.types.end()); - // Parse the body region. - Region *body = result.addRegion(); - if (regionArgs.size() != argTypes.size()) + + if (regionArgs.size() != result.types.size() + 1) return parser.emitError( parser.getNameLoc(), "mismatch in number of loop-carried values and defined values"); - if (parser.parseRegion(*body, regionArgs, argTypes)) + // Parse the body region. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) return failure(); ForOp::ensureTerminator(*body, builder, result.location); @@ -1975,9 +1980,9 @@ ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` - SmallVector ivs; - if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, - /*allowResultNumber=*/false)) + SmallVector ivs; + if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren, + /*allowTypes=*/false, /*allowAttrs=*/false)) return failure(); // Parse loop bounds. @@ -2016,8 +2021,9 @@ // Now parse the body. Region *body = result.addRegion(); - SmallVector types(ivs.size(), builder.getIndexType()); - if (parser.parseRegion(*body, ivs, types)) + for (auto &iv : ivs) + iv.type = builder.getIndexType(); + if (parser.parseRegion(*body, ivs)) return failure(); // Set `operand_segment_sizes` attribute. @@ -2370,7 +2376,8 @@ /// assignment-list ::= assignment | assignment `,` assignment-list /// assignment ::= ssa-value `=` ssa-value ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector regionArgs, operands; + SmallVector regionArgs; + SmallVector operands; Region *before = result.addRegion(); Region *after = result.addRegion(); @@ -2399,10 +2406,13 @@ result.operands))) return failure(); - return failure( - parser.parseRegion(*before, regionArgs, functionType.getInputs()) || - parser.parseKeyword("do") || parser.parseRegion(*after) || - parser.parseOptionalAttrDictWithKeyword(result.attributes)); + // Propagate the types into the region arguments. + for (size_t i = 0, e = regionArgs.size(); i != e; ++i) + regionArgs[i].type = functionType.getInput(i); + + return failure(parser.parseRegion(*before, regionArgs) || + parser.parseKeyword("do") || parser.parseRegion(*after) || + parser.parseOptionalAttrDictWithKeyword(result.attributes)); } /// Prints a `while` op. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2193,10 +2193,8 @@ //===----------------------------------------------------------------------===// ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) { - SmallVector entryArgs; - SmallVector argAttrs; - SmallVector resultAttrs; - SmallVector argTypes; + SmallVector entryArgs; + SmallVector resultAttrs; SmallVector resultTypes; auto &builder = parser.getBuilder(); @@ -2209,10 +2207,13 @@ // Parse the function signature. bool isVariadic = false; if (function_interface_impl::parseFunctionSignature( - parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs)) + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs)) return failure(); + SmallVector argTypes; + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); auto fnType = builder.getFunctionType(argTypes, resultTypes); state.addAttribute(FunctionOpInterface::getTypeAttrName(), TypeAttr::get(fnType)); @@ -2227,15 +2228,13 @@ return failure(); // Add the attributes to the function arguments. - assert(argAttrs.size() == argTypes.size()); assert(resultAttrs.size() == resultTypes.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, + function_interface_impl::addArgAndResultAttrs(builder, state, entryArgs, resultAttrs); // Parse the optional function body. auto *body = state.addRegion(); - OptionalParseResult result = parser.parseOptionalRegion( - *body, entryArgs, entryArgs.empty() ? ArrayRef() : argTypes); + OptionalParseResult result = parser.parseOptionalRegion(*body, entryArgs); return failure(result.hasValue() && failed(*result)); } diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -13,83 +13,60 @@ using namespace mlir; -ParseResult mlir::function_interface_impl::parseFunctionArgumentList( +static ParseResult parseFunctionArgumentList( OpAsmParser &parser, bool allowAttributes, bool allowVariadic, - SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic) { - if (parser.parseLParen()) - return failure(); - - // The argument list either has to consistently have ssa-id's followed by - // types, or just be a type list. It isn't ok to sometimes have SSA ID's and - // sometimes not. - auto parseArgument = [&]() -> ParseResult { - SMLoc loc = parser.getCurrentLocation(); - - // Parse argument name if present. - OpAsmParser::UnresolvedOperand argument; - Type argumentType; - auto hadSSAValue = parser.parseOptionalOperand(argument, - /*allowResultNumber=*/false); - if (hadSSAValue.hasValue()) { - if (failed(hadSSAValue.getValue())) - return failure(); // Argument was present but malformed. - - // Reject this if the preceding argument was missing a name. - if (argNames.empty() && !argTypes.empty()) - return parser.emitError(loc, "expected type instead of SSA identifier"); - - // Parse required type. - if (parser.parseColonType(argumentType)) - return failure(); - } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { - isVariadic = true; - return success(); - } else if (!argNames.empty()) { - // Reject this if the preceding argument had a name. - return parser.emitError(loc, "expected SSA identifier"); - } else if (parser.parseType(argumentType)) { - return failure(); - } - - // Add the argument type. - argTypes.push_back(argumentType); - - // Parse any argument attributes and source location information. - NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs) || - parser.parseOptionalLocationSpecifier(argument.sourceLoc)) - return failure(); - - if (!allowAttributes && !attrs.empty()) - return parser.emitError(loc, "expected arguments without attributes"); - argAttrs.push_back(attrs); - - // If we had an argument name, then remember the parsed argument. - if (!argument.name.empty()) - argNames.push_back(argument); - return success(); - }; + SmallVectorImpl &arguments, bool &isVariadic) { - // Parse the function arguments. + // Parse the function arguments. The argument list either has to consistently + // have ssa-id's followed by types, or just be a type list. It isn't ok to + // sometimes have SSA ID's and sometimes not. isVariadic = false; - if (failed(parser.parseOptionalRParen())) { - do { - unsigned numTypedArguments = argTypes.size(); - if (parseArgument()) - return failure(); - - SMLoc loc = parser.getCurrentLocation(); - if (argTypes.size() == numTypedArguments && - succeeded(parser.parseOptionalComma())) - return parser.emitError( - loc, "variadic arguments must be in the end of the argument list"); - } while (succeeded(parser.parseOptionalComma())); - parser.parseRParen(); - } - return success(); + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + // Ellipsis must be at end of the list. + if (isVariadic) + return parser.emitError( + parser.getCurrentLocation(), + "variadic arguments must be in the end of the argument list"); + + // Handle ellipsis as a special case. + if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { + // This is a variadic designator. + isVariadic = true; + return success(); // Stop parsing arguments. + } + // Parse argument name if present. + OpAsmParser::Argument argument; + auto argPresent = parser.parseOptionalArgument( + argument, /*allowType=*/true, /*allowAttrs*/ allowAttributes); + if (argPresent.hasValue()) { + if (failed(argPresent.getValue())) + return failure(); // Present but malformed. + + // Reject this if the preceding argument was missing a name. + if (!arguments.empty() && arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected type instead of SSA identifier"); + + } else { + argument.ssaName.location = parser.getCurrentLocation(); + // Otherwise we just have a type list without SSA names. Reject + // this if the preceding argument had a name. + if (!arguments.empty() && !arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected SSA identifier"); + + NamedAttrList attrs; + if (parser.parseType(argument.type) || + parser.parseOptionalAttrDict(attrs) || + parser.parseOptionalLocationSpecifier(argument.sourceLoc)) + return failure(); + argument.attrs = attrs.getDictionary(parser.getContext()); + } + arguments.push_back(argument); + return success(); + }); } /// Parse a function result list. @@ -103,7 +80,7 @@ /// static ParseResult parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs) { + SmallVectorImpl &resultAttrs) { if (failed(parser.parseOptionalLParen())) { // We already know that there is no `(`, so parse a type. // Because there is no `(`, it cannot be a function type. @@ -120,83 +97,76 @@ return success(); // Parse individual function results. - do { - resultTypes.emplace_back(); - resultAttrs.emplace_back(); - if (parser.parseType(resultTypes.back()) || - parser.parseOptionalAttrDict(resultAttrs.back())) { - return failure(); - } - } while (succeeded(parser.parseOptionalComma())); + if (parser.parseCommaSeparatedList([&]() -> ParseResult { + resultTypes.emplace_back(); + resultAttrs.emplace_back(); + NamedAttrList attrs; + if (parser.parseType(resultTypes.back()) || + parser.parseOptionalAttrDict(attrs)) + return failure(); + resultAttrs.back() = attrs.getDictionary(parser.getContext()); + return success(); + })) + return failure(); + return parser.parseRParen(); } ParseResult mlir::function_interface_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, - SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs) { + SmallVectorImpl &arguments, bool &isVariadic, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { bool allowArgAttrs = true; - if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames, - argTypes, argAttrs, isVariadic)) + if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, arguments, + isVariadic)) return failure(); if (succeeded(parser.parseOptionalArrow())) return parseFunctionResultList(parser, resultTypes, resultAttrs); return success(); } -/// Implementation of `addArgAndResultAttrs` that is attribute list type -/// agnostic. -template -static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs, - AttrArrayBuildFnT &&buildAttrArrayFn) { - auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); }; +void mlir::function_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs) { + auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { + return attrs && !attrs.empty(); + }; + // Convert the specified array of dictionary attrs (which may have null + // entries) to an ArrayAttr of dictionaries. + auto getArrayAttr = [&](ArrayRef dictAttrs) { + SmallVector attrs; + for (auto &dict : dictAttrs) + attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); + return builder.getArrayAttr(attrs); + }; // Add the attributes to the function arguments. - if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) { - ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs)); + if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) result.addAttribute(function_interface_impl::getArgDictAttrName(), - attrDicts); - } + getArrayAttr(argAttrs)); + // Add the attributes to the function results. - if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) { - ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs)); + if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) result.addAttribute(function_interface_impl::getResultDictAttrName(), - attrDicts); - } + getArrayAttr(resultAttrs)); } void mlir::function_interface_impl::addArgAndResultAttrs( - Builder &builder, OperationState &result, ArrayRef argAttrs, + Builder &builder, OperationState &result, + ArrayRef args, ArrayRef resultAttrs) { - auto buildFn = [](ArrayRef attrs) { - return ArrayRef(attrs.data(), attrs.size()); - }; - addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); -} -void mlir::function_interface_impl::addArgAndResultAttrs( - Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs) { - MLIRContext *context = builder.getContext(); - auto buildFn = [=](ArrayRef attrs) { - return llvm::to_vector<8>( - llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute { - return attrList.getDictionary(context); - })); - }; - addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); + SmallVector argAttrs; + for (const auto &arg : args) + argAttrs.push_back(arg.attrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); } ParseResult mlir::function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, FuncTypeBuilder funcTypeBuilder) { - SmallVector entryArgs; - SmallVector argAttrs; - SmallVector resultAttrs; - SmallVector argTypes; + SmallVector entryArgs; + SmallVector resultAttrs; SmallVector resultTypes; auto &builder = parser.getBuilder(); @@ -212,11 +182,15 @@ // Parse the function signature. SMLoc signatureLocation = parser.getCurrentLocation(); bool isVariadic = false; - if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, - argAttrs, isVariadic, resultTypes, resultAttrs)) + if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, + resultTypes, resultAttrs)) return failure(); std::string errorMessage; + SmallVector argTypes; + argTypes.reserve(entryArgs.size()); + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); Type type = funcTypeBuilder(builder, argTypes, resultTypes, VariadicFlag(isVariadic), errorMessage); if (!type) { @@ -246,17 +220,16 @@ result.attributes.append(parsedAttributes); // Add the attributes to the function arguments. - assert(argAttrs.size() == argTypes.size()); assert(resultAttrs.size() == resultTypes.size()); - addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); + addArgAndResultAttrs(builder, result, entryArgs, resultAttrs); // Parse the optional function body. The printer will not print the body if // its empty, so disallow parsing of empty body in the parser. auto *body = result.addRegion(); SMLoc loc = parser.getCurrentLocation(); - OptionalParseResult parseResult = parser.parseOptionalRegion( - *body, entryArgs, entryArgs.empty() ? ArrayRef() : argTypes, - /*enableNameShadowing=*/false); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs, + /*enableNameShadowing=*/false); if (parseResult.hasValue()) { if (failed(*parseResult)) return failure(); diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -301,11 +301,8 @@ return success(); }; - if (parseCommaSeparatedList(Delimiter::Braces, parseElt, - " in attribute dictionary")) - return failure(); - - return success(); + return parseCommaSeparatedList(Delimiter::Braces, parseElt, + " in attribute dictionary"); } /// Parse a float attribute. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -249,6 +249,7 @@ //===--------------------------------------------------------------------===// using UnresolvedOperand = OpAsmParser::UnresolvedOperand; + using Argument = OpAsmParser::Argument; struct DeferredLocInfo { SMLoc loc; @@ -364,16 +365,13 @@ /// Parse a region into 'region' with the provided entry block arguments. /// 'isIsolatedNameScope' indicates if the naming scope of this region is /// isolated from those above. - ParseResult - parseRegion(Region ®ion, - ArrayRef> entryArguments, - bool isIsolatedNameScope = false); + ParseResult parseRegion(Region ®ion, ArrayRef entryArguments, + bool isIsolatedNameScope = false); /// Parse a region body into 'region'. - ParseResult - parseRegionBody(Region ®ion, SMLoc startLoc, - ArrayRef> entryArguments, - bool isIsolatedNameScope); + ParseResult parseRegionBody(Region ®ion, SMLoc startLoc, + ArrayRef entryArguments, + bool isIsolatedNameScope); //===--------------------------------------------------------------------===// // Block Parsing @@ -947,7 +945,7 @@ unsigned opResI = 0; for (ResultRecord &resIt : resultIDs) { for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { - if (addDefinition({std::get<2>(resIt), std::get<0>(resIt), subRes, {}}, + if (addDefinition({std::get<2>(resIt), std::get<0>(resIt), subRes}, op->getResult(opResI++))) return failure(); } @@ -1279,10 +1277,8 @@ if (parser.parseSSAUse(useInfo, allowResultNumber)) return failure(); - result = {useInfo.location, useInfo.name, useInfo.number, {}}; - - // Parse a source locator on the operand if present. - return parseOptionalLocationSpecifier(result.sourceLoc); + result = {useInfo.location, useInfo.name, useInfo.number}; + return success(); } /// Parse a single operand if present. @@ -1321,11 +1317,7 @@ } auto parseOneOperand = [&]() -> ParseResult { - UnresolvedOperand operandOrArg; - if (parseOperand(operandOrArg, allowResultNumber)) - return failure(); - result.push_back(operandOrArg); - return success(); + return parseOperand(result.emplace_back(), allowResultNumber); }; if (parseCommaSeparatedList(delimiter, parseOneOperand, " in operand list")) @@ -1402,52 +1394,89 @@ return parser.parseAffineExprOfSSAIds(expr, parseElement); } + //===--------------------------------------------------------------------===// + // Argument Parsing + //===--------------------------------------------------------------------===// + + /// Parse a single argument with the following syntax: + /// + /// `%ssaname : !type { optionalAttrDict} loc(optionalSourceLoc)` + /// + /// If `allowType` is false or `allowAttrs` are false then the respective + /// parts of the grammar are not parsed. + ParseResult parseArgument(Argument &result, bool allowType, + bool allowAttrs) override { + NamedAttrList attrs; + if (parseOperand(result.ssaName, /*allowResultNumber=*/false) || + (allowType && parseColonType(result.type)) || + (allowAttrs && parseOptionalAttrDict(attrs)) || + parseOptionalLocationSpecifier(result.sourceLoc)) + return failure(); + result.attrs = attrs.getDictionary(getContext()); + return success(); + } + + /// Parse a single argument if present. + OptionalParseResult parseOptionalArgument(Argument &result, bool allowType, + bool allowAttrs) override { + if (parser.getToken().is(Token::percent_identifier)) + return parseArgument(result, allowType, allowAttrs); + return llvm::None; + } + + ParseResult parseArgumentList(SmallVectorImpl &result, + Delimiter delimiter = Delimiter::None, + bool allowType = true, + bool allowAttrs = true) override { + // The no-delimiter case has some special handling for the empty case. + if (delimiter == Delimiter::None && + parser.getToken().isNot(Token::percent_identifier)) + return success(); + + auto parseOneArgument = [&]() -> ParseResult { + return parseArgument(result.emplace_back(), allowType, allowAttrs); + }; + return parseCommaSeparatedList(delimiter, parseOneArgument, + " in argument list"); + } + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// /// Parse a region that takes `arguments` of `argTypes` types. This /// effectively defines the SSA values of `arguments` and assigns their type. - ParseResult parseRegion(Region ®ion, ArrayRef arguments, - ArrayRef argTypes, + ParseResult parseRegion(Region ®ion, ArrayRef arguments, bool enableNameShadowing) override { - assert(arguments.size() == argTypes.size() && - "mismatching number of arguments and types"); - - SmallVector, 2> - regionArguments; - for (auto pair : llvm::zip(arguments, argTypes)) - regionArguments.emplace_back(std::get<0>(pair), std::get<1>(pair)); - // Try to parse the region. (void)isIsolatedFromAbove; assert((!enableNameShadowing || isIsolatedFromAbove) && "name shadowing is only allowed on isolated regions"); - if (parser.parseRegion(region, regionArguments, enableNameShadowing)) + if (parser.parseRegion(region, arguments, enableNameShadowing)) return failure(); return success(); } /// Parses a region if present. OptionalParseResult parseOptionalRegion(Region ®ion, - ArrayRef arguments, - ArrayRef argTypes, + ArrayRef arguments, bool enableNameShadowing) override { if (parser.getToken().isNot(Token::l_brace)) return llvm::None; - return parseRegion(region, arguments, argTypes, enableNameShadowing); + return parseRegion(region, arguments, enableNameShadowing); } /// Parses a region if present. If the region is present, a new region is /// allocated and placed in `region`. If no region is present, `region` /// remains untouched. - OptionalParseResult parseOptionalRegion( - std::unique_ptr ®ion, ArrayRef arguments, - ArrayRef argTypes, bool enableNameShadowing = false) override { + OptionalParseResult + parseOptionalRegion(std::unique_ptr ®ion, + ArrayRef arguments, + bool enableNameShadowing = false) override { if (parser.getToken().isNot(Token::l_brace)) return llvm::None; std::unique_ptr newRegion = std::make_unique(); - if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing)) + if (parseRegion(*newRegion, arguments, enableNameShadowing)) return failure(); region = std::move(newRegion); @@ -1492,42 +1521,16 @@ /// Parse a list of assignments of the form /// (%x1 = %y1, %x2 = %y2, ...). OptionalParseResult parseOptionalAssignmentList( - SmallVectorImpl &lhs, + SmallVectorImpl &lhs, SmallVectorImpl &rhs) override { if (failed(parseOptionalLParen())) return llvm::None; auto parseElt = [&]() -> ParseResult { - UnresolvedOperand regionArg, operand; - if (parseOperand(regionArg, /*allowResultNumber=*/false) || - parseEqual() || parseOperand(operand)) - return failure(); - lhs.push_back(regionArg); - rhs.push_back(operand); - return success(); - }; - return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); - } - - /// Parse a list of assignments of the form - /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). - OptionalParseResult - parseOptionalAssignmentListWithTypes(SmallVectorImpl &lhs, - SmallVectorImpl &rhs, - SmallVectorImpl &types) override { - if (failed(parseOptionalLParen())) - return llvm::None; - - auto parseElt = [&]() -> ParseResult { - UnresolvedOperand regionArg, operand; - Type type; - if (parseOperand(regionArg, /*allowResultNumber=*/false) || - parseEqual() || parseOperand(operand) || parseColon() || - parseType(type)) + if (parseArgument(lhs.emplace_back(), /*allowType=*/false, + /*allowAttrs=*/false) || + parseEqual() || parseOperand(rhs.emplace_back())) return failure(); - lhs.push_back(regionArg); - rhs.push_back(operand); - types.push_back(type); return success(); }; return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); @@ -1749,11 +1752,9 @@ // Region Parsing //===----------------------------------------------------------------------===// -ParseResult OperationParser::parseRegion( - Region ®ion, - ArrayRef> - entryArguments, - bool isIsolatedNameScope) { +ParseResult OperationParser::parseRegion(Region ®ion, + ArrayRef entryArguments, + bool isIsolatedNameScope) { // Parse the '{'. Token lBraceTok = getToken(); if (parseToken(Token::l_brace, "expected '{' to begin a region")) @@ -1778,11 +1779,9 @@ return success(); } -ParseResult OperationParser::parseRegionBody( - Region ®ion, SMLoc startLoc, - ArrayRef> - entryArguments, - bool isIsolatedNameScope) { +ParseResult OperationParser::parseRegionBody(Region ®ion, SMLoc startLoc, + ArrayRef entryArguments, + bool isIsolatedNameScope) { auto currentPt = opBuilder.saveInsertionPoint(); // Push a new named value scope. @@ -1798,14 +1797,14 @@ if (state.asmState && getToken().isNot(Token::caret_identifier)) state.asmState->addDefinition(block, startLoc); - // Add arguments to the entry block. - if (!entryArguments.empty()) { + // Add arguments to the entry block if we had the form with explicit names. + if (!entryArguments.empty() && !entryArguments[0].ssaName.name.empty()) { // If we had named arguments, then don't allow a block name. if (getToken().is(Token::caret_identifier)) return emitError("invalid block name in region with named arguments"); - for (auto &placeholderArgPair : entryArguments) { - auto &argInfo = placeholderArgPair.first; + for (auto &entryArg : entryArguments) { + auto &argInfo = entryArg.ssaName; // Ensure that the argument was not already defined. if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { @@ -1815,10 +1814,10 @@ .attachNote(getEncodedSourceLocation(*defLoc)) << "previously referenced here"; } - Location loc = argInfo.sourceLoc.hasValue() - ? argInfo.sourceLoc.getValue() + Location loc = entryArg.sourceLoc.hasValue() + ? entryArg.sourceLoc.getValue() : getEncodedSourceLocation(argInfo.location); - BlockArgument arg = block->addArgument(placeholderArgPair.second, loc); + BlockArgument arg = block->addArgument(entryArg.type, loc); // Add a definition of this arg to the assembly state if provided. if (state.asmState) diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -202,7 +202,7 @@ module attributes {gpu.container_module} { func.func @launch_func_kernel_operand_attr(%sz : index) { - // expected-error@+1 {{expected arguments without attributes}} + // expected-error@+1 {{expected ')' in argument list}} gpu.launch_func @foo::@bar blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) args(%sz : index {foo}) return } diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -13,8 +13,9 @@ // CHECK: arith.constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) %2 = arith.constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) + // CHECK: affine.for %arg0 loc("IVlocation") = 0 to 8 { // CHECK: } loc(fused["foo", "mysource.cc":10:8]) - affine.for %i0 = 0 to 8 { + affine.for %i0 loc("IVlocation") = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(fused<"myPass">["foo", "foo2"]) diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -691,18 +691,16 @@ ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand argInfo; - Type argType = parser.getBuilder().getIndexType(); - // Parse the input operand. - if (parser.parseOperand(argInfo) || - parser.resolveOperand(argInfo, argType, result.operands)) + OpAsmParser::Argument argInfo; + argInfo.type = parser.getBuilder().getIndexType(); + if (parser.parseOperand(argInfo.ssaName) || + parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) return failure(); // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); - return parser.parseRegion(*body, argInfo, argType, - /*enableNameShadowing=*/true); + return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); } void IsolatedRegionOp::print(OpAsmPrinter &p) { @@ -930,17 +928,17 @@ //===----------------------------------------------------------------------===// ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector ivsInfo; + SmallVector ivsInfo; // Parse list of region arguments without a delimiter. - if (parser.parseOperandList(ivsInfo, OpAsmParser::Delimiter::None, - /*allowResultNumber=*/false)) + if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None, + /*allowTypes=*/false, /*allowAttrs=*/false)) return failure(); // Parse the body region. Region *body = result.addRegion(); - auto &builder = parser.getBuilder(); - SmallVector argTypes(ivsInfo.size(), builder.getIndexType()); - return parser.parseRegion(*body, ivsInfo, argTypes); + for (auto &iv : ivsInfo) + iv.type = parser.getBuilder().getIndexType(); + return parser.parseRegion(*body, ivsInfo); } void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }