diff --git a/mlir/docs/DataLayout.md b/mlir/docs/DataLayout.md --- a/mlir/docs/DataLayout.md +++ b/mlir/docs/DataLayout.md @@ -72,6 +72,7 @@ explicit DataLayout(DataLayoutOpInterface scope); unsigned getTypeSize(Type type) const; + unsigned getTypeSizeInBits(Type type) const; unsigned getTypeABIAlignment(Type type) const; unsigned getTypePreferredAlignment(Type type) const; }; @@ -178,6 +179,15 @@ provides hooks for verifying the validity of the entry value attributes and for and the compatibility of nested entries. +### Bits and Bytes + +Two versions of hooks are provided for sizes: in bits and in bytes. The version +in bytes has a default implementation that derives the size in bytes by rounding +up the result of division of the size in bits by 8. Types exclusively targeting +architectures with different assumptions can override this. Operations can +redefine this for all types, providing scoped versions for cases of byte sizes +other than eight without having to modify types, including built-in types. + ### Query Dispatch The overall flow of a data layout property query is as follows. @@ -243,6 +253,10 @@ [modeling of n-D vectors](https://mlir.llvm.org/docs/Dialects/Vector/#deeperdive). They **may change** in the future. +### Byte Size + +The default data layout assumes 8-bit bytes. + ### DLTI Dialect The [DLTI](Dialects/DLTI.md) dialect provides the attributes implementing diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h --- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h +++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h @@ -35,7 +35,13 @@ /// Default handler for the type size request. Computes results for built-in /// types and dispatches to the DataLayoutTypeInterface for other types. unsigned getDefaultTypeSize(Type type, const DataLayout &dataLayout, - ArrayRef params); + DataLayoutEntryListRef params); + +/// Default handler for the type size in bits request. Computes results for +/// built-in types and dispatches to the DataLayoutTypeInterface for other +/// types. +unsigned getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout, + DataLayoutEntryListRef params); /// Default handler for the required alignemnt request. Computes results for /// built-in types and dispatches to the DataLayoutTypeInterface for other @@ -140,6 +146,9 @@ /// Returns the size of the given type in the current scope. unsigned getTypeSize(Type t) const; + /// Returns the size in bits of the given type in the current scope. + unsigned getTypeSizeInBits(Type t) const; + /// Returns the required alignment of the given type in the current scope. unsigned getTypeABIAlignment(Type t) const; @@ -166,6 +175,7 @@ /// Caches for individual requests. mutable DenseMap sizes; + mutable DenseMap bitsizes; mutable DenseMap abiAlignments; mutable DenseMap preferredAlignments; }; diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td --- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td +++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td @@ -208,7 +208,23 @@ "::mlir::DataLayoutEntryListRef":$params), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::detail::getDefaultTypeSize(type, dataLayout, params); + unsigned bits = ConcreteOp::getTypeSizeInBits(type, dataLayout, params); + return ::llvm::divideCeil(bits, 8); + }] + >, + StaticInterfaceMethod< + /*description=*/"Returns the size of the given type in bits computed " + "using the relevant entries. The data layout object can " + "be used for recursive queries.", + /*retTy=*/"unsigned", + /*methodName=*/"getTypeSizeInBits", + /*args=*/(ins "::mlir::Type":$type, + "const ::mlir::DataLayout &":$dataLayout, + "::mlir::DataLayoutEntryListRef":$params), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::detail::getDefaultTypeSizeInBits(type, dataLayout, + params); }] >, StaticInterfaceMethod< @@ -281,6 +297,18 @@ /*description=*/"Returns the size of this type in bytes.", /*retTy=*/"unsigned", /*methodName=*/"getTypeSize", + /*args=*/(ins "const ::mlir::DataLayout &":$dataLayout, + "::mlir::DataLayoutEntryListRef":$params), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + unsigned bits = $_type.getTypeSizeInBits(dataLayout, params); + return ::llvm::divideCeil(bits, 8); + }] + >, + InterfaceMethod< + /*description=*/"Returns the size of this type in bits.", + /*retTy=*/"unsigned", + /*methodName=*/"getTypeSizeInBits", /*args=*/(ins "const ::mlir::DataLayout &":$dataLayout, "::mlir::DataLayoutEntryListRef":$params) >, diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -34,18 +34,28 @@ unsigned mlir::detail::getDefaultTypeSize(Type type, const DataLayout &dataLayout, ArrayRef params) { + unsigned bits = getDefaultTypeSizeInBits(type, dataLayout, params); + return llvm::divideCeil(bits, 8); +} + +unsigned mlir::detail::getDefaultTypeSizeInBits(Type type, + const DataLayout &dataLayout, + DataLayoutEntryListRef params) { if (type.isa()) - return llvm::divideCeil(type.getIntOrFloatBitWidth(), 8); + return type.getIntOrFloatBitWidth(); // Sizes of vector types are rounded up to those of types with closest - // power-of-two number of elements. + // power-of-two number of elements in the innermost dimension. We also assume + // there is no bit-packing at the moment element sizes are taken in bytes and + // multiplied with 8 bits. // TODO: make this extensible. if (auto vecType = type.dyn_cast()) - return llvm::PowerOf2Ceil(vecType.getNumElements()) * - dataLayout.getTypeSize(vecType.getElementType()); + return vecType.getNumElements() / vecType.getShape().back() * + llvm::PowerOf2Ceil(vecType.getShape().back()) * + dataLayout.getTypeSize(vecType.getElementType()) * 8; if (auto typeInterface = type.dyn_cast()) - return typeInterface.getTypeSize(dataLayout, params); + return typeInterface.getTypeSizeInBits(dataLayout, params); reportMissingDataLayout(type); } @@ -280,6 +290,19 @@ }); } +unsigned mlir::DataLayout::getTypeSizeInBits(Type t) const { + checkValid(); + return cachedLookup(t, bitsizes, [&](Type ty) { + if (originalLayout) { + DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast(scope)) + return iface.getTypeSizeInBits(ty, *this, list); + return detail::getDefaultTypeSizeInBits(ty, *this, list); + } + return detail::getDefaultTypeSizeInBits(ty, *this, {}); + }); +} + unsigned mlir::DataLayout::getTypeABIAlignment(Type t) const { checkValid(); return cachedLookup(t, abiAlignments, [&](Type ty) { diff --git a/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir b/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir --- a/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir +++ b/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir @@ -6,8 +6,9 @@ // CHECK-LABEL: @module_level_layout func @module_level_layout() { // CHECK: alignment = 32 + // CHECK: bitsize = 12 // CHECK: preferred = 1 - // CHECK: size = 12 + // CHECK: size = 2 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> return } diff --git a/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir b/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir --- a/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir +++ b/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir @@ -3,10 +3,12 @@ // CHECK-LABEL: @no_layout_builtin func @no_layout_builtin() { // CHECK: alignment = 4 + // CHECK: bitsize = 32 // CHECK: preferred = 4 // CHECK: size = 4 "test.data_layout_query"() : () -> i32 // CHECK: alignment = 8 + // CHECK: bitsize = 64 // CHECK: preferred = 8 // CHECK: size = 8 "test.data_layout_query"() : () -> f64 @@ -16,6 +18,7 @@ // CHECK-LABEL: @no_layout_custom func @no_layout_custom() { // CHECK: alignment = 1 + // CHECK: bitsize = 1 // CHECK: preferred = 1 // CHECK: size = 1 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> @@ -26,6 +29,7 @@ func @layout_op_no_layout() { "test.op_with_data_layout"() ({ // CHECK: alignment = 1 + // CHECK: bitsize = 1 // CHECK: preferred = 1 // CHECK: size = 1 "test.data_layout_query"() : () -> !test.test_type_with_layout<1000> @@ -38,8 +42,9 @@ func @layout_op() { "test.op_with_data_layout"() ({ // CHECK: alignment = 20 + // CHECK: bitsize = 10 // CHECK: preferred = 1 - // CHECK: size = 10 + // CHECK: size = 2 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) { dlti.dl_spec = #dlti.dl_spec< @@ -55,8 +60,9 @@ "test.op_with_data_layout"() ({ "test.op_with_data_layout"() ({ // CHECK: alignment = 20 + // CHECK: bitsize = 10 // CHECK: preferred = 1 - // CHECK: size = 10 + // CHECK: size = 2 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) { dlti.dl_spec = #dlti.dl_spec< @@ -74,8 +80,9 @@ "test.op_with_data_layout"() ({ "test.op_with_data_layout"() ({ // CHECK: alignment = 20 + // CHECK: bitsize = 10 // CHECK: preferred = 1 - // CHECK: size = 10 + // CHECK: size = 2 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) : () -> () @@ -93,8 +100,9 @@ "test.op_with_data_layout"() ({ "test.op_with_data_layout"() ({ // CHECK: alignment = 20 + // CHECK: bitsize = 10 // CHECK: preferred = 1 - // CHECK: size = 10 + // CHECK: size = 2 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) : () -> () @@ -114,8 +122,9 @@ "test.op_with_data_layout"() ({ "test.op_with_data_layout"() ({ // CHECK: alignment = 20 + // CHECK: bitsize = 10 // CHECK: preferred = 30 - // CHECK: size = 10 + // CHECK: size = 2 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) : () -> () @@ -125,8 +134,9 @@ #dlti.dl_entry, ["alignment", 20]> >} : () -> () // CHECK: alignment = 1 + // CHECK: bitsize = 42 // CHECK: preferred = 30 - // CHECK: size = 42 + // CHECK: size = 6 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) { dlti.dl_spec = #dlti.dl_spec< @@ -142,8 +152,9 @@ "test.op_with_data_layout"() ({ "test.op_with_data_layout"() ({ // CHECK: alignment = 20 + // CHECK: bitsize = 3 // CHECK: preferred = 30 - // CHECK: size = 3 + // CHECK: size = 1 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) { dlti.dl_spec = #dlti.dl_spec< @@ -151,8 +162,9 @@ #dlti.dl_entry, ["preferred", 30]> >} : () -> () // CHECK: alignment = 20 + // CHECK: bitsize = 10 // CHECK: preferred = 30 - // CHECK: size = 10 + // CHECK: size = 2 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) { dlti.dl_spec = #dlti.dl_spec< @@ -160,8 +172,9 @@ #dlti.dl_entry, ["alignment", 20]> >} : () -> () // CHECK: alignment = 1 + // CHECK: bitsize = 42 // CHECK: preferred = 30 - // CHECK: size = 42 + // CHECK: size = 6 "test.data_layout_query"() : () -> !test.test_type_with_layout<10> "test.maybe_terminator"() : () -> () }) { dlti.dl_spec = #dlti.dl_spec< diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -135,8 +135,8 @@ unsigned getKey() { return getImpl()->key; } - unsigned getTypeSize(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const { + unsigned getTypeSizeInBits(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { return extractKind(params, "size"); } diff --git a/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp b/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp --- a/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp +++ b/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp @@ -46,10 +46,12 @@ const DataLayout &layout = layouts.find(closest)->getSecond(); unsigned size = layout.getTypeSize(op.getType()); + unsigned bitsize = layout.getTypeSizeInBits(op.getType()); unsigned alignment = layout.getTypeABIAlignment(op.getType()); unsigned preferred = layout.getTypePreferredAlignment(op.getType()); op->setAttrs( {builder.getNamedAttr("size", builder.getIndexAttr(size)), + builder.getNamedAttr("bitsize", builder.getIndexAttr(bitsize)), builder.getNamedAttr("alignment", builder.getIndexAttr(alignment)), builder.getNamedAttr("preferred", builder.getIndexAttr(preferred))}); }); diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -71,8 +71,8 @@ static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); } - unsigned getTypeSize(const DataLayout &layout, - DataLayoutEntryListRef params) { + unsigned getTypeSizeInBits(const DataLayout &layout, + DataLayoutEntryListRef params) const { static bool executed = false; if (executed) llvm::report_fatal_error("repeated call"); @@ -121,19 +121,20 @@ return getOperation()->getAttrOfType(kAttrName); } - static unsigned getTypeSize(Type type, const DataLayout &dataLayout, - DataLayoutEntryListRef params) { + static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout, + DataLayoutEntryListRef params) { // Make a recursive query. if (type.isa()) - return dataLayout.getTypeSize( + return dataLayout.getTypeSizeInBits( IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth())); // Handle built-in types that are not handled by the default process. if (auto iType = type.dyn_cast()) { for (DataLayoutEntryInterface entry : params) if (entry.getKey().dyn_cast() == type) - return entry.getValue().cast().getValue().getZExtValue(); - return iType.getIntOrFloatBitWidth(); + return 8 * + entry.getValue().cast().getValue().getZExtValue(); + return 8 * iType.getIntOrFloatBitWidth(); } // Use the default process for everything else. @@ -152,13 +153,30 @@ } }; +struct OpWith7BitByte + : public Op { + using Op::Op; + + static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; } + + DataLayoutSpecInterface getDataLayoutSpec() { + return getOperation()->getAttrOfType(kAttrName); + } + + // Bytes are assumed to be 7-bit here. + static unsigned getTypeSize(Type type, const DataLayout &dataLayout, + DataLayoutEntryListRef params) { + return llvm::divideCeil(dataLayout.getTypeSizeInBits(type), 7); + } +}; + /// A dialect putting all the above together. struct DLTestDialect : Dialect { explicit DLTestDialect(MLIRContext *ctx) : Dialect(getDialectNamespace(), ctx, TypeID::get()) { ctx->getOrLoadDialect(); addAttributes(); - addOperations(); + addOperations(); addTypes(); } static StringRef getDialectNamespace() { return "dltest"; } @@ -222,6 +240,8 @@ DataLayout layout(op); EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u); + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); + EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u); EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u); EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u); EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u); @@ -243,6 +263,8 @@ DataLayout layout(op); EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u); EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u); + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u); + EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u); EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u); EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u); EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u); @@ -267,6 +289,8 @@ DataLayout layout(op); EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u); EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u); + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u); + EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u); EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u); EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u); EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u); @@ -274,6 +298,8 @@ EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u); EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u); + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u); + EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u); EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u); EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u); EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u); @@ -355,3 +381,23 @@ "neither the scoping op nor the type class provide data layout " "information"); } + +TEST(DataLayout, SevenBitByte) { + const char *ir = R"MLIR( +"dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> () + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningModuleRef module = parseSourceString(ir, &ctx); + auto op = + cast(module->getBody()->getOperations().front()); + DataLayout layout(op); + + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u); + EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); + EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u); +}