diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -205,93 +205,6 @@ ArrayRef replTypes) const; }; -//===----------------------------------------------------------------------===// -// LLVMFixedVectorType. -//===----------------------------------------------------------------------===// - -/// LLVM dialect fixed vector type, represents a sequence of elements of known -/// length that can be processed as one. -class LLVMFixedVectorType - : public Type::TypeBase { -public: - /// Inherit base constructor. - using Base::Base; - using Base::getChecked; - - /// Gets or creates a fixed vector type containing `numElements` of - /// `elementType` in the same context as `elementType`. - static LLVMFixedVectorType get(Type elementType, unsigned numElements); - static LLVMFixedVectorType - getChecked(function_ref emitError, Type elementType, - unsigned numElements); - - /// Checks if the given type can be used in a vector type. This type supports - /// only a subset of LLVM dialect types that don't have a built-in - /// counter-part, e.g., pointers. - static bool isValidElementType(Type type); - - /// Returns the element type of the vector. - Type getElementType() const; - - /// Returns the number of elements in the fixed vector. - unsigned getNumElements() const; - - /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verify(function_ref emitError, - Type elementType, unsigned numElements); - - void walkImmediateSubElements(function_ref walkAttrsFn, - function_ref walkTypesFn) const; - Type replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const; -}; - -//===----------------------------------------------------------------------===// -// LLVMScalableVectorType. -//===----------------------------------------------------------------------===// - -/// LLVM dialect scalable vector type, represents a sequence of elements of -/// unknown length that is known to be divisible by some constant. These -/// elements can be processed as one in SIMD context. -class LLVMScalableVectorType - : public Type::TypeBase { -public: - /// Inherit base constructor. - using Base::Base; - using Base::getChecked; - - /// Gets or creates a scalable vector type containing a non-zero multiple of - /// `minNumElements` of `elementType` in the same context as `elementType`. - static LLVMScalableVectorType get(Type elementType, unsigned minNumElements); - static LLVMScalableVectorType - getChecked(function_ref emitError, Type elementType, - unsigned minNumElements); - - /// Checks if the given type can be used in a vector type. - static bool isValidElementType(Type type); - - /// Returns the element type of the vector. - Type getElementType() const; - - /// Returns the scaling factor of the number of elements in the vector. The - /// vector contains at least the resulting number of elements, or any non-zero - /// multiple of this number. - unsigned getMinNumElements() const; - - /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verify(function_ref emitError, - Type elementType, unsigned minNumElements); - - void walkImmediateSubElements(function_ref walkAttrsFn, - function_ref walkTypesFn) const; - Type replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const; -}; - //===----------------------------------------------------------------------===// // Printing and parsing. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -167,4 +167,66 @@ }]; } +//===----------------------------------------------------------------------===// +// LLVMFixedVectorType +//===----------------------------------------------------------------------===// + +def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec", [ + DeclareTypeInterfaceMethods]> { + let summary = "LLVM fixed vector type"; + let description = [{ + LLVM dialect scalable vector type, represents a sequence of elements of + unknown length that is known to be divisible by some constant. These + elements can be processed as one in SIMD context. + }]; + + let parameters = (ins "Type":$elementType, "unsigned":$numElements); + let assemblyFormat = [{ + `<` $numElements `x` ` ` custom($elementType) `>` + }]; + + let genVerifyDecl = 1; + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType, + "unsigned":$numElements)> + ]; + + let extraClassDeclaration = [{ + /// Checks if the given type can be used in a vector type. + static bool isValidElementType(Type type); + }]; +} + +//===----------------------------------------------------------------------===// +// LLVMScalableVectorType +//===----------------------------------------------------------------------===// + +def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec", [ + DeclareTypeInterfaceMethods]> { + let summary = "LLVM scalable vector type"; + let description = [{ + LLVM dialect scalable vector type, represents a sequence of elements of + unknown length that is known to be divisible by some constant. These + elements can be processed as one in SIMD context. + }]; + + let parameters = (ins "Type":$elementType, "unsigned":$minNumElements); + let assemblyFormat = [{ + `<` `?` `x` $minNumElements `x` ` ` custom($elementType) `>` + }]; + + let genVerifyDecl = 1; + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType, + "unsigned":$minNumElements)> + ]; + + let extraClassDeclaration = [{ + /// Checks if the given type can be used in a vector type. + static bool isValidElementType(Type type); + }]; +} + #endif // LLVMTYPES_TD diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2573,8 +2573,6 @@ LLVMTokenType, LLVMLabelType, LLVMMetadataType, - LLVMFixedVectorType, - LLVMScalableVectorType, LLVMStructType>(); // clang-format on registerTypes(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -99,14 +99,6 @@ printer << '>'; } -/// Prints a type containing a fixed number of elements. -template -static void printVectorType(AsmPrinter &printer, TypeTy type) { - printer << '<' << type.getNumElements() << " x "; - dispatchPrint(printer, type.getElementType()); - printer << '>'; -} - /// Prints the given LLVM dialect type recursively. This leverages closedness of /// the LLVM dialect type system to avoid printing the dialect prefix /// repeatedly. For recursive structures, only prints the name of the structure @@ -124,26 +116,13 @@ printer << getTypeKeyword(type); - if (auto ptrType = type.dyn_cast()) - return ptrType.print(printer); - - if (auto arrayType = type.dyn_cast()) - return arrayType.print(printer); - if (auto vectorType = type.dyn_cast()) - return printVectorType(printer, vectorType); - - if (auto vectorType = type.dyn_cast()) { - printer << "'; - return; - } - - if (auto structType = type.dyn_cast()) - return printStructType(printer, structType); - - if (auto funcType = type.dyn_cast()) - return funcType.print(printer); + llvm::TypeSwitch(type) + .Case( + [&](auto type) { type.print(printer); }) + .Case([&](LLVMStructType structType) { + printStructType(printer, structType); + }); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -735,14 +735,6 @@ numElements); } -Type LLVMFixedVectorType::getElementType() const { - return static_cast(impl)->elementType; -} - -unsigned LLVMFixedVectorType::getNumElements() const { - return getImpl()->numElements; -} - bool LLVMFixedVectorType::isValidElementType(Type type) { return type.isa(); } @@ -783,14 +775,6 @@ minNumElements); } -Type LLVMScalableVectorType::getElementType() const { - return static_cast(impl)->elementType; -} - -unsigned LLVMScalableVectorType::getMinNumElements() const { - return getImpl()->numElements; -} - bool LLVMScalableVectorType::isValidElementType(Type type) { if (auto intType = type.dyn_cast()) return intType.isSignless();