diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -54,23 +54,24 @@ /// Parses function arguments using `parser`. The `allowVariadic` argument /// indicates whether functions with variadic arguments are supported. The -/// trailing arguments are populated by this function with names, types and -/// attributes of the arguments. +/// trailing arguments are populated by this function with names, types, +/// attributes and locations of the arguments. ParseResult parseFunctionArgumentList( OpAsmParser &parser, bool allowAttributes, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic); + SmallVectorImpl> &argLocations, bool &isVariadic); /// Parses a function signature using `parser`. The `allowVariadic` argument /// indicates whether functions with variadic arguments are supported. The -/// trailing arguments are populated by this function with names, types and -/// attributes of the arguments and those of the results. +/// trailing arguments are populated by this function with names, types, +/// attributes and locations of the arguments and those of the results. ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, + SmallVectorImpl> &argLocations, bool &isVariadic, SmallVectorImpl &resultTypes, SmallVectorImpl &resultAttrs); diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1205,20 +1205,23 @@ /// Parses a region. Any parsed blocks are appended to 'region' and must be /// moved to the op regions after the op is created. The first block of the - /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is - /// set to true, the argument names are allowed to shadow the names of other - /// existing SSA values defined above the region scope. 'enableNameShadowing' - /// can only be set to true for regions attached to operations that are - /// 'IsolatedFromAbove. - virtual ParseResult parseRegion(Region ®ion, - ArrayRef arguments = {}, - ArrayRef argTypes = {}, - bool enableNameShadowing = false) = 0; + /// region takes 'arguments' of types 'argTypes'. If `argLocations` is + /// non-empty it contains an optional location to be attached to each + /// argument. If 'enableNameShadowing' is set to true, the argument names are + /// allowed to shadow the names of other existing SSA values defined above the + /// region scope. 'enableNameShadowing' can only be set to true for regions + /// attached to operations that are 'IsolatedFromAbove'. + virtual ParseResult + parseRegion(Region ®ion, ArrayRef arguments = {}, + ArrayRef argTypes = {}, + ArrayRef> argLocations = {}, + bool enableNameShadowing = false) = 0; /// Parses a region if present. virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef arguments = {}, ArrayRef argTypes = {}, + ArrayRef> argLocations = {}, bool enableNameShadowing = false) = 0; /// Parses a region if present. If the region is present, a new region is diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -219,6 +219,7 @@ Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, /*argTypes=*/{unwrappedTypes}, + /*argLocations=*/{}, /*enableNameShadowing=*/false)) return failure(); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -684,10 +684,12 @@ if (parser.parseOptionalKeyword("args")) return success(); SmallVector argAttrs; + SmallVector, 4> argLocations; bool isVariadic = false; return function_like_impl::parseFunctionArgumentList( parser, /*allowAttributes=*/false, - /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic); + /*allowVariadic=*/false, argNames, argTypes, argAttrs, argLocations, + isVariadic); } static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, @@ -794,6 +796,7 @@ SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; + SmallVector, 8> argLocations; bool isVariadic; // Parse the function name. @@ -805,7 +808,7 @@ auto signatureLocation = parser.getCurrentLocation(); if (failed(function_like_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs))) + argLocations, isVariadic, resultTypes, resultAttrs))) return failure(); if (entryArgs.empty() && !argTypes.empty()) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2036,6 +2036,7 @@ SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; + SmallVector, 8> argLocations; bool isVariadic; auto signatureLocation = parser.getCurrentLocation(); @@ -2043,7 +2044,7 @@ result.attributes) || function_like_impl::parseFunctionSignature( parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs)) + argLocations, isVariadic, resultTypes, resultAttrs)) return failure(); auto type = diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1947,6 +1947,7 @@ SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; + SmallVector, 4> argLocations; auto &builder = parser.getBuilder(); // Parse the name as a symbol. @@ -1959,7 +1960,7 @@ bool isVariadic = false; if (function_like_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs)) + argLocations, isVariadic, resultTypes, resultAttrs)) return failure(); auto fnType = builder.getFunctionType(argTypes, resultTypes); diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -17,7 +17,7 @@ OpAsmParser &parser, bool allowAttributes, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic) { + SmallVectorImpl> &argLocations, bool &isVariadic) { if (parser.parseLParen()) return failure(); @@ -60,11 +60,12 @@ return parser.emitError(loc, "expected arguments without attributes"); argAttrs.push_back(attrs); - // Parse a location if specified. TODO: Don't drop it on the floor. + // Parse a location if specified. Optional explicitLoc; if (!argument.name.empty() && parser.parseOptionalLocationSpecifier(explicitLoc)) return failure(); + argLocations.push_back(explicitLoc); return success(); }; @@ -130,17 +131,19 @@ /// Parses a function signature using `parser`. The `allowVariadic` argument /// indicates whether functions with variadic arguments are supported. The -/// trailing arguments are populated by this function with names, types and -/// attributes of the arguments and those of the results. +/// trailing arguments are populated by this function with names, types, +/// attributes and optional locations of the arguments and the types and +/// attributes of the results. ParseResult mlir::function_like_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl> &argLocations, bool &isVariadic, + SmallVectorImpl &resultTypes, SmallVectorImpl &resultAttrs) { bool allowArgAttrs = true; if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames, - argTypes, argAttrs, isVariadic)) + argTypes, argAttrs, argLocations, isVariadic)) return failure(); if (succeeded(parser.parseOptionalArrow())) return parseFunctionResultList(parser, resultTypes, resultAttrs); @@ -199,6 +202,7 @@ SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; + SmallVector, 4> argLocations; auto &builder = parser.getBuilder(); // Parse visibility. @@ -214,7 +218,8 @@ llvm::SMLoc signatureLocation = parser.getCurrentLocation(); bool isVariadic = false; if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, - argAttrs, isVariadic, resultTypes, resultAttrs)) + argAttrs, argLocations, isVariadic, resultTypes, + resultAttrs)) return failure(); std::string errorMessage; @@ -257,6 +262,7 @@ llvm::SMLoc loc = parser.getCurrentLocation(); OptionalParseResult parseResult = parser.parseOptionalRegion( *body, entryArgs, entryArgs.empty() ? ArrayRef() : argTypes, + entryArgs.empty() ? ArrayRef>() : argLocations, /*enableNameShadowing=*/false); if (parseResult.hasValue()) { if (failed(*parseResult)) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -358,16 +358,19 @@ //===--------------------------------------------------------------------===// /// Parse a region into 'region' with the provided entry block arguments. - /// 'isIsolatedNameScope' indicates if the naming scope of this region is - /// isolated from those above. + /// If non-empty, 'argLocations' contains an optional locations for each + /// argument. 'isIsolatedNameScope' indicates if the naming scope of this + /// region is isolated from those above. ParseResult parseRegion(Region ®ion, ArrayRef> entryArguments, + ArrayRef> argLocations = {}, bool isIsolatedNameScope = false); /// Parse a region body into 'region'. ParseResult parseRegionBody(Region ®ion, llvm::SMLoc startLoc, ArrayRef> entryArguments, + ArrayRef> argLocations, bool isIsolatedNameScope); //===--------------------------------------------------------------------===// @@ -1426,6 +1429,7 @@ /// effectively defines the SSA values of `arguments` and assigns their type. ParseResult parseRegion(Region ®ion, ArrayRef arguments, ArrayRef argTypes, + ArrayRef> argLocations, bool enableNameShadowing) override { assert(arguments.size() == argTypes.size() && "mismatching number of arguments and types"); @@ -1444,19 +1448,22 @@ (void)isIsolatedFromAbove; assert((!enableNameShadowing || isIsolatedFromAbove) && "name shadowing is only allowed on isolated regions"); - if (parser.parseRegion(region, regionArguments, enableNameShadowing)) + if (parser.parseRegion(region, regionArguments, argLocations, + enableNameShadowing)) return failure(); return success(); } /// Parses a region if present. - OptionalParseResult parseOptionalRegion(Region ®ion, - ArrayRef arguments, - ArrayRef argTypes, - bool enableNameShadowing) override { + OptionalParseResult + parseOptionalRegion(Region ®ion, ArrayRef arguments, + ArrayRef argTypes, + ArrayRef> argLocations, + bool enableNameShadowing) override { if (parser.getToken().isNot(Token::l_brace)) return llvm::None; - return parseRegion(region, arguments, argTypes, enableNameShadowing); + return parseRegion(region, arguments, argTypes, argLocations, + enableNameShadowing); } /// Parses a region if present. If the region is present, a new region is @@ -1469,7 +1476,8 @@ if (parser.getToken().isNot(Token::l_brace)) return llvm::None; std::unique_ptr newRegion = std::make_unique(); - if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing)) + if (parseRegion(*newRegion, arguments, argTypes, /*argLocations=*/{}, + enableNameShadowing)) return failure(); region = std::move(newRegion); @@ -1758,7 +1766,7 @@ ParseResult OperationParser::parseRegion( Region ®ion, ArrayRef> entryArguments, - bool isIsolatedNameScope) { + ArrayRef> argLocations, bool isIsolatedNameScope) { // Parse the '{'. Token lBraceTok = getToken(); if (parseToken(Token::l_brace, "expected '{' to begin a region")) @@ -1770,7 +1778,7 @@ // Parse the region body. if ((!entryArguments.empty() || getToken().isNot(Token::r_brace)) && - parseRegionBody(region, lBraceTok.getLoc(), entryArguments, + parseRegionBody(region, lBraceTok.getLoc(), entryArguments, argLocations, isIsolatedNameScope)) { return failure(); } @@ -1786,7 +1794,8 @@ ParseResult OperationParser::parseRegionBody( Region ®ion, llvm::SMLoc startLoc, ArrayRef> entryArguments, - bool isIsolatedNameScope) { + ArrayRef> argLocations, bool isIsolatedNameScope) { + assert(argLocations.empty() || argLocations.size() == entryArguments.size()); auto currentPt = opBuilder.saveInsertionPoint(); // Push a new named value scope. @@ -1808,7 +1817,8 @@ if (getToken().is(Token::caret_identifier)) return emitError("invalid block name in region with named arguments"); - for (auto &placeholderArgPair : entryArguments) { + for (int argIndex = 0; argIndex < entryArguments.size(); ++argIndex) { + auto &placeholderArgPair = entryArguments[argIndex]; auto &argInfo = placeholderArgPair.first; // Ensure that the argument was not already defined. @@ -1818,7 +1828,10 @@ .attachNote(getEncodedSourceLocation(*defLoc)) << "previously referenced here"; } - auto loc = getEncodedSourceLocation(placeholderArgPair.first.loc); + Location loc = + (!argLocations.empty() && argLocations[argIndex].hasValue()) + ? argLocations[argIndex].getValue() + : getEncodedSourceLocation(placeholderArgPair.first.loc); BlockArgument arg = block->addArgument(placeholderArgPair.second, loc); // Add a definition of this arg to the assembly state if provided. diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -52,7 +52,7 @@ // CHECK-LABEL: func @argLocs( // CHECK-SAME: %arg0: i32 loc({{.*}}locations.mlir":[[# @LINE+1]]:15), func @argLocs(%x: i32, -// CHECK-SAME: %arg1: i64 loc({{.*}}locations.mlir":[[# @LINE+1]]:15)) +// CHECK-SAME: %arg1: i64 loc("hotdog") %y: i64 loc("hotdog")) { return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -595,7 +595,7 @@ // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); - return parser.parseRegion(*body, argInfo, argType, + return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{}, /*enableNameShadowing=*/true); }