diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -13,6 +13,11 @@ -attrdefs-dialect=llvm) add_public_tablegen_target(MLIRLLVMOpsIncGen) +set(LLVM_TARGET_DEFINITIONS LLVMTypes.td) +mlir_tablegen(LLVMTypes.h.inc -gen-typedef-decls -typedefs-dialect=llvm) +mlir_tablegen(LLVMTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=llvm) +add_public_tablegen_target(MLIRLLVMTypesIncGen) + set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td) mlir_tablegen(LLVMIntrinsicOps.h.inc -gen-op-decls) mlir_tablegen(LLVMIntrinsicOps.cpp.inc -gen-op-defs) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -27,7 +27,6 @@ let name = "llvm"; let cppNamespace = "::mlir::LLVM"; - let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; @@ -77,7 +76,13 @@ return "llvm.readnone"; } + Type parseType(DialectAsmParser &p) const override; + void printType(Type, DialectAsmPrinter &p) const override; + private: + /// Register all types. + void registerTypes(); + /// A cache storing compatible LLVM types that have been verified. This /// can save us lots of verification time if there are many occurrences /// of some deeply-nested aggregate types in the program. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -40,8 +40,15 @@ } // namespace LLVM } // namespace mlir +//===----------------------------------------------------------------------===// +// ODS-Generated Declarations +//===----------------------------------------------------------------------===// + #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.h.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMTypes.h.inc" + namespace mlir { namespace LLVM { @@ -65,61 +72,6 @@ #undef DEFINE_TRIVIAL_LLVM_TYPE -//===----------------------------------------------------------------------===// -// LLVMArrayType. -//===----------------------------------------------------------------------===// - -/// LLVM dialect array type. It is an aggregate type representing consecutive -/// elements in memory, parameterized by the number of elements and the element -/// type. -class LLVMArrayType - : public Type::TypeBase { -public: - /// Inherit base constructors. - using Base::Base; - using Base::getChecked; - - /// Checks if the given type can be used inside an array type. - static bool isValidElementType(Type type); - - /// Gets or creates an instance of LLVM dialect array type containing - /// `numElements` of `elementType`, in the same context as `elementType`. - static LLVMArrayType get(Type elementType, unsigned numElements); - static LLVMArrayType getChecked(function_ref emitError, - Type elementType, unsigned numElements); - - /// Returns the element type of the array. - Type getElementType() const; - - /// Returns the number of elements in the array type. - unsigned getNumElements() const; - - /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verify(function_ref emitError, - Type elementType, unsigned numElements); - - /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a - /// DataLayout instance and query it instead. - unsigned getTypeSizeInBits(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const; - - unsigned getTypeSize(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const; - - unsigned getABIAlignment(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const; - - unsigned getPreferredAlignment(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const; - - void walkImmediateSubElements(function_ref walkAttrsFn, - function_ref walkTypesFn) const; - Type replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const; -}; - //===----------------------------------------------------------------------===// // LLVMFunctionType. //===----------------------------------------------------------------------===// @@ -484,6 +436,11 @@ void printType(Type type, AsmPrinter &printer); } // namespace detail +/// Parse any MLIR type or a concise syntax for LLVM types. +ParseResult parsePrettyLLVMType(AsmParser &p, FailureOr &type); +/// Print any MLIR type or a concise syntax for LLVM types. +void printPrettyLLVMType(AsmPrinter &p, Type type); + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// @@ -548,6 +505,7 @@ /// Currently only `PtrDLEntryPos::Index` is optional, and all other positions /// may be assumed to be present. Optional extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos); + } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -0,0 +1,61 @@ +//===-- LLVMOps.td - LLVM IR dialect op definition file ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMTYPES_TD +#define LLVMTYPES_TD + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/SubElementInterfaces.td" +include "mlir/Interfaces/DataLayoutInterfaces.td" + +/// Base class for all LLVM dialect types. +class LLVMType traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +//===----------------------------------------------------------------------===// +// LLVMArrayType +//===----------------------------------------------------------------------===// + +def LLVMArrayType : LLVMType<"LLVMArray", "array", [ + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods]> { + let summary = "LLVM array type"; + let description = [{ + The `!llvm.array` type represents a fixed-size array of element types. + It is an aggregate type representing consecutive elements in memory, + parameterized by the number of elements and the element type. + + Example: + + ```mlir + !llvm.array<4 x i32> + ``` + }]; + + let parameters = (ins "Type":$elementType, "unsigned":$numElements); + let assemblyFormat = [{ + `<` $numElements `x` ` ` custom($elementType) `>` + }]; + + let genVerifyDecl = 1; + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType, + "unsigned":$numElements)> + ]; + + let extraClassDeclaration = [{ + /// Checks if the given type can be used inside an array type. + static bool isValidElementType(Type type); + }]; +} + +#endif // LLVMTYPES_TD diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -12,6 +12,7 @@ DEPENDS MLIRLLVMOpsIncGen + MLIRLLVMTypesIncGen MLIRLLVMIntrinsicOpsIncGen MLIRLLVMOpsInterfacesIncGen MLIROpenMPOpsIncGen diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1430,10 +1430,10 @@ // following example: // ``` // %1 = llvm.insertvalue %f0, %0[0, 0] : - // !llvm.array<4 x !llvm.array<4xf32>> + // !llvm.array<4 x !llvm.array<4 x f32>> // %2 = llvm.insertvalue %arr, %1[0] : - // !llvm.array<4 x !llvm.array<4xf32>> - // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> + // !llvm.array<4 x !llvm.array<4 x f32>> + // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>> // ``` if (getPosition().take_front(min) == insertValueOp.getPosition().take_front(min)) @@ -2577,9 +2577,10 @@ LLVMPointerType, LLVMFixedVectorType, LLVMScalableVectorType, - LLVMArrayType, LLVMStructType>(); // clang-format on + registerTypes(); + addOperations< #define GET_OP_LIST #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" @@ -2595,16 +2596,6 @@ #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" -/// Parse a type registered to this dialect. -Type LLVMDialect::parseType(DialectAsmParser &parser) const { - return detail::parseType(parser); -} - -/// Print a type registered to this dialect. -void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { - return detail::printType(type, os); -} - LogicalResult LLVMDialect::verifyDataLayoutString( StringRef descr, llvm::function_ref reportError) { llvm::Expected maybeDataLayout = diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -101,7 +101,7 @@ /// Prints a type containing a fixed number of elements. template -static void printArrayOrVectorType(AsmPrinter &printer, TypeTy type) { +static void printVectorType(AsmPrinter &printer, TypeTy type) { printer << '<' << type.getNumElements() << " x "; dispatchPrint(printer, type.getElementType()); printer << '>'; @@ -156,9 +156,9 @@ } if (auto arrayType = type.dyn_cast()) - return printArrayOrVectorType(printer, arrayType); + return arrayType.print(printer); if (auto vectorType = type.dyn_cast()) - return printArrayOrVectorType(printer, vectorType); + return printVectorType(printer, vectorType); if (auto vectorType = type.dyn_cast()) { printer << "(loc, elementType, dims[0]); } -/// Parses an LLVM dialect array type. -/// llvm-type ::= `array<` integer `x` llvm-type `>` -static LLVMArrayType parseArrayType(AsmParser &parser) { - SmallVector dims; - SMLoc sizePos; - Type elementType; - SMLoc loc = parser.getCurrentLocation(); - if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || - parser.parseDimensionList(dims, /*allowDynamic=*/false) || - dispatchParse(parser, elementType) || parser.parseGreater()) - return LLVMArrayType(); - - if (dims.size() != 1) { - parser.emitError(sizePos) << "expected ? x "; - return LLVMArrayType(); - } - - return parser.getChecked(loc, elementType, dims[0]); -} - /// Attempts to set the body of an identified structure type. Reports a parsing /// error at `subtypesLoc` in case of failure. static LLVMStructType trySetStructBody(LLVMStructType type, @@ -468,7 +448,7 @@ .Case("func", [&] { return parseFunctionType(parser); }) .Case("ptr", [&] { return parsePointerType(parser); }) .Case("vec", [&] { return parseVectorType(parser); }) - .Case("array", [&] { return parseArrayType(parser); }) + .Case("array", [&] { return LLVMArrayType::parse(parser); }) .Case("struct", [&] { return parseStructType(parser); }) .Default([&] { parser.emitError(keyLoc) << "unknown LLVM type: " << key; @@ -494,3 +474,12 @@ } return type; } + +ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, FailureOr &type) { + type.emplace(); + return dispatchParse(p, *type); +} + +void LLVM::printPrettyLLVMType(AsmPrinter &p, Type type) { + return dispatchPrint(p, type); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -28,7 +28,7 @@ constexpr const static unsigned kBitsInByte = 8; //===----------------------------------------------------------------------===// -// Array type. +// LLVMArrayType //===----------------------------------------------------------------------===// bool LLVMArrayType::isValidElementType(Type type) { @@ -49,12 +49,6 @@ numElements); } -Type LLVMArrayType::getElementType() const { return getImpl()->elementType; } - -unsigned LLVMArrayType::getNumElements() const { - return getImpl()->numElements; -} - LogicalResult LLVMArrayType::verify(function_ref emitError, Type elementType, unsigned numElements) { @@ -63,6 +57,9 @@ return success(); } +//===----------------------------------------------------------------------===// +// DataLayoutTypeInterface + unsigned LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { return kBitsInByte * getTypeSize(dataLayout, params); @@ -86,6 +83,9 @@ return dataLayout.getTypePreferredAlignment(getElementType()); } +//===----------------------------------------------------------------------===// +// SubElementTypeInterface + void LLVMArrayType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { @@ -1005,4 +1005,37 @@ }); } +//===----------------------------------------------------------------------===// +// ODS-Generated Definitions +//===----------------------------------------------------------------------===// + +/// These are unused for now. +/// TODO: Move over to these once more types have been migrated to TypeDef. +LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); +LLVM_ATTRIBUTE_UNUSED static LogicalResult +generatedTypePrinter(Type def, AsmPrinter &printer); + #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc" + +//===----------------------------------------------------------------------===// +// LLVMDialect +//===----------------------------------------------------------------------===// + +void LLVMDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc" + >(); +} + +Type LLVMDialect::parseType(DialectAsmParser &parser) const { + return detail::parseType(parser); +} + +void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { + return detail::printType(type, os); +} diff --git a/mlir/test/CAPI/llvm.c b/mlir/test/CAPI/llvm.c --- a/mlir/test/CAPI/llvm.c +++ b/mlir/test/CAPI/llvm.c @@ -45,11 +45,11 @@ // CHECK: !llvm.void: 1 fprintf(stderr, "%s: %d\n", voidt_text, mlirTypeEqual(voidt, voidt_ref)); - const char *i32_4_text = "!llvm.array<4xi32>"; + const char *i32_4_text = "!llvm.array<4 x i32>"; MlirType i32_4 = mlirLLVMArrayTypeGet(i32, 4); MlirType i32_4_ref = mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(i32_4_text)); - // CHECK: !llvm.array<4xi32>: 1 + // CHECK: !llvm.array<4 x i32>: 1 fprintf(stderr, "%s: %d\n", i32_4_text, mlirTypeEqual(i32_4, i32_4_ref)); const char *i8_i32_i64_text = "!llvm.func"; @@ -78,4 +78,3 @@ mlirContextDestroy(ctx); return 0; } - diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -25,16 +25,16 @@ // ----- // CHECK-LABEL: no_fold_extractvalue -llvm.func @no_fold_extractvalue(%arr: !llvm.array<4xf32>) -> f32 { +llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 { %f0 = arith.constant 0.0 : f32 - %0 = llvm.mlir.undef : !llvm.array<4 x !llvm.array<4xf32>> + %0 = llvm.mlir.undef : !llvm.array<4 x !llvm.array<4 x f32>> // CHECK: insertvalue // CHECK: insertvalue // CHECK: extractvalue - %1 = llvm.insertvalue %f0, %0[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> - %2 = llvm.insertvalue %arr, %1[0] : !llvm.array<4 x !llvm.array<4xf32>> - %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> + %1 = llvm.insertvalue %f0, %0[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>> + %2 = llvm.insertvalue %arr, %1[0] : !llvm.array<4 x !llvm.array<4 x f32>> + %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>> llvm.return %3 : f32 @@ -42,12 +42,12 @@ // ----- // CHECK-LABEL: fold_unrelated_extractvalue -llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4xf32>) -> f32 { +llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 { %f0 = arith.constant 0.0 : f32 // CHECK-NOT: insertvalue // CHECK: extractvalue - %2 = llvm.insertvalue %f0, %arr[0] : !llvm.array<4xf32> - %3 = llvm.extractvalue %2[1] : !llvm.array<4xf32> + %2 = llvm.insertvalue %f0, %arr[0] : !llvm.array<4 x f32> + %3 = llvm.extractvalue %2[1] : !llvm.array<4 x f32> llvm.return %3 : f32 } @@ -144,7 +144,7 @@ // CHECK-NEXT: llvm.return llvm.func @load_dce(%x : !llvm.ptr) { %0 = llvm.load %x : !llvm.ptr - llvm.return + llvm.return } llvm.mlir.global external @fp() : !llvm.ptr @@ -153,7 +153,7 @@ // CHECK-NEXT: llvm.return llvm.func @addr_dce(%x : !llvm.ptr) { %0 = llvm.mlir.addressof @fp : !llvm.ptr> - llvm.return + llvm.return } // CHECK-LABEL: alloca_dce @@ -161,5 +161,5 @@ llvm.func @alloca_dce() { %c1_i64 = arith.constant 1 : i64 %0 = llvm.alloca %c1_i64 x i32 : (i64) -> !llvm.ptr - llvm.return + llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -181,7 +181,7 @@ // CHECK-LABEL: @gep llvm.func @gep(%ptr: !llvm.ptr)>>, %idx: i64, - %ptr2: !llvm.ptr)>>) { + %ptr2: !llvm.ptr)>>) { // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -100,7 +100,7 @@ // CHECK: @common = common global i32 0 llvm.mlir.global common @common(0 : i32) : i32 // CHECK: @appending = appending global [3 x i32] [i32 1, i32 2, i32 3] -llvm.mlir.global appending @appending(dense<[1,2,3]> : tensor<3xi32>) : !llvm.array<3xi32> +llvm.mlir.global appending @appending(dense<[1,2,3]> : tensor<3xi32>) : !llvm.array<3 x i32> // CHECK: @extern_weak = extern_weak global i32 llvm.mlir.global extern_weak @extern_weak() : i32 // CHECK: @linkonce_odr = linkonce_odr global i32 42 @@ -993,11 +993,11 @@ // CHECK-LABEL: @gep llvm.func @gep(%ptr: !llvm.ptr)>>, %idx: i64, - %ptr2: !llvm.ptr)>>) { + %ptr2: !llvm.ptr)>>) { // CHECK: = getelementptr { i32, { i32, float } }, ptr %{{.*}}, i64 %{{.*}}, i32 1, i32 0 llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr // CHECK: = getelementptr { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}} - llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr + llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr llvm.return }