diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -439,6 +439,10 @@ MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos); +/// Populates the default but unset attributes of the operation. +MLIR_CAPI_EXPORTED void +mlirOperationPopulateDefaultAttributes(MlirOperation op); + /// Returns the number of attributes attached to the operation. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op); diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -431,6 +431,7 @@ OperationName::PrintAssemblyFn printFn; OperationName::FoldHookFn foldHookFn; OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; + OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn; friend ExtensibleDialect; }; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -21,6 +21,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include @@ -182,6 +183,10 @@ static void getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) {} + /// This hook populates any unset default attrs. + static void populateDefaultAttrs(const RegisteredOperationName &, + NamedAttrList &, MLIRContext *) {} + protected: /// If the concrete type didn't implement a custom verifier hook, just fall /// back to this one which accepts everything. @@ -1869,6 +1874,10 @@ OpState::printOpName(op, p, defaultDialect); return cast(op).print(p); } + /// Implementation of `PopulateDefaultAttrsFn` OperationName hook. + static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() { + return ConcreteType::populateDefaultAttrs; + } /// Implementation of `VerifyInvariantsFn` OperationName hook. static LogicalResult verifyInvariants(Operation *op) { static_assert(hasNoDataMembers(), diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -467,6 +467,15 @@ setAttrs(attrs.getDictionary(getContext())); } + /// Sets default attributes on unset attributes. + void populateDefaultAttrs() { + if (auto registered = getRegisteredInfo()) { + NamedAttrList attrs(getAttrDictionary()); + registered->populateDefaultAttrs(attrs, getContext()); + setAttrs(attrs.getDictionary(getContext())); + } + } + //===--------------------------------------------------------------------===// // Blocks //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -17,6 +17,7 @@ #include "mlir/IR/BlockSupport.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" +#include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -36,6 +37,7 @@ class DictionaryAttr; class ElementsAttr; class MutableOperandRangeRange; +class NamedAttrList; class Operation; struct OperationState; class OpAsmParser; @@ -69,6 +71,8 @@ using HasTraitFn = llvm::unique_function; using ParseAssemblyFn = llvm::unique_function; + using PopulateDefaultAttrsFn = llvm::unique_function; using PrintAssemblyFn = llvm::unique_function; using VerifyInvariantsFn = @@ -112,6 +116,7 @@ GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; HasTraitFn hasTraitFn; ParseAssemblyFn parseAssemblyFn; + PopulateDefaultAttrsFn populateDefaultAttrsFn; PrintAssemblyFn printAssemblyFn; VerifyInvariantsFn verifyInvariantsFn; VerifyRegionInvariantsFn verifyRegionInvariantsFn; @@ -254,7 +259,8 @@ T::getParseAssemblyFn(), T::getPrintAssemblyFn(), T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(), T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(), - T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames()); + T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(), + T::getPopulateDefaultAttrsFn()); } /// The use of this method is in general discouraged in favor of /// 'insert(dialect)'. @@ -266,7 +272,8 @@ FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, - ArrayRef attrNames); + ArrayRef attrNames, + PopulateDefaultAttrsFn &&populateDefaultAttrs); /// Return the dialect this operation is registered to. Dialect &getDialect() const { return *impl->dialect; } @@ -364,6 +371,10 @@ return impl->attributeNames; } + /// This hook implements the method to populate defaults attributes that are + /// unset. + void populateDefaultAttrs(NamedAttrList &attrs, MLIRContext *context) const; + /// Represent the operation name as an opaque pointer. (Used to support /// PointerLikeTypeTraits). static RegisteredOperationName getFromOpaquePointer(const void *pointer) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -445,6 +445,10 @@ return wrap(unwrap(op)->getSuccessor(static_cast(pos))); } +void mlirOperationPopulateDefaultAttributes(MlirOperation op) { + unwrap(op)->populateDefaultAttrs(); +} + intptr_t mlirOperationGetNumAttributes(MlirOperation op) { return static_cast(unwrap(op)->getAttrs().size()); } diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -447,7 +447,8 @@ std::move(op->printFn), std::move(op->verifyFn), std::move(op->verifyRegionFn), std::move(op->foldHookFn), std::move(op->getCanonicalizationPatternsFn), - detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}); + detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}, + std::move(op->getPopulateDefaultAttrsFn)); } bool ExtensibleDialect::classof(const Dialect *dialect) { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -707,6 +707,11 @@ return impl->parseAssemblyFn(parser, result); } +void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs, + MLIRContext *context) const { + impl->populateDefaultAttrsFn(*this, attrs, context); +} + void RegisteredOperationName::insert( StringRef name, Dialect &dialect, TypeID typeID, ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, @@ -714,7 +719,8 @@ VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, - ArrayRef attrNames) { + ArrayRef attrNames, + PopulateDefaultAttrsFn &&populateDefaultAttrs) { MLIRContext *ctx = dialect.getContext(); auto &ctxImpl = ctx->getImpl(); assert(ctxImpl.multiThreadedExecutionContext == 0 && @@ -769,6 +775,7 @@ impl.verifyInvariantsFn = std::move(verifyInvariants); impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants); impl.attributeNames = cachedAttrNames; + impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs); } //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -430,6 +430,9 @@ // Generates getters for named successors. void genNamedSuccessorGetters(); + // Generates the method to populate default attributes. + void genSetDefaultBuildAttributes(); + // Generates builder methods for the operation. void genBuilder(); @@ -823,6 +826,7 @@ genAttrSetters(); genOptionalAttrRemovers(); genBuilder(); + genSetDefaultBuildAttributes(); genParser(); genPrinter(); genVerifier(); @@ -1587,6 +1591,46 @@ << llvm::join(resultTypes, ", ") << "});\n\n"; } +void OpEmitter::genSetDefaultBuildAttributes() { + // All done if no attributes have default values. + if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) { + return !named.attr.hasDefaultValue(); + })) + return; + + SmallVector paramList; + paramList.emplace_back("const ::mlir::RegisteredOperationName &", + "registeredOpName"); + paramList.emplace_back("::mlir::NamedAttrList &", "attributes"); + paramList.emplace_back("::mlir::MLIRContext *", "context"); + auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList); + ERROR_IF_PRUNED(m, "populateDefaultAttrs", op); + auto &body = m->body(); + body.indent(); + + // Set default attributes. + body << "::mlir::Builder " << odsBuilder << "(context);\n"; + body << "auto attrNames = registeredOpName.getAttributeNames();\n"; + StringMap attrIndex; + for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) { + attrIndex[it.value().first] = it.index(); + } + for (const NamedAttribute &namedAttr : op.getAttributes()) { + auto &attr = namedAttr.attr; + if (!attr.hasDefaultValue()) + continue; + auto index = attrIndex[namedAttr.name]; + body << "if (!attributes.get(attrNames[" << index << "])) {\n"; + FmtContext fctx; + fctx.withBuilder(odsBuilder); + std::string defaultValue = std::string( + tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n", + index, defaultValue); + body.unindent() << "}\n"; + } +} + void OpEmitter::genInferredTypeCollectiveParamBuilder() { SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); @@ -1869,7 +1913,7 @@ auto numResults = op.getNumResults(); resultTypeNames.reserve(numResults); - paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OpBuilder &", odsBuilder); paramList.emplace_back("::mlir::OperationState &", builderOpState); switch (typeParamKind) { @@ -2879,7 +2923,7 @@ tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); body << " if (!attr)\n attr = " << defaultValue << ";\n"; } - body << " return attr;\n"; + body << "return attr;\n"; }; { diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/OperationSupport.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/BitVector.h" @@ -271,4 +272,22 @@ attrs.assign({}); ASSERT_TRUE(attrs.empty()); } + +TEST(OperandStorageTest, PopulateDefaultAttrs) { + MLIRContext context; + context.getOrLoadDialect(); + Builder builder(&context); + + OpBuilder b(&context); + auto req1 = b.getI32IntegerAttr(10); + auto req2 = b.getI32IntegerAttr(60); + Operation *op = b.create(b.getUnknownLoc(), req1, nullptr, + nullptr, req2); + EXPECT_EQ(op->getAttr("default_valued_attr"), nullptr); + op->populateDefaultAttrs(); + auto opt = op->getAttr("default_valued_attr"); + EXPECT_NE(opt, nullptr) << *op; + + op->destroy(); +} } // namespace