diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -28,10 +28,11 @@ endfunction(whole_archive_link) # Declare a dialect in the include directory -function(add_mlir_dialect dialect dialect_doc_filename) +function(add_mlir_dialect dialect dialect_namespace dialect_doc_filename) set(LLVM_TARGET_DEFINITIONS ${dialect}.td) mlir_tablegen(${dialect}.h.inc -gen-op-decls) mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) + mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace}) add_public_tablegen_target(MLIR${dialect}IncGen) add_dependencies(mlir-headers MLIR${dialect}IncGen) diff --git a/mlir/docs/CreatingADialect.md b/mlir/docs/CreatingADialect.md --- a/mlir/docs/CreatingADialect.md +++ b/mlir/docs/CreatingADialect.md @@ -39,7 +39,7 @@ ```cmake -add_mlir_dialect(FooOps FooOps) +add_mlir_dialect(FooOps foo FooOps) ``` diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -36,17 +36,6 @@ /// symbol. bool isTopLevelValue(Value value); -class AffineOpsDialect : public Dialect { -public: - AffineOpsDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "affine"; } - - /// Materialize a single constant operation from a given attribute value with - /// the desired resultant type. - Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, - Location loc) override; -}; - /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data /// from a source memref to a destination memref. The source and destination /// memref need not be of the same dimensionality, but need to have the same @@ -504,6 +493,8 @@ void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl *operands); +#include "mlir/Dialect/AffineOps/AffineOpsDialect.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AffineOps/AffineOps.h.inc" diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -17,14 +17,15 @@ include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffects.td" -def Affine_Dialect : Dialect { +def AffineOps_Dialect : Dialect { let name = "affine"; let cppNamespace = ""; + let hasConstantMaterializer = 1; } // Base class for Affine dialect ops. class Affine_Op traits = []> : - Op { + Op { // For every affine op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) // * LogicalResult verify(${C++ class of Op} op) @@ -290,7 +291,7 @@ } class AffineMinMaxOpBase traits = []> : - Op { + Op { let arguments = (ins AffineMapAttr:$map, Variadic:$operands); let results = (outs Index); diff --git a/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt @@ -1 +1 @@ -add_mlir_dialect(AffineOps AffineOps) +add_mlir_dialect(AffineOps affine AffineOps) diff --git a/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt b/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt @@ -1 +1 @@ -add_mlir_dialect(FxpMathOps FxpMathOps) +add_mlir_dialect(FxpMathOps fxpmath FxpMathOps) diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h --- a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h +++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h @@ -17,11 +17,7 @@ namespace mlir { namespace fxpmath { -/// Defines the 'FxpMathOps' dialect. -class FxpMathOpsDialect : public Dialect { -public: - FxpMathOpsDialect(MLIRContext *context); -}; +#include "mlir/Dialect/FxpMathOps/FxpMathOpsDialect.h.inc" #define GET_OP_CLASSES #include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc" diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td --- a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td +++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td @@ -15,10 +15,10 @@ #define DIALECT_FXPMATHOPS_FXPMATH_OPS_ include "mlir/IR/OpBase.td" -include "mlir/Dialect/QuantOps/QuantPredicates.td" +include "mlir/Dialect/QuantOps/QuantOpsBase.td" include "mlir/Interfaces/SideEffects.td" -def fxpmath_Dialect : Dialect { +def FxpMathOps_Dialect : Dialect { let name = "fxpmath"; } @@ -78,7 +78,7 @@ //===----------------------------------------------------------------------===// class fxpmath_Op traits> : - Op; + Op; //===----------------------------------------------------------------------===// // Fixed-point (fxp) arithmetic ops used by kernels. diff --git a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt --- a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt @@ -1 +1 @@ -add_mlir_dialect(GPUOps GPUOps) +add_mlir_dialect(GPUOps gpu GPUOps) diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -26,51 +26,6 @@ namespace gpu { -/// The dialect containing GPU kernel launching operations and related -/// facilities. -class GPUDialect : public Dialect { -public: - /// Create the dialect in the given `context`. - explicit GPUDialect(MLIRContext *context); - /// Get dialect namespace. - static StringRef getDialectNamespace() { return "gpu"; } - - /// Get the name of the attribute used to annotate the modules that contain - /// kernel modules. - static StringRef getContainerModuleAttrName() { - return "gpu.container_module"; - } - - /// Get the canonical string name of the dialect. - static StringRef getDialectName(); - - /// Get the name of the attribute used to annotate external kernel functions. - static StringRef getKernelFuncAttrName() { return "gpu.kernel"; } - - /// Get the name of the attribute used to annotate kernel modules. - static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; } - - /// Returns whether the given function is a kernel function, i.e., has the - /// 'gpu.kernel' attribute. - static bool isKernel(Operation *op); - - /// Returns the number of workgroup (thread, block) dimensions supported in - /// the GPU dialect. - // TODO(zinenko,herhut): consider generalizing this. - static unsigned getNumWorkgroupDimensions() { return 3; } - - /// Returns the numeric value used to identify the workgroup memory address - /// space. - static unsigned getWorkgroupAddressSpace() { return 3; } - - /// Returns the numeric value used to identify the private memory address - /// space. - static unsigned getPrivateAddressSpace() { return 5; } - - LogicalResult verifyOperationAttribute(Operation *op, - NamedAttribute attr) override; -}; - /// Utility class for the GPU dialect to represent triples of `Value`s /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation. struct KernelDim3 { @@ -79,6 +34,8 @@ Value z; }; +#include "mlir/Dialect/GPU/GPUOpsDialect.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/GPU/GPUOps.h.inc" diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -28,6 +28,39 @@ def GPU_Dialect : Dialect { let name = "gpu"; + let extraClassDeclaration = [{ + /// Get the name of the attribute used to annotate the modules that contain + /// kernel modules. + static StringRef getContainerModuleAttrName() { + return "gpu.container_module"; + } + /// Get the name of the attribute used to annotate external kernel + /// functions. + static StringRef getKernelFuncAttrName() { return "gpu.kernel"; } + + /// Get the name of the attribute used to annotate kernel modules. + static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; } + + /// Returns whether the given function is a kernel function, i.e., has the + /// 'gpu.kernel' attribute. + static bool isKernel(Operation *op); + + /// Returns the number of workgroup (thread, block) dimensions supported in + /// the GPU dialect. + // TODO(zinenko,herhut): consider generalizing this. + static unsigned getNumWorkgroupDimensions() { return 3; } + + /// Returns the numeric value used to identify the workgroup memory address + /// space. + static unsigned getWorkgroupAddressSpace() { return 3; } + + /// Returns the numeric value used to identify the private memory address + /// space. + static unsigned getPrivateAddressSpace() { return 5; } + + LogicalResult verifyOperationAttribute(Operation *op, + NamedAttribute attr) override; + }]; } class GPU_Op traits = []> : diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -1,12 +1,13 @@ set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMOps.h.inc -gen-op-decls) mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) +mlir_tablegen(LLVMOpsDialect.h.inc -gen-dialect-decls) mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRLLVMOpsIncGen) -add_mlir_dialect(NVVMOps NVVMOps) -add_mlir_dialect(ROCDLOps ROCDLOps) +add_mlir_dialect(NVVMOps nvvm NVVMOps) +add_mlir_dialect(ROCDLOps rocdl ROCDLOps) set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -201,32 +201,7 @@ #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOps.h.inc" -class LLVMDialect : public Dialect { -public: - explicit LLVMDialect(MLIRContext *context); - ~LLVMDialect(); - static StringRef getDialectNamespace() { return "llvm"; } - - llvm::LLVMContext &getLLVMContext(); - llvm::Module &getLLVMModule(); - - /// Parse a type registered to this dialect. - Type parseType(DialectAsmParser &parser) const override; - - /// Print a type registered to this dialect. - void printType(Type type, DialectAsmPrinter &os) const override; - - /// Verify a region argument attribute registered to this dialect. - /// Returns failure if the verification failed, success otherwise. - LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx, - unsigned argIdx, - NamedAttribute argAttr) override; - -private: - friend LLVMType; - - std::unique_ptr impl; -}; +#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.h.inc" /// Create an LLVM global containing the string "value" at the module containing /// surrounding the insertion point of builder. Obtain the address of that diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -19,11 +19,28 @@ def LLVM_Dialect : Dialect { let name = "llvm"; let cppNamespace = "LLVM"; + let extraClassDeclaration = [{ + ~LLVMDialect(); + llvm::LLVMContext &getLLVMContext(); + llvm::Module &getLLVMModule(); + + /// Verify a region argument attribute registered to this dialect. + /// Returns failure if the verification failed, success otherwise. + LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx, + unsigned argIdx, + NamedAttribute argAttr) override; + + private: + friend LLVMType; + + std::unique_ptr impl; + }]; } // LLVM IR type wrapped in MLIR. -def LLVM_Type : Type()">, - "LLVM dialect type">; +def LLVM_Type : DialectType()">, + "LLVM dialect type">; // Type constraint accepting only wrapped LLVM integer types. def LLVMInt : TypeConstraint< diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -25,12 +25,7 @@ #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/NVVMOps.h.inc" -class NVVMDialect : public Dialect { -public: - explicit NVVMDialect(MLIRContext *context); - - static StringRef getDialectNamespace() { return "nvvm"; } -}; +#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.h.inc" } // namespace NVVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -33,12 +33,7 @@ #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/ROCDLOps.h.inc" -class ROCDLDialect : public Dialect { -public: - explicit ROCDLDialect(MLIRContext *context); - - static StringRef getDialectNamespace() { return "rocdl"; } -}; +#include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.h.inc" } // namespace ROCDL } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect(LinalgOps LinalgDoc) +add_mlir_dialect(LinalgOps linalg LinalgDoc) set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td) mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls) mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -34,6 +34,6 @@ // Whether a type is a RangeType. def LinalgIsRangeTypePred : CPred<"$_self.isa()">; -def Range : Type; +def Range : DialectType; #endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -21,17 +21,7 @@ LAST_USED_LINALG_TYPE = Range, }; -class LinalgDialect : public Dialect { -public: - explicit LinalgDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "linalg"; } - - /// Parse a type registered to this dialect. - Type parseType(DialectAsmParser &parser) const override; - - /// Print a type registered to this dialect. - void printType(Type type, DialectAsmPrinter &os) const override; -}; +#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" /// A RangeType represents a minimal range abstraction (min, max, step). /// It is constructed by calling the linalg.range op with three values index of diff --git a/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt b/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt @@ -1 +1 @@ -add_mlir_dialect(LoopOps LoopOps) +add_mlir_dialect(LoopOps loop LoopOps) diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h @@ -25,11 +25,7 @@ class TerminatorOp; -class LoopOpsDialect : public Dialect { -public: - LoopOpsDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "loop"; } -}; +#include "mlir/Dialect/LoopOps/LoopOpsDialect.h.inc" #define GET_OP_CLASSES #include "mlir/Dialect/LoopOps/LoopOps.h.inc" diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -16,14 +16,14 @@ include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffects.td" -def Loop_Dialect : Dialect { +def LoopOps_Dialect : Dialect { let name = "loop"; let cppNamespace = ""; } // Base class for Loop dialect ops. class Loop_Op traits = []> : - Op { + Op { // For every standard op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) // * LogicalResult verify(${C++ class of Op} op) diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt --- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt @@ -1 +1 @@ -add_mlir_dialect(OpenMPOps OpenMPOps) +add_mlir_dialect(OpenMPOps omp OpenMPOps) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -22,13 +22,7 @@ #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.h.inc" -class OpenMPDialect : public Dialect { -public: - explicit OpenMPDialect(MLIRContext *context); - - static StringRef getDialectNamespace() { return "omp"; } -}; - +#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc" } // namespace omp } // namespace mlir diff --git a/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt b/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt @@ -1 +1 @@ -add_mlir_dialect(QuantOps QuantOps) +add_mlir_dialect(QuantOps quant QuantOps) diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h @@ -21,17 +21,7 @@ namespace mlir { namespace quant { -/// Defines the 'Quantization' dialect -class QuantizationDialect : public Dialect { -public: - QuantizationDialect(MLIRContext *context); - - /// Parse a type registered to this dialect. - Type parseType(DialectAsmParser &parser) const override; - - /// Print a type registered to this dialect. - void printType(Type type, DialectAsmPrinter &os) const override; -}; +#include "mlir/Dialect/QuantOps/QuantOpsDialect.h.inc" #define GET_OP_CLASSES #include "mlir/Dialect/QuantOps/QuantOps.h.inc" diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -13,20 +13,15 @@ #ifndef DIALECT_QUANTOPS_QUANT_OPS_ #define DIALECT_QUANTOPS_QUANT_OPS_ -include "mlir/IR/OpBase.td" -include "mlir/Dialect/QuantOps/QuantPredicates.td" +include "mlir/Dialect/QuantOps/QuantOpsBase.td" include "mlir/Interfaces/SideEffects.td" -def quant_Dialect : Dialect { - let name = "quant"; -} - //===----------------------------------------------------------------------===// // Base classes //===----------------------------------------------------------------------===// class quant_Op traits> : - Op; + Op; //===----------------------------------------------------------------------===// // Quantization casts diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td b/mlir/include/mlir/Dialect/QuantOps/QuantOpsBase.td rename from mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td rename to mlir/include/mlir/Dialect/QuantOps/QuantOpsBase.td --- a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOpsBase.td @@ -1,4 +1,4 @@ -//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===// +//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,8 +10,14 @@ // //===----------------------------------------------------------------------===// -#ifndef DIALECT_QUANTOPS_QUANT_PREDICATES_ -#define DIALECT_QUANTOPS_QUANT_PREDICATES_ +#ifndef DIALECT_QUANTOPS_QUANT_OPS_BASE_ +#define DIALECT_QUANTOPS_QUANT_OPS_BASE_ + +include "mlir/IR/OpBase.td" + +def Quantization_Dialect : Dialect { + let name = "quant"; +} //===----------------------------------------------------------------------===// // Quantization type definitions @@ -54,10 +60,12 @@ // An implementation of UniformQuantizedType. def quant_UniformQuantizedType : - Type()">, "UniformQuantizedType">; + DialectType()">, + "UniformQuantizedType">; // Predicate for detecting a container or primitive of UniformQuantizedType. def quant_UniformQuantizedValueType : quant_TypedPrimitiveOrContainer; -#endif // DIALECT_QUANTOPS_QUANT_PREDICATES_ +#endif // DIALECT_QUANTOPS_QUANT_OPS_BASE_ diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect(SPIRVOps SPIRVOps) +add_mlir_dialect(SPIRVOps spv SPIRVOps) set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -22,7 +22,7 @@ // SPIR-V dialect definitions //===----------------------------------------------------------------------===// -def SPV_Dialect : Dialect { +def SPIRV_Dialect : Dialect { let name = "spv"; let summary = "The SPIR-V dialect in MLIR."; @@ -46,6 +46,43 @@ }]; let cppNamespace = "spirv"; + let hasConstantMaterializer = 1; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // Type + //===------------------------------------------------------------------===// + + /// Checks if the given `type` is valid in SPIR-V dialect. + static bool isValidType(Type type); + + /// Checks if the given `scalar type` is valid in SPIR-V dialect. + static bool isValidScalarType(Type type); + + //===------------------------------------------------------------------===// + // Attribute + //===------------------------------------------------------------------===// + + /// Returns the attribute name to use when specifying decorations on results + /// of operations. + static std::string getAttributeName(Decoration decoration); + + /// Provides a hook for verifying SPIR-V dialect attributes attached to the + /// given op. + LogicalResult verifyOperationAttribute(Operation *op, + NamedAttribute attribute) override; + + /// Provides a hook for verifying SPIR-V dialect attributes attached to the + /// given op's region argument. + LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex, + unsigned argIndex, + NamedAttribute attribute) override; + + /// Provides a hook for verifying SPIR-V dialect attributes attached to the + /// given op's region result. + LogicalResult verifyRegionResultAttribute( + Operation *op, unsigned regionIndex, unsigned resultIndex, + NamedAttribute attribute) override; + }]; } //===----------------------------------------------------------------------===// @@ -2953,7 +2990,8 @@ // SPIR-V attribute definitions //===----------------------------------------------------------------------===// -def SPV_VerCapExtAttr : Attr< +def SPV_VerCapExtAttr : DialectAttr< + SPIRV_Dialect, CPred<"$_self.isa<::mlir::spirv::VerCapExtAttr>()">, "version-capability-extension attribute"> { let storageType = "::mlir::spirv::VerCapExtAttr"; @@ -2993,10 +3031,14 @@ [SPV_Bool, SPV_Integer, SPV_Float]>; // Component type check is done in the type parser for the following SPIR-V // dialect-specific types so we use "Any" here. -def SPV_AnyPtr : Type; -def SPV_AnyArray : Type; -def SPV_AnyRTArray : Type; -def SPV_AnyStruct : Type; +def SPV_AnyPtr : DialectType; +def SPV_AnyArray : DialectType; +def SPV_AnyRTArray : DialectType; +def SPV_AnyStruct : DialectType; def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>; def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; @@ -3264,7 +3306,7 @@ // Base class for all SPIR-V ops. class SPV_Op traits = []> : - Op, StructFieldAttr<"binding", I32Attr>, StructFieldAttr<"storage_class", SPV_StorageClassAttr> @@ -38,7 +38,7 @@ // For entry functions, this attribute specifies information related to entry // points in the generated SPIR-V module: // 1) WorkGroup Size. -def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPV_Dialect, [ +def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPIRV_Dialect, [ StructFieldAttr<"local_size", I32ElementsAttr> ]>; @@ -54,7 +54,7 @@ // See https://renderdoc.org/vkspec_chunked/chap36.html#limits for the complete // list of limits and their explanation for the Vulkan API. The following ones // are those affecting SPIR-V CodeGen. -def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPV_Dialect, [ +def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPIRV_Dialect, [ StructFieldAttr<"max_compute_workgroup_invocations", I32Attr>, StructFieldAttr<"max_compute_workgroup_size", I32ElementsAttr> ]>; diff --git a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt @@ -1 +1 @@ -add_mlir_dialect(ShapeOps ShapeOps) +add_mlir_dialect(ShapeOps shape ShapeOps) diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -21,13 +21,6 @@ namespace mlir { namespace shape { -/// This dialect contains shape inference related operations and facilities. -class ShapeDialect : public Dialect { -public: - /// Create the dialect in the given `context`. - explicit ShapeDialect(MLIRContext *context); -}; - namespace ShapeTypes { enum Kind { Component = Type::FIRST_SHAPE_TYPE, @@ -112,6 +105,8 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Shape/IR/ShapeOps.h.inc" +#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.h.inc" + } // namespace shape } // namespace mlir diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_TARGET_DEFINITIONS Ops.td) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsDialect.h.inc -gen-dialect-decls) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRStandardOpsIncGen) diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -31,20 +31,11 @@ class FuncOp; class OpBuilder; -class StandardOpsDialect : public Dialect { -public: - StandardOpsDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "std"; } - - /// Materialize a single constant operation from a given attribute value with - /// the desired resultant type. - Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, - Location loc) override; -}; - #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/IR/Ops.h.inc" +#include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc" + /// This is a refinement of the "constant" op for the case where it is /// returning a float value of FloatType. /// diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -18,14 +18,15 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffects.td" -def Std_Dialect : Dialect { +def StandardOps_Dialect : Dialect { let name = "std"; let cppNamespace = ""; + let hasConstantMaterializer = 1; } // Base class for Standard dialect ops. class Std_Op traits = []> : - Op { + Op { // For every standard op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) // * LogicalResult verify(${C++ class of Op} op) @@ -63,7 +64,7 @@ // Base class for unary ops. Requires single operand and result. Individual // classes will have `operand` accessor. class UnaryOp traits = []> : - Op { + Op { let results = (outs AnyType); let printer = [{ return printStandardUnaryOp(this->getOperation(), p); @@ -86,7 +87,7 @@ // results to be of the same type, but does not constrain them to specific // types. Individual classes will have `lhs` and `rhs` accessor to operands. class ArithmeticOp traits = []> : - Op { let results = (outs AnyType); diff --git a/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt b/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect(VectorOps VectorOps) +add_mlir_dialect(VectorOps vector VectorOps) set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td) mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters) diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -24,18 +24,6 @@ class OwningRewritePatternList; namespace vector { -/// Dialect for Ops on higher-dimensional vector types. -class VectorOpsDialect : public Dialect { -public: - VectorOpsDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "vector"; } - - /// Materialize a single constant operation from a given attribute value with - /// the desired resultant type. - Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, - Location loc) override; -}; - /// Collect a set of vector-to-vector canonicalization patterns. void populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context); @@ -75,6 +63,8 @@ #define GET_OP_CLASSES #include "mlir/Dialect/VectorOps/VectorOps.h.inc" +#include "mlir/Dialect/VectorOps/VectorOpsDialect.h.inc" + } // end namespace vector } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -16,14 +16,15 @@ include "mlir/Dialect/AffineOps/AffineOpsBase.td" include "mlir/Interfaces/SideEffects.td" -def Vector_Dialect : Dialect { +def VectorOps_Dialect : Dialect { let name = "vector"; let cppNamespace = "vector"; + let hasConstantMaterializer = 1; } // Base class for Vector dialect ops. class Vector_Op traits = []> : - Op { + Op { // For every vector op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) // * LogicalResult verify(${C++ class of Op} op) @@ -432,7 +433,7 @@ } def Vector_FMAOp : - Op]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>, Results<(outs AnyVector:$result)> { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -253,6 +253,13 @@ // the generated files are included into the dialect, you may want to specify // a full namespace path or a partial one. string cppNamespace = name; + + // An optional code block containing extra declarations to place in the + // dialect declaration. + code extraClassDeclaration = ""; + + // If this dialect overrides the hook for materializing constants. + bit hasConstantMaterializer = 0; } //===----------------------------------------------------------------------===// @@ -753,6 +760,12 @@ Attr baseAttr = ?; } +// An attribute of a specific dialect. +class DialectAttr : + Attr { + Dialect dialect = d; +} + //===----------------------------------------------------------------------===// // Attribute modifier definition diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -25,6 +25,7 @@ namespace mlir { namespace tblgen { +class Dialect; class Type; // Wrapper class with helper methods for accessing attribute constraints defined @@ -105,6 +106,9 @@ // Returns the code body for derived attribute. Aborts if this is not a // derived attribute. StringRef getDerivedCodeBody() const; + + // Returns the dialect for the attribute if defined. + Dialect getDialect() const; }; // Wrapper class providing helper methods for accessing MLIR constant attribute diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -32,6 +32,9 @@ // Returns the C++ namespaces that ops of this dialect should be placed into. StringRef getCppNamespace() const; + // Returns this dialect's C++ class name. + std::string getCppClassName() const; + // Returns the summary description of the dialect. Returns empty string if // none. StringRef getSummary() const; @@ -39,6 +42,12 @@ // Returns the description of the dialect. Returns empty string if none. StringRef getDescription() const; + // Returns the dialects extra class declaration code. + llvm::Optional getExtraClassDeclaration() const; + + // Returns if this dialect has a constant materializer or not. + bool hasConstantMaterializer() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -28,15 +28,13 @@ // GPUDialect //===----------------------------------------------------------------------===// -StringRef GPUDialect::getDialectName() { return "gpu"; } - bool GPUDialect::isKernel(Operation *op) { UnitAttr isKernelAttr = op->getAttrOfType(getKernelFuncAttrName()); return static_cast(isKernelAttr); } GPUDialect::GPUDialect(MLIRContext *context) - : Dialect(getDialectName(), context) { + : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST #include "mlir/Dialect/GPU/GPUOps.cpp.inc" diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -132,6 +132,10 @@ return def->getValueAsString("body"); } +tblgen::Dialect tblgen::Attribute::getDialect() const { + return Dialect(def->getValueAsDef("dialect")); +} + tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { assert(def->isSubClassOf("ConstantAttr") && "must be subclass of TableGen 'ConstantAttr' class"); diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -24,6 +24,13 @@ return def->getValueAsString("cppNamespace"); } +std::string tblgen::Dialect::getCppClassName() const { + // Simply use the name and remove any '_' tokens. + std::string cppName = def->getName().str(); + llvm::erase_if(cppName, [](char c) { return c == '_'; }); + return cppName; +} + static StringRef getAsStringOrEmpty(const llvm::Record &record, StringRef fieldName) { if (auto valueInit = record.getValueInit(fieldName)) { @@ -42,6 +49,15 @@ return getAsStringOrEmpty(*def, "description"); } +llvm::Optional tblgen::Dialect::getExtraClassDeclaration() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} + +bool tblgen::Dialect::hasConstantMaterializer() const { + return def->getValueAsBit("hasConstantMaterializer"); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -4,6 +4,7 @@ ) add_tablegen(mlir-tblgen MLIR + DialectGen.cpp EnumsGen.cpp LLVMIRConversionGen.cpp LLVMIRIntrinsicGen.cpp diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -0,0 +1,166 @@ +//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// DialectGen uses the description of dialects to generate C++ definitions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/STLExtras.h" +#include "mlir/Support/StringExtras.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/OpClass.h" +#include "mlir/TableGen/OpInterfaces.h" +#include "mlir/TableGen/OpTrait.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#define DEBUG_TYPE "mlir-tblgen-opdefgen" + +using namespace mlir; +using namespace mlir::tblgen; + +static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*"); +static llvm::cl::opt + selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"), + llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); + +/// Given a set of records for a T, filter the ones that correspond to +/// the given dialect. +template +static auto filterForDialect(ArrayRef records, + Dialect &dialect) { + return llvm::make_filter_range(records, [&](const llvm::Record *record) { + return T(record).getDialect() == dialect; + }); +} + +//===----------------------------------------------------------------------===// +// GEN: Dialect declarations +//===----------------------------------------------------------------------===// + +/// The code block for the start of a dialect class declaration. +/// +/// {0}: The name of the dialect class. +/// {1}: The dialect namespace. +static const char *const dialectDeclBeginStr = R"( +class {0} : public ::mlir::Dialect { +public: + explicit {0}(::mlir::MLIRContext *context); + static ::llvm::StringRef getDialectNamespace() { return "{1}"; } +)"; + +/// The code block for the attribute parser/printer hooks. +static const char *const attrParserDecl = R"( + /// Parse an attribute registered to this dialect. + ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, + ::mlir::Type type) const override; + + /// Print an attribute registered to this dialect. + void printAttribute(::mlir::Attribute attr, + ::mlir::DialectAsmPrinter &os) const override; +)"; + +/// The code block for the type parser/printer hooks. +static const char *const typeParserDecl = R"( + /// Parse a type registered to this dialect. + ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; + + /// Print a type registered to this dialect. + void printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &os) const override; +)"; + +/// The code block for the constant materializer hook. +static const char *const constantMaterializerDecl = R"( + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder, + ::mlir::Attribute value, + ::mlir::Type type, + ::mlir::Location loc) override; +)"; + +/// Generate the declaration for the given dialect class. +static void emitDialectDecl( + Dialect &dialect, + FunctionTraits)>::result_t + dialectAttrs, + FunctionTraits)>::result_t dialectTypes, + raw_ostream &os) { + // Emit the start of the decl. + std::string cppName = dialect.getCppClassName(); + os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName()); + + // Check for any attributes/types registered to this dialect. If there are, + // add the hooks for parsing/printing. + if (!dialectAttrs.empty()) + os << attrParserDecl; + if (!dialectTypes.empty()) + os << typeParserDecl; + + // Add the decls for the various features of the dialect. + if (dialect.hasConstantMaterializer()) + os << constantMaterializerDecl; + if (llvm::Optional extraDecl = dialect.getExtraClassDeclaration()) + os << *extraDecl; + + // End the dialect decl. + os << "};\n"; +} + +static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper, + raw_ostream &os) { + emitSourceFileHeader("Dialect Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("Dialect"); + if (defs.empty()) + return false; + + // Select the dialect to gen for. + const llvm::Record *dialectDef = nullptr; + if (defs.size() == 1 && selectedDialect.getNumOccurrences() == 0) { + dialectDef = defs.front(); + } else if (selectedDialect.getNumOccurrences() == 0) { + llvm::errs() << "when more than 1 dialect is present, one must be selected " + "via '-dialect'"; + return true; + } else { + auto dialectIt = llvm::find_if(defs, [](const llvm::Record *def) { + return Dialect(def).getName() == selectedDialect; + }); + if (dialectIt == defs.end()) { + llvm::errs() << "selected dialect with '-dialect' does not exist"; + return true; + } + dialectDef = *dialectIt; + } + + auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr"); + auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType"); + Dialect dialect(dialectDef); + emitDialectDecl(dialect, filterForDialect(attrDefs, dialect), + filterForDialect(typeDefs, dialect), os); + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: Dialect registration hooks +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genDialectDecls("gen-dialect-decls", "Generate dialect declarations", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return emitDialectDecls(records, os); + });