diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1563,7 +1563,8 @@ mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step; - if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || + if (parser.parseLParen() || + parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) || parser.parseEqual()) return mlir::failure(); @@ -1581,8 +1582,9 @@ mlir::OpAsmParser::UnresolvedOperand iterateVar, iterateInput; if (parser.parseKeyword("and") || parser.parseLParen() || - parser.parseRegionArgument(iterateVar) || parser.parseEqual() || - parser.parseOperand(iterateInput) || parser.parseRParen() || + parser.parseOperand(iterateVar, /*allowResultNumber=*/false) || + parser.parseEqual() || parser.parseOperand(iterateInput) || + parser.parseRParen() || parser.resolveOperand(iterateInput, i1Type, result.operands)) return mlir::failure(); @@ -1876,7 +1878,8 @@ auto &builder = parser.getBuilder(); mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step; // Parse the induction variable followed by '='. - if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) + if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) || + parser.parseEqual()) return mlir::failure(); // Parse loop bounds. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -584,7 +584,7 @@ } /// These are the supported delimiters around operand lists and region - /// argument lists, used by parseOperandList and parseRegionArgumentList. + /// argument lists, used by parseOperandList. enum class Delimiter { /// Zero or more operands with no delimiters. None, @@ -1110,22 +1110,27 @@ Optional> parsedAttributes = llvm::None, Optional parsedFnType = llvm::None) = 0; - /// Parse a single operand. - virtual ParseResult parseOperand(UnresolvedOperand &result) = 0; + /// Parse a single SSA value operand name along with a result number if + /// `allowResultNumber` is true. + virtual ParseResult parseOperand(UnresolvedOperand &result, + bool allowResultNumber = true) = 0; /// Parse a single operand if present. virtual OptionalParseResult - parseOptionalOperand(UnresolvedOperand &result) = 0; + parseOptionalOperand(UnresolvedOperand &result, + bool allowResultNumber = true) = 0; /// Parse zero or more SSA comma-separated operand references with a specified /// surrounding delimiter, and an optional required operand count. - virtual ParseResult - parseOperandList(SmallVectorImpl &result, - int requiredOperandCount = -1, - Delimiter delimiter = Delimiter::None) = 0; + virtual ParseResult parseOperandList( + SmallVectorImpl &result, int requiredOperandCount = -1, + Delimiter delimiter = Delimiter::None, bool allowResultNumber = true) = 0; + ParseResult parseOperandList(SmallVectorImpl &result, - Delimiter delimiter) { - return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter); + Delimiter delimiter, + bool allowResultNumber = true) { + return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter, + allowResultNumber); } /// Parse zero or more trailing SSA comma-separated trailing operand @@ -1243,29 +1248,6 @@ ArrayRef argTypes = {}, bool enableNameShadowing = false) = 0; - /// Parse a region argument, this argument is resolved when calling - /// 'parseRegion'. - virtual ParseResult parseRegionArgument(UnresolvedOperand &argument) = 0; - - /// Parse zero or more region arguments with a specified surrounding - /// delimiter, and an optional required argument count. Region arguments - /// define new values; so this also checks if values with the same names have - /// not been defined yet. - virtual ParseResult - parseRegionArgumentList(SmallVectorImpl &result, - int requiredOperandCount = -1, - Delimiter delimiter = Delimiter::None) = 0; - virtual ParseResult - parseRegionArgumentList(SmallVectorImpl &result, - Delimiter delimiter) { - return parseRegionArgumentList(result, /*requiredOperandCount=*/-1, - delimiter); - } - - /// Parse a region argument if present. - virtual ParseResult - parseOptionalRegionArgument(UnresolvedOperand &argument) = 0; - //===--------------------------------------------------------------------===// // Successor Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1433,7 +1433,8 @@ auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand inductionVariable; // Parse the induction variable followed by '='. - if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) + if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) || + parser.parseEqual()) return failure(); // Parse loop bounds. @@ -3527,8 +3528,8 @@ auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); SmallVector ivs; - if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren) || + if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, + /*allowResultNumber=*/false) || parser.parseEqual() || parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) || parser.parseKeyword("to") || diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -489,7 +489,8 @@ parser.parseSuccessor(defaultDestination)) return failure(); if (succeeded(parser.parseOptionalLParen())) { - if (parser.parseRegionArgumentList(defaultOperands) || + if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false) || parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) return failure(); } @@ -509,7 +510,8 @@ failed(parser.parseSuccessor(destination))) return failure(); if (succeeded(parser.parseOptionalLParen())) { - if (failed(parser.parseRegionArgumentList(operands)) || + if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false)) || failed(parser.parseColonTypeList(operandTypes)) || failed(parser.parseRParen())) return failure(); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -539,8 +539,8 @@ MutableArrayRef indices) { assert(indices.size() == 3 && "space for three indices expected"); SmallVector args; - if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, - OpAsmParser::Delimiter::Paren) || + if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren, + /*allowResultNumber=*/false) || parser.parseKeyword("in") || parser.parseLParen()) return failure(); std::move(args.begin(), args.end(), indices.begin()); @@ -548,8 +548,8 @@ for (int i = 0; i < 3; ++i) { if (i != 0 && parser.parseComma()) return failure(); - if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || - parser.parseOperand(sizes[i])) + if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) || + parser.parseEqual() || parser.parseOperand(sizes[i])) return failure(); } @@ -869,7 +869,8 @@ OpAsmParser::UnresolvedOperand arg; Type type; - if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) + if (parser.parseOperand(arg, /*allowResultNumber=*/false) || + parser.parseColonType(type)) return failure(); args.push_back(arg); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -332,7 +332,8 @@ if (parser.parseColon() || parser.parseSuccessor(destination)) return failure(); if (!parser.parseOptionalLParen()) { - if (parser.parseRegionArgumentList(operands) || + if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false) || parser.parseColonTypeList(operandTypes) || parser.parseRParen()) return failure(); } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -70,7 +70,8 @@ OpAsmParser::UnresolvedOperand arg; Type type; - if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) + if (parser.parseOperand(arg, /*allowResultNumber=*/false) || + parser.parseColonType(type)) return failure(); args.push_back(arg); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -524,8 +524,8 @@ SmallVectorImpl &loopVarTypes, UnitAttr &inclusive) { // Parse an opening `(` followed by induction variables followed by `)` SmallVector ivs; - if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren)) + if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, + /*allowResultNumber=*/false)) return failure(); size_t numIVs = ivs.size(); @@ -587,8 +587,8 @@ ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) { // Parse an opening `(` followed by induction variables followed by `)` SmallVector ivs; - if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren)) + if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, + /*allowResultNumber=*/false)) return failure(); int numIVs = static_cast(ivs.size()); Type loopVarType; diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -103,7 +103,7 @@ // Parse the loop variable followed by type. OpAsmParser::UnresolvedOperand loopVariable; Type loopVariableType; - if (parser.parseRegionArgument(loopVariable) || + if (parser.parseOperand(loopVariable, /*allowResultNumber=*/false) || parser.parseColonType(loopVariableType)) return failure(); diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -401,7 +401,8 @@ auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step; // Parse the induction variable followed by '='. - if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) + if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) || + parser.parseEqual()) return failure(); // Parse loop bounds. @@ -1975,8 +1976,8 @@ auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` SmallVector ivs; - if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren)) + if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, + /*allowResultNumber=*/false)) return failure(); // Parse loop bounds. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4698,7 +4698,8 @@ OpAsmParser::UnresolvedOperand laneId; // Parse predicate operand. - if (parser.parseLParen() || parser.parseRegionArgument(laneId) || + if (parser.parseLParen() || + parser.parseOperand(laneId, /*allowResultNumber=*/false) || parser.parseRParen()) return failure(); diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -30,8 +30,12 @@ // Parse argument name if present. OpAsmParser::UnresolvedOperand argument; Type argumentType; - if (succeeded(parser.parseOptionalRegionArgument(argument)) && - !argument.name.empty()) { + auto hadSSAValue = parser.parseOptionalOperand(argument, + /*allowResultNumber=*/false); + if (hadSSAValue.hasValue()) { + if (failed(hadSSAValue.getValue())) + return failure(); // Argument was present but malformed. + // Reject this if the preceding argument was missing a name. if (argNames.empty() && !argTypes.empty()) return parser.emitError(loc, "expected type instead of SSA identifier"); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -268,8 +268,10 @@ ParseResult parseOptionalSSAUseList(SmallVectorImpl &results); - /// Parse a single SSA use into 'result'. - ParseResult parseSSAUse(UnresolvedOperand &result); + /// Parse a single SSA use into 'result'. If 'allowResultNumber' is true then + /// we allow #42 syntax. + ParseResult parseSSAUse(UnresolvedOperand &result, + bool allowResultNumber = true); /// Given a reference to an SSA value and its type, return a reference. This /// returns null on failure. @@ -699,7 +701,8 @@ /// /// ssa-use ::= ssa-id /// -ParseResult OperationParser::parseSSAUse(UnresolvedOperand &result) { +ParseResult OperationParser::parseSSAUse(UnresolvedOperand &result, + bool allowResultNumber) { result.name = getTokenSpelling(); result.number = 0; result.location = getToken().getLoc(); @@ -708,6 +711,9 @@ // If we have an attribute ID, it is a result number. if (getToken().is(Token::hash_identifier)) { + if (!allowResultNumber) + return emitError("result number not allowed in argument list"); + if (auto value = getToken().getHashIdentifierNumber()) result.number = value.getValue(); else @@ -1267,9 +1273,10 @@ //===--------------------------------------------------------------------===// /// Parse a single operand. - ParseResult parseOperand(UnresolvedOperand &result) override { + ParseResult parseOperand(UnresolvedOperand &result, + bool allowResultNumber = true) override { OperationParser::UnresolvedOperand useInfo; - if (parser.parseSSAUse(useInfo)) + if (parser.parseSSAUse(useInfo, allowResultNumber)) return failure(); result = {useInfo.location, useInfo.name, useInfo.number, {}}; @@ -1279,9 +1286,11 @@ } /// Parse a single operand if present. - OptionalParseResult parseOptionalOperand(UnresolvedOperand &result) override { + OptionalParseResult + parseOptionalOperand(UnresolvedOperand &result, + bool allowResultNumber = true) override { if (parser.getToken().is(Token::percent_identifier)) - return parseOperand(result); + return parseOperand(result, allowResultNumber); return llvm::None; } @@ -1289,17 +1298,8 @@ /// surrounding delimiter, and an optional required operand count. ParseResult parseOperandList(SmallVectorImpl &result, int requiredOperandCount = -1, - Delimiter delimiter = Delimiter::None) override { - return parseOperandOrRegionArgList(result, /*isOperandList=*/true, - requiredOperandCount, delimiter); - } - - /// Parse zero or more SSA comma-separated operand or region arguments with - /// optional surrounding delimiter and required operand count. - ParseResult - parseOperandOrRegionArgList(SmallVectorImpl &result, - bool isOperandList, int requiredOperandCount = -1, - Delimiter delimiter = Delimiter::None) { + Delimiter delimiter = Delimiter::None, + bool allowResultNumber = true) override { auto startLoc = parser.getToken().getLoc(); // The no-delimiter case has some special handling for better diagnostics. @@ -1322,8 +1322,7 @@ auto parseOneOperand = [&]() -> ParseResult { UnresolvedOperand operandOrArg; - if (isOperandList ? parseOperand(operandOrArg) - : parseRegionArgument(operandOrArg)) + if (parseOperand(operandOrArg, allowResultNumber)) return failure(); result.push_back(operandOrArg); return success(); @@ -1472,28 +1471,6 @@ return success(); } - /// Parse a region argument. The type of the argument will be resolved later - /// by a call to `parseRegion`. - ParseResult parseRegionArgument(UnresolvedOperand &argument) override { - return parseOperand(argument); - } - - /// Parse a region argument if present. - ParseResult - parseOptionalRegionArgument(UnresolvedOperand &argument) override { - if (parser.getToken().isNot(Token::percent_identifier)) - return success(); - return parseRegionArgument(argument); - } - - ParseResult - parseRegionArgumentList(SmallVectorImpl &result, - int requiredOperandCount = -1, - Delimiter delimiter = Delimiter::None) override { - return parseOperandOrRegionArgList(result, /*isOperandList=*/false, - requiredOperandCount, delimiter); - } - //===--------------------------------------------------------------------===// // Successor Parsing //===--------------------------------------------------------------------===// @@ -1539,8 +1516,8 @@ auto parseElt = [&]() -> ParseResult { UnresolvedOperand regionArg, operand; - if (parseRegionArgument(regionArg) || parseEqual() || - parseOperand(operand)) + if (parseOperand(regionArg, /*allowResultNumber=*/false) || + parseEqual() || parseOperand(operand)) return failure(); lhs.push_back(regionArg); rhs.push_back(operand); @@ -1561,8 +1538,9 @@ auto parseElt = [&]() -> ParseResult { UnresolvedOperand regionArg, operand; Type type; - if (parseRegionArgument(regionArg) || parseEqual() || - parseOperand(operand) || parseColon() || parseType(type)) + if (parseOperand(regionArg, /*allowResultNumber=*/false) || + parseEqual() || parseOperand(operand) || parseColon() || + parseType(type)) return failure(); lhs.push_back(regionArg); rhs.push_back(operand); diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -380,3 +380,13 @@ } return %res : f32 } + + +// ----- + +func.func @result_number() { + // expected-error@+1 {{result number not allowed}} + affine.for %n0#0 = 0 to 7 { + } + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -875,7 +875,8 @@ ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector ivsInfo; // Parse list of region arguments without a delimiter. - if (parser.parseRegionArgumentList(ivsInfo)) + if (parser.parseOperandList(ivsInfo, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false)) return failure(); // Parse the body region.