diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -54,23 +54,24 @@ /// Parses function arguments using `parser`. The `allowVariadic` argument /// indicates whether functions with variadic arguments are supported. The -/// trailing arguments are populated by this function with names, types and -/// attributes of the arguments. +/// trailing arguments are populated by this function with names, types, +/// attributes and locations of the arguments. ParseResult parseFunctionArgumentList( OpAsmParser &parser, bool allowAttributes, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic); + SmallVectorImpl> &argLocations, bool &isVariadic); /// Parses a function signature using `parser`. The `allowVariadic` argument /// indicates whether functions with variadic arguments are supported. The -/// trailing arguments are populated by this function with names, types and -/// attributes of the arguments and those of the results. +/// trailing arguments are populated by this function with names, types, +/// attributes and locations of the arguments and those of the results. ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, + SmallVectorImpl> &argLocations, bool &isVariadic, SmallVectorImpl &resultTypes, SmallVectorImpl &resultAttrs); diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1208,20 +1208,23 @@ /// Parses a region. Any parsed blocks are appended to 'region' and must be /// moved to the op regions after the op is created. The first block of the - /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is - /// set to true, the argument names are allowed to shadow the names of other - /// existing SSA values defined above the region scope. 'enableNameShadowing' - /// can only be set to true for regions attached to operations that are - /// 'IsolatedFromAbove. - virtual ParseResult parseRegion(Region ®ion, - ArrayRef arguments = {}, - ArrayRef argTypes = {}, - bool enableNameShadowing = false) = 0; + /// region takes 'arguments' of types 'argTypes'. If `argLocations` is + /// non-empty it contains an optional location to be attached to each + /// argument. If 'enableNameShadowing' is set to true, the argument names are + /// allowed to shadow the names of other existing SSA values defined above the + /// region scope. 'enableNameShadowing' can only be set to true for regions + /// attached to operations that are 'IsolatedFromAbove'. + virtual ParseResult + parseRegion(Region ®ion, ArrayRef arguments = {}, + ArrayRef argTypes = {}, + ArrayRef> argLocations = {}, + bool enableNameShadowing = false) = 0; /// Parses a region if present. virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef arguments = {}, ArrayRef argTypes = {}, + ArrayRef> argLocations = {}, bool enableNameShadowing = false) = 0; /// Parses a region if present. If the region is present, a new region is diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -220,6 +220,7 @@ Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, /*argTypes=*/{unwrappedTypes}, + /*argLocations=*/{}, /*enableNameShadowing=*/false)) return failure(); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -670,11 +670,13 @@ SmallVectorImpl &argTypes) { if (parser.parseOptionalKeyword("args")) return success(); - SmallVector argAttrs; + SmallVector argAttrs; + SmallVector> argLocations; bool isVariadic = false; return function_interface_impl::parseFunctionArgumentList( parser, /*allowAttributes=*/false, - /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic); + /*allowVariadic=*/false, argNames, argTypes, argAttrs, argLocations, + isVariadic); } static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, @@ -776,11 +778,12 @@ /// (`->` function-result-list)? memory-attribution `kernel`? /// function-attributes? region static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { - SmallVector entryArgs; - SmallVector argAttrs; - SmallVector resultAttrs; - SmallVector argTypes; - SmallVector resultTypes; + SmallVector entryArgs; + SmallVector argAttrs; + SmallVector resultAttrs; + SmallVector argTypes; + SmallVector resultTypes; + SmallVector> argLocations; bool isVariadic; // Parse the function name. @@ -792,7 +795,7 @@ auto signatureLocation = parser.getCurrentLocation(); if (failed(function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs))) + argLocations, isVariadic, resultTypes, resultAttrs))) return failure(); if (entryArgs.empty() && !argTypes.empty()) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2033,11 +2033,12 @@ parser, result, LLVM::Linkage::External))); StringAttr nameAttr; - SmallVector entryArgs; - SmallVector argAttrs; - SmallVector resultAttrs; - SmallVector argTypes; - SmallVector resultTypes; + SmallVector entryArgs; + SmallVector argAttrs; + SmallVector resultAttrs; + SmallVector argTypes; + SmallVector resultTypes; + SmallVector> argLocations; bool isVariadic; auto signatureLocation = parser.getCurrentLocation(); @@ -2045,7 +2046,7 @@ result.attributes) || function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs)) + argLocations, isVariadic, resultTypes, resultAttrs)) return failure(); auto type = diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1942,11 +1942,12 @@ //===----------------------------------------------------------------------===// static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) { - SmallVector entryArgs; - SmallVector argAttrs; - SmallVector resultAttrs; - SmallVector argTypes; - SmallVector resultTypes; + SmallVector entryArgs; + SmallVector argAttrs; + SmallVector resultAttrs; + SmallVector argTypes; + SmallVector resultTypes; + SmallVector> argLocations; auto &builder = parser.getBuilder(); // Parse the name as a symbol. @@ -1959,7 +1960,7 @@ bool isVariadic = false; if (function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, - isVariadic, resultTypes, resultAttrs)) + argLocations, isVariadic, resultTypes, resultAttrs)) return failure(); auto fnType = builder.getFunctionType(argTypes, resultTypes); diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -17,7 +17,7 @@ OpAsmParser &parser, bool allowAttributes, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic) { + SmallVectorImpl> &argLocations, bool &isVariadic) { if (parser.parseLParen()) return failure(); @@ -60,11 +60,12 @@ return parser.emitError(loc, "expected arguments without attributes"); argAttrs.push_back(attrs); - // Parse a location if specified. TODO: Don't drop it on the floor. + // Parse a location if specified. Optional explicitLoc; if (!argument.name.empty() && parser.parseOptionalLocationSpecifier(explicitLoc)) return failure(); + argLocations.push_back(explicitLoc); return success(); }; @@ -132,11 +133,12 @@ OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, - bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl> &argLocations, bool &isVariadic, + SmallVectorImpl &resultTypes, SmallVectorImpl &resultAttrs) { bool allowArgAttrs = true; if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames, - argTypes, argAttrs, isVariadic)) + argTypes, argAttrs, argLocations, isVariadic)) return failure(); if (succeeded(parser.parseOptionalArrow())) return parseFunctionResultList(parser, resultTypes, resultAttrs); @@ -190,11 +192,12 @@ ParseResult mlir::function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, FuncTypeBuilder funcTypeBuilder) { - SmallVector entryArgs; - SmallVector argAttrs; - SmallVector resultAttrs; - SmallVector argTypes; - SmallVector resultTypes; + SmallVector entryArgs; + SmallVector argAttrs; + SmallVector resultAttrs; + SmallVector argTypes; + SmallVector resultTypes; + SmallVector> argLocations; auto &builder = parser.getBuilder(); // Parse visibility. @@ -210,7 +213,8 @@ llvm::SMLoc signatureLocation = parser.getCurrentLocation(); bool isVariadic = false; if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, - argAttrs, isVariadic, resultTypes, resultAttrs)) + argAttrs, argLocations, isVariadic, resultTypes, + resultAttrs)) return failure(); std::string errorMessage; @@ -253,6 +257,7 @@ llvm::SMLoc loc = parser.getCurrentLocation(); OptionalParseResult parseResult = parser.parseOptionalRegion( *body, entryArgs, entryArgs.empty() ? ArrayRef() : argTypes, + entryArgs.empty() ? ArrayRef>() : argLocations, /*enableNameShadowing=*/false); if (parseResult.hasValue()) { if (failed(*parseResult)) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -363,16 +363,19 @@ //===--------------------------------------------------------------------===// /// Parse a region into 'region' with the provided entry block arguments. - /// 'isIsolatedNameScope' indicates if the naming scope of this region is - /// isolated from those above. + /// If non-empty, 'argLocations' contains an optional locations for each + /// argument. 'isIsolatedNameScope' indicates if the naming scope of this + /// region is isolated from those above. ParseResult parseRegion(Region ®ion, ArrayRef> entryArguments, + ArrayRef> argLocations = {}, bool isIsolatedNameScope = false); /// Parse a region body into 'region'. ParseResult parseRegionBody(Region ®ion, llvm::SMLoc startLoc, ArrayRef> entryArguments, + ArrayRef> argLocations, bool isIsolatedNameScope); //===--------------------------------------------------------------------===// @@ -1448,6 +1451,7 @@ /// effectively defines the SSA values of `arguments` and assigns their type. ParseResult parseRegion(Region ®ion, ArrayRef arguments, ArrayRef argTypes, + ArrayRef> argLocations, bool enableNameShadowing) override { assert(arguments.size() == argTypes.size() && "mismatching number of arguments and types"); @@ -1466,19 +1470,22 @@ (void)isIsolatedFromAbove; assert((!enableNameShadowing || isIsolatedFromAbove) && "name shadowing is only allowed on isolated regions"); - if (parser.parseRegion(region, regionArguments, enableNameShadowing)) + if (parser.parseRegion(region, regionArguments, argLocations, + enableNameShadowing)) return failure(); return success(); } /// Parses a region if present. - OptionalParseResult parseOptionalRegion(Region ®ion, - ArrayRef arguments, - ArrayRef argTypes, - bool enableNameShadowing) override { + OptionalParseResult + parseOptionalRegion(Region ®ion, ArrayRef arguments, + ArrayRef argTypes, + ArrayRef> argLocations, + bool enableNameShadowing) override { if (parser.getToken().isNot(Token::l_brace)) return llvm::None; - return parseRegion(region, arguments, argTypes, enableNameShadowing); + return parseRegion(region, arguments, argTypes, argLocations, + enableNameShadowing); } /// Parses a region if present. If the region is present, a new region is @@ -1491,7 +1498,8 @@ if (parser.getToken().isNot(Token::l_brace)) return llvm::None; std::unique_ptr newRegion = std::make_unique(); - if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing)) + if (parseRegion(*newRegion, arguments, argTypes, /*argLocations=*/{}, + enableNameShadowing)) return failure(); region = std::move(newRegion); @@ -1815,7 +1823,7 @@ ParseResult OperationParser::parseRegion( Region ®ion, ArrayRef> entryArguments, - bool isIsolatedNameScope) { + ArrayRef> argLocations, bool isIsolatedNameScope) { // Parse the '{'. Token lBraceTok = getToken(); if (parseToken(Token::l_brace, "expected '{' to begin a region")) @@ -1827,7 +1835,7 @@ // Parse the region body. if ((!entryArguments.empty() || getToken().isNot(Token::r_brace)) && - parseRegionBody(region, lBraceTok.getLoc(), entryArguments, + parseRegionBody(region, lBraceTok.getLoc(), entryArguments, argLocations, isIsolatedNameScope)) { return failure(); } @@ -1843,7 +1851,8 @@ ParseResult OperationParser::parseRegionBody( Region ®ion, llvm::SMLoc startLoc, ArrayRef> entryArguments, - bool isIsolatedNameScope) { + ArrayRef> argLocations, bool isIsolatedNameScope) { + assert(argLocations.empty() || argLocations.size() == entryArguments.size()); auto currentPt = opBuilder.saveInsertionPoint(); // Push a new named value scope. @@ -1865,7 +1874,9 @@ if (getToken().is(Token::caret_identifier)) return emitError("invalid block name in region with named arguments"); - for (auto &placeholderArgPair : entryArguments) { + for (const auto &it : llvm::enumerate(entryArguments)) { + size_t argIndex = it.index(); + auto &placeholderArgPair = it.value(); auto &argInfo = placeholderArgPair.first; // Ensure that the argument was not already defined. @@ -1875,7 +1886,10 @@ .attachNote(getEncodedSourceLocation(*defLoc)) << "previously referenced here"; } - auto loc = getEncodedSourceLocation(placeholderArgPair.first.loc); + Location loc = + (!argLocations.empty() && argLocations[argIndex]) + ? *argLocations[argIndex] + : getEncodedSourceLocation(placeholderArgPair.first.loc); BlockArgument arg = block->addArgument(placeholderArgPair.second, loc); // Add a definition of this arg to the assembly state if provided. diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -52,7 +52,7 @@ // CHECK-LABEL: func @argLocs( // CHECK-SAME: %arg0: i32 loc({{.*}}locations.mlir":[[# @LINE+1]]:15), func @argLocs(%x: i32, -// CHECK-SAME: %arg1: i64 loc({{.*}}locations.mlir":[[# @LINE+1]]:15)) +// CHECK-SAME: %arg1: i64 loc("hotdog") %y: i64 loc("hotdog")) { return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -595,7 +595,7 @@ // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); - return parser.parseRegion(*body, argInfo, argType, + return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{}, /*enableNameShadowing=*/true); }