diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -30,6 +30,10 @@ // We set this bit to generate a declaration of the `materializeConstant` // method so that we can materialize constants for our toy operations. let hasConstantMaterializer = 1; + + // We set this bit to generate the declarations for the dialect's type parsing + // and printing hooks. + let useDefaultTypePrinterParser = 1; } // Base class for toy dialect operations. This operation inherits from the base diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td @@ -21,20 +21,22 @@ def AsyncDialect : Dialect { let name = "async"; + let cppNamespace = "::mlir::async"; let summary = "Types and operations for async dialect"; let description = [{ This dialect contains operations for modeling asynchronous execution. }]; - let cppNamespace = "::mlir::async"; + let useDefaultTypePrinterParser = 1; let extraClassDeclaration = [{ - // The name of a unit attribute on funcs that are allowed to have a blocking - // async.runtime.await ops. Only useful in combination with - // 'eliminate-blocking-await-ops' option, which in absence of this attribute - // might convert a func to a coroutine. - static constexpr StringRef kAllowedToBlockAttrName = "async.allowed_to_block"; + /// The name of a unit attribute on funcs that are allowed to have a + /// blocking async.runtime.await ops. Only useful in combination with + /// 'eliminate-blocking-await-ops' option, which in absence of this + /// attribute might convert a func to a coroutine. + static constexpr StringRef kAllowedToBlockAttrName = + "async.allowed_to_block"; }]; } diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td --- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td +++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td @@ -35,6 +35,8 @@ constexpr const static ::llvm::StringLiteral kDataLayoutEndiannessLittle = "little"; }]; + + let useDefaultAttributePrinterParser = 1; } def DLTI_DataLayoutEntryAttr : DialectAttr< diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td @@ -30,6 +30,7 @@ let hasConstantMaterializer = 1; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; } #endif // MLIR_DIALECT_EMITC_IR_EMITCBASE diff --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -54,6 +54,7 @@ let dependentDialects = ["arith::ArithmeticDialect"]; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } def GPU_AsyncToken : DialectType< diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -26,9 +26,12 @@ let name = "llvm"; let cppNamespace = "::mlir::LLVM"; + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; let hasOperationAttrVerify = 1; + let extraClassDeclaration = [{ /// Name of the data layout attributes. static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -40,6 +40,7 @@ "memref::MemRefDialect", "tensor::TensorDialect", ]; + let useDefaultAttributePrinterParser = 1; let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; let hasConstantMaterializer = 1; diff --git a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td @@ -32,7 +32,8 @@ representing PTX specific operations while using MLIR high level concepts like memref and 2-D vector. }]; - let useDefaultAttributePrinterParser = 1; + + let useDefaultTypePrinterParser = 1; } /// Device-side synchronization token. diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td @@ -20,8 +20,8 @@ //===----------------------------------------------------------------------===// def PDL_Dialect : Dialect { - string summary = "High level pattern definition dialect"; - string description = [{ + let summary = "High level pattern definition dialect"; + let description = [{ PDL presents a high level abstraction for the rewrite pattern infrastructure available in MLIR. This abstraction allows for representing patterns transforming MLIR, as MLIR. This allows for applying all of the benefits @@ -64,6 +64,8 @@ let name = "pdl"; let cppNamespace = "::mlir::pdl"; + + let useDefaultTypePrinterParser = 1; let extraClassDeclaration = [{ void registerTypes(); }]; diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td --- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td @@ -18,6 +18,8 @@ def Quantization_Dialect : Dialect { let name = "quant"; let cppNamespace = "::mlir::quant"; + + let useDefaultTypePrinterParser = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -47,6 +47,8 @@ }]; let cppNamespace = "::mlir::spirv"; + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -37,6 +37,7 @@ let cppNamespace = "::mlir::shape"; let dependentDialects = ["arith::ArithmeticDialect", "tensor::TensorDialect"]; + let useDefaultTypePrinterParser = 1; let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -1,4 +1,4 @@ -//===- Shape.td - Shape operations definition --------------*- tablegen -*-===// +//===- ShapeOps.td - Shape operations definition -----------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td @@ -72,6 +72,8 @@ * [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. PhD thesis, MIT, February, 2020. }]; + + let useDefaultAttributePrinterParser = 1; } #endif // SPARSETENSOR_BASE diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -23,6 +23,8 @@ def Vector_Dialect : Dialect { let name = "vector"; let cppNamespace = "::mlir::vector"; + + let useDefaultAttributePrinterParser = 1; let hasConstantMaterializer = 1; let dependentDialects = ["arith::ArithmeticDialect"]; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td --- a/mlir/include/mlir/IR/DialectBase.td +++ b/mlir/include/mlir/IR/DialectBase.td @@ -73,13 +73,17 @@ // If this dialect overrides the hook for op interface fallback. bit hasOperationInterfaceFallback = 0; - // If this dialect should use default generated attribute parser boilerplate: - // it'll dispatch the parsing to every individual attributes directly. - bit useDefaultAttributePrinterParser = 1; + // If this dialect should use default generated attribute parser boilerplate. + // When set, ODS will generate declarations for the attribute parsing and + // printing hooks in the dialect and default implementations that dispatch to + // each individual attribute directly. + bit useDefaultAttributePrinterParser = 0; // If this dialect should use default generated type parser boilerplate: - // it'll dispatch the parsing to every individual types directly. - bit useDefaultTypePrinterParser = 1; + // When set, ODS will generate declarations for the type parsing and printing + // hooks in the dialect and default implementations that dispatch to each + // individual type directly. + bit useDefaultTypePrinterParser = 0; // If this dialect overrides the hook for canonicalization patterns. bit hasCanonicalizer = 0; diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -17,6 +17,9 @@ def Python_Test_Dialect : Dialect { let name = "python_test"; let cppNamespace = "python_test"; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; } class TestType diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -574,11 +574,9 @@ protected: DefGenerator(std::vector &&defs, raw_ostream &os, - StringRef defType, StringRef valueType, bool isAttrGenerator, - bool needsDialectParserPrinter) + StringRef defType, StringRef valueType, bool isAttrGenerator) : defRecords(std::move(defs)), os(os), defType(defType), - valueType(valueType), isAttrGenerator(isAttrGenerator), - needsDialectParserPrinter(needsDialectParserPrinter) {} + valueType(valueType), isAttrGenerator(isAttrGenerator) {} /// Emit the list of def type names. void emitTypeDefList(ArrayRef defs); @@ -597,30 +595,19 @@ /// Flag indicating if this generator is for Attributes. False if the /// generator is for types. bool isAttrGenerator; - /// Track if we need to emit the printAttribute/parseAttribute - /// implementations. - bool needsDialectParserPrinter; }; /// A specialized generator for AttrDefs. struct AttrDefGenerator : public DefGenerator { AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os, - "Attr", "Attribute", - /*isAttrGenerator=*/true, - /*needsDialectParserPrinter=*/ - !records.getAllDerivedDefinitions("DialectAttr").empty()) { - } + "Attr", "Attribute", /*isAttrGenerator=*/true) {} }; /// A specialized generator for TypeDefs. struct TypeDefGenerator : public DefGenerator { TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os, - "Type", "Type", - /*isAttrGenerator=*/false, - /*needsDialectParserPrinter=*/ - !records.getAllDerivedDefinitions("DialectType").empty()) { - } + "Type", "Type", /*isAttrGenerator=*/false) {} }; } // namespace @@ -879,10 +866,9 @@ } Dialect firstDialect = defs.front().getDialect(); - // Emit the default parser/printer for Attributes if the dialect asked for - // it. - if (valueType == "Attribute" && needsDialectParserPrinter && - firstDialect.useDefaultAttributePrinterParser()) { + + // Emit the default parser/printer for Attributes if the dialect asked for it. + if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) { NamespaceEmitter nsEmitter(os, firstDialect); if (firstDialect.isExtensible()) { os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, @@ -896,8 +882,7 @@ } // Emit the default parser/printer for Types if the dialect asked for it. - if (valueType == "Type" && needsDialectParserPrinter && - firstDialect.useDefaultTypePrinterParser()) { + if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) { NamespaceEmitter nsEmitter(os, firstDialect); if (firstDialect.isExtensible()) { os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -182,11 +182,7 @@ )"; /// Generate the declaration for the given dialect class. -static void -emitDialectDecl(Dialect &dialect, - const iterator_range &dialectAttrs, - const iterator_range &dialectTypes, - raw_ostream &os) { +static void emitDialectDecl(Dialect &dialect, raw_ostream &os) { // Emit all nested namespaces. { NamespaceEmitter nsEmitter(os, dialect); @@ -198,11 +194,13 @@ os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), superClassName); - // Check for any attributes/types registered to this dialect. If there are, - // add the hooks for parsing/printing. - if (!dialectAttrs.empty() && dialect.useDefaultAttributePrinterParser()) + // If the dialect requested the default attribute printer and parser, emit + // the declarations for the hooks. + if (dialect.useDefaultAttributePrinterParser()) os << attrParserDecl; - if (!dialectTypes.empty() && dialect.useDefaultTypePrinterParser()) + // If the dialect requested the default type printer and parser, emit the + // delcarations for the hooks. + if (dialect.useDefaultTypePrinterParser()) os << typeParserDecl; // Add the decls for the various features of the dialect. @@ -242,10 +240,7 @@ Optional dialect = findDialectToGenerate(dialects); if (!dialect) return true; - auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr"); - auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType"); - emitDialectDecl(*dialect, filterForDialect(attrDefs, *dialect), - filterForDialect(typeDefs, *dialect), os); + emitDialectDecl(*dialect, os); return false; }