diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -229,16 +229,17 @@ /// Creates a MemRef type with the given rank and shape, a potentially empty /// list of affine layout maps, the given memory space and element type, in the /// same context as element type. The type is owned by the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( - MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, MlirAttribute memorySpace); +MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType, + intptr_t rank, + const int64_t *shape, + MlirAttribute layout, + MlirAttribute memorySpace); /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, MlirAffineMap const *affineMaps, - MlirAttribute memorySpace); + MlirAttribute layout, MlirAttribute memorySpace); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, @@ -264,12 +265,11 @@ MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, MlirAttribute memorySpace); -/// Returns the number of affine layout maps in the given MemRef type. -MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); +/// Returns the layout of the given MemRef type. +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type); -/// Returns the pos-th affine map of the given MemRef type. -MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, - intptr_t pos); +/// Returns the affine map of the given MemRef type. +MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type); /// Returns the memory space of the given MemRef type. MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" @@ -227,6 +228,21 @@ ptrdiff_t index; }; } // namespace detail + +//===----------------------------------------------------------------------===// +// MemRefLayoutAttrInterface +//===----------------------------------------------------------------------===// + +namespace detail { + +// Verify the affine map 'm' can be used as a layout specification +// for memref with 'shape'. +LogicalResult +verifyAffineMapAsLayout(AffineMap m, ArrayRef shape, + function_ref emitError); + +} // namespace detail + } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -432,4 +432,52 @@ }] # ElementsAttrInterfaceAccessors; } +//===----------------------------------------------------------------------===// +// MemRefLayoutAttrInterface +//===----------------------------------------------------------------------===// + +def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> { + let cppNamespace = "::mlir"; + + let description = [{ + This interface is used for attributes that can represent the MemRef type's + layout semantics, such as dimension order in the memory, strides and offsets. + Such a layout attribute should be representable as a + [semi-affine map](Affine.md/#semi-affine-maps). + + Note: the MemRef type's layout is assumed to represent simple strided buffer + layout. For more complicated case, like sparse storage buffers, + it is preferable to use separate type with more specic layout, rather then + introducing extra complexity to the builin MemRef type. + }]; + + let methods = [ + InterfaceMethod< + "Get the MemRef layout as an AffineMap, the method must not return NULL", + "::mlir::AffineMap", "getAffineMap", (ins) + >, + + InterfaceMethod< + "Return true if this attribute represents the identity layout", + "bool", "isIdentity", (ins), + [{}], + [{ + return $_attr.getAffineMap().isIdentity(); + }] + >, + + InterfaceMethod< + "Check if the current layout is applicable to the provided shape", + "::mlir::LogicalResult", "verifyLayout", + (ins "::llvm::ArrayRef":$shape, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError), + [{}], + [{ + return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(), + shape, emitError); + }] + > + ]; +} + #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -34,7 +34,9 @@ // AffineMapAttr //===----------------------------------------------------------------------===// -def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap"> { +def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [ + MemRefLayoutAttrInterface + ]> { let summary = "An Attribute containing an AffineMap object"; let description = [{ Syntax: @@ -56,7 +58,10 @@ return $_get(value.getContext(), value); }]> ]; - let extraClassDeclaration = "using ValueType = AffineMap;"; + let extraClassDeclaration = [{ + using ValueType = AffineMap; + AffineMap getAffineMap() const { return getValue(); } + }]; let skipDefaultBuilders = 1; let typeBuilder = "IndexType::get($_value.getContext())"; } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_BUILTINTYPES_H #define MLIR_IR_BUILTINTYPES_H +#include "BuiltinAttributeInterfaces.h" #include "SubElementInterfaces.h" namespace llvm { @@ -209,12 +210,11 @@ // Build from another MemRefType. explicit Builder(MemRefType other) : shape(other.getShape()), elementType(other.getElementType()), - affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()) { - } + layout(other.getLayout()), memorySpace(other.getMemorySpace()) {} // Build from scratch. Builder(ArrayRef shape, Type elementType) - : shape(shape), elementType(elementType), affineMaps() {} + : shape(shape), elementType(elementType) {} Builder &setShape(ArrayRef newShape) { shape = newShape; @@ -226,8 +226,8 @@ return *this; } - Builder &setAffineMaps(ArrayRef newAffineMaps) { - affineMaps = newAffineMaps; + Builder &setLayout(MemRefLayoutAttrInterface newLayout) { + layout = newLayout; return *this; } @@ -240,13 +240,13 @@ Builder &setMemorySpace(unsigned newMemorySpace); operator MemRefType() { - return MemRefType::get(shape, elementType, affineMaps, memorySpace); + return MemRefType::get(shape, elementType, layout, memorySpace); } private: ArrayRef shape; Type elementType; - ArrayRef affineMaps; + MemRefLayoutAttrInterface layout; Attribute memorySpace; }; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -278,8 +278,7 @@ stride-list ::= `[` (dimension (`,` dimension)*)? `]` strided-layout ::= `offset:` dimension `,` `strides: ` stride-list - semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map - layout-specification ::= semi-affine-map-composition | strided-layout + layout-specification ::= semi-affine-map | strided-layout | attribute-value memory-space ::= attribute-value ``` @@ -486,27 +485,6 @@ #layout_tiled = (i, j) -> (i floordiv 64, j floordiv 64, i mod 64, j mod 64) ``` - ##### Affine Map Composition - - A memref specifies a semi-affine map composition as part of its type. A - semi-affine map composition is a composition of semi-affine maps beginning - with zero or more index maps, and ending with a layout map. The composition - must be conformant: the number of dimensions of the range of one map, must - match the number of dimensions of the domain of the next map in the - composition. - - The semi-affine map composition specified in the memref type, maps from - accesses used to index the memref in load/store operations to other index - spaces (i.e. logical to physical index mapping). Each of the - [semi-affine maps](Affine.md/#semi-affine-maps) and thus its composition is required - to be one-to-one. - - The semi-affine map composition can be used in dependence analysis, memory - access pattern analysis, and for performance optimizations like - vectorization, copy elision and in-place updates. If an affine map - composition is not specified for the memref, the identity affine map is - assumed. - ##### Strided MemRef A memref may specify a strided layout as part of its type. A stride @@ -544,36 +522,23 @@ let parameters = (ins ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, - ArrayRefParameter<"AffineMap">:$affineMaps, + "MemRefLayoutAttrInterface":$layout, "Attribute":$memorySpace ); let builders = [ TypeBuilderWithInferredContext<(ins "ArrayRef":$shape, "Type":$elementType, - CArg<"ArrayRef", "{}">:$affineMaps, - CArg<"Attribute", "{}">:$memorySpace - ), [{ - // Drop identity maps from the composition. This may lead to the - // composition becoming empty, which is interpreted as an implicit - // identity. - auto nonIdentityMaps = llvm::make_filter_range(affineMaps, - [](AffineMap map) { return !map.isIdentity(); }); - // Drop default memory space value and replace it with empty attribute. - Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace); - return $_get(elementType.getContext(), shape, elementType, - llvm::to_vector<4>(nonIdentityMaps), nonDefaultMemorySpace); - }]>, + CArg<"MemRefLayoutAttrInterface", "{}">:$layout, + CArg<"Attribute", "{}">:$memorySpace)>, + TypeBuilderWithInferredContext<(ins + "ArrayRef":$shape, "Type":$elementType, + CArg<"AffineMap">:$map, + CArg<"Attribute", "{}">:$memorySpace)>, /// [deprecated] `Attribute`-based form should be used instead. TypeBuilderWithInferredContext<(ins "ArrayRef":$shape, "Type":$elementType, - "ArrayRef":$affineMaps, - "unsigned":$memorySpace - ), [{ - // Convert deprecated integer-like memory space to Attribute. - Attribute memorySpaceAttr = - wrapIntegerMemorySpace(memorySpace, elementType.getContext()); - return MemRefType::get(shape, elementType, affineMaps, memorySpaceAttr); - }]> + "AffineMap":$map, + "unsigned":$memorySpaceInd)> ]; let extraClassDeclaration = [{ /// This is a builder type that keeps local references to arguments. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -219,15 +219,8 @@ assert(memRefDim && "memRefDim == nullptr"); auto memRefType = memoryOp.getMemRefType(); - auto layoutMap = memRefType.getAffineMaps(); - // TODO: remove dependence on Builder once we support non-identity layout map. - Builder b(memoryOp.getContext()); - if (layoutMap.size() >= 2 || - (layoutMap.size() == 1 && - !(layoutMap[0] == - b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) { + if (!memRefType.getLayout().isIdentity()) return memoryOp.emitError("NYI: non-trivial layoutMap"), false; - } int uniqueVaryingIndexAlongIv = -1; auto accessMap = memoryOp.getAffineMap(); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -616,9 +616,7 @@ Optional MemRefRegion::getRegionSize() { auto memRefType = memref.getType().cast(); - auto layoutMaps = memRefType.getAffineMaps(); - if (layoutMaps.size() > 1 || - (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + if (!memRefType.getLayout().isIdentity()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); return false; } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -401,8 +401,6 @@ } }; -class PyMemRefLayoutMapList; - /// Ranked MemRef Type subclass - MemRefType. class PyMemRefType : public PyConcreteType { public: @@ -410,26 +408,18 @@ static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; - PyMemRefLayoutMapList getLayout(); - static void bindDerived(ClassTy &c) { c.def_static( "get", [](std::vector shape, PyType &elementType, - std::vector layout, PyAttribute *memorySpace, + PyAttribute *layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { - SmallVector maps; - maps.reserve(layout.size()); - for (PyAffineMap &map : layout) - maps.push_back(map); - - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; - - MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), maps.size(), - maps.data(), memSpaceAttr); + MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), layoutAttr, memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -444,10 +434,22 @@ return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), + py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly("layout", &PyMemRefType::getLayout, - "The list of layout maps of the MemRef type.") + .def_property_readonly( + "layout", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute layout = mlirMemRefTypeGetLayout(self); + return PyAttribute(self.getContext(), layout); + }, + "The layout of the MemRef type.") + .def_property_readonly( + "affine_map", + [](PyMemRefType &self) -> PyAffineMap { + MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); + return PyAffineMap(self.getContext(), map); + }, + "The layout of the MemRef type as an affine map.") .def_property_readonly( "memory_space", [](PyMemRefType &self) -> PyAttribute { @@ -458,41 +460,6 @@ } }; -/// A list of affine layout maps in a memref type. Internally, these are stored -/// as consecutive elements, random access is cheap. Both the type and the maps -/// are owned by the context, no need to worry about lifetime extension. -class PyMemRefLayoutMapList - : public Sliceable { -public: - static constexpr const char *pyClassName = "MemRefLayoutMapList"; - - PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, - step), - memref(type) {} - - intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } - - PyAffineMap getElement(intptr_t index) { - return PyAffineMap(memref.getContext(), - mlirMemRefTypeGetAffineMap(memref, index)); - } - - PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyMemRefLayoutMapList(memref, startIndex, length, step); - } - -private: - PyMemRefType memref; -}; - -PyMemRefLayoutMapList PyMemRefType::getLayout() { - return PyMemRefLayoutMapList(*this); -} - /// Unranked MemRef Type subclass - UnrankedMemRefType. class PyUnrankedMemRefType : public PyConcreteType { @@ -640,7 +607,6 @@ PyRankedTensorType::bind(m); PyUnrankedTensorType::bind(m); PyMemRefType::bind(m); - PyMemRefLayoutMapList::bind(m); PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); PyFunctionType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -226,34 +226,35 @@ bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, - const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, + const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace) { - SmallVector maps; - (void)unwrapList(numMaps, affineMaps, maps); - return wrap( - MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, unwrap(memorySpace))); + return wrap(MemRefType::get( + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + mlirAttributeIsNull(layout) + ? MemRefLayoutAttrInterface() + : unwrap(layout).cast(), + unwrap(memorySpace))); } MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, - MlirAffineMap const *affineMaps, + MlirAttribute layout, MlirAttribute memorySpace) { - SmallVector maps; - (void)unwrapList(numMaps, affineMaps, maps); return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, unwrap(memorySpace))); + unwrap(elementType), + mlirAttributeIsNull(layout) + ? MemRefLayoutAttrInterface() + : unwrap(layout).cast(), + unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute memorySpace) { - return wrap( - MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, unwrap(memorySpace))); + return wrap(MemRefType::get( + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + MemRefLayoutAttrInterface(), unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, @@ -262,16 +263,15 @@ MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, unwrap(memorySpace))); + unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace))); } -intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { - return static_cast( - unwrap(type).cast().getAffineMaps().size()); +MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { + return wrap(unwrap(type).cast().getLayout()); } -MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) { - return wrap(unwrap(type).cast().getAffineMaps()[pos]); +MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { + return wrap(unwrap(type).cast().getLayout().getAffineMap()); } MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -106,9 +106,7 @@ MemRefType type) const { if (!typeConverter->convertType(type.getElementType())) return false; - return type.getAffineMaps().empty() || - llvm::all_of(type.getAffineMaps(), - [](AffineMap map) { return map.isIdentity(); }); + return type.getLayout().isIdentity(); } Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1142,7 +1142,8 @@ ConversionPatternRewriter &rewriter) const override { MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); - if (!srcType.getAffineMaps().empty() || !dstType.getAffineMaps().empty()) { + if (!srcType.getLayout().isIdentity() || + !dstType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure(reshapeOp, "only empty layout map is supported"); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -950,8 +950,7 @@ if (!strides.empty() && strides.back() != 1) return None; // If no layout or identity layout, this is contiguous by definition. - if (memRefType.getAffineMaps().empty() || - memRefType.getAffineMaps().front().isIdentity()) + if (memRefType.getLayout().isIdentity()) return strides; // Otherwise, we must determine contiguity form shapes. This can only ever diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1047,8 +1047,7 @@ auto srcMemrefType = srcType.cast(); auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); - if (!srcMemrefType.getAffineMaps().empty() && - !srcMemrefType.getAffineMaps().front().isIdentity()) + if (!srcMemrefType.getLayout().isIdentity()) return op.emitError("expected identity layout map for source memref"); if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && @@ -1074,9 +1073,7 @@ auto srcMatrixType = srcType.cast(); auto dstMemrefType = dstType.cast(); auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); - - if (!dstMemrefType.getAffineMaps().empty() && - !dstMemrefType.getAffineMaps().front().isIdentity()) + if (!dstMemrefType.getLayout().isIdentity()) return op.emitError("expected identity layout map for destination memref"); if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -159,9 +159,8 @@ Value createWorkgroupBuffer() { int workgroupMemoryAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); - auto bufferType = - MemRefType::get({kSubgroupSize}, valueType, ArrayRef{}, - workgroupMemoryAddressSpace); + auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{}, + workgroupMemoryAddressSpace); return funcOp.addWorkgroupAttribution(bufferType); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1219,27 +1219,24 @@ /// with the same shape as `shapedType` and specified `layout` and /// `addressSpace`. static MemRefType getContiguousMemRefType(ShapedType shapedType, - ArrayRef layout = {}, - unsigned addressSpace = 0) { - if (RankedTensorType tensorType = shapedType.dyn_cast()) - return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), - layout, addressSpace); - MemRefType memrefType = shapedType.cast(); - return MemRefType::get(memrefType.getShape(), memrefType.getElementType(), - layout, addressSpace); + MemRefLayoutAttrInterface layout = {}, + Attribute memorySpace = {}) { + return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), + layout, memorySpace); } /// Return a contiguous MemRefType (i.e. with canonical/empty layout map) /// with the same shape as `shapedType` and specified `layout` and /// `addressSpace` or an UnrankedMemRefType otherwise. -static Type getContiguousOrUnrankedMemRefType(Type type, - ArrayRef layout = {}, - unsigned addressSpace = 0) { +static Type +getContiguousOrUnrankedMemRefType(Type type, + MemRefLayoutAttrInterface layout = {}, + Attribute memorySpace = {}) { if (type.isa()) return getContiguousMemRefType(type.cast(), layout, - addressSpace); - assert(layout.empty() && "expected empty layout with UnrankedMemRefType"); - return UnrankedMemRefType::get(getElementTypeOrSelf(type), addressSpace); + memorySpace); + assert(!layout && "expected empty layout with UnrankedMemRefType"); + return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace); } /// Return a MemRefType to which the `tensorType` can be bufferized in a @@ -1644,16 +1641,16 @@ auto rankedMemRefType = sourceType.dyn_cast(); auto unrankedMemRefType = sourceType.dyn_cast(); assert(rankedMemRefType || unrankedMemRefType); - unsigned memorySpace = rankedMemRefType - ? rankedMemRefType.getMemorySpaceAsInt() - : unrankedMemRefType.getMemorySpaceAsInt(); + Attribute memorySpace = rankedMemRefType + ? rankedMemRefType.getMemorySpace() + : unrankedMemRefType.getMemorySpace(); TensorType tensorType = castOp.getResult().getType().cast(); - ArrayRef affineMaps = + MemRefLayoutAttrInterface layout = rankedMemRefType && tensorType.isa() - ? rankedMemRefType.getAffineMaps() - : ArrayRef{}; + ? rankedMemRefType.getLayout() + : MemRefLayoutAttrInterface(); Type memRefType = getContiguousOrUnrankedMemRefType( - castOp.getResult().getType(), affineMaps, memorySpace); + castOp.getResult().getType(), layout, memorySpace); Value res = b.create(castOp.getLoc(), memRefType, resultBuffer); aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -258,7 +258,7 @@ // leave them unchanged. Type actualType = opOperand->get().getType(); if (auto memref = actualType.dyn_cast()) { - if (!memref.getAffineMaps().empty()) + if (!memref.getLayout().isIdentity()) return llvm::None; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -88,8 +88,8 @@ "dynamic dimension count"); unsigned numSymbols = 0; - if (!memRefType.getAffineMaps().empty()) - numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); + if (!memRefType.getLayout().isIdentity()) + numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); if (op.symbolOperands().size() != numSymbols) return op.emitOpError("symbol operand count does not equal memref symbol " "count: expected ") @@ -496,7 +496,7 @@ if (aT && bT) { if (aT.getElementType() != bT.getElementType()) return false; - if (aT.getAffineMaps() != bT.getAffineMaps()) { + if (aT.getLayout() != bT.getLayout()) { int64_t aOffset, bOffset; SmallVector aStrides, bStrides; if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || @@ -1408,7 +1408,7 @@ // Match offset and strides in static_offset and static_strides attributes if // result memref type has an affine map specified. - if (!resultType.getAffineMaps().empty()) { + if (!resultType.getLayout().isIdentity()) { int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) @@ -1526,8 +1526,8 @@ } // Early-exit: if `type` is contiguous, the result must be contiguous. - if (canonicalizeStridedLayout(type).getAffineMaps().empty()) - return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); + if (canonicalizeStridedLayout(type).getLayout().isIdentity()) + return MemRefType::Builder(type).setShape(newSizes).setLayout({}); // Convert back to int64_t because we don't have enough information to create // new strided layouts from AffineExpr only. This corresponds to a case where @@ -1546,7 +1546,8 @@ auto layout = makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); return canonicalizeStridedLayout( - MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); + MemRefType::Builder(type).setShape(newSizes).setLayout( + AffineMapAttr::get(layout))); } void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, @@ -1662,14 +1663,14 @@ "types should be the same"); if (auto operandMemRefType = operandType.dyn_cast()) - if (!operandMemRefType.getAffineMaps().empty()) + if (!operandMemRefType.getLayout().isIdentity()) return op.emitOpError( "source memref type should have identity affine map"); int64_t shapeSize = op.shape().getType().cast().getDimSize(0); auto resultMemRefType = resultType.dyn_cast(); if (resultMemRefType) { - if (!resultMemRefType.getAffineMaps().empty()) + if (!resultMemRefType.getLayout().isIdentity()) return op.emitOpError( "result memref type should have identity affine map"); if (shapeSize == ShapedType::kDynamicSize) @@ -1824,10 +1825,9 @@ if (!dimsToProject.contains(pos)) projectedShape.push_back(shape[pos]); - AffineMap map; - auto maps = inferredType.getAffineMaps(); - if (!maps.empty() && maps.front()) - map = getProjectedMap(maps.front(), dimsToProject); + AffineMap map = inferredType.getLayout().getAffineMap(); + if (!map.isIdentity()) + map = getProjectedMap(map, dimsToProject); inferredType = MemRefType::get(projectedShape, inferredType.getElementType(), map, inferredType.getMemorySpace()); @@ -2279,7 +2279,9 @@ auto map = makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); map = permutationMap ? map.compose(permutationMap) : map; - return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); + return MemRefType::Builder(memRefType) + .setShape(sizes) + .setLayout(AffineMapAttr::get(map)); } void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, @@ -2387,15 +2389,11 @@ auto viewType = op.getType(); // The base memref should have identity layout map (or none). - if (baseType.getAffineMaps().size() > 1 || - (baseType.getAffineMaps().size() == 1 && - !baseType.getAffineMaps()[0].isIdentity())) + if (!baseType.getLayout().isIdentity()) return op.emitError("unsupported map for base memref type ") << baseType; // The result memref should have identity layout map (or none). - if (viewType.getAffineMaps().size() > 1 || - (viewType.getAffineMaps().size() == 1 && - !viewType.getAffineMaps()[0].isIdentity())) + if (!viewType.getLayout().isIdentity()) return op.emitError("unsupported map for result memref type ") << viewType; // The base memref and the view memref should be in the same memory space. diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -3767,16 +3767,17 @@ VectorType vectorType = VectorType::get(extractShape(memRefType), getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); - result.addTypes( - MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace())); + result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(), + memRefType.getMemorySpace())); } static LogicalResult verify(TypeCastOp op) { MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType()); - if (!canonicalType.getAffineMaps().empty()) - return op.emitOpError("expects operand to be a memref with no layout"); - if (!op.getResultMemRefType().getAffineMaps().empty()) - return op.emitOpError("expects result to be a memref with no layout"); + if (!canonicalType.getLayout().isIdentity()) + return op.emitOpError( + "expects operand to be a memref with identity layout"); + if (!op.getResultMemRefType().getLayout().isIdentity()) + return op.emitOpError("expects result to be a memref with identity layout"); if (op.getResultMemRefType().getMemorySpace() != op.getMemRefType().getMemorySpace()) return op.emitOpError("expects result in same memory space"); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1971,9 +1971,9 @@ os << 'x'; } printType(memrefTy.getElementType()); - for (auto map : memrefTy.getAffineMaps()) { + if (!memrefTy.getLayout().isIdentity()) { os << ", "; - printAttribute(AffineMapAttr::get(map)); + printAttribute(memrefTy.getLayout(), AttrTypeElision::May); } // Only print the memory space if it is the non-default one. if (memrefTy.getMemorySpace()) { diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "llvm/ADT/Sequence.h" using namespace mlir; @@ -72,3 +73,17 @@ } return valueIndex; } + +//===----------------------------------------------------------------------===// +// MemRefLayoutAttrInterface +//===----------------------------------------------------------------------===// + +LogicalResult mlir::detail::verifyAffineMapAsLayout( + AffineMap m, ArrayRef shape, + function_ref emitError) { + if (m.getNumDims() != shape.size()) + return emitError() << "memref layout mismatch between rank and affine map: " + << shape.size() << " != " << m.getNumDims(); + + return success(); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -646,9 +646,118 @@ return detail::getMemorySpaceAsInt(getMemorySpace()); } +MemRefType MemRefType::get(ArrayRef shape, Type elementType, + MemRefLayoutAttrInterface layout, + Attribute memorySpace) { + // Use default layout for empty attribute. + if (!layout) + layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( + shape.size(), elementType.getContext())); + + // Drop default memory space value and replace it with empty attribute. + memorySpace = skipDefaultMemorySpace(memorySpace); + + return Base::get(elementType.getContext(), shape, elementType, layout, + memorySpace); +} + +MemRefType MemRefType::getChecked( + function_ref emitErrorFn, ArrayRef shape, + Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { + + // Use default layout for empty attribute. + if (!layout) + layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( + shape.size(), elementType.getContext())); + + // Drop default memory space value and replace it with empty attribute. + memorySpace = skipDefaultMemorySpace(memorySpace); + + return Base::getChecked(emitErrorFn, elementType.getContext(), shape, + elementType, layout, memorySpace); +} + +MemRefType MemRefType::get(ArrayRef shape, Type elementType, + AffineMap map, Attribute memorySpace) { + + // Use default layout for empty map. + if (!map) + map = AffineMap::getMultiDimIdentityMap(shape.size(), + elementType.getContext()); + + // Wrap AffineMap into Attribute. + Attribute layout = AffineMapAttr::get(map); + + // Drop default memory space value and replace it with empty attribute. + memorySpace = skipDefaultMemorySpace(memorySpace); + + return Base::get(elementType.getContext(), shape, elementType, layout, + memorySpace); +} + +MemRefType +MemRefType::getChecked(function_ref emitErrorFn, + ArrayRef shape, Type elementType, AffineMap map, + Attribute memorySpace) { + + // Use default layout for empty map. + if (!map) + map = AffineMap::getMultiDimIdentityMap(shape.size(), + elementType.getContext()); + + // Wrap AffineMap into Attribute. + Attribute layout = AffineMapAttr::get(map); + + // Drop default memory space value and replace it with empty attribute. + memorySpace = skipDefaultMemorySpace(memorySpace); + + return Base::getChecked(emitErrorFn, elementType.getContext(), shape, + elementType, layout, memorySpace); +} + +MemRefType MemRefType::get(ArrayRef shape, Type elementType, + AffineMap map, unsigned memorySpaceInd) { + + // Use default layout for empty map. + if (!map) + map = AffineMap::getMultiDimIdentityMap(shape.size(), + elementType.getContext()); + + // Wrap AffineMap into Attribute. + Attribute layout = AffineMapAttr::get(map); + + // Convert deprecated integer-like memory space to Attribute. + Attribute memorySpace = + wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); + + return Base::get(elementType.getContext(), shape, elementType, layout, + memorySpace); +} + +MemRefType +MemRefType::getChecked(function_ref emitErrorFn, + ArrayRef shape, Type elementType, AffineMap map, + unsigned memorySpaceInd) { + + // Use default layout for empty map. + if (!map) + map = AffineMap::getMultiDimIdentityMap(shape.size(), + elementType.getContext()); + + // Wrap AffineMap into Attribute. + Attribute layout = AffineMapAttr::get(map); + + // Convert deprecated integer-like memory space to Attribute. + Attribute memorySpace = + wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); + + return Base::getChecked(emitErrorFn, elementType.getContext(), shape, + elementType, layout, memorySpace); +} + LogicalResult MemRefType::verify(function_ref emitError, ArrayRef shape, Type elementType, - ArrayRef affineMapComposition, + MemRefLayoutAttrInterface layout, Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError() << "invalid memref element type"; @@ -658,26 +767,12 @@ if (s < -1) return emitError() << "invalid memref size"; - // Check that the structure of the composition is valid, i.e. that each - // subsequent affine map has as many inputs as the previous map has results. - // Take the dimensionality of the MemRef for the first map. - size_t dim = shape.size(); - for (auto it : llvm::enumerate(affineMapComposition)) { - AffineMap map = it.value(); - if (map.getNumDims() == dim) { - dim = map.getNumResults(); - continue; - } - return emitError() << "memref affine map dimension mismatch between " - << (it.index() == 0 ? Twine("memref rank") - : "affine map " + Twine(it.index())) - << " and affine map" << it.index() + 1 << ": " << dim - << " != " << map.getNumDims(); - } + assert(layout && "missing layout specification"); + if (failed(layout.verifyLayout(shape, emitError))) + return failure(); - if (!isSupportedMemorySpace(memorySpace)) { + if (!isSupportedMemorySpace(memorySpace)) return emitError() << "unsupported memory space Attribute"; - } return success(); } @@ -686,9 +781,9 @@ function_ref walkAttrsFn, function_ref walkTypesFn) const { walkTypesFn(getElementType()); + if (!getLayout().isIdentity()) + walkAttrsFn(getLayout()); walkAttrsFn(getMemorySpace()); - for (AffineMap map : getAffineMaps()) - walkAttrsFn(AffineMapAttr::get(map)); } //===----------------------------------------------------------------------===// @@ -775,23 +870,18 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t, SmallVectorImpl &strides, AffineExpr &offset) { - auto affineMaps = t.getAffineMaps(); - - if (affineMaps.size() > 1) - return failure(); + AffineMap m = t.getLayout().getAffineMap(); - if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1) + if (m.getNumResults() != 1 && !m.isIdentity()) return failure(); - AffineMap m = affineMaps.empty() ? AffineMap() : affineMaps.back(); - auto zero = getAffineConstantExpr(0, t.getContext()); auto one = getAffineConstantExpr(1, t.getContext()); offset = zero; strides.assign(t.getRank(), zero); // Canonical case for empty map. - if (!m || m.isIdentity()) { + if (m.isIdentity()) { // 0-D corner case, offset is already 0. if (t.getRank() == 0) return success(); @@ -938,21 +1028,21 @@ /// `t` with simplified layout. /// If `t` has multiple layout maps or a multi-result layout, just return `t`. MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { - auto affineMaps = t.getAffineMaps(); + AffineMap m = t.getLayout().getAffineMap(); + // Already in canonical form. - if (affineMaps.empty()) + if (m.isIdentity()) return t; // Can't reduce to canonical identity form, return in canonical form. - if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) + if (m.getNumResults() > 1) return t; // Corner-case for 0-D affine maps. - auto m = affineMaps[0]; if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { if (auto cst = m.getResult(0).dyn_cast()) if (cst.getValue() == 0) - return MemRefType::Builder(t).setAffineMaps({}); + return MemRefType::Builder(t).setLayout({}); return t; } @@ -970,9 +1060,9 @@ auto simplifiedLayoutExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (expr != simplifiedLayoutExpr) - return MemRefType::Builder(t).setAffineMaps({AffineMap::get( - m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); - return MemRefType::Builder(t).setAffineMaps({}); + return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( + m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); + return MemRefType::Builder(t).setLayout({}); } AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, @@ -1016,8 +1106,9 @@ /// strides. This is used to erase the static layout. MemRefType mlir::eraseStridedLayout(MemRefType t) { auto val = ShapedType::kDynamicStrideOrOffset; - return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( - SmallVector(t.getRank(), val), val, t.getContext())); + return MemRefType::Builder(t).setLayout( + AffineMapAttr::get(makeStridedLinearLayoutMap( + SmallVector(t.getRank(), val), val, t.getContext()))); } AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -185,9 +185,8 @@ /// /// stride-list ::= `[` (dimension (`,` dimension)*)? `]` /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list -/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map -/// layout-specification ::= semi-affine-map-composition | strided-layout -/// memory-space ::= integer-literal /* | TODO: address-space-id */ +/// layout-specification ::= semi-affine-map | strided-layout | attribute +/// memory-space ::= integer-literal | attribute /// Type Parser::parseMemRefType() { llvm::SMLoc loc = getToken().getLoc(); @@ -221,15 +220,10 @@ if (!BaseMemRefType::isValidElementType(elementType)) return emitError(typeLoc, "invalid memref element type"), nullptr; - // Parse semi-affine-map-composition. - SmallVector affineMapComposition; + MemRefLayoutAttrInterface layout; Attribute memorySpace; - unsigned numDims = dimensions.size(); auto parseElt = [&]() -> ParseResult { - AffineMap map; - llvm::SMLoc mapLoc = getToken().getLoc(); - // Check for AffineMap as offset/strides. if (getToken().is(Token::kw_offset)) { int64_t offset; @@ -237,15 +231,17 @@ if (failed(parseStridedLayout(offset, strides))) return failure(); // Construct strided affine map. - map = makeStridedLinearLayoutMap(strides, offset, state.context); + AffineMap map = + makeStridedLinearLayoutMap(strides, offset, state.context); + layout = AffineMapAttr::get(map); } else { - // Either it is AffineMapAttr or memory space attribute. + // Either it is MemRefLayoutAttrInterface or memory space attribute. Attribute attr = parseAttribute(); if (!attr) return failure(); - if (AffineMapAttr affineMapAttr = attr.dyn_cast()) { - map = affineMapAttr.getValue(); + if (attr.isa()) { + layout = attr.cast(); } else if (memorySpace) { return emitError("multiple memory spaces specified in memref type"); } else { @@ -259,15 +255,6 @@ if (memorySpace) return emitError("expected memory space to be last in memref type"); - if (map.getNumDims() != numDims) { - size_t i = affineMapComposition.size(); - return emitError(mapLoc, "memref affine map dimension mismatch between ") - << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) - << " and affine map" << i + 1 << ": " << numDims - << " != " << map.getNumDims(); - } - numDims = map.getNumResults(); - affineMapComposition.push_back(map); return success(); }; @@ -284,8 +271,8 @@ if (isUnranked) return getChecked(loc, elementType, memorySpace); - return getChecked(loc, dimensions, elementType, - affineMapComposition, memorySpace); + return getChecked(loc, dimensions, elementType, layout, + memorySpace); } /// Parse any type except the function type. diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -225,7 +225,7 @@ // memref type is normalized. // TODO: When selective normalization is implemented, handle multiple // results case where some are normalized, some aren't. - if (memrefType.getAffineMaps().empty()) + if (memrefType.getLayout().isIdentity()) resultTypes[operandEn.index()] = memrefType; } }); @@ -269,7 +269,7 @@ if (oldResult.getType() == newResult.getType()) continue; AffineMap layoutMap = - oldResult.getType().dyn_cast().getAffineMaps().front(); + oldResult.getType().cast().getLayout().getAffineMap(); if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, /*extraIndices=*/{}, /*indexRemap=*/layoutMap, @@ -363,7 +363,7 @@ BlockArgument newMemRef = funcOp.front().insertArgument(argIndex, newMemRefType); BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); - AffineMap layoutMap = memrefType.getAffineMaps().front(); + AffineMap layoutMap = memrefType.getLayout().getAffineMap(); // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, /*extraIndices=*/{}, @@ -412,7 +412,7 @@ if (oldMemRefType == newMemRefType) continue; // TODO: Assume single layout map. Multiple maps not supported. - AffineMap layoutMap = oldMemRefType.getAffineMaps().front(); + AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap(); if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, /*extraIndices=*/{}, diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -74,9 +74,7 @@ SmallVector newShape(1 + oldMemRefType.getRank()); newShape[0] = 2; std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); - return MemRefType::Builder(oldMemRefType) - .setShape(newShape) - .setAffineMaps({}); + return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); }; auto oldMemRefType = oldMemRef.getType().cast(); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -2648,9 +2648,7 @@ auto memref = region.memref; auto memRefType = memref.getType().cast(); - auto layoutMaps = memRefType.getAffineMaps(); - if (layoutMaps.size() > 1 || - (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + if (!memRefType.getLayout().isIdentity()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); return failure(); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -647,7 +647,7 @@ Value oldMemRef = allocOp->getResult(); SmallVector symbolOperands(allocOp->symbolOperands()); - AffineMap layoutMap = memrefType.getAffineMaps().front(); + AffineMap layoutMap = memrefType.getLayout().getAffineMap(); memref::AllocOp newAlloc; // Check if `layoutMap` is a tiled layout. Only single layout map is // supported for normalizing dynamic memrefs. @@ -695,13 +695,12 @@ if (rank == 0) return memrefType; - ArrayRef layoutMaps = memrefType.getAffineMaps(); - if (layoutMaps.empty() || - layoutMaps.front() == b.getMultiDimIdentityMap(rank)) { + if (memrefType.getLayout().isIdentity()) { // Either no maps is associated with this memref or this memref has // a trivial (identity) map. return memrefType; } + AffineMap layoutMap = memrefType.getLayout().getAffineMap(); // We don't do any checks for one-to-one'ness; we assume that it is // one-to-one. @@ -710,7 +709,7 @@ // for now. // TODO: Normalize the other types of dynamic memrefs. SmallVector> tileSizePos; - (void)getTileSizePos(layoutMaps.front(), tileSizePos); + (void)getTileSizePos(layoutMap, tileSizePos); if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty()) return memrefType; @@ -731,7 +730,6 @@ } // We compose this map with the original index (logical) space to derive // the upper bounds for the new index space. - AffineMap layoutMap = layoutMaps.front(); unsigned newRank = layoutMap.getNumResults(); if (failed(fac.composeMatchingMap(layoutMap))) return memrefType; @@ -763,7 +761,7 @@ MemRefType newMemRefType = MemRefType::Builder(memrefType) .setShape(newShape) - .setAffineMaps(b.getMultiDimIdentityMap(newRank)); + .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank))); return newMemRefType; } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -726,7 +726,6 @@ MlirType memRef = mlirMemRefTypeContiguousGet( f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2); if (!mlirTypeIsAMemRef(memRef) || - mlirMemRefTypeGetNumAffineMaps(memRef) != 0 || !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2)) return 18; mlirTypeDump(memRef); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1095,7 +1095,7 @@ // ----- func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) { - // expected-error@+1 {{expects operand to be a memref with no layout}} + // expected-error@+1 {{expects operand to be a memref with identity layout}} %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref> } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -104,7 +104,8 @@ func @test_alloc_memref_map_rank_mismatch() { ^bb0: - %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> // expected-error {{memref affine map dimension mismatch}} + // expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}} + %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> return } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -61,13 +61,7 @@ // The error must be emitted even for the trivial identity layout maps that are // dropped in type creation. #map0 = affine_map<(d0, d1) -> (d0, d1)> -func @memrefs(memref<42xi8, #map0>) // expected-error {{memref affine map dimension mismatch}} - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0) -> (d0)> -func @memrefs(memref<42x42xi8, #map0, #map1>) // expected-error {{memref affine map dimension mismatch}} +func @memrefs(memref<42xi8, #map0>) // expected-error {{memref layout mismatch between rank and affine map: 1 != 2}} // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -12,9 +12,6 @@ // CHECK-DAG: #map{{[0-9]+}} = affine_map<(d0, d1, d2) -> (d1, d0, d2)> #map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -// CHECK-DAG: #map{{[0-9]+}} = affine_map<(d0, d1, d2) -> (d2, d1, d0)> -#map4 = affine_map<(d0, d1, d2) -> (d2, d1, d0)> - // CHECK-DAG: #map{{[0-9]+}} = affine_map<()[s0] -> (0, s0 - 1)> #inline_map_minmax_loop1 = affine_map<()[s0] -> (0, s0 - 1)> @@ -80,28 +77,15 @@ // CHECK: func private @tensor_encoding(tensor<16x32xf64, "sparse">) func private @tensor_encoding(tensor<16x32xf64, "sparse">) -// CHECK: func private @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>) -func private @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>) - -// Test memref affine map compositions. +// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ()) +func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->()) // CHECK: func private @memrefs2(memref<2x4x8xi8, 1>) func private @memrefs2(memref<2x4x8xi8, #map2, 1>) -// CHECK: func private @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}>) -func private @memrefs23(memref<2x4x8xi8, #map2, #map3, 0>) - -// CHECK: func private @memrefs234(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 3>) -func private @memrefs234(memref<2x4x8xi8, #map2, #map3, #map4, 3>) - -// Test memref inline affine map compositions, minding that identity maps are removed. - // CHECK: func private @memrefs3(memref<2x4x8xi8>) func private @memrefs3(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>>) -// CHECK: func private @memrefs33(memref<2x4x8xi8, #map{{[0-9]+}}, 1>) -func private @memrefs33(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 1>) - // CHECK: func private @memrefs_drop_triv_id_inline(memref<2xi8>) func private @memrefs_drop_triv_id_inline(memref<2xi8, affine_map<(d0) -> (d0)>>) @@ -111,35 +95,6 @@ // CHECK: func private @memrefs_drop_triv_id_inline1(memref<2xi8, 1>) func private @memrefs_drop_triv_id_inline1(memref<2xi8, affine_map<(d0) -> (d0)>, 1>) -// Identity maps should be dropped from the composition, but not the pair of -// "interchange" maps that, if composed, would be also an identity. -// CHECK: func private @memrefs_drop_triv_id_composition(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>) -func private @memrefs_drop_triv_id_composition(memref<2x2xi8, - affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>>) - -// CHECK: func private @memrefs_drop_triv_id_trailing(memref<2x2xi8, #map{{[0-9]+}}>) -func private @memrefs_drop_triv_id_trailing(memref<2x2xi8, affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0, d1)>>) - -// CHECK: func private @memrefs_drop_triv_id_middle(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>) -func private @memrefs_drop_triv_id_middle(memref<2x2xi8, - affine_map<(d0, d1) -> (d0, d1 + 1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0 + 1, d1)>>) - -// CHECK: func private @memrefs_drop_triv_id_multiple(memref<2xi8>) -func private @memrefs_drop_triv_id_multiple(memref<2xi8, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>>) - -// These maps appeared before, so they must be uniqued and hoisted to the beginning. -// Identity map should be removed. -// CHECK: func private @memrefs_compose_with_id(memref<2x2xi8, #map{{[0-9]+}}>) -func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1, d0)>>) - // Test memref with custom memory space // CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>) @@ -202,9 +157,6 @@ // CHECK: func private @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>) func private @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>) -// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ()) -func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->()) - // CHECK-LABEL: func @simpleCFG(%{{.*}}: i32, %{{.*}}: f32) -> i1 { func @simpleCFG(%arg0: i32, %f: f32) -> i1 { // CHECK: %{{.*}} = "foo"() : () -> i64 diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -372,18 +372,21 @@ memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref) - # CHECK: number of affine layout maps: 0 - print("number of affine layout maps:", len(memref.layout)) + # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)> + print("memref layout:", memref.layout) + # CHECK: memref affine map: (d0, d1) -> (d0, d1) + print("memref affine map:", memref.affine_map) # CHECK: memory space: 2 print("memory space:", memref.memory_space) - layout = AffineMap.get_permutation([1, 0]) - memref_layout = MemRefType.get(shape, f32, [layout]) + layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) + memref_layout = MemRefType.get(shape, f32, layout=layout) # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> print("memref type:", memref_layout) - assert len(memref_layout.layout) == 1 - # CHECK: memref layout: (d0, d1) -> (d1, d0) - print("memref layout:", memref_layout.layout[0]) + # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)> + print("memref layout:", memref_layout.layout) + # CHECK: memref affine map: (d0, d1) -> (d1, d0) + print("memref affine map:", memref_layout.affine_map) # CHECK: memory space: <> print("memory space:", memref_layout.memory_space) diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -32,26 +32,26 @@ ShapedType memrefType = MemRefType::Builder(memrefOriginalShape, memrefOriginalType) .setMemorySpace(memSpace) - .setAffineMaps(map); + .setLayout(AffineMapAttr::get(map)); // Update shape. llvm::SmallVector memrefNewShape({30, 40}); ASSERT_NE(memrefOriginalShape, memrefNewShape); ASSERT_EQ(memrefType.clone(memrefNewShape), (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) .setMemorySpace(memSpace) - .setAffineMaps(map)); + .setLayout(AffineMapAttr::get(map))); // Update type. Type memrefNewType = f32; ASSERT_NE(memrefOriginalType, memrefNewType); ASSERT_EQ(memrefType.clone(memrefNewType), (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) .setMemorySpace(memSpace) - .setAffineMaps(map)); + .setLayout(AffineMapAttr::get(map))); // Update both. ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) .setMemorySpace(memSpace) - .setAffineMaps(map)); + .setLayout(AffineMapAttr::get(map))); // Test unranked memref cloning. ShapedType unrankedTensorType =