diff --git a/mlir/docs/Dialects/Builtin.md b/mlir/docs/Dialects/Builtin.md --- a/mlir/docs/Dialects/Builtin.md +++ b/mlir/docs/Dialects/Builtin.md @@ -30,3 +30,7 @@ ## Types [include "Dialects/BuiltinTypes.md"] + +## Type Interfaces + +[include "Dialects/BuiltinTypeAttributes.md"] diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -192,6 +192,12 @@ #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.h.inc" +//===----------------------------------------------------------------------===// +// Tablegen Interface Declarations +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypeInterfaces.h.inc" + namespace mlir { //===----------------------------------------------------------------------===// // MemRefType @@ -266,7 +272,8 @@ } inline bool BaseMemRefType::isValidElementType(Type type) { - return type.isIntOrIndexOrFloat() || type.isa(); + return type.isIntOrIndexOrFloat() || type.isa() || + type.isa(); } inline bool FloatType::classof(Type type) { diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -248,6 +248,31 @@ }]; } +//===----------------------------------------------------------------------===// +// MemRefElementTypeInterface +//===----------------------------------------------------------------------===// + +def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + Indication that this type can be used as element in memref types. + + Implementing this interface establishes a contract between this type and the + memref type indicating that this type can be used as element of ranked or + unranked memrefs. The type is expected to: + + - model an entity stored in memory; + - have non-zero size. + + For example, scalar values such as integers can implement this interface, + but indicator types such as `void` or `unit` should not. + + The interface currently has no methods and is used by types to opt into + being memref elements. This may change in the future, in particular to + require types to provide their size or alignment given a data layout. + }]; +} + //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// @@ -282,6 +307,14 @@ on the rank. Other uses of this type are disallowed or will have undefined behavior. + Are accepted as elements: + + - built-in integer types; + - built-in index type; + - built-in floating point types; + - built-in vector types with elements of the above types; + - any other type implementing `MemRefElementTypeInterface`. + ##### Codegen of Unranked Memref Using unranked memref in codegen besides the case mentioned above is highly diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -24,6 +24,8 @@ set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td) mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls) mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs) +mlir_tablegen(BuiltinTypeAttributes.h.inc -gen-type-attribute-decls) +mlir_tablegen(BuiltinTypeAttributes.cpp.inc -gen-type-attribute-defs) add_public_tablegen_target(MLIRBuiltinTypesIncGen) set(LLVM_TARGET_DEFINITIONS TensorEncoding.td) @@ -35,3 +37,4 @@ add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc) add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc) add_mlir_doc(BuiltinTypes BuiltinTypes Dialects/ -gen-typedef-doc) +add_mlir_doc(BuiltinTypes BuiltinTypeAttributes Dialects/ -gen-type-attribute-docs) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -349,6 +349,8 @@ // unpack the `sizes` and `strides` arrays. SmallVector types = getMemRefDescriptorFields(type, /*unpackAggregates=*/false); + if (types.empty()) + return {}; return LLVM::LLVMStructType::getLiteral(&getContext(), types); } @@ -368,6 +370,8 @@ } Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { + if (!convertType(type.getElementType())) + return {}; return LLVM::LLVMStructType::getLiteral(&getContext(), getUnrankedMemRefDescriptorFields()); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -31,6 +31,12 @@ #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.cpp.inc" +//===----------------------------------------------------------------------===// +/// Tablegen Interface Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" + //===----------------------------------------------------------------------===// // BuiltinDialect //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -427,3 +427,23 @@ return } } + +// ----- + +// Should not convert memrefs with unsupported types in any convention. + +// CHECK: @unsupported_memref_element_type +// CHECK-SAME: memref< +// CHECK-NOT: !llvm.struct +// BAREPTR: @unsupported_memref_element_type +// BAREPTR-SAME: memref< +// BAREPTR-NOT: !llvm.ptr +func private @unsupported_memref_element_type() -> memref<42 x !test.memref_element> + +// CHECK: @unsupported_unranked_memref_element_type +// CHECK-SAME: memref< +// CHECK-NOT: !llvm.struct +// BAREPTR: @unsupported_unranked_memref_element_type +// BAREPTR-SAME: memref< +// BAREPTR-NOT: !llvm.ptr +func private @unsupported_unranked_memref_element_type() -> memref<* x !test.memref_element> diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir --- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -6,3 +6,4 @@ // ----- func private @partially_supported_signature() -> (vector<10 x i32>, tensor<10 x i32>) + diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -178,6 +178,9 @@ // CHECK: func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>) func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>) +// CHECK: func private @memref_with_custom_elem(memref<1x?x!test.memref_element>) +func private @memref_with_custom_elem(memref<1x?x!test.memref_element>) + // CHECK: func private @unranked_memref_with_complex_elems(memref<*xcomplex>) func private @unranked_memref_with_complex_elems(memref<*xcomplex>) diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -15,6 +15,7 @@ // To get the test dialect def. include "TestOps.td" +include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" // All of the types will extend this class. @@ -176,4 +177,9 @@ }]; } +def TestMemRefElementType : Test_Type<"TestMemRefElementType", + [MemRefElementTypeInterface]> { + let mnemonic = "memref_element"; +} + #endif // TEST_TYPEDEFS