diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -294,6 +294,17 @@ // Returns the builders of this operation. ArrayRef getBuilders() const { return builders; } + // Returns the preferred getter name for the accessor. + std::string getGetterName(StringRef name) const { + return getGetterNames(name).front(); + } + + // Returns the getter names for the accessor. + SmallVector getGetterNames(StringRef name) const; + + // Returns the setter names for the accessor. + SmallVector getSetterNames(StringRef name) const; + private: // Populates the vectors containing operands, attributes, results and traits. void populateOpStructure(); diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -642,3 +643,58 @@ -> OperandOrAttribute { return attrOrOperandMapping[index]; } + +// Helper to return the names for accessor. +static SmallVector +getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) { + Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix(); + std::string prefix; + if (prefixType != Dialect::EmitPrefix::Raw) + prefix = isGetter ? "get" : "set"; + + SmallVector names; + bool rawToo = prefixType == Dialect::EmitPrefix::Both; + + auto skip = [&](StringRef newName) { + bool shouldSkip = newName == "getOperands"; + if (!shouldSkip) + return false; + + // This note could be avoided where the final function generated would + // have been identical. But preferably in the op definition avoiding using + // the generic name and then getting a more specialize type is better. + PrintNote(op.getLoc(), + "Skipping generation of prefixed accessor `" + newName + + "` as it overlaps with default one; generating raw form (`" + + name + "`) still"); + return true; + }; + + if (!prefix.empty()) { + if (name.startswith(prefix)) + llvm::report_fatal_error("done"); + names.push_back(prefix + convertToCamelFromSnakeCase(name, true)); + // Skip cases which would overlap with default ones for now. + if (skip(names.back())) { + rawToo = true; + names.clear(); + } else { + LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName() + << "::" << names.back() << "\");\n" + << "WITH_GETTER(\"" << op.getQualCppClassName() + << "Adaptor::" << names.back() << "\");\n";); + } + } + + if (prefix.empty() || rawToo) + names.push_back(name.str()); + return names; +} + +SmallVector Operator::getGetterNames(StringRef name) const { + return getGetterOrSetterNames(/*isGetter=*/true, *this, name); +} + +SmallVector Operator::getSetterNames(StringRef name) const { + return getGetterOrSetterNames(/*isGetter=*/false, *this, name); +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -339,7 +339,7 @@ Optional TestBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return targetOperandsMutable(); + return getTargetOperandsMutable(); } //===----------------------------------------------------------------------===// @@ -369,7 +369,7 @@ LogicalResult matchAndRewrite(FoldToCallOp op, PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, TypeRange(), op.calleeAttr(), + rewriter.replaceOpWithNewOp(op, TypeRange(), op.getCalleeAttr(), ValueRange()); return success(); } @@ -597,8 +597,8 @@ static void print(OpAsmPrinter &p, IsolatedRegionOp op) { p << "test.isolated_region "; p.printOperand(op.getOperand()); - p.shadowRegionArgs(op.region(), op.getOperand()); - p.printRegion(op.region(), /*printEntryBlockArgs=*/false); + p.shadowRegionArgs(op.getRegion(), op.getOperand()); + p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// @@ -622,7 +622,7 @@ static void print(OpAsmPrinter &p, GraphRegionOp op) { p << "test.graph_region "; - p.printRegion(op.region(), /*printEntryBlockArgs=*/false); + p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); } RegionKind GraphRegionOp::getRegionKind(unsigned index) { @@ -642,7 +642,7 @@ static void print(OpAsmPrinter &p, AffineScopeOp op) { p << "test.affine_scope "; - p.printRegion(op.region(), /*printEntryBlockArgs=*/false); + p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// @@ -678,7 +678,7 @@ } static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { - p << " " << op.keyword(); + p << " " << op.getKeyword(); } //===----------------------------------------------------------------------===// @@ -717,7 +717,7 @@ static void print(OpAsmPrinter &p, WrappingRegionOp op) { p << " wraps "; - p.printGenericOp(&op.region().front().front()); + p.printGenericOp(&op.getRegion().front().front()); } //===----------------------------------------------------------------------===// @@ -762,7 +762,7 @@ } OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { - return operand(); + return getOperand(); } OpFoldResult TestOpConstant::fold(ArrayRef operands) { @@ -971,7 +971,7 @@ // Note that we only need to print the "name" attribute if the asmprinter // result name disagrees with it. This can happen in strange cases, e.g. // when there are conflicts. - bool namesDisagree = op.names().size() != op.getNumResults(); + bool namesDisagree = op.getNames().size() != op.getNumResults(); SmallString<32> resultNameStr; for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { @@ -979,7 +979,7 @@ llvm::raw_svector_ostream tmpStream(resultNameStr); p.printOperand(op.getResult(i), tmpStream); - auto expectedName = op.names()[i].dyn_cast(); + auto expectedName = op.getNames()[i].dyn_cast(); if (!expectedName || tmpStream.str().drop_front() != expectedName.getValue()) { namesDisagree = true; @@ -997,7 +997,7 @@ void StringAttrPrettyNameOp::getAsmResultNames( function_ref setNameFn) { - auto value = names(); + auto value = getNames(); for (size_t i = 0, e = value.size(); i != e; ++i) if (auto str = value[i].dyn_cast()) if (!str.getValue().empty()) @@ -1014,15 +1014,15 @@ p << ": " << op.getOperandTypes(); p.printArrowTypeList(op.getResultTypes()); p << " then"; - p.printRegion(op.thenRegion(), + p.printRegion(op.getThenRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " else"; - p.printRegion(op.elseRegion(), + p.printRegion(op.getElseRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " join"; - p.printRegion(op.joinRegion(), + p.printRegion(op.getJoinRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); } @@ -1064,15 +1064,15 @@ // We always branch to the join region. if (index.hasValue()) { if (index.getValue() < 2) - regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); + regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); else regions.push_back(RegionSuccessor(getResults())); return; } // The then and else regions are the entry regions of this op. - regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); - regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); + regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); + regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -28,7 +28,7 @@ let cppNamespace = "::test"; // Temporarily flipping to _Both (given this is test only/not intended for // general use, this won't be following the 2 week process here). - let emitAccessorPrefix = kEmitAccessorPrefix_Both; + let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; let hasCanonicalizer = 1; let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; @@ -305,9 +305,9 @@ def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> { let results = (outs AnyTensor:$output); DerivedTypeAttr element_dtype = - DerivedTypeAttr<"return getElementTypeOrSelf(output().getType());">; + DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">; DerivedAttr size = DerivedAttr<"int", - "return output().getType().cast().getSizeInBits();", + "return getOutput().getType().cast().getSizeInBits();", "$_builder.getI32IntegerAttr($_self)">; } @@ -374,13 +374,10 @@ def ConversionCallOp : TEST_Op<"conversion_call_op", [CallOpInterface]> { - let arguments = (ins Variadic:$inputs, SymbolRefAttr:$callee); + let arguments = (ins Variadic:$arg_operands, SymbolRefAttr:$callee); let results = (outs Variadic); let extraClassDeclaration = [{ - /// Get the argument operands to the called function. - operand_range getArgOperands() { return inputs(); } - /// Return the callee of this operation. ::mlir::CallInterfaceCallable getCallableForCallee() { return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"); @@ -394,7 +391,7 @@ let results = (outs FunctionType); let extraClassDeclaration = [{ - ::mlir::Region *getCallableRegion() { return &body(); } + ::mlir::Region *getCallableRegion() { return &getBody(); } ::llvm::ArrayRef<::mlir::Type> getCallableResults() { return getType().cast<::mlir::FunctionType>().getResults(); } @@ -673,7 +670,7 @@ let arguments = (ins AnyAttr:$attr); let verifier = [{ - if (this->attr().hasTrait()) + if (this->getAttr().hasTrait()) return success(); return this->emitError("'attr' attribute should have trait 'TestAttrTrait'"); }]; @@ -2340,6 +2337,10 @@ std::string getLibraryCallName() { return ""; } + + // To conform with interface requirement on operand naming. + mlir::ValueRange inputs() { return getInputs(); } + mlir::ValueRange outputs() { return getOutputs(); } }]; } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -32,8 +32,8 @@ static void handleNoResultOp(PatternRewriter &rewriter, OpSymbolBindingNoResult op) { // Turn the no result op to a one-result op. - rewriter.create(op.getLoc(), op.operand().getType(), - op.operand()); + rewriter.create(op.getLoc(), op.getOperand().getType(), + op.getOperand()); } static bool getFirstI32Result(Operation *op, Value &value) { @@ -531,7 +531,7 @@ PatternRewriter &rewriter) const final { // Decrement the depth of the op in-place. rewriter.updateRootInPlace(op, [&] { - op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); + op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); }); return success(); } @@ -705,7 +705,7 @@ // Mark the bound recursion operation as dynamically legal. target.addDynamicallyLegalOp( - [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); + [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); // Handle a partial conversion. if (mode == ConversionMode::Partial) { @@ -1026,9 +1026,9 @@ LogicalResult matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - Block &firstBlock = op.body().front(); + Block &firstBlock = op.getBody().front(); Operation *branchOp = firstBlock.getTerminator(); - Block *secondBlock = &*(std::next(op.body().begin())); + Block *secondBlock = &*(std::next(op.getBody().begin())); auto succOperands = branchOp->getOperands(); SmallVector replacements(succOperands); rewriter.eraseOp(branchOp); @@ -1073,7 +1073,7 @@ op->getParentOfType(); if (!parentOp) return failure(); - Block &innerBlock = op.region().front(); + Block &innerBlock = op.getRegion().front(); TerminatorOp innerTerminator = cast(innerBlock.getTerminator()); rewriter.mergeBlockBefore(&innerBlock, op); @@ -1104,7 +1104,7 @@ /// Expect the op to have a single block after legalization. target.addDynamicallyLegalOp( [&](TestMergeBlocksOp op) -> bool { - return llvm::hasSingleElement(op.body()); + return llvm::hasSingleElement(op.getBody()); }); /// Only allow `test.br` within test.merge_blocks op. diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp --- a/mlir/test/lib/Transforms/TestInlining.cpp +++ b/mlir/test/lib/Transforms/TestInlining.cpp @@ -51,7 +51,7 @@ // Inline the functional region operation, but only clone the internal // region if there is more than one use. if (failed(inlineRegion( - interface, &callee.body(), caller, caller.getArgOperands(), + interface, &callee.getBody(), caller, caller.getArgOperands(), caller.getResults(), caller.getLoc(), /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse()))) continue; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -358,10 +358,6 @@ // The emitter containing all of the locally emitted verification functions. const StaticVerifierFunctionEmitter &staticVerifierEmitter; - - // A map of attribute names (including implicit attributes) registered to the - // current operation, to the relative order in which they were registered. - llvm::MapVector attributeNames; }; } // end anonymous namespace @@ -525,62 +521,6 @@ void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } -// Helper to return the names for accessor. -static SmallVector -getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) { - Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix(); - std::string prefix; - if (prefixType != Dialect::EmitPrefix::Raw) - prefix = isGetter ? "get" : "set"; - - SmallVector names; - bool rawToo = prefixType == Dialect::EmitPrefix::Both; - - auto skip = [&](StringRef newName) { - bool shouldSkip = newName == "getOperands"; - if (!shouldSkip) - return false; - - // This note could be avoided where the final function generated would - // have been identical. But preferably in the op definition avoiding using - // the generic name and then getting a more specialize type is better. - PrintNote(op.getLoc(), - "Skipping generation of prefixed accessor `" + newName + - "` as it overlaps with default one; generating raw form (`" + - name + "`) still"); - return true; - }; - - if (!prefix.empty()) { - names.push_back(prefix + convertToCamelFromSnakeCase(name, true)); - // Skip cases which would overlap with default ones for now. - if (skip(names.back())) { - rawToo = true; - names.clear(); - } else { - LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName() - << "::" << names.back() << "\");\n" - << "WITH_GETTER(\"" << op.getQualCppClassName() - << "Adaptor::" << names.back() << "\");\n";); - } - } - - if (prefix.empty() || rawToo) - names.push_back(name.str()); - return names; -} -static SmallVector getGetterNames(const Operator &op, - StringRef name) { - return getGetterOrSetterNames(/*isGetter=*/true, op, name); -} -static std::string getGetterName(const Operator &op, StringRef name) { - return getGetterOrSetterNames(/*isGetter=*/true, op, name).front(); -} -static SmallVector getSetterNames(const Operator &op, - StringRef name) { - return getGetterOrSetterNames(/*isGetter=*/false, op, name); -} - static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName, const Operator &op) { if (m) @@ -593,6 +533,10 @@ #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) void OpEmitter::genAttrNameGetters() { + // A map of attribute names (including implicit attributes) registered to the + // current operation, to the relative order in which they were registered. + llvm::MapVector attributeNames; + // Enumerate the attribute names of this op, assigning each a relative // ordering. auto addAttrName = [&](StringRef name) { @@ -602,10 +546,12 @@ for (const NamedAttribute &namedAttr : op.getAttributes()) addAttrName(namedAttr.name); // Include key attributes from several traits as implicitly registered. + std::string operandSizes = "operand_segment_sizes"; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) - addAttrName("operand_segment_sizes"); + addAttrName(operandSizes); + std::string attrSizes = "result_segment_sizes"; if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) - addAttrName("result_segment_sizes"); + addAttrName(attrSizes); // Emit the getAttributeNames method. { @@ -656,7 +602,7 @@ // users. const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; for (const std::pair &attrIt : attributeNames) { - for (StringRef name : getGetterNames(op, attrIt.first)) { + for (StringRef name : op.getGetterNames(attrIt.first)) { std::string methodName = (name + "AttrName").str(); // Generate the non-static variant. @@ -734,7 +680,7 @@ }; for (const NamedAttribute &namedAttr : op.getAttributes()) { - for (StringRef name : getGetterNames(op, namedAttr.name)) { + for (StringRef name : op.getGetterNames(namedAttr.name)) { if (namedAttr.attr.isDerivedAttr()) { emitDerivedAttr(name, namedAttr.attr); } else { @@ -777,8 +723,9 @@ if (!nonMaterializable.empty()) { std::string attrs; llvm::raw_string_ostream os(attrs); - interleaveComma(nonMaterializable, os, - [&](const NamedAttribute &attr) { os << attr.name; }); + interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) { + os << op.getGetterName(attr.name); + }); PrintWarning( op.getLoc(), formatv( @@ -799,8 +746,9 @@ derivedAttrs, body, [&](const NamedAttribute &namedAttr) { auto tmpl = namedAttr.attr.getConvertFromStorageCall(); - body << " {" << namedAttr.name << "AttrName(),\n" - << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()") + std::string name = op.getGetterName(namedAttr.name); + body << " {" << name << "AttrName(),\n" + << tgfmt(tmpl, &fctx.withSelf(name + "()") .withBuilder("odsBuilder") .addSubst("_ctx", "ctx")) << "}"; @@ -826,8 +774,8 @@ for (const NamedAttribute &namedAttr : op.getAttributes()) { if (!namedAttr.attr.isDerivedAttr()) - for (auto names : llvm::zip(getSetterNames(op, namedAttr.name), - getGetterNames(op, namedAttr.name))) + for (auto names : llvm::zip(op.getSetterNames(namedAttr.name), + op.getGetterNames(namedAttr.name))) emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), namedAttr.attr); } @@ -843,7 +791,7 @@ "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str()); if (!method) return; - method->body() << " return (*this)->removeAttr(" << getGetterName(op, name) + method->body() << " return (*this)->removeAttr(" << op.getGetterName(name) << "AttrName());"; }; @@ -945,7 +893,7 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - for (StringRef name : getGetterNames(op, operand.name)) { + for (StringRef name : op.getGetterNames(operand.name)) { if (operand.isOptional()) { m = opClass.addMethodAndPrune("::mlir::Value", name); ERROR_IF_PRUNED(m, name, op); @@ -953,8 +901,8 @@ << " return operands.empty() ? ::mlir::Value() : " "*operands.begin();"; } else if (operand.isVariadicOfVariadic()) { - StringRef segmentAttr = - operand.constraint.getVariadicOfVariadicSegmentSizeAttr(); + std::string segmentAttr = op.getGetterName( + operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); if (isAdaptor) { m = opClass.addMethodAndPrune( "::llvm::SmallVector<::mlir::ValueRange>", name); @@ -982,13 +930,12 @@ } void OpEmitter::genNamedOperandGetters() { - // Build the code snippet used for initializing the operand_segment_sizes + // Build the code snippet used for initializing the operand_segment_size)s // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - attrSizeInitCode = - formatv(opSegmentSizeAttrInitCode, "operand_segment_sizesAttrName()") - .str(); + std::string attr = op.getGetterName("operand_segment_sizes") + "AttrName()"; + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); } generateNamedOperandGetters( @@ -1008,7 +955,7 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - for (StringRef name : getGetterNames(op, operand.name)) { + for (StringRef name : op.getGetterNames(operand.name)) { auto *m = opClass.addMethodAndPrune( operand.isVariadicOfVariadic() ? "::mlir::MutableOperandRangeRange" : "::mlir::MutableOperandRange", @@ -1022,7 +969,7 @@ if (attrSizedOperands) body << ", ::mlir::MutableOperandRange::OperandSegment(" << i << "u, *getOperation()->getAttrDictionary().getNamed(" - "operand_segment_sizesAttrName()))"; + << op.getGetterName("operand_segment_sizes") << "AttrName()))"; body << ");\n"; // If this operand is a nested variadic, we split the range into a @@ -1032,8 +979,7 @@ // body << " return " "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" - << getGetterName( - op, + << op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) << "AttrName()));\n"; } else { @@ -1076,9 +1022,8 @@ // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - attrSizeInitCode = - formatv(opSegmentSizeAttrInitCode, "result_segment_sizesAttrName()") - .str(); + std::string attr = op.getGetterName("result_segment_sizes") + "AttrName()"; + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); } generateValueRangeStartAndEnd( @@ -1096,7 +1041,7 @@ const auto &result = op.getResult(i); if (result.name.empty()) continue; - for (StringRef name : getGetterNames(op, result.name)) { + for (StringRef name : op.getGetterNames(result.name)) { if (result.isOptional()) { m = opClass.addMethodAndPrune("::mlir::Value", name); ERROR_IF_PRUNED(m, name, op); @@ -1123,7 +1068,7 @@ if (region.name.empty()) continue; - for (StringRef name : getGetterNames(op, region.name)) { + for (StringRef name : op.getGetterNames(region.name)) { // Generate the accessors for a variadic region. if (region.isVariadic()) { auto *m = opClass.addMethodAndPrune( @@ -1148,7 +1093,7 @@ if (successor.name.empty()) continue; - for (StringRef name : getGetterNames(op, successor.name)) { + for (StringRef name : op.getGetterNames(successor.name)) { // Generate the accessors for a variadic successor list. if (successor.isVariadic()) { auto *m = opClass.addMethodAndPrune("::mlir::SuccessorRange", name); @@ -1430,7 +1375,7 @@ std::string resultType; const auto &namedAttr = op.getAttribute(0); - body << " auto attrName = " << getGetterName(op, namedAttr.name) + body << " auto attrName = " << op.getGetterName(namedAttr.name) << "AttrName(" << builderOpState << ".name);\n" " for (auto attr : attributes) {\n" @@ -1746,8 +1691,8 @@ << " for (::mlir::ValueRange range : " << argName << ")\n" << " rangeSegments.push_back(range.size());\n" << " " << builderOpState << ".addAttribute(" - << getGetterName( - op, operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) + << op.getGetterName( + operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) << "AttrName(" << builderOpState << ".name), " << odsBuilder << ".getI32TensorAttr(rangeSegments));" << " }\n"; @@ -1761,9 +1706,9 @@ // If the operation has the operand segment size attribute, add it here. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - body << " " << builderOpState - << ".addAttribute(operand_segment_sizesAttrName(" << builderOpState - << ".name), " + std::string sizes = op.getGetterName("operand_segment_sizes"); + body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" + << builderOpState << ".name), " << "odsBuilder.getI32VectorAttr({"; interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { const NamedTypeConstraint &operand = op.getOperand(i); @@ -1816,10 +1761,10 @@ std::string value = std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", - builderOpState, getGetterName(op, namedAttr.name), value); + builderOpState, op.getGetterName(namedAttr.name), value); } else { body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", - builderOpState, getGetterName(op, namedAttr.name), + builderOpState, op.getGetterName(namedAttr.name), namedAttr.name); } if (emitNotNullCheck) @@ -2255,7 +2200,7 @@ ? "{0}()" : "::mlir::MutableArrayRef<::mlir::Region>((*this)" "->getRegion({1}))", - region.name, i); + op.getGetterName(region.name), i); body << ") {\n"; auto constraint = tgfmt(region.constraint.getConditionTemplate(), &verifyCtx.withSelf("region")) @@ -2497,8 +2442,8 @@ ERROR_IF_PRUNED(m, "getOperands", op); m->body() << " return odsOperands;"; } - std::string sizeAttrInit = - formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); + std::string attr = op.getGetterName("operand_segment_sizes"); + std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, attr); generateNamedOperandGetters(op, adaptor, /*isAdaptor=*/true, sizeAttrInit, /*rangeType=*/"::mlir::ValueRange", @@ -2543,7 +2488,8 @@ const auto &name = namedAttr.name; const auto &attr = namedAttr.attr; if (!attr.isDerivedAttr()) - emitAttr(name, attr); + for (auto emitName : op.getGetterNames(name)) + emitAttr(emitName, attr); } unsigned numRegions = op.getNumRegions(); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -873,7 +873,8 @@ } /// Generate the storage code required for parsing the given element. -static void genElementParserStorage(Element *element, OpMethodBody &body) { +static void genElementParserStorage(Element *element, const Operator &op, + OpMethodBody &body) { if (auto *optional = dyn_cast(element)) { auto elements = optional->getThenElements(); @@ -885,13 +886,13 @@ elidedAnchorElement = anchor; for (auto &childElement : elements) if (&childElement != elidedAnchorElement) - genElementParserStorage(&childElement, body); + genElementParserStorage(&childElement, op, body); for (auto &childElement : optional->getElseElements()) - genElementParserStorage(&childElement, body); + genElementParserStorage(&childElement, op, body); } else if (auto *custom = dyn_cast(element)) { for (auto ¶mElement : custom->getArguments()) - genElementParserStorage(¶mElement, body); + genElementParserStorage(¶mElement, op, body); } else if (isa(element)) { body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " @@ -1188,7 +1189,7 @@ // allows for referencing these variables in the presence of optional // groupings. for (auto &element : elements) - genElementParserStorage(&*element, body); + genElementParserStorage(&*element, op, body); // A format context used when parsing attributes with buildable types. FmtContext attrTypeCtx; @@ -1735,36 +1736,38 @@ /// Generate the printer for a custom directive parameter. static void genCustomDirectiveParameterPrinter(Element *element, + const Operator &op, OpMethodBody &body) { if (auto *attr = dyn_cast(element)) { - body << attr->getVar()->name << "Attr()"; + body << op.getGetterName(attr->getVar()->name) << "Attr()"; } else if (isa(element)) { body << "getOperation()->getAttrDictionary()"; } else if (auto *operand = dyn_cast(element)) { - body << operand->getVar()->name << "()"; + body << op.getGetterName(operand->getVar()->name) << "()"; } else if (auto *region = dyn_cast(element)) { - body << region->getVar()->name << "()"; + body << op.getGetterName(region->getVar()->name) << "()"; } else if (auto *successor = dyn_cast(element)) { - body << successor->getVar()->name << "()"; + body << op.getGetterName(successor->getVar()->name) << "()"; } else if (auto *dir = dyn_cast(element)) { - genCustomDirectiveParameterPrinter(dir->getOperand(), body); + genCustomDirectiveParameterPrinter(dir->getOperand(), op, body); } else if (auto *dir = dyn_cast(element)) { auto *typeOperand = dir->getOperand(); auto *operand = dyn_cast(typeOperand); auto *var = operand ? operand->getVar() : cast(typeOperand)->getVar(); + std::string name = op.getGetterName(var->name); if (var->isVariadic()) - body << var->name << "().getTypes()"; + body << name << "().getTypes()"; else if (var->isOptional()) - body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); + body << llvm::formatv("({0}() ? {0}().getType() : Type())", name); else - body << var->name << "().getType()"; + body << name << "().getType()"; } else { llvm_unreachable("unknown custom directive parameter"); } @@ -1772,11 +1775,11 @@ /// Generate the printer for a custom directive. static void genCustomDirectivePrinter(CustomDirective *customDir, - OpMethodBody &body) { + const Operator &op, OpMethodBody &body) { body << " print" << customDir->getName() << "(p, *this"; for (Element ¶m : customDir->getArguments()) { body << ", "; - genCustomDirectiveParameterPrinter(¶m, body); + genCustomDirectiveParameterPrinter(¶m, op, body); } body << ");\n"; } @@ -1800,7 +1803,8 @@ } /// Generate the C++ for an operand to a (*-)type directive. -static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) { +static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op, + OpMethodBody &body) { if (isa(arg)) return body << "getOperation()->getOperandTypes()"; if (isa(arg)) @@ -1808,26 +1812,29 @@ auto *operand = dyn_cast(arg); auto *var = operand ? operand->getVar() : cast(arg)->getVar(); if (var->isVariadicOfVariadic()) - return body << llvm::formatv("{0}().join().getTypes()", var->name); + return body << llvm::formatv("{0}().join().getTypes()", + op.getGetterName(var->name)); if (var->isVariadic()) - return body << var->name << "().getTypes()"; + return body << op.getGetterName(var->name) << "().getTypes()"; if (var->isOptional()) return body << llvm::formatv( "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " "::llvm::ArrayRef<::mlir::Type>())", - var->name); - return body << "::llvm::ArrayRef<::mlir::Type>(" << var->name - << "().getType())"; + op.getGetterName(var->name)); + return body << "::llvm::ArrayRef<::mlir::Type>(" + << op.getGetterName(var->name) << "().getType())"; } /// Generate the printer for an enum attribute. -static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) { +static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, + OpMethodBody &body) { Attribute baseAttr = var->attr.getBaseAttr(); const EnumAttr &enumAttr = cast(baseAttr); std::vector cases = enumAttr.getAllCases(); body << llvm::formatv(enumAttrBeginPrinterCode, - (var->attr.isOptional() ? "*" : "") + var->name, + (var->attr.isOptional() ? "*" : "") + + op.getGetterName(var->name), enumAttr.getSymbolToStringFnName()); // Get a string containing all of the cases that can't be represented with a @@ -1897,25 +1904,28 @@ } /// Generate the check for the anchor of an optional group. -static void genOptionalGroupPrinterAnchor(Element *anchor, OpMethodBody &body) { +static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op, + OpMethodBody &body) { TypeSwitch(anchor) .Case([&](auto *element) { const NamedTypeConstraint *var = element->getVar(); + std::string name = op.getGetterName(var->name); if (var->isOptional()) - body << " if (" << var->name << "()) {\n"; + body << " if (" << name << "()) {\n"; else if (var->isVariadic()) - body << " if (!" << var->name << "().empty()) {\n"; + body << " if (!" << name << "().empty()) {\n"; }) .Case([&](RegionVariable *element) { const NamedRegion *var = element->getVar(); + std::string name = op.getGetterName(var->name); // TODO: Add a check for optional regions here when ODS supports it. - body << " if (!" << var->name << "().empty()) {\n"; + body << " if (!" << name << "().empty()) {\n"; }) .Case([&](TypeDirective *element) { - genOptionalGroupPrinterAnchor(element->getOperand(), body); + genOptionalGroupPrinterAnchor(element->getOperand(), op, body); }) .Case([&](FunctionalTypeDirective *element) { - genOptionalGroupPrinterAnchor(element->getInputs(), body); + genOptionalGroupPrinterAnchor(element->getInputs(), op, body); }) .Case([&](AttributeVariable *attr) { body << " if ((*this)->getAttr(\"" << attr->getVar()->name @@ -1943,7 +1953,7 @@ if (OptionalElement *optional = dyn_cast(element)) { // Emit the check for the presence of the anchor element. Element *anchor = optional->getAnchor(); - genOptionalGroupPrinterAnchor(anchor, body); + genOptionalGroupPrinterAnchor(anchor, op, body); // If the anchor is a unit attribute, we don't need to print it. When // parsing, we will add this attribute if this group is present. @@ -1998,47 +2008,53 @@ // If we are formatting as an enum, symbolize the attribute as a string. if (canFormatEnumAttr(var)) - return genEnumAttrPrinter(var, body); + return genEnumAttrPrinter(var, op, body); // If we are formatting as a symbol name, handle it as a symbol name. if (shouldFormatSymbolNameAttr(var)) { - body << " p.printSymbolName(" << var->name << "Attr().getValue());\n"; + body << " p.printSymbolName(" << op.getGetterName(var->name) + << "Attr().getValue());\n"; return; } // Elide the attribute type if it is buildable. if (attr->getTypeBuilder()) - body << " p.printAttributeWithoutType(" << var->name << "Attr());\n"; + body << " p.printAttributeWithoutType(" << op.getGetterName(var->name) + << "Attr());\n"; else - body << " p.printAttribute(" << var->name << "Attr());\n"; + body << " p.printAttribute(" << op.getGetterName(var->name) + << "Attr());\n"; } else if (auto *operand = dyn_cast(element)) { if (operand->getVar()->isVariadicOfVariadic()) { - body << " ::llvm::interleaveComma(" << operand->getVar()->name + body << " ::llvm::interleaveComma(" + << op.getGetterName(operand->getVar()->name) << "(), p, [&](const auto &operands) { p << \"(\" << operands << " "\")\"; });\n"; } else if (operand->getVar()->isOptional()) { - body << " if (::mlir::Value value = " << operand->getVar()->name - << "())\n" + body << " if (::mlir::Value value = " + << op.getGetterName(operand->getVar()->name) << "())\n" << " p << value;\n"; } else { - body << " p << " << operand->getVar()->name << "();\n"; + body << " p << " << op.getGetterName(operand->getVar()->name) << "();\n"; } } else if (auto *region = dyn_cast(element)) { const NamedRegion *var = region->getVar(); + std::string name = op.getGetterName(var->name); if (var->isVariadic()) { - genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait); + genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait); } else { - genRegionPrinter(var->name + "()", body, hasImplicitTermTrait); + genRegionPrinter(name + "()", body, hasImplicitTermTrait); } } else if (auto *successor = dyn_cast(element)) { const NamedSuccessor *var = successor->getVar(); + std::string name = op.getGetterName(var->name); if (var->isVariadic()) - body << " ::llvm::interleaveComma(" << var->name << "(), p);\n"; + body << " ::llvm::interleaveComma(" << name << "(), p);\n"; else - body << " p << " << var->name << "();\n"; + body << " p << " << name << "();\n"; } else if (auto *dir = dyn_cast(element)) { - genCustomDirectivePrinter(dir, body); + genCustomDirectivePrinter(dir, op, body); } else if (isa(element)) { body << " p << getOperation()->getOperands();\n"; } else if (isa(element)) { @@ -2052,16 +2068,16 @@ body << llvm::formatv(" ::llvm::interleaveComma({0}().getTypes(), p, " "[&](::mlir::TypeRange types) {{ p << \"(\" << " "types << \")\"; });\n", - operand->getVar()->name); + op.getGetterName(operand->getVar()->name)); return; } } body << " p << "; - genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; + genTypeOperandPrinter(dir->getOperand(), op, body) << ";\n"; } else if (auto *dir = dyn_cast(element)) { body << " p.printFunctionalType("; - genTypeOperandPrinter(dir->getInputs(), body) << ", "; - genTypeOperandPrinter(dir->getResults(), body) << ");\n"; + genTypeOperandPrinter(dir->getInputs(), op, body) << ", "; + genTypeOperandPrinter(dir->getResults(), op, body) << ");\n"; } else { llvm_unreachable("unknown format element"); }