diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -280,52 +280,48 @@ unsigned mlir::DataLayout::getTypeSize(Type t) const { checkValid(); return cachedLookup(t, sizes, [&](Type ty) { - if (originalLayout) { - DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); - if (auto iface = dyn_cast(scope)) - return iface.getTypeSize(ty, *this, list); - return detail::getDefaultTypeSize(ty, *this, list); - } - return detail::getDefaultTypeSize(ty, *this, {}); + DataLayoutEntryList list; + if (originalLayout) + list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast_or_null(scope)) + return iface.getTypeSize(ty, *this, list); + return detail::getDefaultTypeSize(ty, *this, list); }); } unsigned mlir::DataLayout::getTypeSizeInBits(Type t) const { checkValid(); return cachedLookup(t, bitsizes, [&](Type ty) { - if (originalLayout) { - DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); - if (auto iface = dyn_cast(scope)) - return iface.getTypeSizeInBits(ty, *this, list); - return detail::getDefaultTypeSizeInBits(ty, *this, list); - } - return detail::getDefaultTypeSizeInBits(ty, *this, {}); + DataLayoutEntryList list; + if (originalLayout) + list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast_or_null(scope)) + return iface.getTypeSizeInBits(ty, *this, list); + return detail::getDefaultTypeSizeInBits(ty, *this, list); }); } unsigned mlir::DataLayout::getTypeABIAlignment(Type t) const { checkValid(); return cachedLookup(t, abiAlignments, [&](Type ty) { - if (originalLayout) { - DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); - if (auto iface = dyn_cast(scope)) - return iface.getTypeABIAlignment(ty, *this, list); - return detail::getDefaultABIAlignment(ty, *this, list); - } - return detail::getDefaultABIAlignment(ty, *this, {}); + DataLayoutEntryList list; + if (originalLayout) + list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast_or_null(scope)) + return iface.getTypeABIAlignment(ty, *this, list); + return detail::getDefaultABIAlignment(ty, *this, list); }); } unsigned mlir::DataLayout::getTypePreferredAlignment(Type t) const { checkValid(); return cachedLookup(t, preferredAlignments, [&](Type ty) { - if (originalLayout) { - DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); - if (auto iface = dyn_cast(scope)) - return iface.getTypePreferredAlignment(ty, *this, list); - return detail::getDefaultPreferredAlignment(ty, *this, list); - } - return detail::getDefaultPreferredAlignment(ty, *this, {}); + DataLayoutEntryList list; + if (originalLayout) + list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast_or_null(scope)) + return iface.getTypePreferredAlignment(ty, *this, list); + return detail::getDefaultPreferredAlignment(ty, *this, list); }); } diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -227,7 +227,7 @@ TEST(DataLayout, FallbackDefault) { const char *ir = R"MLIR( -"dltest.op_with_layout"() : () -> () +module {} )MLIR"; DialectRegistry registry; @@ -235,9 +235,7 @@ MLIRContext ctx(registry); OwningModuleRef module = parseSourceString(ir, &ctx); - auto op = - cast(module->getBody()->getOperations().front()); - DataLayout layout(op); + DataLayout layout(module.get()); EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u); EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); @@ -248,6 +246,29 @@ EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u); } +TEST(DataLayout, NullSpec) { + const char *ir = R"MLIR( +"dltest.op_with_layout"() : () -> () + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningModuleRef module = parseSourceString(ir, &ctx); + auto op = + cast(module->getBody()->getOperations().front()); + DataLayout layout(op); + EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u); + EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u); + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u); + EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u); + EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u); + EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u); + EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u); + EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u); +} + TEST(DataLayout, EmptySpec) { const char *ir = R"MLIR( "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()