diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h --- a/mlir/include/mlir/TableGen/OpClass.h +++ b/mlir/include/mlir/TableGen/OpClass.h @@ -88,13 +88,13 @@ Property property, bool declOnly); virtual ~OpMethod() = default; - OpMethodBody &body(); + OpMethodBody &body() { return methodBody; } // Returns true if this is a static method. - bool isStatic() const; + bool isStatic() const { return properties & MP_Static; } // Returns true if this is a private method. - bool isPrivate() const; + bool isPrivate() const { return properties & MP_Private; } // Writes the method as a declaration to the given `os`. virtual void writeDeclTo(raw_ostream &os) const; diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp --- a/mlir/lib/TableGen/Argument.cpp +++ b/mlir/lib/TableGen/Argument.cpp @@ -10,15 +10,12 @@ #include "llvm/TableGen/Record.h" using namespace mlir; +using namespace mlir::tblgen; -bool tblgen::NamedTypeConstraint::hasPredicate() const { +bool NamedTypeConstraint::hasPredicate() const { return !constraint.getPredicate().isNull(); } -bool tblgen::NamedTypeConstraint::isOptional() const { - return constraint.isOptional(); -} +bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); } -bool tblgen::NamedTypeConstraint::isVariadic() const { - return constraint.isVariadic(); -} +bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); } diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -16,6 +16,7 @@ #include "llvm/TableGen/Record.h" using namespace mlir; +using namespace mlir::tblgen; using llvm::CodeInit; using llvm::DefInit; @@ -28,41 +29,35 @@ static StringRef getValueAsString(const Init *init) { if (const auto *code = dyn_cast(init)) return code->getValue().trim(); - else if (const auto *str = dyn_cast(init)) + if (const auto *str = dyn_cast(init)) return str->getValue().trim(); return {}; } -tblgen::AttrConstraint::AttrConstraint(const Record *record) +AttrConstraint::AttrConstraint(const Record *record) : Constraint(Constraint::CK_Attr, record) { assert(isSubClassOf("AttrConstraint") && "must be subclass of TableGen 'AttrConstraint' class"); } -bool tblgen::AttrConstraint::isSubClassOf(StringRef className) const { +bool AttrConstraint::isSubClassOf(StringRef className) const { return def->isSubClassOf(className); } -tblgen::Attribute::Attribute(const Record *record) : AttrConstraint(record) { +Attribute::Attribute(const Record *record) : AttrConstraint(record) { assert(record->isSubClassOf("Attr") && "must be subclass of TableGen 'Attr' class"); } -tblgen::Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} +Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} -bool tblgen::Attribute::isDerivedAttr() const { - return isSubClassOf("DerivedAttr"); -} +bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); } -bool tblgen::Attribute::isTypeAttr() const { - return isSubClassOf("TypeAttrBase"); -} +bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); } -bool tblgen::Attribute::isEnumAttr() const { - return isSubClassOf("EnumAttrInfo"); -} +bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); } -StringRef tblgen::Attribute::getStorageType() const { +StringRef Attribute::getStorageType() const { const auto *init = def->getValueInit("storageType"); auto type = getValueAsString(init); if (type.empty()) @@ -70,35 +65,35 @@ return type; } -StringRef tblgen::Attribute::getReturnType() const { +StringRef Attribute::getReturnType() const { const auto *init = def->getValueInit("returnType"); return getValueAsString(init); } // Return the type constraint corresponding to the type of this attribute, or // None if this is not a TypedAttr. -llvm::Optional tblgen::Attribute::getValueType() const { +llvm::Optional Attribute::getValueType() const { if (auto *defInit = dyn_cast(def->getValueInit("valueType"))) - return tblgen::Type(defInit->getDef()); + return Type(defInit->getDef()); return llvm::None; } -StringRef tblgen::Attribute::getConvertFromStorageCall() const { +StringRef Attribute::getConvertFromStorageCall() const { const auto *init = def->getValueInit("convertFromStorage"); return getValueAsString(init); } -bool tblgen::Attribute::isConstBuildable() const { +bool Attribute::isConstBuildable() const { const auto *init = def->getValueInit("constBuilderCall"); return !getValueAsString(init).empty(); } -StringRef tblgen::Attribute::getConstBuilderTemplate() const { +StringRef Attribute::getConstBuilderTemplate() const { const auto *init = def->getValueInit("constBuilderCall"); return getValueAsString(init); } -tblgen::Attribute tblgen::Attribute::getBaseAttr() const { +Attribute Attribute::getBaseAttr() const { if (const auto *defInit = llvm::dyn_cast(def->getValueInit("baseAttr"))) { return Attribute(defInit).getBaseAttr(); @@ -106,178 +101,166 @@ return *this; } -bool tblgen::Attribute::hasDefaultValue() const { +bool Attribute::hasDefaultValue() const { const auto *init = def->getValueInit("defaultValue"); return !getValueAsString(init).empty(); } -StringRef tblgen::Attribute::getDefaultValue() const { +StringRef Attribute::getDefaultValue() const { const auto *init = def->getValueInit("defaultValue"); return getValueAsString(init); } -bool tblgen::Attribute::isOptional() const { - return def->getValueAsBit("isOptional"); -} +bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); } -StringRef tblgen::Attribute::getAttrDefName() const { +StringRef Attribute::getAttrDefName() const { if (def->isAnonymous()) { return getBaseAttr().def->getName(); } return def->getName(); } -StringRef tblgen::Attribute::getDerivedCodeBody() const { +StringRef Attribute::getDerivedCodeBody() const { assert(isDerivedAttr() && "only derived attribute has 'body' field"); return def->getValueAsString("body"); } -tblgen::Dialect tblgen::Attribute::getDialect() const { +Dialect Attribute::getDialect() const { return Dialect(def->getValueAsDef("dialect")); } -tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { +ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { assert(def->isSubClassOf("ConstantAttr") && "must be subclass of TableGen 'ConstantAttr' class"); } -tblgen::Attribute tblgen::ConstantAttr::getAttribute() const { +Attribute ConstantAttr::getAttribute() const { return Attribute(def->getValueAsDef("attr")); } -StringRef tblgen::ConstantAttr::getConstantValue() const { +StringRef ConstantAttr::getConstantValue() const { return def->getValueAsString("value"); } -tblgen::EnumAttrCase::EnumAttrCase(const llvm::Record *record) - : Attribute(record) { +EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) { assert(isSubClassOf("EnumAttrCaseInfo") && "must be subclass of TableGen 'EnumAttrInfo' class"); } -tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) +EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) : EnumAttrCase(init->getDef()) {} -bool tblgen::EnumAttrCase::isStrCase() const { - return isSubClassOf("StrEnumAttrCase"); -} +bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); } -StringRef tblgen::EnumAttrCase::getSymbol() const { +StringRef EnumAttrCase::getSymbol() const { return def->getValueAsString("symbol"); } -StringRef tblgen::EnumAttrCase::getStr() const { - return def->getValueAsString("str"); -} +StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); } -int64_t tblgen::EnumAttrCase::getValue() const { - return def->getValueAsInt("value"); -} +int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } -const llvm::Record &tblgen::EnumAttrCase::getDef() const { return *def; } +const llvm::Record &EnumAttrCase::getDef() const { return *def; } -tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { +EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { assert(isSubClassOf("EnumAttrInfo") && "must be subclass of TableGen 'EnumAttr' class"); } -tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} +EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} -tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init) - : EnumAttr(init->getDef()) {} +EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} -bool tblgen::EnumAttr::classof(const Attribute *attr) { +bool EnumAttr::classof(const Attribute *attr) { return attr->isSubClassOf("EnumAttrInfo"); } -bool tblgen::EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } +bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } -StringRef tblgen::EnumAttr::getEnumClassName() const { +StringRef EnumAttr::getEnumClassName() const { return def->getValueAsString("className"); } -StringRef tblgen::EnumAttr::getCppNamespace() const { +StringRef EnumAttr::getCppNamespace() const { return def->getValueAsString("cppNamespace"); } -StringRef tblgen::EnumAttr::getUnderlyingType() const { +StringRef EnumAttr::getUnderlyingType() const { return def->getValueAsString("underlyingType"); } -StringRef tblgen::EnumAttr::getUnderlyingToSymbolFnName() const { +StringRef EnumAttr::getUnderlyingToSymbolFnName() const { return def->getValueAsString("underlyingToSymbolFnName"); } -StringRef tblgen::EnumAttr::getStringToSymbolFnName() const { +StringRef EnumAttr::getStringToSymbolFnName() const { return def->getValueAsString("stringToSymbolFnName"); } -StringRef tblgen::EnumAttr::getSymbolToStringFnName() const { +StringRef EnumAttr::getSymbolToStringFnName() const { return def->getValueAsString("symbolToStringFnName"); } -StringRef tblgen::EnumAttr::getSymbolToStringFnRetType() const { +StringRef EnumAttr::getSymbolToStringFnRetType() const { return def->getValueAsString("symbolToStringFnRetType"); } -StringRef tblgen::EnumAttr::getMaxEnumValFnName() const { +StringRef EnumAttr::getMaxEnumValFnName() const { return def->getValueAsString("maxEnumValFnName"); } -std::vector tblgen::EnumAttr::getAllCases() const { +std::vector EnumAttr::getAllCases() const { const auto *inits = def->getValueAsListInit("enumerants"); - std::vector cases; + std::vector cases; cases.reserve(inits->size()); for (const llvm::Init *init : *inits) { - cases.push_back(tblgen::EnumAttrCase(cast(init))); + cases.push_back(EnumAttrCase(cast(init))); } return cases; } -tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record *record) - : def(record) { +StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) { assert(def->isSubClassOf("StructFieldAttr") && "must be subclass of TableGen 'StructFieldAttr' class"); } -tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record &record) +StructFieldAttr::StructFieldAttr(const llvm::Record &record) : StructFieldAttr(&record) {} -tblgen::StructFieldAttr::StructFieldAttr(const llvm::DefInit *init) +StructFieldAttr::StructFieldAttr(const llvm::DefInit *init) : StructFieldAttr(init->getDef()) {} -StringRef tblgen::StructFieldAttr::getName() const { +StringRef StructFieldAttr::getName() const { return def->getValueAsString("name"); } -tblgen::Attribute tblgen::StructFieldAttr::getType() const { +Attribute StructFieldAttr::getType() const { auto init = def->getValueInit("type"); - return tblgen::Attribute(cast(init)); + return Attribute(cast(init)); } -tblgen::StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) { +StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) { assert(isSubClassOf("StructAttr") && "must be subclass of TableGen 'StructAttr' class"); } -tblgen::StructAttr::StructAttr(const llvm::DefInit *init) +StructAttr::StructAttr(const llvm::DefInit *init) : StructAttr(init->getDef()) {} -StringRef tblgen::StructAttr::getStructClassName() const { +StringRef StructAttr::getStructClassName() const { return def->getValueAsString("className"); } -StringRef tblgen::StructAttr::getCppNamespace() const { +StringRef StructAttr::getCppNamespace() const { Dialect dialect(def->getValueAsDef("structDialect")); return dialect.getCppNamespace(); } -std::vector -tblgen::StructAttr::getAllFields() const { - std::vector attributes; +std::vector StructAttr::getAllFields() const { + std::vector attributes; const auto *inits = def->getValueAsListInit("fields"); attributes.reserve(inits->size()); @@ -289,4 +272,4 @@ return attributes; } -const char *mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; +const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -13,18 +13,16 @@ #include "mlir/TableGen/Dialect.h" #include "llvm/TableGen/Record.h" -namespace mlir { -namespace tblgen { +using namespace mlir; +using namespace mlir::tblgen; -StringRef tblgen::Dialect::getName() const { - return def->getValueAsString("name"); -} +StringRef Dialect::getName() const { return def->getValueAsString("name"); } -StringRef tblgen::Dialect::getCppNamespace() const { +StringRef Dialect::getCppNamespace() const { return def->getValueAsString("cppNamespace"); } -std::string tblgen::Dialect::getCppClassName() const { +std::string Dialect::getCppClassName() const { // Simply use the name and remove any '_' tokens. std::string cppName = def->getName().str(); llvm::erase_if(cppName, [](char c) { return c == '_'; }); @@ -40,32 +38,32 @@ return ""; } -StringRef tblgen::Dialect::getSummary() const { +StringRef Dialect::getSummary() const { return getAsStringOrEmpty(*def, "summary"); } -StringRef tblgen::Dialect::getDescription() const { +StringRef Dialect::getDescription() const { return getAsStringOrEmpty(*def, "description"); } -llvm::Optional tblgen::Dialect::getExtraClassDeclaration() const { +llvm::Optional Dialect::getExtraClassDeclaration() const { auto value = def->getValueAsString("extraClassDeclaration"); return value.empty() ? llvm::Optional() : value; } -bool tblgen::Dialect::hasConstantMaterializer() const { +bool Dialect::hasConstantMaterializer() const { return def->getValueAsBit("hasConstantMaterializer"); } -bool tblgen::Dialect::hasOperationAttrVerify() const { +bool Dialect::hasOperationAttrVerify() const { return def->getValueAsBit("hasOperationAttrVerify"); } -bool tblgen::Dialect::hasRegionArgAttrVerify() const { +bool Dialect::hasRegionArgAttrVerify() const { return def->getValueAsBit("hasRegionArgAttrVerify"); } -bool tblgen::Dialect::hasRegionResultAttrVerify() const { +bool Dialect::hasRegionResultAttrVerify() const { return def->getValueAsBit("hasRegionResultAttrVerify"); } @@ -76,6 +74,3 @@ bool Dialect::operator<(const Dialect &other) const { return getName() < other.getName(); } - -} // end namespace tblgen -} // end namespace mlir diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp --- a/mlir/lib/TableGen/Format.cpp +++ b/mlir/lib/TableGen/Format.cpp @@ -21,28 +21,28 @@ // Marker to indicate an error happened when replacing a placeholder. const char *const kMarkerForNoSubst = ""; -FmtContext &tblgen::FmtContext::addSubst(StringRef placeholder, Twine subst) { +FmtContext &FmtContext::addSubst(StringRef placeholder, Twine subst) { customSubstMap[placeholder] = subst.str(); return *this; } -FmtContext &tblgen::FmtContext::withBuilder(Twine subst) { +FmtContext &FmtContext::withBuilder(Twine subst) { builtinSubstMap[PHKind::Builder] = subst.str(); return *this; } -FmtContext &tblgen::FmtContext::withOp(Twine subst) { +FmtContext &FmtContext::withOp(Twine subst) { builtinSubstMap[PHKind::Op] = subst.str(); return *this; } -FmtContext &tblgen::FmtContext::withSelf(Twine subst) { +FmtContext &FmtContext::withSelf(Twine subst) { builtinSubstMap[PHKind::Self] = subst.str(); return *this; } Optional -tblgen::FmtContext::getSubstFor(FmtContext::PHKind placeholder) const { +FmtContext::getSubstFor(FmtContext::PHKind placeholder) const { if (placeholder == FmtContext::PHKind::None || placeholder == FmtContext::PHKind::Custom) return {}; @@ -52,15 +52,14 @@ return StringRef(it->second); } -Optional -tblgen::FmtContext::getSubstFor(StringRef placeholder) const { +Optional FmtContext::getSubstFor(StringRef placeholder) const { auto it = customSubstMap.find(placeholder); if (it == customSubstMap.end()) return {}; return StringRef(it->second); } -FmtContext::PHKind tblgen::FmtContext::getPlaceHolderKind(StringRef str) { +FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) { return llvm::StringSwitch(str) .Case("_builder", FmtContext::PHKind::Builder) .Case("_op", FmtContext::PHKind::Op) @@ -70,7 +69,7 @@ } std::pair -tblgen::FmtObjectBase::splitFmtSegment(StringRef fmt) { +FmtObjectBase::splitFmtSegment(StringRef fmt) { size_t begin = fmt.find_first_of('$'); if (begin == StringRef::npos) { // No placeholders: the whole format string should be returned as a diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp --- a/mlir/lib/TableGen/OpClass.cpp +++ b/mlir/lib/TableGen/OpClass.cpp @@ -13,22 +13,23 @@ #include "llvm/Support/raw_ostream.h" using namespace mlir; +using namespace mlir::tblgen; //===----------------------------------------------------------------------===// // OpMethodSignature definitions //===----------------------------------------------------------------------===// -tblgen::OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name, - StringRef params) +OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name, + StringRef params) : returnType(retType), methodName(name), parameters(params) {} -void tblgen::OpMethodSignature::writeDeclTo(raw_ostream &os) const { +void OpMethodSignature::writeDeclTo(raw_ostream &os) const { os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName << "(" << parameters << ")"; } -void tblgen::OpMethodSignature::writeDefTo(raw_ostream &os, - StringRef namePrefix) const { +void OpMethodSignature::writeDefTo(raw_ostream &os, + StringRef namePrefix) const { // We need to remove the default values for parameters in method definition. // TODO: We are using '=' and ',' as delimiters for parameter // initializers. This is incorrect for initializer list with more than one @@ -50,7 +51,7 @@ << removeParamDefaultValue(parameters) << ")"; } -bool tblgen::OpMethodSignature::elideSpaceAfterType(StringRef type) { +bool OpMethodSignature::elideSpaceAfterType(StringRef type) { return type.empty() || type.endswith("&") || type.endswith("*"); } @@ -58,28 +59,27 @@ // OpMethodBody definitions //===----------------------------------------------------------------------===// -tblgen::OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} +OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} -tblgen::OpMethodBody &tblgen::OpMethodBody::operator<<(Twine content) { +OpMethodBody &OpMethodBody::operator<<(Twine content) { if (isEffective) body.append(content.str()); return *this; } -tblgen::OpMethodBody &tblgen::OpMethodBody::operator<<(int content) { +OpMethodBody &OpMethodBody::operator<<(int content) { if (isEffective) body.append(std::to_string(content)); return *this; } -tblgen::OpMethodBody & -tblgen::OpMethodBody::operator<<(const FmtObjectBase &content) { +OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) { if (isEffective) body.append(content.str()); return *this; } -void tblgen::OpMethodBody::writeTo(raw_ostream &os) const { +void OpMethodBody::writeTo(raw_ostream &os) const { auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); os << bodyRef; if (bodyRef.empty() || bodyRef.back() != '\n') @@ -90,18 +90,11 @@ // OpMethod definitions //===----------------------------------------------------------------------===// -tblgen::OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params, - OpMethod::Property property, bool declOnly) +OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params, + OpMethod::Property property, bool declOnly) : properties(property), isDeclOnly(declOnly), methodSignature(retType, name, params), methodBody(declOnly) {} - -tblgen::OpMethodBody &tblgen::OpMethod::body() { return methodBody; } - -bool tblgen::OpMethod::isStatic() const { return properties & MP_Static; } - -bool tblgen::OpMethod::isPrivate() const { return properties & MP_Private; } - -void tblgen::OpMethod::writeDeclTo(raw_ostream &os) const { +void OpMethod::writeDeclTo(raw_ostream &os) const { os.indent(2); if (isStatic()) os << "static "; @@ -109,7 +102,7 @@ os << ";"; } -void tblgen::OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { +void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { if (isDeclOnly) return; @@ -123,14 +116,12 @@ // OpConstructor definitions //===----------------------------------------------------------------------===// -void mlir::tblgen::OpConstructor::addMemberInitializer(StringRef name, - StringRef value) { +void OpConstructor::addMemberInitializer(StringRef name, StringRef value) { memberInitializers.append(std::string(llvm::formatv( "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); } -void mlir::tblgen::OpConstructor::writeDefTo(raw_ostream &os, - StringRef namePrefix) const { +void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const { if (isDeclOnly) return; @@ -144,25 +135,21 @@ // Class definitions //===----------------------------------------------------------------------===// -tblgen::Class::Class(StringRef name) : className(name) {} +Class::Class(StringRef name) : className(name) {} -tblgen::OpMethod &tblgen::Class::newMethod(StringRef retType, StringRef name, - StringRef params, - OpMethod::Property property, - bool declOnly) { +OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params, + OpMethod::Property property, bool declOnly) { methods.emplace_back(retType, name, params, property, declOnly); return methods.back(); } -tblgen::OpConstructor &tblgen::Class::newConstructor(StringRef params, - bool declOnly) { +OpConstructor &Class::newConstructor(StringRef params, bool declOnly) { constructors.emplace_back("", getClassName(), params, OpMethod::MP_Constructor, declOnly); return constructors.back(); } -void tblgen::Class::newField(StringRef type, StringRef name, - StringRef defaultValue) { +void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { std::string varName = formatv("{0} {1}", type, name).str(); std::string field = defaultValue.empty() ? varName @@ -170,7 +157,7 @@ fields.push_back(std::move(field)); } -void tblgen::Class::writeDeclTo(raw_ostream &os) const { +void Class::writeDeclTo(raw_ostream &os) const { bool hasPrivateMethod = false; os << "class " << className << " {\n"; os << "public:\n"; @@ -200,7 +187,7 @@ os << "};\n"; } -void tblgen::Class::writeDefTo(raw_ostream &os) const { +void Class::writeDefTo(raw_ostream &os) const { for (const auto &method : llvm::concat(constructors, methods)) { method.writeDefTo(os, className); @@ -212,16 +199,16 @@ // OpClass definitions //===----------------------------------------------------------------------===// -tblgen::OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) +OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) : Class(name), extraClassDeclaration(extraClassDeclaration) {} -void tblgen::OpClass::addTrait(Twine trait) { +void OpClass::addTrait(Twine trait) { auto traitStr = trait.str(); if (traitsSet.insert(traitStr).second) traitsVec.push_back(std::move(traitStr)); } -void tblgen::OpClass::writeDeclTo(raw_ostream &os) const { +void OpClass::writeDeclTo(raw_ostream &os) const { os << "class " << className << " : public ::mlir::Op<" << className; for (const auto &trait : traitsVec) os << ", " << trait; diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp --- a/mlir/lib/TableGen/OpTrait.cpp +++ b/mlir/lib/TableGen/OpTrait.cpp @@ -44,7 +44,7 @@ } std::string PredOpTrait::getPredTemplate() const { - auto pred = tblgen::Pred(def->getValueInit("predicate")); + auto pred = Pred(def->getValueInit("predicate")); return pred.getCondition(); } diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -27,12 +27,13 @@ #define DEBUG_TYPE "mlir-tblgen-operator" using namespace mlir; +using namespace mlir::tblgen; using llvm::DagInit; using llvm::DefInit; using llvm::Record; -tblgen::Operator::Operator(const llvm::Record &def) +Operator::Operator(const llvm::Record &def) : dialect(def.getValueAsDef("opDialect")), def(def) { // The first `_` in the op's TableGen def name is treated as separating the // dialect prefix and the op class name. The dialect prefix will be ignored if @@ -51,7 +52,7 @@ populateOpStructure(); } -std::string tblgen::Operator::getOperationName() const { +std::string Operator::getOperationName() const { auto prefix = dialect.getName(); auto opName = def.getValueAsString("opName"); if (prefix.empty()) @@ -59,62 +60,58 @@ return std::string(llvm::formatv("{0}.{1}", prefix, opName)); } -std::string tblgen::Operator::getAdaptorName() const { +std::string Operator::getAdaptorName() const { return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); } -StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); } +StringRef Operator::getDialectName() const { return dialect.getName(); } -StringRef tblgen::Operator::getCppClassName() const { return cppClassName; } +StringRef Operator::getCppClassName() const { return cppClassName; } -std::string tblgen::Operator::getQualCppClassName() const { +std::string Operator::getQualCppClassName() const { auto prefix = dialect.getCppNamespace(); if (prefix.empty()) return std::string(cppClassName); return std::string(llvm::formatv("{0}::{1}", prefix, cppClassName)); } -int tblgen::Operator::getNumResults() const { +int Operator::getNumResults() const { DagInit *results = def.getValueAsDag("results"); return results->getNumArgs(); } -StringRef tblgen::Operator::getExtraClassDeclaration() const { +StringRef Operator::getExtraClassDeclaration() const { constexpr auto attr = "extraClassDeclaration"; if (def.isValueUnset(attr)) return {}; return def.getValueAsString(attr); } -const llvm::Record &tblgen::Operator::getDef() const { return def; } +const llvm::Record &Operator::getDef() const { return def; } -bool tblgen::Operator::skipDefaultBuilders() const { +bool Operator::skipDefaultBuilders() const { return def.getValueAsBit("skipDefaultBuilders"); } -auto tblgen::Operator::result_begin() -> value_iterator { - return results.begin(); -} +auto Operator::result_begin() -> value_iterator { return results.begin(); } -auto tblgen::Operator::result_end() -> value_iterator { return results.end(); } +auto Operator::result_end() -> value_iterator { return results.end(); } -auto tblgen::Operator::getResults() -> value_range { +auto Operator::getResults() -> value_range { return {result_begin(), result_end()}; } -tblgen::TypeConstraint -tblgen::Operator::getResultTypeConstraint(int index) const { +TypeConstraint Operator::getResultTypeConstraint(int index) const { DagInit *results = def.getValueAsDag("results"); return TypeConstraint(cast(results->getArg(index))); } -StringRef tblgen::Operator::getResultName(int index) const { +StringRef Operator::getResultName(int index) const { DagInit *results = def.getValueAsDag("results"); return results->getArgNameStr(index); } -auto tblgen::Operator::getResultDecorators(int index) const - -> var_decorator_range { +auto Operator::getResultDecorators(int index) const -> var_decorator_range { Record *result = cast(def.getValueAsDag("results")->getArg(index))->getDef(); if (!result->isSubClassOf("OpVariable")) @@ -122,42 +119,37 @@ return *result->getValueAsListInit("decorators"); } -unsigned tblgen::Operator::getNumVariableLengthResults() const { +unsigned Operator::getNumVariableLengthResults() const { return llvm::count_if(results, [](const NamedTypeConstraint &c) { return c.constraint.isVariableLength(); }); } -unsigned tblgen::Operator::getNumVariableLengthOperands() const { +unsigned Operator::getNumVariableLengthOperands() const { return llvm::count_if(operands, [](const NamedTypeConstraint &c) { return c.constraint.isVariableLength(); }); } -bool tblgen::Operator::hasSingleVariadicArg() const { - return getNumArgs() == 1 && getArg(0).is() && +bool Operator::hasSingleVariadicArg() const { + return getNumArgs() == 1 && getArg(0).is() && getOperand(0).isVariadic(); } -tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const { - return arguments.begin(); -} +Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); } -tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const { - return arguments.end(); -} +Operator::arg_iterator Operator::arg_end() const { return arguments.end(); } -tblgen::Operator::arg_range tblgen::Operator::getArgs() const { +Operator::arg_range Operator::getArgs() const { return {arg_begin(), arg_end()}; } -StringRef tblgen::Operator::getArgName(int index) const { +StringRef Operator::getArgName(int index) const { DagInit *argumentValues = def.getValueAsDag("arguments"); return argumentValues->getArgName(index)->getValue(); } -auto tblgen::Operator::getArgDecorators(int index) const - -> var_decorator_range { +auto Operator::getArgDecorators(int index) const -> var_decorator_range { Record *arg = cast(def.getValueAsDag("arguments")->getArg(index))->getDef(); if (!arg->isSubClassOf("OpVariable")) @@ -165,15 +157,15 @@ return *arg->getValueAsListInit("decorators"); } -const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const { +const OpTrait *Operator::getTrait(StringRef trait) const { for (const auto &t : traits) { - if (const auto *opTrait = dyn_cast(&t)) { + if (const auto *opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) return opTrait; - } else if (const auto *opTrait = dyn_cast(&t)) { + } else if (const auto *opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) return opTrait; - } else if (const auto *opTrait = dyn_cast(&t)) { + } else if (const auto *opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) return opTrait; } @@ -181,100 +173,90 @@ return nullptr; } -auto tblgen::Operator::region_begin() const -> const_region_iterator { +auto Operator::region_begin() const -> const_region_iterator { return regions.begin(); } -auto tblgen::Operator::region_end() const -> const_region_iterator { +auto Operator::region_end() const -> const_region_iterator { return regions.end(); } -auto tblgen::Operator::getRegions() const +auto Operator::getRegions() const -> llvm::iterator_range { return {region_begin(), region_end()}; } -unsigned tblgen::Operator::getNumRegions() const { return regions.size(); } +unsigned Operator::getNumRegions() const { return regions.size(); } -const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const { +const NamedRegion &Operator::getRegion(unsigned index) const { return regions[index]; } -unsigned tblgen::Operator::getNumVariadicRegions() const { +unsigned Operator::getNumVariadicRegions() const { return llvm::count_if(regions, [](const NamedRegion &c) { return c.isVariadic(); }); } -auto tblgen::Operator::successor_begin() const -> const_successor_iterator { +auto Operator::successor_begin() const -> const_successor_iterator { return successors.begin(); } -auto tblgen::Operator::successor_end() const -> const_successor_iterator { +auto Operator::successor_end() const -> const_successor_iterator { return successors.end(); } -auto tblgen::Operator::getSuccessors() const +auto Operator::getSuccessors() const -> llvm::iterator_range { return {successor_begin(), successor_end()}; } -unsigned tblgen::Operator::getNumSuccessors() const { - return successors.size(); -} +unsigned Operator::getNumSuccessors() const { return successors.size(); } -const tblgen::NamedSuccessor & -tblgen::Operator::getSuccessor(unsigned index) const { +const NamedSuccessor &Operator::getSuccessor(unsigned index) const { return successors[index]; } -unsigned tblgen::Operator::getNumVariadicSuccessors() const { +unsigned Operator::getNumVariadicSuccessors() const { return llvm::count_if(successors, [](const NamedSuccessor &c) { return c.isVariadic(); }); } -auto tblgen::Operator::trait_begin() const -> const_trait_iterator { +auto Operator::trait_begin() const -> const_trait_iterator { return traits.begin(); } -auto tblgen::Operator::trait_end() const -> const_trait_iterator { +auto Operator::trait_end() const -> const_trait_iterator { return traits.end(); } -auto tblgen::Operator::getTraits() const - -> llvm::iterator_range { +auto Operator::getTraits() const -> llvm::iterator_range { return {trait_begin(), trait_end()}; } -auto tblgen::Operator::attribute_begin() const -> attribute_iterator { +auto Operator::attribute_begin() const -> attribute_iterator { return attributes.begin(); } -auto tblgen::Operator::attribute_end() const -> attribute_iterator { +auto Operator::attribute_end() const -> attribute_iterator { return attributes.end(); } -auto tblgen::Operator::getAttributes() const +auto Operator::getAttributes() const -> llvm::iterator_range { return {attribute_begin(), attribute_end()}; } -auto tblgen::Operator::operand_begin() -> value_iterator { - return operands.begin(); -} -auto tblgen::Operator::operand_end() -> value_iterator { - return operands.end(); -} -auto tblgen::Operator::getOperands() -> value_range { +auto Operator::operand_begin() -> value_iterator { return operands.begin(); } +auto Operator::operand_end() -> value_iterator { return operands.end(); } +auto Operator::getOperands() -> value_range { return {operand_begin(), operand_end()}; } -auto tblgen::Operator::getArg(int index) const -> Argument { - return arguments[index]; -} +auto Operator::getArg(int index) const -> Argument { return arguments[index]; } // Mapping from result index to combined argument and result index. Arguments // are indexed to match getArg index, while the result indexes are mapped to // avoid overlap. static int resultIndex(int i) { return -1 - i; } -bool tblgen::Operator::isVariadic() const { +bool Operator::isVariadic() const { return any_of(llvm::concat(operands, results), [](const NamedTypeConstraint &op) { return op.isVariadic(); }); } -void tblgen::Operator::populateTypeInferenceInfo( +void Operator::populateTypeInferenceInfo( const llvm::StringMap &argumentsAndResultsIndex) { // If the type inference op interface is not registered, then do not attempt // to determine if the result types an be inferred. @@ -340,7 +322,7 @@ if (def.isSubClassOf( llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) return; - if (const auto *opTrait = dyn_cast(&trait)) + if (const auto *opTrait = dyn_cast(&trait)) if (&opTrait->getDef() == inferTrait) return; @@ -364,7 +346,7 @@ traits.push_back(OpTrait::create(inferTrait->getDefInit())); } -void tblgen::Operator::populateOpStructure() { +void Operator::populateOpStructure() { auto &recordKeeper = def.getRecords(); auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); auto *attrClass = recordKeeper.getClass("Attr"); @@ -541,42 +523,39 @@ LLVM_DEBUG(print(llvm::dbgs())); } -auto tblgen::Operator::getSameTypeAsResult(int index) const - -> ArrayRef { +auto Operator::getSameTypeAsResult(int index) const -> ArrayRef { assert(allResultTypesKnown()); return resultTypeMapping[index]; } -ArrayRef tblgen::Operator::getLoc() const { return def.getLoc(); } +ArrayRef Operator::getLoc() const { return def.getLoc(); } -bool tblgen::Operator::hasDescription() const { +bool Operator::hasDescription() const { return def.getValue("description") != nullptr; } -StringRef tblgen::Operator::getDescription() const { +StringRef Operator::getDescription() const { return def.getValueAsString("description"); } -bool tblgen::Operator::hasSummary() const { - return def.getValue("summary") != nullptr; -} +bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; } -StringRef tblgen::Operator::getSummary() const { +StringRef Operator::getSummary() const { return def.getValueAsString("summary"); } -bool tblgen::Operator::hasAssemblyFormat() const { +bool Operator::hasAssemblyFormat() const { auto *valueInit = def.getValueInit("assemblyFormat"); return isa(valueInit); } -StringRef tblgen::Operator::getAssemblyFormat() const { +StringRef Operator::getAssemblyFormat() const { return TypeSwitch(def.getValueInit("assemblyFormat")) .Case( [&](auto *init) { return init->getValue(); }); } -void tblgen::Operator::print(llvm::raw_ostream &os) const { +void Operator::print(llvm::raw_ostream &os) const { os << "op '" << getOperationName() << "'\n"; for (Argument arg : arguments) { if (auto *attr = arg.dyn_cast()) @@ -586,12 +565,12 @@ } } -auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) +auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) -> VariableDecorator { return VariableDecorator(cast(init)->getDef()); } -auto tblgen::Operator::getArgToOperandOrAttribute(int index) const +auto Operator::getArgToOperandOrAttribute(int index) const -> OperandOrAttribute { return attrOrOperandMapping[index]; } diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -22,80 +22,78 @@ #define DEBUG_TYPE "mlir-tblgen-pattern" using namespace mlir; +using namespace tblgen; using llvm::formatv; -using mlir::tblgen::Operator; //===----------------------------------------------------------------------===// // DagLeaf //===----------------------------------------------------------------------===// -bool tblgen::DagLeaf::isUnspecified() const { +bool DagLeaf::isUnspecified() const { return dyn_cast_or_null(def); } -bool tblgen::DagLeaf::isOperandMatcher() const { +bool DagLeaf::isOperandMatcher() const { // Operand matchers specify a type constraint. return isSubClassOf("TypeConstraint"); } -bool tblgen::DagLeaf::isAttrMatcher() const { +bool DagLeaf::isAttrMatcher() const { // Attribute matchers specify an attribute constraint. return isSubClassOf("AttrConstraint"); } -bool tblgen::DagLeaf::isNativeCodeCall() const { +bool DagLeaf::isNativeCodeCall() const { return isSubClassOf("NativeCodeCall"); } -bool tblgen::DagLeaf::isConstantAttr() const { - return isSubClassOf("ConstantAttr"); -} +bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } -bool tblgen::DagLeaf::isEnumAttrCase() const { +bool DagLeaf::isEnumAttrCase() const { return isSubClassOf("EnumAttrCaseInfo"); } -bool tblgen::DagLeaf::isStringAttr() const { +bool DagLeaf::isStringAttr() const { return isa(def); } -tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { +Constraint DagLeaf::getAsConstraint() const { assert((isOperandMatcher() || isAttrMatcher()) && "the DAG leaf must be operand or attribute"); return Constraint(cast(def)->getDef()); } -tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { +ConstantAttr DagLeaf::getAsConstantAttr() const { assert(isConstantAttr() && "the DAG leaf must be constant attribute"); return ConstantAttr(cast(def)); } -tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const { +EnumAttrCase DagLeaf::getAsEnumAttrCase() const { assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); return EnumAttrCase(cast(def)); } -std::string tblgen::DagLeaf::getConditionTemplate() const { +std::string DagLeaf::getConditionTemplate() const { return getAsConstraint().getConditionTemplate(); } -llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { +llvm::StringRef DagLeaf::getNativeCodeTemplate() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); return cast(def)->getDef()->getValueAsString("expression"); } -std::string tblgen::DagLeaf::getStringAttr() const { +std::string DagLeaf::getStringAttr() const { assert(isStringAttr() && "the DAG leaf must be string attribute"); return def->getAsUnquotedString(); } -bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { +bool DagLeaf::isSubClassOf(StringRef superclass) const { if (auto *defInit = dyn_cast_or_null(def)) return defInit->getDef()->isSubClassOf(superclass); return false; } -void tblgen::DagLeaf::print(raw_ostream &os) const { +void DagLeaf::print(raw_ostream &os) const { if (def) def->print(os); } @@ -104,28 +102,26 @@ // DagNode //===----------------------------------------------------------------------===// -bool tblgen::DagNode::isNativeCodeCall() const { +bool DagNode::isNativeCodeCall() const { if (auto *defInit = dyn_cast_or_null(node->getOperator())) return defInit->getDef()->isSubClassOf("NativeCodeCall"); return false; } -bool tblgen::DagNode::isOperation() const { +bool DagNode::isOperation() const { return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); } -llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { +llvm::StringRef DagNode::getNativeCodeTemplate() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); return cast(node->getOperator()) ->getDef() ->getValueAsString("expression"); } -llvm::StringRef tblgen::DagNode::getSymbol() const { - return node->getNameStr(); -} +llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } -Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { +Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { llvm::Record *opDef = cast(node->getOperator())->getDef(); auto it = mapper->find(opDef); if (it != mapper->end()) @@ -134,7 +130,7 @@ .first->second; } -int tblgen::DagNode::getNumOps() const { +int DagNode::getNumOps() const { int count = isReplaceWithValue() ? 0 : 1; for (int i = 0, e = getNumArgs(); i != e; ++i) { if (auto child = getArgAsNestedDag(i)) @@ -143,36 +139,36 @@ return count; } -int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } +int DagNode::getNumArgs() const { return node->getNumArgs(); } -bool tblgen::DagNode::isNestedDagArg(unsigned index) const { +bool DagNode::isNestedDagArg(unsigned index) const { return isa(node->getArg(index)); } -tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { +DagNode DagNode::getArgAsNestedDag(unsigned index) const { return DagNode(dyn_cast_or_null(node->getArg(index))); } -tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { +DagLeaf DagNode::getArgAsLeaf(unsigned index) const { assert(!isNestedDagArg(index)); return DagLeaf(node->getArg(index)); } -StringRef tblgen::DagNode::getArgName(unsigned index) const { +StringRef DagNode::getArgName(unsigned index) const { return node->getArgNameStr(index); } -bool tblgen::DagNode::isReplaceWithValue() const { +bool DagNode::isReplaceWithValue() const { auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "replaceWithValue"; } -bool tblgen::DagNode::isLocationDirective() const { +bool DagNode::isLocationDirective() const { auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "location"; } -void tblgen::DagNode::print(raw_ostream &os) const { +void DagNode::print(raw_ostream &os) const { if (node) node->print(os); } @@ -181,8 +177,7 @@ // SymbolInfoMap //===----------------------------------------------------------------------===// -StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, - int *index) { +StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { StringRef name, indexStr; int idx = -1; std::tie(name, indexStr) = symbol.rsplit("__"); @@ -197,12 +192,11 @@ return name; } -tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, - SymbolInfo::Kind kind, - Optional index) +SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, + Optional index) : op(op), kind(kind), argIndex(index) {} -int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { +int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { switch (kind) { case Kind::Attr: case Kind::Operand: @@ -214,8 +208,7 @@ llvm_unreachable("unknown kind"); } -std::string -tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { +std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); switch (kind) { case Kind::Attr: { @@ -240,7 +233,7 @@ llvm_unreachable("unknown kind"); } -std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse( +std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( StringRef name, int index, const char *fmt, const char *separator) const { LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); switch (kind) { @@ -311,7 +304,7 @@ llvm_unreachable("unknown kind"); } -std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse( +std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( StringRef name, int index, const char *fmt, const char *separator) const { LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); switch (kind) { @@ -353,8 +346,8 @@ llvm_unreachable("unknown kind"); } -bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, - int argIndex) { +bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, + int argIndex) { StringRef name = getValuePackName(symbol); if (name != symbol) { auto error = formatv( @@ -369,26 +362,25 @@ return symbolInfoMap.insert({symbol, symInfo}).second; } -bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { +bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { StringRef name = getValuePackName(symbol); return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; } -bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { +bool SymbolInfoMap::bindValue(StringRef symbol) { return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; } -bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { +bool SymbolInfoMap::contains(StringRef symbol) const { return find(symbol) != symbolInfoMap.end(); } -tblgen::SymbolInfoMap::const_iterator -tblgen::SymbolInfoMap::find(StringRef key) const { +SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { StringRef name = getValuePackName(key); return symbolInfoMap.find(name); } -int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { +int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { StringRef name = getValuePackName(symbol); if (name != symbol) { // If there is a trailing index inside symbol, it references just one @@ -399,9 +391,9 @@ return find(name)->getValue().getStaticValueCount(); } -std::string -tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt, - const char *separator) const { +std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, + const char *fmt, + const char *separator) const { int index = -1; StringRef name = getValuePackName(symbol, &index); @@ -414,9 +406,8 @@ return it->getValue().getValueAndRangeUse(name, index, fmt, separator); } -std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol, - const char *fmt, - const char *separator) const { +std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, + const char *separator) const { int index = -1; StringRef name = getValuePackName(symbol, &index); @@ -433,32 +424,30 @@ // Pattern //==----------------------------------------------------------------------===// -tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) +Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) : def(*def), recordOpMap(mapper) {} -tblgen::DagNode tblgen::Pattern::getSourcePattern() const { - return tblgen::DagNode(def.getValueAsDag("sourcePattern")); +DagNode Pattern::getSourcePattern() const { + return DagNode(def.getValueAsDag("sourcePattern")); } -int tblgen::Pattern::getNumResultPatterns() const { +int Pattern::getNumResultPatterns() const { auto *results = def.getValueAsListInit("resultPatterns"); return results->size(); } -tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { +DagNode Pattern::getResultPattern(unsigned index) const { auto *results = def.getValueAsListInit("resultPatterns"); - return tblgen::DagNode(cast(results->getElement(index))); + return DagNode(cast(results->getElement(index))); } -void tblgen::Pattern::collectSourcePatternBoundSymbols( - tblgen::SymbolInfoMap &infoMap) { +void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); } -void tblgen::Pattern::collectResultPatternBoundSymbols( - tblgen::SymbolInfoMap &infoMap) { +void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { auto pattern = getResultPattern(i); @@ -467,17 +456,17 @@ LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); } -const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { +const Operator &Pattern::getSourceRootOp() { return getSourcePattern().getDialectOp(recordOpMap); } -tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { +Operator &Pattern::getDialectOp(DagNode node) { return node.getDialectOp(recordOpMap); } -std::vector tblgen::Pattern::getConstraints() const { +std::vector Pattern::getConstraints() const { auto *listInit = def.getValueAsListInit("constraints"); - std::vector ret; + std::vector ret; ret.reserve(listInit->size()); for (auto it : *listInit) { @@ -503,7 +492,7 @@ return ret; } -int tblgen::Pattern::getBenefit() const { +int Pattern::getBenefit() const { // The initial benefit value is a heuristic with number of ops in the source // pattern. int initBenefit = getSourcePattern().getNumOps(); @@ -515,8 +504,7 @@ return initBenefit + dyn_cast(delta->getArg(0))->getValue(); } -std::vector -tblgen::Pattern::getLocation() const { +std::vector Pattern::getLocation() const { std::vector> result; result.reserve(def.getLoc().size()); for (auto loc : def.getLoc()) { @@ -529,8 +517,8 @@ return result; } -void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, - bool isSrcPattern) { +void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, + bool isSrcPattern) { auto treeName = tree.getSymbol(); if (!tree.isOperation()) { if (!treeName.empty()) { diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp --- a/mlir/lib/TableGen/Predicate.cpp +++ b/mlir/lib/TableGen/Predicate.cpp @@ -19,20 +19,21 @@ #include "llvm/TableGen/Record.h" using namespace mlir; +using namespace tblgen; // Construct a Predicate from a record. -tblgen::Pred::Pred(const llvm::Record *record) : def(record) { +Pred::Pred(const llvm::Record *record) : def(record) { assert(def->isSubClassOf("Pred") && "must be a subclass of TableGen 'Pred' class"); } // Construct a Predicate from an initializer. -tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) { +Pred::Pred(const llvm::Init *init) : def(nullptr) { if (const auto *defInit = dyn_cast_or_null(init)) def = defInit->getDef(); } -std::string tblgen::Pred::getCondition() const { +std::string Pred::getCondition() const { // Static dispatch to subclasses. if (def->isSubClassOf("CombinedPred")) return static_cast(this)->getConditionImpl(); @@ -41,44 +42,44 @@ llvm_unreachable("Pred::getCondition must be overridden in subclasses"); } -bool tblgen::Pred::isCombined() const { +bool Pred::isCombined() const { return def && def->isSubClassOf("CombinedPred"); } -ArrayRef tblgen::Pred::getLoc() const { return def->getLoc(); } +ArrayRef Pred::getLoc() const { return def->getLoc(); } -tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) { +CPred::CPred(const llvm::Record *record) : Pred(record) { assert(def->isSubClassOf("CPred") && "must be a subclass of Tablegen 'CPred' class"); } -tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) { +CPred::CPred(const llvm::Init *init) : Pred(init) { assert((!def || def->isSubClassOf("CPred")) && "must be a subclass of Tablegen 'CPred' class"); } // Get condition of the C Predicate. -std::string tblgen::CPred::getConditionImpl() const { +std::string CPred::getConditionImpl() const { assert(!isNull() && "null predicate does not have a condition"); return std::string(def->getValueAsString("predExpr")); } -tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { +CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { assert(def->isSubClassOf("CombinedPred") && "must be a subclass of Tablegen 'CombinedPred' class"); } -tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { +CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { assert((!def || def->isSubClassOf("CombinedPred")) && "must be a subclass of Tablegen 'CombinedPred' class"); } -const llvm::Record *tblgen::CombinedPred::getCombinerDef() const { +const llvm::Record *CombinedPred::getCombinerDef() const { assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); return def->getValueAsDef("kind"); } -const std::vector tblgen::CombinedPred::getChildren() const { +const std::vector CombinedPred::getChildren() const { assert(def->getValue("children") && "CombinedPred must have a value 'children'"); return def->getValueAsListOfDefs("children"); @@ -101,7 +102,7 @@ // A node in a logical predicate tree. struct PredNode { PredCombinerKind kind; - const tblgen::Pred *predicate; + const Pred *predicate; SmallVector children; std::string expr; @@ -113,11 +114,11 @@ // Get a predicate tree node kind based on the kind used in the predicate // TableGen record. -static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) { +static PredCombinerKind getPredCombinerKind(const Pred &pred) { if (!pred.isCombined()) return PredCombinerKind::Leaf; - const auto &combinedPred = static_cast(pred); + const auto &combinedPred = static_cast(pred); return llvm::StringSwitch( combinedPred.getCombinerDef()->getName()) .Case("PredCombinerAnd", PredCombinerKind::And) @@ -137,7 +138,7 @@ // substitution, nodes are still pointing to the original TableGen record. // All nodes are created within "allocator". static PredNode * -buildPredicateTree(const tblgen::Pred &root, +buildPredicateTree(const Pred &root, llvm::SpecificBumpPtrAllocator &allocator, ArrayRef substitutions) { auto *rootNode = allocator.Allocate(); @@ -166,22 +167,22 @@ // list before continuing. auto allSubstitutions = llvm::to_vector<4>(substitutions); if (rootNode->kind == PredCombinerKind::SubstLeaves) { - const auto &substPred = static_cast(root); + const auto &substPred = static_cast(root); allSubstitutions.push_back( {substPred.getPattern(), substPred.getReplacement()}); } // If the current predicate is a ConcatPred, record the prefix and suffix. else if (rootNode->kind == PredCombinerKind::Concat) { - const auto &concatPred = static_cast(root); + const auto &concatPred = static_cast(root); rootNode->prefix = std::string(concatPred.getPrefix()); rootNode->suffix = std::string(concatPred.getSuffix()); } // Build child subtrees. - auto combined = static_cast(root); + auto combined = static_cast(root); for (const auto *record : combined.getChildren()) { auto childTree = - buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions); + buildPredicateTree(Pred(record), allocator, allSubstitutions); rootNode->children.push_back(childTree); } return rootNode; @@ -192,9 +193,10 @@ // children is known to be false(true), the result is also false(true). // Furthermore, for AND(OR) combined predicates, children that are known to be // true(false) don't have to be checked dynamically. -static PredNode *propagateGroundTruth( - PredNode *node, const llvm::SmallPtrSetImpl &knownTruePreds, - const llvm::SmallPtrSetImpl &knownFalsePreds) { +static PredNode * +propagateGroundTruth(PredNode *node, + const llvm::SmallPtrSetImpl &knownTruePreds, + const llvm::SmallPtrSetImpl &knownFalsePreds) { // If the current predicate is known to be true or false, change the kind of // the node and return immediately. if (knownTruePreds.count(node->predicate) != 0) { @@ -339,29 +341,29 @@ llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); } -std::string tblgen::CombinedPred::getConditionImpl() const { +std::string CombinedPred::getConditionImpl() const { llvm::SpecificBumpPtrAllocator allocator; auto predicateTree = buildPredicateTree(*this, allocator, {}); - predicateTree = propagateGroundTruth( - predicateTree, - /*knownTruePreds=*/llvm::SmallPtrSet(), - /*knownFalsePreds=*/llvm::SmallPtrSet()); + predicateTree = + propagateGroundTruth(predicateTree, + /*knownTruePreds=*/llvm::SmallPtrSet(), + /*knownFalsePreds=*/llvm::SmallPtrSet()); return getCombinedCondition(*predicateTree); } -StringRef tblgen::SubstLeavesPred::getPattern() const { +StringRef SubstLeavesPred::getPattern() const { return def->getValueAsString("pattern"); } -StringRef tblgen::SubstLeavesPred::getReplacement() const { +StringRef SubstLeavesPred::getReplacement() const { return def->getValueAsString("replacement"); } -StringRef tblgen::ConcatPred::getPrefix() const { +StringRef ConcatPred::getPrefix() const { return def->getValueAsString("prefix"); } -StringRef tblgen::ConcatPred::getSuffix() const { +StringRef ConcatPred::getSuffix() const { return def->getValueAsString("suffix"); }