diff --git a/mlir/docs/DataLayout.md b/mlir/docs/DataLayout.md --- a/mlir/docs/DataLayout.md +++ b/mlir/docs/DataLayout.md @@ -18,24 +18,26 @@ types. Built-in types are handled specially to decrease the overall query cost. +Similarly, built-in `ModuleOp` supports data layouts without going through the +interface. ## Usage ### Scoping Following MLIR's nested structure, data layout properties are _scoped_ to -regions belonging to specific operations that implement the -`DataLayoutOpInterface`. Such scoping operations partially control the data -layout properties and may have attributes that affect them, typically organized -in a data layout specification. +regions belonging to either operations that implement the +`DataLayoutOpInterface` or `ModuleOp` opreations. Such scoping operations +partially control the data layout properties and may have attributes that affect +them, typically organized in a data layout specification. Types may have a different data layout in different scopes, including scopes that are nested in other scopes such as modules contained in other modules. At the same time, within the given scope excluding any nested scope, a given type has fixed data layout properties. Types are also expected to have a default, "natural" data layout in case they are used outside of any operation that -provides data layout scope for them. This ensure data layout queries always have -a valid result. +provides data layout scope for them. This ensures that data layout queries +always have a valid result. ### Compatibility and Transformations @@ -180,20 +182,24 @@ The overall flow of a data layout property query is as follows. -- The user constructs a `DataLayout` at the given scope. The constructor +1. The user constructs a `DataLayout` at the given scope. The constructor fetches the data layout specification and combines it with those of enclosing scopes (layouts are expected to be compatible). -- The user calls `DataLayout::query(Type ty)`. -- If `DataLayout` has a cached response, this response is returned +2. The user calls `DataLayout::query(Type ty)`. +3. If `DataLayout` has a cached response, this response is returned immediately. -- Otherwise, the query is handed down by `DataLayout` to - `DataLayoutOpInterface::query(ty, *this, relevantEntries)` where the - relevant entries are computed as described above. -- Unless the `query` hook is reimplemented by the op interface, the query is +4. Otherwise, the query is handed down by `DataLayout` to the closest layout + scoping operation. If it implements `DataLayoutOpInterface`, then the query + is forwarded to`DataLayoutOpInterface::query(ty, *this, relevantEntries)` + where the relevant entries are computed as described above. Otherwise it's a + `ModuleOp`, and the query is forwarded to + `DataLayoutTypeInterface::query(dataLayout, relevantEntries)` after casting + `ty` to the type interface. +5. Unless the `query` hook is reimplemented by the op interface, the query is handled further down to `DataLayoutTypeInterface::query(dataLayout, relevantEntries)` after casting `ty` to the type interface. If the type does not implement the interface, an unrecoverable fatal error is produced. -- The type is expected to always provide the response, which is returned up +6. The type is expected to always provide the response, which is returned up the call stack and cached by the `DataLayout.` ## Default Implementation @@ -201,6 +207,14 @@ The default implementation of the data layout interfaces directly handles queries for a subset of built-in types. +### Built-in Modules + +Built-in `ModuleOp` allows at most one attribute that implements +`DataLayoutSpecInterface`. It does not implement the entire interface for +efficiency and layering reasons. Instead, `DataLayout` can be constructed for +`ModuleOp` and handles modules transparently alongside other operations that +implement the interface. + ### Built-in Types The following describes the default properties of built-in types. diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -18,6 +18,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -18,6 +18,7 @@ include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Base class for Builtin dialect ops. @@ -198,6 +199,12 @@ /// A ModuleOp may optionally define a symbol. bool isOptionalSymbol() { return true; } + + //===------------------------------------------------------------------===// + // DataLayoutOpInterface Methods + //===------------------------------------------------------------------===// + + DataLayoutSpecInterface getDataLayoutSpec(); }]; let verifier = [{ return ::verify(*this); }]; diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h --- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h +++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h @@ -29,6 +29,7 @@ using DataLayoutEntryListRef = llvm::ArrayRef; class DataLayoutOpInterface; class DataLayoutSpecInterface; +class ModuleOp; namespace detail { /// Default handler for the type size request. Computes results for built-in @@ -60,10 +61,11 @@ DataLayoutEntryInterface filterEntryForIdentifier(DataLayoutEntryListRef entries, Identifier id); -/// Verifies that the operation implementing the data layout interface is valid. -/// This calls the verifier of the spec attribute and checks if the layout is -/// compatible with specs attached to the enclosing operations. -LogicalResult verifyDataLayoutOp(DataLayoutOpInterface op); +/// Verifies that the operation implementing the data layout interface, or a +/// module operation, is valid. This calls the verifier of the spec attribute +/// and checks if the layout is compatible with specs attached to the enclosing +/// operations. +LogicalResult verifyDataLayoutOp(Operation *op); /// Verifies that a data layout spec is valid. This dispatches to individual /// entry verifiers, and then to the verifiers implemented by the relevant type @@ -133,6 +135,7 @@ class DataLayout { public: explicit DataLayout(DataLayoutOpInterface op); + explicit DataLayout(ModuleOp op); /// Returns the size of the given type in the current scope. unsigned getTypeSize(Type t) const; @@ -159,7 +162,7 @@ /// Operation defining the scope of requests. // TODO: this is mutable because the generated interface method are not const. // Update the generator to support const methods and change this to const. - mutable DataLayoutOpInterface scope; + mutable Operation *scope; /// Caches for individual requests. mutable DenseMap sizes; diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" @@ -369,6 +370,8 @@ return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName << "' is expected to be a #dlti.dl_spec attribute"; } + if (isa(op)) + return detail::verifyDataLayoutOp(op); return success(); } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -230,6 +230,17 @@ return builder.create(loc, name); } +DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() { + // Take the first and only (if present) attribute that implements the + // interface. This needs a linear search, but is called only once per data + // layout object construction that is used for repeated queries. + for (Attribute attr : llvm::make_second_range(getOperation()->getAttrs())) { + if (auto spec = attr.dyn_cast()) + return spec; + } + return {}; +} + static LogicalResult verify(ModuleOp op) { // Check that none of the attributes are non-dialect attributes, except for // the symbol related attributes. @@ -244,6 +255,23 @@ << attr.first << "'"; } + // Check that there is at most one data layout spec attribute. + StringRef layoutSpecAttrName; + DataLayoutSpecInterface layoutSpec; + for (const NamedAttribute &na : op->getAttrs()) { + if (auto spec = na.second.dyn_cast()) { + if (layoutSpec) { + InFlightDiagnostic diag = + op.emitOpError() << "expects at most one data layout attribute"; + diag.attachNote() << "'" << layoutSpecAttrName + << "' is a data layout attribute"; + diag.attachNote() << "'" << na.first << "' is a data layout attribute"; + } + layoutSpecAttrName = na.first.strref(); + layoutSpec = spec; + } + } + return success(); } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -40,6 +40,7 @@ MLIRBuiltinTypesIncGen MLIRCallInterfacesIncGen MLIRCastInterfacesIncGen + MLIRDataLayoutInterfacesIncGen MLIROpAsmInterfaceIncGen MLIRRegionKindInterfaceIncGen MLIRSideEffectInterfacesIncGen diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -8,9 +8,12 @@ #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/TypeSwitch.h" + using namespace mlir; //===----------------------------------------------------------------------===// @@ -105,65 +108,93 @@ return it == entries.end() ? DataLayoutEntryInterface() : *it; } +static DataLayoutSpecInterface getSpec(Operation *operation) { + return llvm::TypeSwitch(operation) + .Case( + [&](auto op) { return op.getDataLayoutSpec(); }) + .Default([](Operation *) { + llvm_unreachable("expected an op with data layout spec"); + return DataLayoutSpecInterface(); + }); +} + /// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that -/// implement the `DataLayoutOpInterface`. -static void findProperAscendantsWithLayout( - Operation *leaf, SmallVectorImpl &opsWithLayout) { +/// are either modules or implement the `DataLayoutOpInterface`. +static void +collectParentLayouts(Operation *leaf, + SmallVectorImpl &specs, + SmallVectorImpl *opLocations = nullptr) { if (!leaf) return; - while (auto opLayout = leaf->getParentOfType()) { - opsWithLayout.push_back(opLayout); - leaf = opLayout; + for (Operation *parent = leaf->getParentOp(); parent != nullptr; + parent = parent->getParentOp()) { + llvm::TypeSwitch(parent) + .Case([&](ModuleOp op) { + // Skip top-level module op unless it has a layout. + if (!op->getParentOp() && !op.getDataLayoutSpec()) + return; + specs.push_back(op.getDataLayoutSpec()); + if (opLocations) + opLocations->push_back(op.getLoc()); + }) + .Case([&](DataLayoutOpInterface op) { + specs.push_back(op.getDataLayoutSpec()); + if (opLocations) + opLocations->push_back(op.getLoc()); + }); } } /// Returns a layout spec that is a combination of the layout specs attached /// to the given operation and all its ancestors. -static DataLayoutSpecInterface -getCombinedDataLayout(DataLayoutOpInterface leaf) { +static DataLayoutSpecInterface getCombinedDataLayout(Operation *leaf) { if (!leaf) return {}; + assert((isa(leaf)) && + "expected an op with data layout spec"); + SmallVector opsWithLayout; - findProperAscendantsWithLayout(leaf, opsWithLayout); + SmallVector specs; + collectParentLayouts(leaf, specs); // Fast track if there are no ancestors. - if (opsWithLayout.empty()) - return leaf.getDataLayoutSpec(); + if (specs.empty()) + return getSpec(leaf); // Create the list of non-null specs (null/missing specs can be safely // ignored) from the outermost to the innermost. - SmallVector specs; - specs.reserve(opsWithLayout.size()); - for (DataLayoutOpInterface op : llvm::reverse(opsWithLayout)) - if (DataLayoutSpecInterface current = op.getDataLayoutSpec()) - specs.push_back(current); + auto nonNullSpecs = llvm::to_vector<2>(llvm::make_filter_range( + llvm::reverse(specs), + [](DataLayoutSpecInterface iface) { return iface != nullptr; })); // Combine the specs using the innermost as anchor. - if (DataLayoutSpecInterface current = leaf.getDataLayoutSpec()) - return current.combineWith(specs); - if (specs.empty()) + if (DataLayoutSpecInterface current = getSpec(leaf)) + return current.combineWith(nonNullSpecs); + if (nonNullSpecs.empty()) return {}; - return specs.back().combineWith(llvm::makeArrayRef(specs).drop_back()); + return nonNullSpecs.back().combineWith( + llvm::makeArrayRef(nonNullSpecs).drop_back()); } -LogicalResult mlir::detail::verifyDataLayoutOp(DataLayoutOpInterface op) { - DataLayoutSpecInterface spec = op.getDataLayoutSpec(); +LogicalResult mlir::detail::verifyDataLayoutOp(Operation *op) { + DataLayoutSpecInterface spec = getSpec(op); // The layout specification may be missing and it's fine. if (!spec) return success(); - if (failed(spec.verifySpec(op.getLoc()))) + if (failed(spec.verifySpec(op->getLoc()))) return failure(); if (!getCombinedDataLayout(op)) { InFlightDiagnostic diag = - op.emitError() - << "data layout is not a refinement of the layouts in enclosing ops"; - SmallVector opsWithLayout; - findProperAscendantsWithLayout(op, opsWithLayout); - for (DataLayoutOpInterface parent : opsWithLayout) - diag.attachNote(parent.getLoc()) << "enclosing op with data layout"; + op->emitError() + << "data layout does not combine with layouts of enclosing ops"; + SmallVector specs; + SmallVector opLocations; + collectParentLayouts(op, specs, &opLocations); + for (Location loc : opLocations) + diag.attachNote(loc) << "enclosing op with data layout"; return diag; } return success(); @@ -173,33 +204,40 @@ // DataLayout //===----------------------------------------------------------------------===// -mlir::DataLayout::DataLayout(DataLayoutOpInterface op) - : originalLayout(getCombinedDataLayout(op)), scope(op) { +template +void checkMissingLayout(DataLayoutSpecInterface originalLayout, OpTy op) { if (!originalLayout) { assert((!op || !op.getDataLayoutSpec()) && "could not compute layout information for an op (failed to " "combine attributes?)"); } +} +mlir::DataLayout::DataLayout(DataLayoutOpInterface op) + : originalLayout(getCombinedDataLayout(op)), scope(op) { #ifndef NDEBUG - SmallVector opsWithLayout; - findProperAscendantsWithLayout(op, opsWithLayout); - layoutStack = llvm::to_vector<2>( - llvm::map_range(opsWithLayout, [](DataLayoutOpInterface iface) { - return iface.getDataLayoutSpec(); - })); + checkMissingLayout(originalLayout, op); + collectParentLayouts(op, layoutStack); +#endif +} + +mlir::DataLayout::DataLayout(ModuleOp op) + : originalLayout(getCombinedDataLayout(op)), scope(op) { +#ifndef NDEBUG + checkMissingLayout(originalLayout, op); + collectParentLayouts(op, layoutStack); #endif } void mlir::DataLayout::checkValid() const { #ifndef NDEBUG - SmallVector opsWithLayout; - findProperAscendantsWithLayout(scope, opsWithLayout); - assert(opsWithLayout.size() == layoutStack.size() && + SmallVector specs; + collectParentLayouts(scope, specs); + assert(specs.size() == layoutStack.size() && "data layout object used, but no longer valid due to the change in " "number of nested layouts"); - for (auto pair : llvm::zip(opsWithLayout, layoutStack)) { - Attribute newLayout = std::get<0>(pair).getDataLayoutSpec(); + for (auto pair : llvm::zip(specs, layoutStack)) { + Attribute newLayout = std::get<0>(pair); Attribute origLayout = std::get<1>(pair); assert(newLayout == origLayout && "data layout object used, but no longer valid " @@ -228,30 +266,39 @@ unsigned mlir::DataLayout::getTypeSize(Type t) const { checkValid(); return cachedLookup(t, sizes, [&](Type ty) { - return (scope && originalLayout) - ? scope.getTypeSize( - ty, *this, originalLayout.getSpecForType(ty.getTypeID())) - : detail::getDefaultTypeSize(ty, *this, {}); + if (originalLayout) { + DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast(scope)) + return iface.getTypeSize(ty, *this, list); + return detail::getDefaultTypeSize(ty, *this, list); + } + return detail::getDefaultTypeSize(ty, *this, {}); }); } unsigned mlir::DataLayout::getTypeABIAlignment(Type t) const { checkValid(); return cachedLookup(t, abiAlignments, [&](Type ty) { - return (scope && originalLayout) - ? scope.getTypeABIAlignment( - ty, *this, originalLayout.getSpecForType(ty.getTypeID())) - : detail::getDefaultABIAlignment(ty, *this, {}); + if (originalLayout) { + DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast(scope)) + return iface.getTypeABIAlignment(ty, *this, list); + return detail::getDefaultABIAlignment(ty, *this, list); + } + return detail::getDefaultABIAlignment(ty, *this, {}); }); } unsigned mlir::DataLayout::getTypePreferredAlignment(Type t) const { checkValid(); return cachedLookup(t, preferredAlignments, [&](Type ty) { - return (scope && originalLayout) - ? scope.getTypePreferredAlignment( - ty, *this, originalLayout.getSpecForType(ty.getTypeID())) - : detail::getDefaultPreferredAlignment(ty, *this, {}); + if (originalLayout) { + DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast(scope)) + return iface.getTypePreferredAlignment(ty, *this, list); + return detail::getDefaultPreferredAlignment(ty, *this, list); + } + return detail::getDefaultPreferredAlignment(ty, *this, {}); }); } diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir --- a/mlir/test/Dialect/DLTI/invalid.mlir +++ b/mlir/test/Dialect/DLTI/invalid.mlir @@ -55,7 +55,7 @@ // Mismatching entries don't combine. "test.op_with_data_layout"() ({ - // expected-error@below {{data layout is not a refinement of the layouts in enclosing ops}} + // expected-error@below {{data layout does not combine with layouts of enclosing ops}} // expected-note@above {{enclosing op with data layout}} "test.op_with_data_layout"() { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>> } : () -> () "test.maybe_terminator_op"() : () -> () @@ -71,3 +71,22 @@ // expected-error@below {{data layout specified for a type that does not support it}} "test.op_with_data_layout"() { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry> } : () -> () + +// ----- + +// Mismatching entries are checked on module ops as well. +module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 33>>} { + // expected-error@below {{data layout does not combine with layouts of enclosing ops}} + // expected-note@above {{enclosing op with data layout}} + module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>>} { + } +} + +// ----- + +// Mismatching entries are checked on a combination of modules and other ops. +module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 33>>} { + // expected-error@below {{data layout does not combine with layouts of enclosing ops}} + // expected-note@above {{enclosing op with data layout}} + "test.op_with_data_layout"() { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>>} : () -> () +} diff --git a/mlir/test/IR/module-op.mlir b/mlir/test/IR/module-op.mlir --- a/mlir/test/IR/module-op.mlir +++ b/mlir/test/IR/module-op.mlir @@ -55,3 +55,12 @@ } } } + +// ----- + +// expected-error@below {{expects at most one data layout attribute}} +// expected-note@below {{'test.another_attribute' is a data layout attribute}} +// expected-note@below {{'test.random_attribute' is a data layout attribute}} +module attributes { test.random_attribute = #dlti.dl_spec<>, + test.another_attribute = #dlti.dl_spec<>} { +} diff --git a/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp b/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp --- a/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp +++ b/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp @@ -36,8 +36,15 @@ scope, scope ? cast(scope.getOperation()) : nullptr); } + auto module = op->getParentOfType(); + if (!layouts.count(module)) + layouts.try_emplace(module, module); - const DataLayout &layout = layouts.find(scope)->getSecond(); + Operation *closest = (scope && module && module->isProperAncestor(scope)) + ? scope.getOperation() + : module.getOperation(); + + const DataLayout &layout = layouts.find(closest)->getSecond(); unsigned size = layout.getTypeSize(op.getType()); unsigned alignment = layout.getTypeABIAlignment(op.getType()); unsigned preferred = layout.getTypePreferredAlignment(op.getType());