diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -277,7 +277,7 @@ } patterns.insert(context, newArg); target.addDynamicallyLegalOp( - [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); + [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); } } } diff --git a/mlir/docs/DefiningDialects.md b/mlir/docs/DefiningDialects.md --- a/mlir/docs/DefiningDialects.md +++ b/mlir/docs/DefiningDialects.md @@ -279,59 +279,6 @@ See the documentation for [Canonicalization in MLIR](Canonicalization.md) for a much more detailed description about canonicalization patterns. -### C++ Accessor Prefix - -Historically, MLIR has generated accessors for operation components (such as attribute, operands, -results) using the tablegen definition name verbatim. This means that if an operation was defined -as: - -```tablegen -def MyOp : MyDialect<"op"> { - let arguments = (ins StrAttr:$value, StrAttr:$other_value); -} -``` - -It would have accessors generated for the `value` and `other_value` attributes as follows: - -```c++ -StringAttr MyOp::value(); -void MyOp::value(StringAttr newValue); - -StringAttr MyOp::other_value(); -void MyOp::other_value(StringAttr newValue); -``` - -Since then, we have decided to move accessors over to a style that matches the rest of the -code base. More specifically, this means that we prefix accessors with `get` and `set` -respectively, and transform `snake_style` names to camel case (`UpperCamel` when prefixed, -and `lowerCamel` for individual variable names). If we look at the same example as above, this -would produce: - -```c++ -StringAttr MyOp::getValue(); -void MyOp::setValue(StringAttr newValue); - -StringAttr MyOp::getOtherValue(); -void MyOp::setOtherValue(StringAttr newValue); -``` - -The form in which accessors are generated is controlled by the `emitAccessorPrefix` field. -This field may any of the following values: - -* `kEmitAccessorPrefix_Raw` - - Don't emit any `get`/`set` prefix. - -* `kEmitAccessorPrefix_Prefixed` - - Only emit with `get`/`set` prefix. - -* `kEmitAccessorPrefix_Both` - - Emit with **and** without prefix. - -All new dialects are strongly encouraged to use the default `kEmitAccessorPrefix_Prefixed` -value, as the `Raw` form is deprecated and in the process of being removed. - -Note: Remove this section when all dialects have been switched to the new accessor form. - ## Defining an Extensible dialect This section documents the design and API of the extensible dialects. Extensible diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -555,14 +555,14 @@ let extraClassDeclaration = [{ static StringRef getMapAttrStrName() { return "map"; } AffineMap getAffineMap() { return getMap(); } - ValueRange getMapOperands() { return operands(); } + ValueRange getMapOperands() { return getOperands(); } ValueRange getDimOperands() { - return OperandRange{operands().begin(), - operands().begin() + getMap().getNumDims()}; + return OperandRange{getOperands().begin(), + getOperands().begin() + getMap().getNumDims()}; } ValueRange getSymbolOperands() { - return OperandRange{operands().begin() + getMap().getNumDims(), - operands().end()}; + return OperandRange{getOperands().begin() + getMap().getNumDims(), + getOperands().end()}; } }]; let hasCustomAssemblyFormat = 1; diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td --- a/mlir/include/mlir/IR/DialectBase.td +++ b/mlir/include/mlir/IR/DialectBase.td @@ -17,11 +17,6 @@ // Dialect definitions //===----------------------------------------------------------------------===// -// "Enum" values for emitAccessorPrefix of Dialect. -defvar kEmitAccessorPrefix_Raw = 0; // Don't emit any getter/setter prefix. -defvar kEmitAccessorPrefix_Prefixed = 1; // Only emit with getter/setter prefix. -defvar kEmitAccessorPrefix_Both = 2; // Emit without and with prefix. - class Dialect { // The name of the dialect. string name = ?; @@ -88,17 +83,6 @@ // If this dialect overrides the hook for canonicalization patterns. bit hasCanonicalizer = 0; - // Whether to emit raw/with no prefix or format changes, or emit with - // accessor with prefix only and UpperCamel suffix or to emit accessors with - // both. - // - // If emitting with prefix is specified then the attribute/operand's - // name is converted to UpperCamel from snake_case (which would result in - // leaving UpperCamel unchanged while also converting lowerCamel to - // UpperCamel) and prefixed with `get` or `set` depending on if it is a getter - // or setter. - int emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; - // If this dialect can be extended at runtime with new operations or types. bit isExtensible = 0; } diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -98,10 +98,6 @@ // Returns whether the dialect is defined. explicit operator bool() const { return def != nullptr; } - // Returns how the accessors should be prefixed in dialect. - enum class EmitPrefix { Raw = 0, Prefixed = 1, Both = 2 }; - EmitPrefix getEmitAccessorPrefix() const; - private: const llvm::Record *def; std::vector dependentDialects; diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -302,16 +302,11 @@ // Returns the builders of this operation. ArrayRef getBuilders() const { return builders; } - // Returns the preferred getter name for the accessor. - std::string getGetterName(StringRef name) const { - return getGetterNames(name).front(); - } - - // Returns the getter names for the accessor. - SmallVector getGetterNames(StringRef name) const; + // Returns the getter name for the accessor of `name`. + std::string getGetterName(StringRef name) const; - // Returns the setter names for the accessor. - SmallVector getSetterNames(StringRef name) const; + // Returns the setter name for the accessor of `name`. + std::string getSetterName(StringRef name) const; private: // Populates the vectors containing operands, attributes, results and traits. diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -103,7 +103,7 @@ LogicalResult matchAndRewrite(AffineMinOp op, PatternRewriter &rewriter) const override { Value reduced = - lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.operands()); + lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.getOperands()); if (!reduced) return failure(); @@ -119,7 +119,7 @@ LogicalResult matchAndRewrite(AffineMaxOp op, PatternRewriter &rewriter) const override { Value reduced = - lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.operands()); + lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.getOperands()); if (!reduced) return failure(); @@ -141,7 +141,7 @@ rewriter.replaceOpWithNewOp(op); return success(); } - rewriter.replaceOpWithNewOp(op, op.operands()); + rewriter.replaceOpWithNewOp(op, op.getOperands()); return success(); } }; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -536,7 +536,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( async::YieldOp yieldOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType)) + if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType)) return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); Location loc = yieldOp.getLoc(); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -51,7 +51,7 @@ MutableOperandRange YieldOp::getMutableSuccessorOperands(Optional index) { - return operandsMutable(); + return getOperandsMutable(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -82,7 +82,7 @@ SmallVector newReturnValues; BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); DenseMap resultToArgs; - for (const auto &it : llvm::enumerate(returnOp.operands())) { + for (const auto &it : llvm::enumerate(returnOp.getOperands())) { bool erased = false; for (BlockArgument bbArg : funcOp.getArguments()) { Value val = it.value(); @@ -105,7 +105,7 @@ // Update function. funcOp.eraseResults(erasedResultIndices); - returnOp.operandsMutable().assign(newReturnValues); + returnOp.getOperandsMutable().assign(newReturnValues); // Update function calls. module.walk([&](func::CallOp callOp) { @@ -114,7 +114,7 @@ rewriter.setInsertionPoint(callOp); auto newCallOp = rewriter.create(callOp.getLoc(), funcOp, - callOp.operands()); + callOp.getOperands()); SmallVector newResults; int64_t nextResult = 0; for (int64_t i = 0; i < callOp.getNumResults(); ++i) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -483,7 +483,7 @@ } // 3. Rewrite the terminator without the in-place bufferizable values. - returnOp.operandsMutable().assign(returnValues); + returnOp.getOperandsMutable().assign(returnValues); // 4. Rewrite the FuncOp type to buffer form. funcOp.setType(FunctionType::get(op->getContext(), argTypes, diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1057,14 +1057,14 @@ FunctionType funType = function.getFunctionType(); - if (funType.getNumResults() != operands().size()) + if (funType.getNumResults() != getOperands().size()) return emitOpError() .append("expected ", funType.getNumResults(), " result operands") .attachNote(function.getLoc()) .append("return type declared here"); for (const auto &pair : llvm::enumerate( - llvm::zip(function.getFunctionType().getResults(), operands()))) { + llvm::zip(function.getFunctionType().getResults(), getOperands()))) { auto [type, operand] = pair.value(); if (type != operand.getType()) return emitOpError() << "unexpected type `" << operand.getType() diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1080,7 +1080,7 @@ auto minOp1 = v1.getDefiningOp(); auto minOp2 = v2.getDefiningOp(); if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && - minOp1.operands() == minOp2.operands()) + minOp1.getOperands() == minOp2.getOperands()) continue; // Add additional cases as needed. diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -192,8 +192,8 @@ return failure(); }; - return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), - op.operands(), IsMin, loopMatcher); + return scf::canonicalizeMinMaxOpInLoop( + rewriter, op, op.getAffineMap(), op.getOperands(), IsMin, loopMatcher); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -167,14 +167,14 @@ forOp.walk([&](OpTy affineOp) { AffineMap map = affineOp.getAffineMap(); (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, - affineOp.operands(), IsMin, mainIv, + affineOp.getOperands(), IsMin, mainIv, previousUb, step, /*insideLoop=*/true); }); partialIteration.walk([&](OpTy affineOp) { AffineMap map = affineOp.getAffineMap(); (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, - affineOp.operands(), IsMin, partialIv, + affineOp.getOperands(), IsMin, partialIv, previousUb, step, /*insideLoop=*/false); }); } diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -67,7 +67,7 @@ assumingOp.getDoRegion().front().getTerminator()); // Create new op and move over region. - TypeRange newResultTypes(yieldOp.operands()); + TypeRange newResultTypes(yieldOp.getOperands()); auto newOp = rewriter.create( op->getLoc(), newResultTypes, assumingOp.getWitness()); newOp.getDoRegion().takeBody(assumingOp.getRegion()); @@ -130,7 +130,7 @@ const BufferizationOptions &options) const { auto yieldOp = cast(op); SmallVector newResults; - for (Value value : yieldOp.operands()) { + for (Value value : yieldOp.getOperands()) { if (value.getType().isa()) { FailureOr buffer = getBuffer(rewriter, value, options); if (failed(buffer)) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4792,7 +4792,7 @@ auto isNotDefByConstant = [](Value operand) { return !isa_and_nonnull(operand.getDefiningOp()); }; - if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant)) + if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant)) return failure(); // CreateMaskOp for scalable vectors can be folded only if all dimensions @@ -4809,7 +4809,7 @@ // Gather constant mask dimension sizes. SmallVector maskDimSizes; - for (auto it : llvm::zip(createMaskOp.operands(), + for (auto it : llvm::zip(createMaskOp.getOperands(), createMaskOp.getType().getShape())) { auto *defOp = std::get<0>(it).getDefiningOp(); int64_t maxDimSize = std::get<1>(it); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -182,7 +182,7 @@ cast(newOpBody.getBlocks().begin()->getTerminator()); rewriter.updateRootInPlace( - yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); + yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); }); return newWarpOp; } @@ -348,7 +348,7 @@ SmallVector replacements; auto yieldOp = cast(ifOp.thenBlock()->getTerminator()); Location yieldLoc = yieldOp.getLoc(); - for (const auto &it : llvm::enumerate(yieldOp.operands())) { + for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { Value sequentialVal = it.value(); Value distributedVal = warpOp->getResult(it.index()); DistributedLoadStoreHelper helper(sequentialVal, distributedVal, @@ -378,7 +378,7 @@ } // Step 6. Insert sync after all the stores and before all the loads. - if (!yieldOp.operands().empty()) { + if (!yieldOp.getOperands().empty()) { rewriter.setInsertionPointAfter(ifOp); options.warpSyncronizationFn(loc, rewriter, warpOp); } diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -98,14 +98,6 @@ return def->getValueAsBit("useDefaultTypePrinterParser"); } -Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const { - int prefix = def->getValueAsInt("emitAccessorPrefix"); - if (prefix < 0 || prefix > static_cast(EmitPrefix::Both)) - PrintFatalError(def->getLoc(), "Invalid accessor prefix value"); - - return static_cast(prefix); -} - bool Dialect::isExtensible() const { return def->getValueAsBit("isExtensible"); } diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -69,6 +69,43 @@ return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); } +/// Assert the invariants of accessors generated for the given name. +static void assertAccessorInvariants(const Operator &op, StringRef name) { + std::string accessorName = + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); + + // Functor used to detect when an accessor will cause an overlap with an + // operation API. + // + // There are a little bit more invasive checks possible for cases where not + // all ops have the trait that would cause overlap. For many cases here, + // renaming would be better (e.g., we can only guard in limited manner + // against methods from traits and interfaces here, so avoiding these in op + // definition is safer). + auto nameOverlapsWithOpAPI = [&](StringRef newName) { + if (newName == "AttributeNames" || newName == "Attributes" || + newName == "Operation") + return true; + if (newName == "Operands") + return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1; + if (newName == "Regions") + return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1; + if (newName == "Type") + return op.getNumResults() != 1; + return false; + }; + if (nameOverlapsWithOpAPI(name)) { + // This error could be avoided in situations where the final function is + // identical, but preferably the op definition should avoid using generic + // names. + PrintFatalError(op.getLoc(), + "generated accessor `" + accessorName + + "` overlaps with a default one; please " + "rename the offending field to avoid overlap (`" + + name + "`)"); + } +} + void Operator::assertInvariants() const { // Check that the name of arguments/results/regions/successors don't overlap. DenseMap existingNames; @@ -76,8 +113,11 @@ if (name.empty()) return; auto insertion = existingNames.insert({name, entity}); - if (insertion.second) + if (insertion.second) { + // Assert invariants for accessors generated for this name. + assertAccessorInvariants(*this, name); return; + } if (entity == insertion.first->second) PrintFatalError(getLoc(), "op has a conflict with two " + entity + " having the same name '" + name + "'"); @@ -692,82 +732,10 @@ return attrOrOperandMapping[index]; } -// Helper to return the names for accessor. -static SmallVector -getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) { - Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix(); - std::string prefix; - if (prefixType != Dialect::EmitPrefix::Raw) - prefix = isGetter ? "get" : "set"; - - SmallVector names; - bool rawToo = prefixType == Dialect::EmitPrefix::Both; - - // Whether to skip generating prefixed form for argument. This just does some - // basic checks. - // - // There are a little bit more invasive checks possible for cases where not - // all ops have the trait that would cause overlap. For many cases here, - // renaming would be better (e.g., we can only guard in limited manner against - // methods from traits and interfaces here, so avoiding these in op definition - // is safer). - auto skip = [&](StringRef newName) { - bool shouldSkip = newName == "getAttributeNames" || - newName == "getAttributes" || newName == "getOperation"; - if (newName == "getOperands") { - // To reduce noise, skip generating the prefixed form and the warning if - // $operands correspond to single variadic argument. - if (op.getNumOperands() == 1 && op.getNumVariableLengthOperands() == 1) - return true; - shouldSkip = true; - } - if (newName == "getRegions") { - if (op.getNumRegions() == 1 && op.getNumVariadicRegions() == 1) - return true; - shouldSkip = true; - } - if (newName == "getType") { - if (op.getNumResults() != 1) - return false; - shouldSkip = true; - } - if (!shouldSkip) - return false; - - // This note could be avoided where the final function generated would - // have been identical. But preferably in the op definition avoiding using - // the generic name and then getting a more specialize type is better. - PrintNote(op.getLoc(), - "Skipping generation of prefixed accessor `" + newName + - "` as it overlaps with default one; generating raw form (`" + - name + "`) still"); - return true; - }; - - if (!prefix.empty()) { - names.push_back( - prefix + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true)); - // Skip cases which would overlap with default ones for now. - if (skip(names.back())) { - rawToo = true; - names.clear(); - } else if (rawToo) { - LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName() - << "::" << name << "\")\n" - << "WITH_GETTER(\"" << op.getQualCppClassName() - << "Adaptor::" << name << "\")\n";); - } - } - - if (prefix.empty() || rawToo) - names.push_back(name.str()); - return names; -} - -SmallVector Operator::getGetterNames(StringRef name) const { - return getGetterOrSetterNames(/*isGetter=*/true, *this, name); +std::string Operator::getGetterName(StringRef name) const { + return "get" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); } -SmallVector Operator::getSetterNames(StringRef name) const { - return getGetterOrSetterNames(/*isGetter=*/false, *this, name); +std::string Operator::getSetterName(StringRef name) const { + return "set" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); } diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -109,9 +109,9 @@ // DEFS-LABEL: NS::AOp definitions // DEFS: AOpAdaptor::AOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsOperands(values), odsAttrs(attrs), odsRegions(regions) -// DEFS: ::mlir::RegionRange AOpAdaptor::getRegions() // DEFS: ::mlir::RegionRange AOpAdaptor::getSomeRegions() // DEFS-NEXT: return odsRegions.drop_front(1); +// DEFS: ::mlir::RegionRange AOpAdaptor::getRegions() // Check AttrSizedOperandSegments // --- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -933,26 +933,24 @@ // users. const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) { - for (StringRef name : op.getGetterNames(attrIt.value())) { - std::string methodName = (name + "AttrName").str(); - - // Generate the non-static variant. - { - auto *method = - opClass.addInlineMethod("::mlir::StringAttr", methodName); - ERROR_IF_PRUNED(method, methodName, op); - method->body() << llvm::formatv(attrNameMethodBody, attrIt.index()); - } + std::string name = op.getGetterName(attrIt.value()); + std::string methodName = name + "AttrName"; + + // Generate the non-static variant. + { + auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); + ERROR_IF_PRUNED(method, methodName, op); + method->body() << llvm::formatv(attrNameMethodBody, attrIt.index()); + } - // Generate the static variant. - { - auto *method = opClass.addStaticInlineMethod( - "::mlir::StringAttr", methodName, - MethodParameter("::mlir::OperationName", "name")); - ERROR_IF_PRUNED(method, methodName, op); - method->body() << llvm::formatv(attrNameMethodBody, - "name, " + Twine(attrIt.index())); - } + // Generate the static variant. + { + auto *method = opClass.addStaticInlineMethod( + "::mlir::StringAttr", methodName, + MethodParameter("::mlir::OperationName", "name")); + ERROR_IF_PRUNED(method, methodName, op); + method->body() << llvm::formatv(attrNameMethodBody, + "name, " + Twine(attrIt.index())); } } } @@ -1014,13 +1012,12 @@ }; for (const NamedAttribute &namedAttr : op.getAttributes()) { - for (StringRef name : op.getGetterNames(namedAttr.name)) { - if (namedAttr.attr.isDerivedAttr()) { - emitDerivedAttr(name, namedAttr.attr); - } else { - emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr); - emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr); - } + std::string name = op.getGetterName(namedAttr.name); + if (namedAttr.attr.isDerivedAttr()) { + emitDerivedAttr(name, namedAttr.attr); + } else { + emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr); + emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr); } } @@ -1165,12 +1162,10 @@ for (const NamedAttribute &namedAttr : op.getAttributes()) { if (namedAttr.attr.isDerivedAttr()) continue; - for (auto [setterName, getterName] : - llvm::zip(op.getSetterNames(namedAttr.name), - op.getGetterNames(namedAttr.name))) { - emitAttrWithStorageType(setterName, getterName, namedAttr.attr); - emitAttrWithReturnType(setterName, getterName, namedAttr.attr); - } + std::string setterName = op.getSetterName(namedAttr.name); + std::string getterName = op.getGetterName(namedAttr.name); + emitAttrWithStorageType(setterName, getterName, namedAttr.attr); + emitAttrWithReturnType(setterName, getterName, namedAttr.attr); } } @@ -1305,38 +1300,36 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - for (StringRef name : op.getGetterNames(operand.name)) { - if (operand.isOptional()) { - m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); - ERROR_IF_PRUNED(m, name, op); - m->body() << " auto operands = getODSOperands(" << i << ");\n" - << " return operands.empty() ? ::mlir::Value() : " - "*operands.begin();"; - } else if (operand.isVariadicOfVariadic()) { - std::string segmentAttr = op.getGetterName( - operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); - if (isAdaptor) { - m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", - name); - ERROR_IF_PRUNED(m, name, op); - m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, - segmentAttr, i); - continue; - } - - m = opClass.addMethod("::mlir::OperandRangeRange", name); - ERROR_IF_PRUNED(m, name, op); - m->body() << " return getODSOperands(" << i << ").split(" - << segmentAttr << "Attr());"; - } else if (operand.isVariadic()) { - m = opClass.addMethod(rangeType, name); - ERROR_IF_PRUNED(m, name, op); - m->body() << " return getODSOperands(" << i << ");"; - } else { - m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); + std::string name = op.getGetterName(operand.name); + if (operand.isOptional()) { + m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " auto operands = getODSOperands(" << i << ");\n" + << " return operands.empty() ? ::mlir::Value() : " + "*operands.begin();"; + } else if (operand.isVariadicOfVariadic()) { + std::string segmentAttr = op.getGetterName( + operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); + if (isAdaptor) { + m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", name); ERROR_IF_PRUNED(m, name, op); - m->body() << " return *getODSOperands(" << i << ").begin();"; + m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, + segmentAttr, i); + continue; } + + m = opClass.addMethod("::mlir::OperandRangeRange", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr + << "Attr());"; + } else if (operand.isVariadic()) { + m = opClass.addMethod(rangeType, name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return getODSOperands(" << i << ");"; + } else { + m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return *getODSOperands(" << i << ").begin();"; } } } @@ -1367,37 +1360,37 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - for (StringRef name : op.getGetterNames(operand.name)) { - auto *m = opClass.addMethod(operand.isVariadicOfVariadic() - ? "::mlir::MutableOperandRangeRange" - : "::mlir::MutableOperandRange", - (name + "Mutable").str()); - ERROR_IF_PRUNED(m, name, op); - auto &body = m->body(); - body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" - << " auto mutableRange = " - "::mlir::MutableOperandRange(getOperation(), " - "range.first, range.second"; - if (attrSizedOperands) { - body << formatv( - ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, - emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true)); - } - body << ");\n"; - - // If this operand is a nested variadic, we split the range into a - // MutableOperandRangeRange that provides a range over all of the - // sub-ranges. - if (operand.isVariadicOfVariadic()) { - body << " return " - "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" - << op.getGetterName( - operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) - << "AttrName()));\n"; - } else { - // Otherwise, we use the full range directly. - body << " return mutableRange;\n"; - } + std::string name = op.getGetterName(operand.name); + + auto *m = opClass.addMethod(operand.isVariadicOfVariadic() + ? "::mlir::MutableOperandRangeRange" + : "::mlir::MutableOperandRange", + name + "Mutable"); + ERROR_IF_PRUNED(m, name, op); + auto &body = m->body(); + body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" + << " auto mutableRange = " + "::mlir::MutableOperandRange(getOperation(), " + "range.first, range.second"; + if (attrSizedOperands) { + body << formatv( + ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, + emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true)); + } + body << ");\n"; + + // If this operand is a nested variadic, we split the range into a + // MutableOperandRangeRange that provides a range over all of the + // sub-ranges. + if (operand.isVariadicOfVariadic()) { + body << " return " + "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" + << op.getGetterName( + operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) + << "AttrName()));\n"; + } else { + // Otherwise, we use the full range directly. + body << " return mutableRange;\n"; } } } @@ -1454,24 +1447,23 @@ const auto &result = op.getResult(i); if (result.name.empty()) continue; - for (StringRef name : op.getGetterNames(result.name)) { - if (result.isOptional()) { - m = opClass.addMethod( - generateTypeForGetter(/*isAdaptor=*/false, result), name); - ERROR_IF_PRUNED(m, name, op); - m->body() - << " auto results = getODSResults(" << i << ");\n" - << " return results.empty() ? ::mlir::Value() : *results.begin();"; - } else if (result.isVariadic()) { - m = opClass.addMethod("::mlir::Operation::result_range", name); - ERROR_IF_PRUNED(m, name, op); - m->body() << " return getODSResults(" << i << ");"; - } else { - m = opClass.addMethod( - generateTypeForGetter(/*isAdaptor=*/false, result), name); - ERROR_IF_PRUNED(m, name, op); - m->body() << " return *getODSResults(" << i << ").begin();"; - } + std::string name = op.getGetterName(result.name); + if (result.isOptional()) { + m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result), + name); + ERROR_IF_PRUNED(m, name, op); + m->body() + << " auto results = getODSResults(" << i << ");\n" + << " return results.empty() ? ::mlir::Value() : *results.begin();"; + } else if (result.isVariadic()) { + m = opClass.addMethod("::mlir::Operation::result_range", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return getODSResults(" << i << ");"; + } else { + m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result), + name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return *getODSResults(" << i << ").begin();"; } } } @@ -1482,22 +1474,21 @@ const auto ®ion = op.getRegion(i); if (region.name.empty()) continue; + std::string name = op.getGetterName(region.name); - for (StringRef name : op.getGetterNames(region.name)) { - // Generate the accessors for a variadic region. - if (region.isVariadic()) { - auto *m = - opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name); - ERROR_IF_PRUNED(m, name, op); - m->body() << formatv(" return (*this)->getRegions().drop_front({0});", - i); - continue; - } - - auto *m = opClass.addMethod("::mlir::Region &", name); + // Generate the accessors for a variadic region. + if (region.isVariadic()) { + auto *m = + opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name); ERROR_IF_PRUNED(m, name, op); - m->body() << formatv(" return (*this)->getRegion({0});", i); + m->body() << formatv(" return (*this)->getRegions().drop_front({0});", + i); + continue; } + + auto *m = opClass.addMethod("::mlir::Region &", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << formatv(" return (*this)->getRegion({0});", i); } } @@ -1507,23 +1498,21 @@ const NamedSuccessor &successor = op.getSuccessor(i); if (successor.name.empty()) continue; - - for (StringRef name : op.getGetterNames(successor.name)) { - // Generate the accessors for a variadic successor list. - if (successor.isVariadic()) { - auto *m = opClass.addMethod("::mlir::SuccessorRange", name); - ERROR_IF_PRUNED(m, name, op); - m->body() << formatv( - " return {std::next((*this)->successor_begin(), {0}), " - "(*this)->successor_end()};", - i); - continue; - } - - auto *m = opClass.addMethod("::mlir::Block *", name); + std::string name = op.getGetterName(successor.name); + // Generate the accessors for a variadic successor list. + if (successor.isVariadic()) { + auto *m = opClass.addMethod("::mlir::SuccessorRange", name); ERROR_IF_PRUNED(m, name, op); - m->body() << formatv(" return (*this)->getSuccessor({0});", i); + m->body() << formatv( + " return {std::next((*this)->successor_begin(), {0}), " + "(*this)->successor_end()};", + i); + continue; } + + auto *m = opClass.addMethod("::mlir::Block *", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << formatv(" return (*this)->getSuccessor({0});", i); } } @@ -2992,11 +2981,6 @@ constructor->addMemberInitializer("odsOpName", "op->getName()"); } - { - auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands"); - ERROR_IF_PRUNED(m, "getOperands", op); - m->body() << " return odsOperands;"; - } std::string sizeAttrInit; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, @@ -3009,6 +2993,11 @@ /*rangeSizeCall=*/"odsOperands.size()", /*getOperandCallPattern=*/"odsOperands[{0}]"); + // Any invalid overlap for `getOperands` will have been diagnosed before here + // already. + if (auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands")) + m->body() << " return odsOperands;"; + FmtContext fctx; fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); @@ -3046,36 +3035,35 @@ const auto &attr = namedAttr.attr; if (attr.isDerivedAttr()) continue; - for (const auto &emitName : op.getGetterNames(name)) { - emitAttrWithStorageType(name, emitName, attr); - emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr); - } + std::string emitName = op.getGetterName(name); + emitAttrWithStorageType(name, emitName, attr); + emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr); } unsigned numRegions = op.getNumRegions(); - if (numRegions > 0) { - auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions"); - ERROR_IF_PRUNED(m, "Adaptor::getRegions", op); - m->body() << " return odsRegions;"; - } for (unsigned i = 0; i < numRegions; ++i) { const auto ®ion = op.getRegion(i); if (region.name.empty()) continue; // Generate the accessors for a variadic region. - for (StringRef name : op.getGetterNames(region.name)) { - if (region.isVariadic()) { - auto *m = adaptor.addMethod("::mlir::RegionRange", name); - ERROR_IF_PRUNED(m, "Adaptor::" + name, op); - m->body() << formatv(" return odsRegions.drop_front({0});", i); - continue; - } - - auto *m = adaptor.addMethod("::mlir::Region &", name); + std::string name = op.getGetterName(region.name); + if (region.isVariadic()) { + auto *m = adaptor.addMethod("::mlir::RegionRange", name); ERROR_IF_PRUNED(m, "Adaptor::" + name, op); - m->body() << formatv(" return *odsRegions[{0}];", i); + m->body() << formatv(" return odsRegions.drop_front({0});", i); + continue; } + + auto *m = adaptor.addMethod("::mlir::Region &", name); + ERROR_IF_PRUNED(m, "Adaptor::" + name, op); + m->body() << formatv(" return *odsRegions[{0}];", i); + } + if (numRegions > 0) { + // Any invalid overlap for `getRegions` will have been diagnosed before here + // already. + if (auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions")) + m->body() << " return odsRegions;"; } // Add verification function.