diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -213,16 +213,6 @@ GPUDialect::getKernelFuncAttrName()) != nullptr; } - /// Change the type of this function in place. This is an extremely - /// dangerous operation and it is up to the caller to ensure that this is - /// legal for this function, and to restore invariants: - /// - the entry block args must be updated to match the function params. - /// - the argument/result attributes may need an update: if the new type - /// has less parameters we drop the extra attributes, if there are more - /// parameters they won't have any attributes. - // TODO: consider removing this function thanks to rewrite patterns. - void setType(FunctionType newType); - /// Returns the number of buffers located in the workgroup memory. unsigned getNumWorkgroupAttributions() { return (*this)->getAttrOfType( diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -300,7 +300,7 @@ }]; let parameters = (ins ArrayRefParameter<"NamedAttribute", "">:$value); let builders = [ - AttrBuilder<(ins "ArrayRef":$value)> + AttrBuilder<(ins CArg<"ArrayRef", "llvm::None">:$value)> ]; let extraClassDeclaration = [{ using ValueType = ArrayRef; diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -20,7 +20,7 @@ namespace mlir { -namespace impl { +namespace function_like_impl { /// A named class for passing around the variadic flag. class VariadicFlag { @@ -37,6 +37,9 @@ /// `resultAttrs` arguments, to the list of operation attributes in `result`. /// Internally, argument and result attributes are stored as dict attributes /// with special names given by getResultAttrName, getArgumentAttrName. +void addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef argAttrs, + ArrayRef resultAttrs); void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef argAttrs, ArrayRef resultAttrs); @@ -103,7 +106,7 @@ unsigned numResults, ArrayRef elided = {}); -} // namespace impl +} // namespace function_like_impl } // namespace mlir diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -20,45 +20,41 @@ namespace mlir { -namespace impl { +namespace function_like_impl { /// Return the name of the attribute used for function types. inline StringRef getTypeAttrName() { return "type"; } -/// Return the name of the attribute used for function arguments. -inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl &out) { - out.clear(); - return ("arg" + Twine(arg)).toStringRef(out); -} - -/// Returns true if the given name is a valid argument attribute name. -inline bool isArgAttrName(StringRef name) { - APInt unused; - return name.startswith("arg") && - !name.drop_front(3).getAsInteger(/*Radix=*/10, unused); -} +/// Return the name of the attribute used for function argument attributes. +inline StringRef getArgDictAttrName() { return "arg_attrs"; } -/// Return the name of the attribute used for function results. -inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl &out) { - out.clear(); - return ("result" + Twine(arg)).toStringRef(out); -} +/// Return the name of the attribute used for function argument attributes. +inline StringRef getResultDictAttrName() { return "res_attrs"; } /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. -inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) { - SmallString<8> nameOut; - return op->getAttrOfType(getArgAttrName(index, nameOut)); -} +DictionaryAttr getArgAttrDict(Operation *op, unsigned index); /// Returns the dictionary attribute corresponding to the result at 'index'. /// If there are no result attributes at 'index', a null attribute is /// returned. -inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) { - SmallString<8> nameOut; - return op->getAttrOfType(getResultAttrName(index, nameOut)); -} +DictionaryAttr getResultAttrDict(Operation *op, unsigned index); + +namespace detail { +/// Update the given index into an argument or result attribute dictionary. +void setArgResAttrDict(Operation *op, StringRef attrName, + unsigned numTotalIndices, unsigned index, + DictionaryAttr attrs); +} // namespace detail + +/// Set all of the argument or result attribute dictionaries for a function. The +/// size of `attrs` is expected to match the number of arguments/results of the +/// given `op`. +void setAllArgAttrDicts(Operation *op, ArrayRef attrs); +void setAllArgAttrDicts(Operation *op, ArrayRef attrs); +void setAllResultAttrDicts(Operation *op, ArrayRef attrs); +void setAllResultAttrDicts(Operation *op, ArrayRef attrs); /// Return all of the attributes for the argument at 'index'. inline ArrayRef getArgAttrs(Operation *op, unsigned index) { @@ -87,7 +83,7 @@ /// Get a FunctionLike operation's body. Region &getFunctionBody(Operation *op); -} // namespace impl +} // namespace function_like_impl namespace OpTrait { @@ -142,7 +138,7 @@ bool isExternal() { return empty(); } Region &getBody() { - return ::mlir::impl::getFunctionBody(this->getOperation()); + return function_like_impl::getFunctionBody(this->getOperation()); } /// Delete all blocks from this function. @@ -194,7 +190,9 @@ //===--------------------------------------------------------------------===// /// Return the name of the attribute used for function types. - static StringRef getTypeAttrName() { return ::mlir::impl::getTypeAttrName(); } + static StringRef getTypeAttrName() { + return function_like_impl::getTypeAttrName(); + } TypeAttr getTypeAttr() { return this->getOperation()->template getAttrOfType( @@ -207,7 +205,7 @@ /// hide this one if the concrete class does not use FunctionType for the /// function type under the hood. FunctionType getType() { - return ::mlir::impl::getFunctionType(this->getOperation()); + return function_like_impl::getFunctionType(this->getOperation()); } /// Return the type of this function without the specified arguments and @@ -277,8 +275,8 @@ void eraseArguments(ArrayRef argIndices) { unsigned originalNumArgs = getNumArguments(); Type newType = getTypeWithoutArgsAndResults(argIndices, {}); - ::mlir::impl::eraseFunctionArguments(this->getOperation(), argIndices, - originalNumArgs, newType); + function_like_impl::eraseFunctionArguments(this->getOperation(), argIndices, + originalNumArgs, newType); } /// Erase a single result at `resultIndex`. @@ -289,8 +287,8 @@ void eraseResults(ArrayRef resultIndices) { unsigned originalNumResults = getNumResults(); Type newType = getTypeWithoutArgsAndResults({}, resultIndices); - ::mlir::impl::eraseFunctionResults(this->getOperation(), resultIndices, - originalNumResults, newType); + function_like_impl::eraseFunctionResults( + this->getOperation(), resultIndices, originalNumResults, newType); } //===--------------------------------------------------------------------===// @@ -306,14 +304,23 @@ /// Return all of the attributes for the argument at 'index'. ArrayRef getArgAttrs(unsigned index) { - return ::mlir::impl::getArgAttrs(this->getOperation(), index); + return function_like_impl::getArgAttrs(this->getOperation(), index); } - /// Return all argument attributes of this function. If an argument does not - /// have any attributes, the corresponding entry in `result` is nullptr. + /// Return an ArrayAttr containing all argument attribute dictionaries of this + /// function, or nullptr if no arguments have attributes. + ArrayAttr getAllArgAttrs() { + return this->getOperation()->template getAttrOfType( + function_like_impl::getArgDictAttrName()); + } + /// Return all argument attributes of this function. void getAllArgAttrs(SmallVectorImpl &result) { - for (unsigned i = 0, e = getNumArguments(); i != e; ++i) - result.emplace_back(getArgAttrDict(i)); + if (ArrayAttr argAttrs = getAllArgAttrs()) { + auto argAttrRange = argAttrs.template getAsRange(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.resize(getNumArguments()); + } } /// Return the specified attribute, if present, for the argument at 'index', @@ -342,7 +349,19 @@ /// Set the attributes held by the argument at 'index'. `attributes` may be /// null, in which case any existing argument attributes are removed. void setArgAttrs(unsigned index, DictionaryAttr attributes); - void setAllArgAttrs(ArrayRef attributes); + void setAllArgAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumArguments()); + function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes); + } + void setAllArgAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumArguments()); + function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes); + } + void setAllArgAttrs(ArrayAttr attributes) { + assert(attributes.size() == getNumArguments()); + this->getOperation()->setAttr(function_like_impl::getArgDictAttrName(), + attributes); + } /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -370,14 +389,23 @@ /// Return all of the attributes for the result at 'index'. ArrayRef getResultAttrs(unsigned index) { - return ::mlir::impl::getResultAttrs(this->getOperation(), index); + return function_like_impl::getResultAttrs(this->getOperation(), index); } - /// Return all result attributes of this function. If a result does not have - /// any attributes, the corresponding entry in `result` is nullptr. + /// Return an ArrayAttr containing all result attribute dictionaries of this + /// function, or nullptr if no result have attributes. + ArrayAttr getAllResultAttrs() { + return this->getOperation()->template getAttrOfType( + function_like_impl::getResultDictAttrName()); + } + /// Return all result attributes of this function. void getAllResultAttrs(SmallVectorImpl &result) { - for (unsigned i = 0, e = getNumResults(); i != e; ++i) - result.emplace_back(getResultAttrDict(i)); + if (ArrayAttr argAttrs = getAllResultAttrs()) { + auto argAttrRange = argAttrs.template getAsRange(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.resize(getNumResults()); + } } /// Return the specified attribute, if present, for the result at 'index', @@ -402,10 +430,23 @@ /// Set the attributes held by the result at 'index'. void setResultAttrs(unsigned index, ArrayRef attributes); + /// Set the attributes held by the result at 'index'. `attributes` may be /// null, in which case any existing argument attributes are removed. void setResultAttrs(unsigned index, DictionaryAttr attributes); - void setAllResultAttrs(ArrayRef attributes); + void setAllResultAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumResults()); + function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes); + } + void setAllResultAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumResults()); + function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes); + } + void setAllResultAttrs(ArrayAttr attributes) { + assert(attributes.size() == getNumResults()); + this->getOperation()->setAttr(function_like_impl::getResultDictAttrName(), + attributes); + } /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -422,25 +463,12 @@ Attribute removeResultAttr(unsigned index, Identifier name); protected: - /// Returns the attribute entry name for the set of argument attributes at - /// 'index'. - static StringRef getArgAttrName(unsigned index, SmallVectorImpl &out) { - return ::mlir::impl::getArgAttrName(index, out); - } - /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. DictionaryAttr getArgAttrDict(unsigned index) { assert(index < getNumArguments() && "invalid argument number"); - return ::mlir::impl::getArgAttrDict(this->getOperation(), index); - } - - /// Returns the attribute entry name for the set of result attributes at - /// 'index'. - static StringRef getResultAttrName(unsigned index, - SmallVectorImpl &out) { - return ::mlir::impl::getResultAttrName(index, out); + return function_like_impl::getArgAttrDict(this->getOperation(), index); } /// Returns the dictionary attribute corresponding to the result at 'index'. @@ -448,7 +476,7 @@ /// returned. DictionaryAttr getResultAttrDict(unsigned index) { assert(index < getNumResults() && "invalid result number"); - return ::mlir::impl::getResultAttrDict(this->getOperation(), index); + return function_like_impl::getResultAttrDict(this->getOperation(), index); } /// Hook for concrete classes to verify that the type attribute respects @@ -475,9 +503,7 @@ template LogicalResult FunctionLike::verifyTrait(Operation *op) { - MLIRContext *ctx = op->getContext(); auto funcOp = cast(op); - if (!funcOp.isTypeAttrValid()) return funcOp.emitOpError("requires a type attribute '") << getTypeAttrName() << '\''; @@ -485,35 +511,69 @@ if (failed(funcOp.verifyType())) return failure(); - for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) { - // Verify that all of the argument attributes are dialect attributes, i.e. - // that they contain a dialect prefix in their name. Call the dialect, if - // registered, to verify the attributes themselves. - for (auto attr : funcOp.getArgAttrs(i)) { - if (!attr.first.strref().contains('.')) - return funcOp.emitOpError("arguments may only have dialect attributes"); - auto dialectNamePair = attr.first.strref().split('.'); - if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) { - if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, - /*argIndex=*/i, attr))) - return failure(); + if (ArrayAttr allArgAttrs = funcOp.getAllArgAttrs()) { + unsigned numArgs = funcOp.getNumArguments(); + if (allArgAttrs.size() != numArgs) { + return funcOp.emitOpError() + << "expects argument attribute array `" + << function_like_impl::getArgDictAttrName() + << "` to have the same number of elements as the number of " + "function arguments, got " + << allArgAttrs.size() << ", but expected " << numArgs; + } + for (unsigned i = 0; i != numArgs; ++i) { + DictionaryAttr argAttrs = allArgAttrs[i].dyn_cast(); + if (!argAttrs) { + return funcOp.emitOpError() << "expects argument attribute dictionary " + "to be a DictionaryAttr, but got `" + << allArgAttrs[i] << "`"; + } + + // Verify that all of the argument attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : argAttrs) { + if (!attr.first.strref().contains('.')) + return funcOp.emitOpError( + "arguments may only have dialect attributes"); + if (Dialect *dialect = attr.first.getDialect()) { + if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, + /*argIndex=*/i, attr))) + return failure(); + } } } } + if (ArrayAttr allResultAttrs = funcOp.getAllResultAttrs()) { + unsigned numResults = funcOp.getNumResults(); + if (allResultAttrs.size() != numResults) { + return funcOp.emitOpError() + << "expects result attribute array `" + << function_like_impl::getResultDictAttrName() + << "` to have the same number of elements as the number of " + "function results, got " + << allResultAttrs.size() << ", but expected " << numResults; + } + for (unsigned i = 0; i != numResults; ++i) { + DictionaryAttr resultAttrs = allResultAttrs[i].dyn_cast(); + if (!resultAttrs) { + return funcOp.emitOpError() << "expects result attribute dictionary " + "to be a DictionaryAttr, but got `" + << allResultAttrs[i] << "`"; + } - for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) { - // Verify that all of the result attributes are dialect attributes, i.e. - // that they contain a dialect prefix in their name. Call the dialect, if - // registered, to verify the attributes themselves. - for (auto attr : funcOp.getResultAttrs(i)) { - if (!attr.first.strref().contains('.')) - return funcOp.emitOpError("results may only have dialect attributes"); - auto dialectNamePair = attr.first.strref().split('.'); - if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) { - if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, - /*resultIndex=*/i, - attr))) - return failure(); + // Verify that all of the result attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : resultAttrs) { + if (!attr.first.strref().contains('.')) + return funcOp.emitOpError("results may only have dialect attributes"); + if (Dialect *dialect = attr.first.getDialect()) { + if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, + /*resultIndex=*/i, + attr))) + return failure(); + } } } } @@ -551,7 +611,7 @@ template void FunctionLike::setType(FunctionType newType) { - ::mlir::impl::setFunctionType(this->getOperation(), newType); + function_like_impl::setFunctionType(this->getOperation(), newType); } //===----------------------------------------------------------------------===// @@ -563,45 +623,19 @@ void FunctionLike::setArgAttrs( unsigned index, ArrayRef attributes) { assert(index < getNumArguments() && "invalid argument number"); - SmallString<8> nameOut; - getArgAttrName(index, nameOut); - Operation *op = this->getOperation(); - if (attributes.empty()) - return (void)op->removeAttr(nameOut); - op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes)); + return function_like_impl::detail::setArgResAttrDict( + op, function_like_impl::getArgDictAttrName(), getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); } template void FunctionLike::setArgAttrs(unsigned index, DictionaryAttr attributes) { - assert(index < getNumArguments() && "invalid argument number"); - SmallString<8> nameOut; - if (!attributes || attributes.empty()) - this->getOperation()->removeAttr(getArgAttrName(index, nameOut)); - else - return this->getOperation()->setAttr(getArgAttrName(index, nameOut), - attributes); -} - -template -void FunctionLike::setAllArgAttrs( - ArrayRef attributes) { - assert(attributes.size() == getNumArguments()); - NamedAttrList attrs = this->getOperation()->getAttrs(); - - // Instead of calling setArgAttrs() multiple times, which rebuild the - // attribute dictionary every time, build a new list of attributes for the - // operation so that we rebuild the attribute dictionary in one shot. - SmallString<8> argAttrName; - for (unsigned i = 0, e = attributes.size(); i != e; ++i) { - StringRef attrName = getArgAttrName(i, argAttrName); - if (!attributes[i] || attributes[i].empty()) - attrs.erase(attrName); - else - attrs.set(attrName, attributes[i]); - } - this->getOperation()->setAttrs(attrs); + Operation *op = this->getOperation(); + return function_like_impl::detail::setArgResAttrDict( + op, function_like_impl::getArgDictAttrName(), getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } /// If the an attribute exists with the specified name, change it to the new @@ -640,45 +674,20 @@ void FunctionLike::setResultAttrs( unsigned index, ArrayRef attributes) { assert(index < getNumResults() && "invalid result number"); - SmallString<8> nameOut; - getResultAttrName(index, nameOut); - - if (attributes.empty()) - return (void)this->getOperation()->removeAttr(nameOut); Operation *op = this->getOperation(); - op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes)); + return function_like_impl::detail::setArgResAttrDict( + op, function_like_impl::getResultDictAttrName(), getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); } template void FunctionLike::setResultAttrs(unsigned index, DictionaryAttr attributes) { assert(index < getNumResults() && "invalid result number"); - SmallString<8> nameOut; - if (!attributes || attributes.empty()) - this->getOperation()->removeAttr(getResultAttrName(index, nameOut)); - else - this->getOperation()->setAttr(getResultAttrName(index, nameOut), - attributes); -} - -template -void FunctionLike::setAllResultAttrs( - ArrayRef attributes) { - assert(attributes.size() == getNumResults()); - NamedAttrList attrs = this->getOperation()->getAttrs(); - - // Instead of calling setResultAttrs() multiple times, which rebuild the - // attribute dictionary every time, build a new list of attributes for the - // operation so that we rebuild the attribute dictionary in one shot. - SmallString<8> resultAttrName; - for (unsigned i = 0, e = attributes.size(); i != e; ++i) { - StringRef attrName = getResultAttrName(i, resultAttrName); - if (!attributes[i] || attributes[i].empty()) - attrs.erase(attrName); - else - attrs.set(attrName, attributes[i]); - } - this->getOperation()->setAttrs(attrs); + Operation *op = this->getOperation(); + return function_like_impl::detail::setArgResAttrDict( + op, function_like_impl::getResultDictAttrName(), getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } /// If the an attribute exists with the specified name, change it to the new diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -58,7 +58,7 @@ SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.first == SymbolTable::getSymbolAttrName() || - attr.first == impl::getTypeAttrName() || + attr.first == function_like_impl::getTypeAttrName() || attr.first == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -195,7 +195,7 @@ rewriter.getFunctionType(signatureConverter.getConvertedTypes(), llvm::None)); for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.first == impl::getTypeAttrName() || + if (namedAttr.first == function_like_impl::getTypeAttrName() || namedAttr.first == SymbolTable::getSymbolAttrName()) continue; newFuncOp->setAttr(namedAttr.first, namedAttr.second); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1211,8 +1211,10 @@ SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.first == SymbolTable::getSymbolAttrName() || - attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" || - (filterArgAttrs && impl::isArgAttrName(attr.first.strref()))) + attr.first == function_like_impl::getTypeAttrName() || + attr.first == "std.varargs" || + (filterArgAttrs && + attr.first == function_like_impl::getArgDictAttrName())) continue; result.push_back(attr); } @@ -1395,19 +1397,19 @@ SmallVector attributes; filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, attributes); - for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { - auto attr = impl::getArgAttrDict(funcOp, i); - if (!attr) - continue; - - auto mapping = result.getInputMapping(i); - assert(mapping.hasValue() && "unexpected deletion of function argument"); - - SmallString<8> name; - for (size_t j = 0; j < mapping->size; ++j) { - impl::getArgAttrName(mapping->inputNo + j, name); - attributes.push_back(rewriter.getNamedAttr(name, attr)); + if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { + SmallVector newArgAttrs( + llvmType.cast().getNumParams()); + for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { + auto mapping = result.getInputMapping(i); + assert(mapping.hasValue() && + "unexpected deletion of function argument"); + for (size_t j = 0; j < mapping->size; ++j) + newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; } + attributes.push_back( + rewriter.getNamedAttr(function_like_impl::getArgDictAttrName(), + rewriter.getArrayAttr(newArgAttrs))); } // Create an LLVM function, use external linkage by default until MLIR diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -599,9 +599,9 @@ return success(); SmallVector argAttrs; bool isVariadic = false; - return impl::parseFunctionArgumentList(parser, /*allowAttributes=*/false, - /*allowVariadic=*/false, argNames, - argTypes, argAttrs, isVariadic); + return function_like_impl::parseFunctionArgumentList( + parser, /*allowAttributes=*/false, + /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic); } static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, @@ -717,7 +717,7 @@ return failure(); auto signatureLocation = parser.getCurrentLocation(); - if (failed(impl::parseFunctionSignature( + if (failed(function_like_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, isVariadic, resultTypes, resultAttrs))) return failure(); @@ -756,7 +756,8 @@ // Parse attributes. if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); + function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, + resultAttrs); // Parse the region. If no argument names were provided, take all names // (including those of attributions) from the entry block. @@ -781,33 +782,22 @@ p.printSymbolName(op.getName()); FunctionType type = op.getType(); - impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), - /*isVariadic=*/false, type.getResults()); + function_like_impl::printFunctionSignature( + p, op.getOperation(), type.getInputs(), + /*isVariadic=*/false, type.getResults()); printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); if (op.isKernel()) p << ' ' << op.getKernelKeyword(); - impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), - type.getNumResults(), - {op.getNumWorkgroupAttributionsAttrName(), - GPUDialect::getKernelFuncAttrName()}); + function_like_impl::printFunctionAttributes( + p, op.getOperation(), type.getNumInputs(), type.getNumResults(), + {op.getNumWorkgroupAttributionsAttrName(), + GPUDialect::getKernelFuncAttrName()}); p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); } -void GPUFuncOp::setType(FunctionType newType) { - auto oldType = getType(); - assert(newType.getNumResults() == oldType.getNumResults() && - "unimplemented: changes to the number of results"); - - SmallVector nameBuf; - for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) - (*this)->removeAttr(getArgAttrName(i, nameBuf)); - - (*this)->setAttr(getTypeAttrName(), TypeAttr::get(newType)); -} - /// Hook for FunctionLike verifier. LogicalResult GPUFuncOp::verifyType() { Type type = getTypeAttr().getValue(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1732,21 +1732,19 @@ if (argAttrs.empty()) return; - unsigned numInputs = type.cast().getNumParams(); - assert(numInputs == argAttrs.size() && + assert(type.cast().getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - SmallString<8> argAttrName; - for (unsigned i = 0; i < numInputs; ++i) - if (DictionaryAttr argDict = argAttrs[i]) - result.addAttribute(getArgAttrName(i, argAttrName), argDict); + function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, + /*resultAttrs=*/llvm::None); } // Builds an LLVM function type from the given lists of input and output types. // Returns a null type if any of the types provided are non-LLVM types, or if // there is more than one output type. -static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, - ArrayRef inputs, ArrayRef outputs, - impl::VariadicFlag variadicFlag) { +static Type +buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, + ArrayRef inputs, ArrayRef outputs, + function_like_impl::VariadicFlag variadicFlag) { Builder &b = parser.getBuilder(); if (outputs.size() > 1) { parser.emitError(loc, "failed to construct function type: expected zero or " @@ -1803,22 +1801,23 @@ auto signatureLocation = parser.getCurrentLocation(); if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes) || - impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs, - argTypes, argAttrs, isVariadic, resultTypes, - resultAttrs)) + function_like_impl::parseFunctionSignature( + parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, + isVariadic, resultTypes, resultAttrs)) return failure(); auto type = buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, - impl::VariadicFlag(isVariadic)); + function_like_impl::VariadicFlag(isVariadic)); if (!type) return failure(); - result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(function_like_impl::getTypeAttrName(), + TypeAttr::get(type)); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs, - resultAttrs); + function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result, + argAttrs, resultAttrs); auto *body = result.addRegion(); OptionalParseResult parseResult = parser.parseOptionalRegion( @@ -1846,9 +1845,10 @@ if (!returnType.isa()) resTypes.push_back(returnType); - impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes); - impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(), - {getLinkageAttrName()}); + function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), + resTypes); + function_like_impl::printFunctionAttributes( + p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); // Print the body if this is not an external function. Region &body = op.body(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -99,7 +99,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.startRootUpdate(op); - Region ®ion = mlir::impl::getFunctionBody(op); + Region ®ion = function_like_impl::getFunctionBody(op); SmallVector conversions; for (Block &block : llvm::drop_begin(region, 1)) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1783,13 +1783,14 @@ // Parse the function signature. bool isVariadic = false; - if (impl::parseFunctionSignature(parser, /*allowVariadic=*/false, entryArgs, - argTypes, argAttrs, isVariadic, resultTypes, - resultAttrs)) + if (function_like_impl::parseFunctionSignature( + parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, + isVariadic, resultTypes, resultAttrs)) return failure(); auto fnType = builder.getFunctionType(argTypes, resultTypes); - state.addAttribute(impl::getTypeAttrName(), TypeAttr::get(fnType)); + state.addAttribute(function_like_impl::getTypeAttrName(), + TypeAttr::get(fnType)); // Parse the optional function control keyword. spirv::FunctionControl fnControl; @@ -1803,7 +1804,8 @@ // Add the attributes to the function arguments. assert(argAttrs.size() == argTypes.size()); assert(resultAttrs.size() == resultTypes.size()); - impl::addArgAndResultAttrs(builder, state, argAttrs, resultAttrs); + function_like_impl::addArgAndResultAttrs(builder, state, argAttrs, + resultAttrs); // Parse the optional function body. auto *body = state.addRegion(); @@ -1817,11 +1819,12 @@ printer << spirv::FuncOp::getOperationName() << " "; printer.printSymbolName(fnOp.sym_name()); auto fnType = fnOp.getType(); - impl::printFunctionSignature(printer, fnOp, fnType.getInputs(), - /*isVariadic=*/false, fnType.getResults()); + function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(), + /*isVariadic=*/false, + fnType.getResults()); printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control()) << "\""; - impl::printFunctionAttributes( + function_like_impl::printFunctionAttributes( printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(), {spirv::attributeName()}); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -582,7 +582,7 @@ // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.first != impl::getTypeAttrName() && + if (namedAttr.first != function_like_impl::getTypeAttrName() && namedAttr.first != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.first, namedAttr.second); } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -106,27 +106,25 @@ if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - SmallString<8> argAttrName; - for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) - if (DictionaryAttr argDict = argAttrs[i]) - state.addAttribute(getArgAttrName(i, argAttrName), argDict); + function_like_impl::addArgAndResultAttrs(builder, state, argAttrs, + /*resultAttrs=*/llvm::None); } static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, - ArrayRef results, impl::VariadicFlag, - std::string &) { + ArrayRef results, + function_like_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; - return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false, - buildFuncType); + return function_like_impl::parseFunctionLikeOp( + parser, result, /*allowVariadic=*/false, buildFuncType); } static void print(FuncOp op, OpAsmPrinter &p) { FunctionType fnType = op.getType(); - impl::printFunctionLikeOp(p, op, fnType.getInputs(), /*isVariadic=*/false, - fnType.getResults()); + function_like_impl::printFunctionLikeOp( + p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); } static LogicalResult verify(FuncOp op) { @@ -170,30 +168,39 @@ /// to cloned sub-values with the corresponding value that is copied, and adds /// those mappings to the mapper. FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { - FunctionType newType = getType(); + // Create the new function. + FuncOp newFunc = cast(getOperation()->cloneWithoutRegions()); // If the function has a body, then the user might be deleting arguments to // the function by specifying them in the mapper. If so, we don't add the // argument to the input type vector. - bool isExternalFn = isExternal(); - if (!isExternalFn) { - SmallVector inputTypes; - inputTypes.reserve(newType.getNumInputs()); - for (unsigned i = 0, e = getNumArguments(); i != e; ++i) + if (!isExternal()) { + FunctionType oldType = getType(); + + unsigned oldNumArgs = oldType.getNumInputs(); + SmallVector newInputs; + newInputs.reserve(oldNumArgs); + for (unsigned i = 0; i != oldNumArgs; ++i) if (!mapper.contains(getArgument(i))) - inputTypes.push_back(newType.getInput(i)); - newType = FunctionType::get(getContext(), inputTypes, newType.getResults()); + newInputs.push_back(oldType.getInput(i)); + + /// If any of the arguments were dropped, update the type and drop any + /// necessary argument attributes. + if (newInputs.size() != oldNumArgs) { + newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, + oldType.getResults())); + + if (ArrayAttr argAttrs = getAllArgAttrs()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(newInputs.size()); + for (unsigned i = 0; i != oldNumArgs; ++i) + if (!mapper.contains(getArgument(i))) + newArgAttrs.push_back(argAttrs[i]); + newFunc.setAllArgAttrs(newArgAttrs); + } + } } - // Create the new function. - FuncOp newFunc = cast(getOperation()->cloneWithoutRegions()); - newFunc.setType(newType); - - /// Set the argument attributes for arguments that aren't being replaced. - for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i) - if (isExternalFn || !mapper.contains(getArgument(i))) - newFunc.setArgAttrs(destI++, getArgAttrs(i)); - /// Clone the current function into the new one and return it. cloneInto(newFunc, mapper); return newFunc; diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -13,7 +13,7 @@ using namespace mlir; -ParseResult mlir::impl::parseFunctionArgumentList( +ParseResult mlir::function_like_impl::parseFunctionArgumentList( OpAsmParser &parser, bool allowAttributes, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, @@ -125,7 +125,7 @@ /// indicates whether functions with variadic arguments are supported. The /// trailing arguments are populated by this function with names, types and /// attributes of the arguments and those of the results. -ParseResult mlir::impl::parseFunctionSignature( +ParseResult mlir::function_like_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, @@ -140,29 +140,53 @@ return success(); } -void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs) { - // Add the attributes to the function arguments. - SmallString<8> attrNameBuf; - for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) - if (!argAttrs[i].empty()) - result.addAttribute(getArgAttrName(i, attrNameBuf), - builder.getDictionaryAttr(argAttrs[i])); +/// Implementation of `addArgAndResultAttrs` that is attribute list type +/// agnostic. +template +static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result, + ArrayRef argAttrs, + ArrayRef resultAttrs, + AttrArrayBuildFnT &&buildAttrArrayFn) { + auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); }; + // Add the attributes to the function arguments. + if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) { + ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs)); + result.addAttribute(function_like_impl::getArgDictAttrName(), attrDicts); + } // Add the attributes to the function results. - for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i) - if (!resultAttrs[i].empty()) - result.addAttribute(getResultAttrName(i, attrNameBuf), - builder.getDictionaryAttr(resultAttrs[i])); + if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) { + ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs)); + result.addAttribute(function_like_impl::getResultDictAttrName(), attrDicts); + } +} + +void mlir::function_like_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs) { + auto buildFn = [](ArrayRef attrs) { + return ArrayRef(attrs.data(), attrs.size()); + }; + addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); +} +void mlir::function_like_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs) { + MLIRContext *context = builder.getContext(); + auto buildFn = [=](ArrayRef attrs) { + return llvm::to_vector<8>( + llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute { + return attrList.getDictionary(context); + })); + }; + addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); } /// Parser implementation for function-like operations. Uses `funcTypeBuilder` /// to construct the custom function type given lists of input and output types. -ParseResult -mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, - bool allowVariadic, - mlir::impl::FuncTypeBuilder funcTypeBuilder) { +ParseResult mlir::function_like_impl::parseFunctionLikeOp( + OpAsmParser &parser, OperationState &result, bool allowVariadic, + FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; SmallVector argAttrs; SmallVector resultAttrs; @@ -187,13 +211,14 @@ return failure(); std::string errorMessage; - if (auto type = funcTypeBuilder(builder, argTypes, resultTypes, - impl::VariadicFlag(isVariadic), errorMessage)) - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); - else + Type type = funcTypeBuilder(builder, argTypes, resultTypes, + VariadicFlag(isVariadic), errorMessage); + if (!type) { return parser.emitError(signatureLocation) << "failed to construct function type" << (errorMessage.empty() ? "" : ": ") << errorMessage; + } + result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); // If function attributes are present, parse them. NamedAttrList parsedAttributes; @@ -236,35 +261,38 @@ return success(); } -// Print a function result list. +/// Print a function result list. The provided `attrs` must either be null, or +/// contain a set of DictionaryAttrs of the same arity as `types`. static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, - ArrayRef> attrs) { + ArrayAttr attrs) { assert(!types.empty() && "Should not be called for empty result list."); + assert((!attrs || attrs.size() == types.size()) && + "Invalid number of attributes."); + auto &os = p.getStream(); - bool needsParens = - types.size() > 1 || types[0].isa() || !attrs[0].empty(); + bool needsParens = types.size() > 1 || types[0].isa() || + (attrs && !attrs[0].cast().empty()); if (needsParens) os << '('; - llvm::interleaveComma( - llvm::zip(types, attrs), os, - [&](const std::tuple> &t) { - p.printType(std::get<0>(t)); - p.printOptionalAttrDict(std::get<1>(t)); - }); + llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { + p.printType(types[i]); + if (attrs) + p.printOptionalAttrDict(attrs[i].cast().getValue()); + }); if (needsParens) os << ')'; } /// Print the signature of the function-like operation `op`. Assumes `op` has /// the FunctionLike trait and passed the verification. -void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, - bool isVariadic, - ArrayRef resultTypes) { +void mlir::function_like_impl::printFunctionSignature( + OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes) { Region &body = op->getRegion(0); bool isExternal = body.empty(); p << '('; + ArrayAttr argAttrs = op->getAttrOfType(getArgDictAttrName()); for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { if (i > 0) p << ", "; @@ -275,7 +303,8 @@ } p.printType(argTypes[i]); - p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i)); + if (argAttrs) + p.printOptionalAttrDict(argAttrs[i].cast().getValue()); } if (isVariadic) { @@ -288,9 +317,7 @@ if (!resultTypes.empty()) { p.getStream() << " -> "; - SmallVector, 4> resultAttrs; - for (int i = 0, e = resultTypes.size(); i < e; ++i) - resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i)); + auto resultAttrs = op->getAttrOfType(getResultDictAttrName()); printFunctionResultList(p, resultTypes, resultAttrs); } } @@ -300,39 +327,25 @@ /// function-like operation internally are not printed. Nothing is printed /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and /// passed the verification. -void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op, - unsigned numInputs, - unsigned numResults, - ArrayRef elided) { +void mlir::function_like_impl::printFunctionAttributes( + OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, + ArrayRef elided) { // Print out function attributes, if present. SmallVector ignoredAttrs = { - ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; + ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), + getArgDictAttrName(), getResultDictAttrName()}; ignoredAttrs.append(elided.begin(), elided.end()); - SmallString<8> attrNameBuf; - - // Ignore any argument attributes. - std::vector> argAttrStorage; - for (unsigned i = 0; i != numInputs; ++i) - if (op->getAttr(getArgAttrName(i, attrNameBuf))) - argAttrStorage.emplace_back(attrNameBuf); - ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); - - // Ignore any result attributes. - std::vector> resultAttrStorage; - for (unsigned i = 0; i != numResults; ++i) - if (op->getAttr(getResultAttrName(i, attrNameBuf))) - resultAttrStorage.emplace_back(attrNameBuf); - ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end()); - p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); } /// Printer implementation for function-like operations. Accepts lists of /// argument and result types to use while printing. -void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes) { +void mlir::function_like_impl::printFunctionLikeOp(OpAsmPrinter &p, + Operation *op, + ArrayRef argTypes, + bool isVariadic, + ArrayRef resultTypes) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -31,103 +31,199 @@ // Function Arguments and Results. //===----------------------------------------------------------------------===// -void mlir::impl::eraseFunctionArguments(Operation *op, - ArrayRef argIndices, - unsigned originalNumArgs, - Type newType) { +static bool isEmptyAttrDict(Attribute attr) { + return attr.cast().empty(); +} + +DictionaryAttr mlir::function_like_impl::getArgAttrDict(Operation *op, + unsigned index) { + ArrayAttr attrs = op->getAttrOfType(getArgDictAttrName()); + DictionaryAttr argAttrs = + attrs ? attrs[index].cast() : DictionaryAttr(); + return (argAttrs && !argAttrs.empty()) ? argAttrs : DictionaryAttr(); +} + +DictionaryAttr mlir::function_like_impl::getResultAttrDict(Operation *op, + unsigned index) { + ArrayAttr attrs = op->getAttrOfType(getResultDictAttrName()); + DictionaryAttr resAttrs = + attrs ? attrs[index].cast() : DictionaryAttr(); + return (resAttrs && !resAttrs.empty()) ? resAttrs : DictionaryAttr(); +} + +void mlir::function_like_impl::detail::setArgResAttrDict( + Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, + DictionaryAttr attrs) { + ArrayAttr allAttrs = op->getAttrOfType(attrName); + if (!allAttrs) { + if (attrs.empty()) + return; + + // If this attribute is not empty, we need to create a new attribute array. + SmallVector newAttrs(numTotalIndices, + DictionaryAttr::get(op->getContext())); + newAttrs[index] = attrs; + op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + return; + } + // Check to see if the attribute is different from what we already have. + if (allAttrs[index] == attrs) + return; + + // If it is, check to see if the attribute array would now contain only empty + // dictionaries. + ArrayRef rawAttrArray = allAttrs.getValue(); + if (attrs.empty() && + llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && + llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) { + op->removeAttr(attrName); + return; + } + + // Otherwise, create a new attribute array with the updated dictionary. + SmallVector newAttrs(rawAttrArray.begin(), rawAttrArray.end()); + newAttrs[index] = attrs; + op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); +} + +/// Set all of the argument or result attribute dictionaries for a function. +static void setAllArgResAttrDicts(Operation *op, StringRef attrName, + ArrayRef attrs) { + if (llvm::all_of(attrs, isEmptyAttrDict)) + op->removeAttr(attrName); + else + op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); +} + +void mlir::function_like_impl::setAllArgAttrDicts( + Operation *op, ArrayRef attrs) { + setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} +void mlir::function_like_impl::setAllArgAttrDicts(Operation *op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, getArgDictAttrName(), + llvm::to_vector<8>(wrappedAttrs)); +} + +void mlir::function_like_impl::setAllResultAttrDicts( + Operation *op, ArrayRef attrs) { + setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} +void mlir::function_like_impl::setAllResultAttrDicts( + Operation *op, ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, getResultDictAttrName(), + llvm::to_vector<8>(wrappedAttrs)); +} + +void mlir::function_like_impl::eraseFunctionArguments( + Operation *op, ArrayRef argIndices, unsigned originalNumArgs, + Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. // - Block arguments of entry block. Block &entry = op->getRegion(0).front(); - SmallString<8> nameBuf; - - // Collect arg attrs to set. - SmallVector newArgAttrs; - iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { - newArgAttrs.emplace_back(getArgAttrDict(op, i)); - }); - - // Remove any arg attrs that are no longer needed. - for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i) - op->removeAttr(getArgAttrName(i, nameBuf)); - - // Set the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); - // Set the new arg attrs, or remove them if empty. - for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) { - auto nameAttr = getArgAttrName(i, nameBuf); - if (newArgAttrs[i] && !newArgAttrs[i].empty()) - op->setAttr(nameAttr, newArgAttrs[i]); - else - op->removeAttr(nameAttr); + // Update the argument attributes of the function. + if (auto argAttrs = op->getAttrOfType(getArgDictAttrName())) { + SmallVector newArgAttrs; + newArgAttrs.reserve(argAttrs.size()); + iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { + newArgAttrs.emplace_back(argAttrs[i].cast()); + }); + setAllArgAttrDicts(op, newArgAttrs); } - // Update the entry block's arguments. + // Update the function type and any entry block arguments. + op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); entry.eraseArguments(argIndices); } -void mlir::impl::eraseFunctionResults(Operation *op, - ArrayRef resultIndices, - unsigned originalNumResults, - Type newType) { +void mlir::function_like_impl::eraseFunctionResults( + Operation *op, ArrayRef resultIndices, + unsigned originalNumResults, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. - SmallString<8> nameBuf; - - // Collect result attrs to set. - SmallVector newResultAttrs; - iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { - newResultAttrs.emplace_back(getResultAttrDict(op, i)); - }); - // Remove any result attrs that are no longer needed. - for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i) - op->removeAttr(getResultAttrName(i, nameBuf)); + // Update the result attributes of the function. + if (auto resAttrs = op->getAttrOfType(getResultDictAttrName())) { + SmallVector newResultAttrs; + newResultAttrs.reserve(resAttrs.size()); + iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { + newResultAttrs.emplace_back(resAttrs[i].cast()); + }); + setAllResultAttrDicts(op, newResultAttrs); + } - // Set the function type. + // Update the function type. op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); - - // Set the new result attrs, or remove them if empty. - for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) { - auto nameAttr = getResultAttrName(i, nameBuf); - if (newResultAttrs[i] && !newResultAttrs[i].empty()) - op->setAttr(nameAttr, newResultAttrs[i]); - else - op->removeAttr(nameAttr); - } } //===----------------------------------------------------------------------===// // Function type signature. //===----------------------------------------------------------------------===// -FunctionType mlir::impl::getFunctionType(Operation *op) { +FunctionType mlir::function_like_impl::getFunctionType(Operation *op) { assert(op->hasTrait()); - return op->getAttrOfType(mlir::impl::getTypeAttrName()) + return op->getAttrOfType(getTypeAttrName()) .getValue() .cast(); } -void mlir::impl::setFunctionType(Operation *op, FunctionType newType) { +void mlir::function_like_impl::setFunctionType(Operation *op, + FunctionType newType) { assert(op->hasTrait()); - SmallVector nameBuf; FunctionType oldType = getFunctionType(op); - - for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) - op->removeAttr(getArgAttrName(i, nameBuf)); - for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++) - op->removeAttr(getResultAttrName(i, nameBuf)); op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + + // Functor used to update the argument and result attributes of the function. + auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, + unsigned newCount, auto setAttrFn) { + if (oldCount == newCount) + return; + // The new type has no arguments/results, just drop the attribute. + if (newCount == 0) { + op->removeAttr(attrName); + return; + } + ArrayAttr attrs = op->getAttrOfType(attrName); + if (!attrs) + return; + + // The new type has less arguments/results, take the first N attributes. + if (newCount < oldCount) + return setAttrFn(op, attrs.getValue().take_front(newCount)); + + // Otherwise, the new type has more arguments/results. Initialize the new + // arguments/results with empty attributes. + SmallVector newAttrs(attrs.begin(), attrs.end()); + newAttrs.resize(newCount); + setAttrFn(op, newAttrs); + }; + + // Update the argument and result attributes. + updateAttrFn(function_like_impl::getArgDictAttrName(), oldType.getNumInputs(), + newType.getNumInputs(), [&](Operation *op, auto &&attrs) { + setAllArgAttrDicts(op, attrs); + }); + updateAttrFn( + function_like_impl::getResultDictAttrName(), oldType.getNumResults(), + newType.getNumResults(), + [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); } //===----------------------------------------------------------------------===// // Function body. //===----------------------------------------------------------------------===// -Region &mlir::impl::getFunctionBody(Operation *op) { +Region &mlir::function_like_impl::getFunctionBody(Operation *op) { assert(op->hasTrait()); return op->getRegion(0); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2628,15 +2628,15 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = mlir::impl::getFunctionType(op); + FunctionType type = function_like_impl::getFunctionType(op); // Convert the original function types. TypeConverter::SignatureConversion result(type.getNumInputs()); SmallVector newResults; if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || failed(typeConverter->convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op), - *typeConverter, &result))) + failed(rewriter.convertRegionTypes( + &function_like_impl::getFunctionBody(op), *typeConverter, &result))) return failure(); // Update the function signature in-place. @@ -2644,7 +2644,7 @@ result.getConvertedTypes(), newResults); rewriter.updateRootInPlace( - op, [&] { mlir::impl::setFunctionType(op, newType); }); + op, [&] { function_like_impl::setFunctionType(op, newType); }); return success(); } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -35,7 +35,7 @@ // CHECK: attributes {xxx = {yyy = 42 : i64}} "llvm.func"() ({ }) {sym_name = "qux", type = !llvm.func, i64)>, - arg0 = {llvm.noalias = true}, xxx = {yyy = 42}} : () -> () + arg_attrs = [{llvm.noalias = true}, {}], xxx = {yyy = 42}} : () -> () // CHECK: llvm.func @roundtrip1() llvm.func @roundtrip1() diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -94,3 +94,22 @@ // expected-error@+1 {{'type' is an inferred attribute and should not be specified in the explicit attribute dictionary}} func private @invalid_symbol_type_attr() attributes { type = "x" } +// ----- + +// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}} +func private @invalid_arg_attrs() attributes { arg_attrs = [{}] } + +// ----- + +// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} +func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] } + +// ----- + +// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}} +func private @invalid_res_attrs() attributes { res_attrs = [{}] } + +// ----- + +// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} +func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] } diff --git a/mlir/test/IR/test-func-set-type.mlir b/mlir/test/IR/test-func-set-type.mlir --- a/mlir/test/IR/test-func-set-type.mlir +++ b/mlir/test/IR/test-func-set-type.mlir @@ -9,7 +9,6 @@ // Test case: The setType call needs to erase some arg attrs. // CHECK: func private @erase_arg(f32 {test.A}) -// CHECK-NOT: attributes{{.*arg[0-9]}} func private @t(f32) func private @erase_arg(%arg0: f32 {test.A}, %arg1: f32 {test.B}) attributes {test.set_type_from = @t} @@ -19,7 +18,6 @@ // Test case: The setType call needs to erase some result attrs. // CHECK: func private @erase_result() -> (f32 {test.A}) -// CHECK-NOT: attributes{{.*result[0-9]}} func private @t() -> (f32) func private @erase_result() -> (f32 {test.A}, f32 {test.B}) attributes {test.set_type_from = @t}