diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -229,16 +229,17 @@ /// Creates a MemRef type with the given rank and shape, a potentially empty /// list of affine layout maps, the given memory space and element type, in the /// same context as element type. The type is owned by the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( - MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, MlirAttribute memorySpace); +MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType, + intptr_t rank, + const int64_t *shape, + MlirAttribute layout, + MlirAttribute memorySpace); /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, MlirAffineMap const *affineMaps, - MlirAttribute memorySpace); + MlirAttribute layout, MlirAttribute memorySpace); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, @@ -264,12 +265,11 @@ MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, MlirAttribute memorySpace); -/// Returns the number of affine layout maps in the given MemRef type. -MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); +/// Returns the layout of the given MemRef type. +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type); -/// Returns the pos-th affine map of the given MemRef type. -MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, - intptr_t pos); +/// Returns the affine map of the given MemRef type. +MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type); /// Returns the memory space of the given MemRef type. MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -10,6 +10,7 @@ #define MLIR_IR_BUILTINATTRIBUTES_H #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/MemRefLayoutAttrInterfaces.h" #include "mlir/IR/SubElementInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -16,6 +16,7 @@ include "mlir/IR/BuiltinDialect.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/MemRefLayoutAttrInterfaces.td" include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the attributes defined in this file are prefixed with @@ -34,7 +35,9 @@ // AffineMapAttr //===----------------------------------------------------------------------===// -def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap"> { +def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [ + MemRefLayoutAttrInterface + ]> { let summary = "An Attribute containing an AffineMap object"; let description = [{ Syntax: @@ -56,7 +59,10 @@ return $_get(value.getContext(), value); }]> ]; - let extraClassDeclaration = "using ValueType = AffineMap;"; + let extraClassDeclaration = [{ + using ValueType = AffineMap; + AffineMap getAffineMap() const { return getValue(); } + }]; let skipDefaultBuilders = 1; let typeBuilder = "IndexType::get($_value.getContext())"; } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -209,12 +209,11 @@ // Build from another MemRefType. explicit Builder(MemRefType other) : shape(other.getShape()), elementType(other.getElementType()), - affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()) { - } + layout(other.getLayout()), memorySpace(other.getMemorySpace()) {} // Build from scratch. Builder(ArrayRef shape, Type elementType) - : shape(shape), elementType(elementType), affineMaps() {} + : shape(shape), elementType(elementType) {} Builder &setShape(ArrayRef newShape) { shape = newShape; @@ -226,11 +225,13 @@ return *this; } - Builder &setAffineMaps(ArrayRef newAffineMaps) { - affineMaps = newAffineMaps; + Builder &setLayout(Attribute newLayout) { + layout = newLayout; return *this; } + Builder &setAffineMap(AffineMap newAffineMap); + Builder &setMemorySpace(Attribute newMemorySpace) { memorySpace = newMemorySpace; return *this; @@ -240,13 +241,13 @@ Builder &setMemorySpace(unsigned newMemorySpace); operator MemRefType() { - return MemRefType::get(shape, elementType, affineMaps, memorySpace); + return MemRefType::get(shape, elementType, layout, memorySpace); } private: ArrayRef shape; Type elementType; - ArrayRef affineMaps; + Attribute layout; Attribute memorySpace; }; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -278,8 +278,7 @@ stride-list ::= `[` (dimension (`,` dimension)*)? `]` strided-layout ::= `offset:` dimension `,` `strides: ` stride-list - semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map - layout-specification ::= semi-affine-map-composition | strided-layout + layout-specification ::= semi-affine-map | strided-layout | attribute-value memory-space ::= attribute-value ``` @@ -486,27 +485,6 @@ #layout_tiled = (i, j) -> (i floordiv 64, j floordiv 64, i mod 64, j mod 64) ``` - ##### Affine Map Composition - - A memref specifies a semi-affine map composition as part of its type. A - semi-affine map composition is a composition of semi-affine maps beginning - with zero or more index maps, and ending with a layout map. The composition - must be conformant: the number of dimensions of the range of one map, must - match the number of dimensions of the domain of the next map in the - composition. - - The semi-affine map composition specified in the memref type, maps from - accesses used to index the memref in load/store operations to other index - spaces (i.e. logical to physical index mapping). Each of the - [semi-affine maps](Affine.md/#semi-affine-maps) and thus its composition is required - to be one-to-one. - - The semi-affine map composition can be used in dependence analysis, memory - access pattern analysis, and for performance optimizations like - vectorization, copy elision and in-place updates. If an affine map - composition is not specified for the memref, the identity affine map is - assumed. - ##### Strided MemRef A memref may specify a strided layout as part of its type. A stride @@ -544,35 +522,38 @@ let parameters = (ins ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, - ArrayRefParameter<"AffineMap">:$affineMaps, + "Attribute":$layout, "Attribute":$memorySpace ); let builders = [ TypeBuilderWithInferredContext<(ins "ArrayRef":$shape, "Type":$elementType, - CArg<"ArrayRef", "{}">:$affineMaps, + CArg<"Attribute", "{}">:$layout, CArg<"Attribute", "{}">:$memorySpace ), [{ - // Drop identity maps from the composition. This may lead to the - // composition becoming empty, which is interpreted as an implicit - // identity. - auto nonIdentityMaps = llvm::make_filter_range(affineMaps, - [](AffineMap map) { return !map.isIdentity(); }); // Drop default memory space value and replace it with empty attribute. Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace); - return $_get(elementType.getContext(), shape, elementType, - llvm::to_vector<4>(nonIdentityMaps), nonDefaultMemorySpace); + return $_get(elementType.getContext(), shape, elementType, layout, + nonDefaultMemorySpace); + }]>, + TypeBuilderWithInferredContext<(ins + "ArrayRef":$shape, "Type":$elementType, + CArg<"AffineMap">:$map, + CArg<"Attribute", "{}">:$memorySpace + ), [{ + Attribute layout = map ? AffineMapAttr::get(map) : nullptr; + return MemRefType::get(shape, elementType, layout, memorySpace); }]>, /// [deprecated] `Attribute`-based form should be used instead. TypeBuilderWithInferredContext<(ins "ArrayRef":$shape, "Type":$elementType, - "ArrayRef":$affineMaps, + "AffineMap":$map, "unsigned":$memorySpace ), [{ // Convert deprecated integer-like memory space to Attribute. Attribute memorySpaceAttr = wrapIntegerMemorySpace(memorySpace, elementType.getContext()); - return MemRefType::get(shape, elementType, affineMaps, memorySpaceAttr); + return MemRefType::get(shape, elementType, map, memorySpaceAttr); }]> ]; let extraClassDeclaration = [{ @@ -590,9 +571,13 @@ static int64_t getDynamicStrideOrOffset() { return ShapedType::kDynamicStrideOrOffset; } + + AffineMap getAffineMap() const; + bool hasIdentityLayout() const; }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; + let genStorageClass = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -49,6 +49,11 @@ mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs) add_public_tablegen_target(MLIRTensorEncodingIncGen) +set(LLVM_TARGET_DEFINITIONS MemRefLayoutAttrInterfaces.td) +mlir_tablegen(MemRefLayoutAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(MemRefLayoutAttrInterfaces.cpp.inc -gen-attr-interface-defs) +add_public_tablegen_target(MLIRMemRefLayoutAttrInterfacesIncGen) + add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc) add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc) add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/IR/MemRefLayoutAttrInterfaces.h b/mlir/include/mlir/IR/MemRefLayoutAttrInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/MemRefLayoutAttrInterfaces.h @@ -0,0 +1,31 @@ +//===- MemRefLayoutAttrInterfaces.h - MemRef layout interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_MEMREFLAYOUTATTRINTERFACES_H +#define MLIR_IR_MEMREFLAYOUTATTRINTERFACES_H + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" + +namespace mlir { + +class MemRefType; + +namespace detail { + +LogicalResult +verifyAffineMapAsLayout(AffineMap m, ArrayRef shape, + function_ref emitError); + +} // end namespace detail + +} // end namespace mlir + +#include "mlir/IR/MemRefLayoutAttrInterfaces.h.inc" + +#endif // MLIR_IR_MEMREFLAYOUTATTRINTERFACES_H diff --git a/mlir/include/mlir/IR/MemRefLayoutAttrInterfaces.td b/mlir/include/mlir/IR/MemRefLayoutAttrInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/MemRefLayoutAttrInterfaces.td @@ -0,0 +1,40 @@ +//===-- MemRefLayoutAttrInterfaces.td - MemRef layout interfaces -*- td -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_MEMREFLAYOUTATTRINTERFACES_TD +#define MLIR_IR_MEMREFLAYOUTATTRINTERFACES_TD + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// MemRefLayoutAttrInterface +//===----------------------------------------------------------------------===// + +def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> { + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + "Get the MemRef layout as an AffineMap", + "::mlir::AffineMap", "getAffineMap" + >, + + InterfaceMethod< + "Check if the current layout is applicable to the provided shape", + "::mlir::LogicalResult", "verifyLayout", + (ins "::llvm::ArrayRef":$shape, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError), + [{}], + [{ + return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(), shape, emitError); + }] + > + ]; +} + +#endif // MLIR_IR_MEMREFLAYOUTATTRINTERFACES_TD diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -219,13 +219,9 @@ assert(memRefDim && "memRefDim == nullptr"); auto memRefType = memoryOp.getMemRefType(); - auto layoutMap = memRefType.getAffineMaps(); // TODO: remove dependence on Builder once we support non-identity layout map. Builder b(memoryOp.getContext()); - if (layoutMap.size() >= 2 || - (layoutMap.size() == 1 && - !(layoutMap[0] == - b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) { + if (!memRefType.hasIdentityLayout()) { return memoryOp.emitError("NYI: non-trivial layoutMap"), false; } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -615,9 +615,7 @@ Optional MemRefRegion::getRegionSize() { auto memRefType = memref.getType().cast(); - auto layoutMaps = memRefType.getAffineMaps(); - if (layoutMaps.size() > 1 || - (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + if (!memRefType.hasIdentityLayout()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); return false; } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -401,8 +401,6 @@ } }; -class PyMemRefLayoutMapList; - /// Ranked MemRef Type subclass - MemRefType. class PyMemRefType : public PyConcreteType { public: @@ -410,26 +408,23 @@ static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; - PyMemRefLayoutMapList getLayout(); - static void bindDerived(ClassTy &c) { c.def_static( "get", [](std::vector shape, PyType &elementType, - std::vector layout, PyAttribute *memorySpace, + PyAttribute *layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { - SmallVector maps; - maps.reserve(layout.size()); - for (PyAffineMap &map : layout) - maps.push_back(map); + MlirAttribute layoutAttr = {}; + if (layout) + layoutAttr = *layout; MlirAttribute memSpaceAttr = {}; if (memorySpace) memSpaceAttr = *memorySpace; - MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), maps.size(), - maps.data(), memSpaceAttr); + MlirType t = + mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), layoutAttr, memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -444,10 +439,22 @@ return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), + py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly("layout", &PyMemRefType::getLayout, - "The list of layout maps of the MemRef type.") + .def_property_readonly( + "layout", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute layout = mlirMemRefTypeGetLayout(self); + return PyAttribute(self.getContext(), layout); + }, + "The layout of the MemRef type.") + .def_property_readonly( + "affine_map", + [](PyMemRefType &self) -> PyAffineMap { + MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); + return PyAffineMap(self.getContext(), map); + }, + "The affine map of the MemRef type.") .def_property_readonly( "memory_space", [](PyMemRefType &self) -> PyAttribute { @@ -458,41 +465,6 @@ } }; -/// A list of affine layout maps in a memref type. Internally, these are stored -/// as consecutive elements, random access is cheap. Both the type and the maps -/// are owned by the context, no need to worry about lifetime extension. -class PyMemRefLayoutMapList - : public Sliceable { -public: - static constexpr const char *pyClassName = "MemRefLayoutMapList"; - - PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, - step), - memref(type) {} - - intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } - - PyAffineMap getElement(intptr_t index) { - return PyAffineMap(memref.getContext(), - mlirMemRefTypeGetAffineMap(memref, index)); - } - - PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyMemRefLayoutMapList(memref, startIndex, length, step); - } - -private: - PyMemRefType memref; -}; - -PyMemRefLayoutMapList PyMemRefType::getLayout() { - return PyMemRefLayoutMapList(*this); -} - /// Unranked MemRef Type subclass - UnrankedMemRefType. class PyUnrankedMemRefType : public PyConcreteType { @@ -640,7 +612,6 @@ PyRankedTensorType::bind(m); PyUnrankedTensorType::bind(m); PyMemRefType::bind(m); - PyMemRefLayoutMapList::bind(m); PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); PyFunctionType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -226,26 +226,20 @@ bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, - const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, + const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace) { - SmallVector maps; - (void)unwrapList(numMaps, affineMaps, maps); - return wrap( - MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, unwrap(memorySpace))); + return wrap(MemRefType::get( + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + unwrap(layout), unwrap(memorySpace))); } MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, - MlirAffineMap const *affineMaps, + MlirAttribute layout, MlirAttribute memorySpace) { - SmallVector maps; - (void)unwrapList(numMaps, affineMaps, maps); return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, unwrap(memorySpace))); + unwrap(elementType), unwrap(layout), unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, @@ -253,7 +247,7 @@ MlirAttribute memorySpace) { return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, unwrap(memorySpace))); + unwrap(elementType), Attribute(), unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, @@ -262,16 +256,15 @@ MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, unwrap(memorySpace))); + unwrap(elementType), Attribute(), unwrap(memorySpace))); } -intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { - return static_cast( - unwrap(type).cast().getAffineMaps().size()); +MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { + return wrap(unwrap(type).cast().getLayout()); } -MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) { - return wrap(unwrap(type).cast().getAffineMaps()[pos]); +MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { + return wrap(unwrap(type).cast().getAffineMap()); } MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -106,9 +106,7 @@ MemRefType type) const { if (!typeConverter->convertType(type.getElementType())) return false; - return type.getAffineMaps().empty() || - llvm::all_of(type.getAffineMaps(), - [](AffineMap map) { return map.isIdentity(); }); + return type.hasIdentityLayout(); } Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1142,7 +1142,7 @@ ConversionPatternRewriter &rewriter) const override { MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); - if (!srcType.getAffineMaps().empty() || !dstType.getAffineMaps().empty()) { + if (!srcType.hasIdentityLayout() || !dstType.hasIdentityLayout()) { return rewriter.notifyMatchFailure(reshapeOp, "only empty layout map is supported"); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -949,8 +949,7 @@ if (!strides.empty() && strides.back() != 1) return None; // If no layout or identity layout, this is contiguous by definition. - if (memRefType.getAffineMaps().empty() || - memRefType.getAffineMaps().front().isIdentity()) + if (memRefType.hasIdentityLayout()) return strides; // Otherwise, we must determine contiguity form shapes. This can only ever diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1045,8 +1045,7 @@ auto srcMemrefType = srcType.cast(); auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); - if (!srcMemrefType.getAffineMaps().empty() && - !srcMemrefType.getAffineMaps().front().isIdentity()) + if (!srcMemrefType.hasIdentityLayout()) return op.emitError("expected identity layout map for source memref"); if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && @@ -1072,9 +1071,7 @@ auto srcMatrixType = srcType.cast(); auto dstMemrefType = dstType.cast(); auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); - - if (!dstMemrefType.getAffineMaps().empty() && - !dstMemrefType.getAffineMaps().front().isIdentity()) + if (!dstMemrefType.hasIdentityLayout()) return op.emitError("expected identity layout map for destination memref"); if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -156,9 +156,8 @@ Value createWorkgroupBuffer() { int workgroupMemoryAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); - auto bufferType = - MemRefType::get({kSubgroupSize}, valueType, ArrayRef{}, - workgroupMemoryAddressSpace); + auto bufferType = MemRefType::get({kSubgroupSize}, valueType, {}, + workgroupMemoryAddressSpace); return funcOp.addWorkgroupAttribution(bufferType); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1239,7 +1239,7 @@ /// with the same shape as `shapedType` and specified `layout` and /// `addressSpace`. static MemRefType getContiguousMemRefType(ShapedType shapedType, - ArrayRef layout = {}, + AffineMap layout = {}, unsigned addressSpace = 0) { if (RankedTensorType tensorType = shapedType.dyn_cast()) return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), @@ -1252,13 +1252,12 @@ /// Return a contiguous MemRefType (i.e. with canonical/empty layout map) /// with the same shape as `shapedType` and specified `layout` and /// `addressSpace` or an UnrankedMemRefType otherwise. -static Type getContiguousOrUnrankedMemRefType(Type type, - ArrayRef layout = {}, +static Type getContiguousOrUnrankedMemRefType(Type type, AffineMap layout = {}, unsigned addressSpace = 0) { if (type.isa()) return getContiguousMemRefType(type.cast(), layout, addressSpace); - assert(layout.empty() && "expected empty layout with UnrankedMemRefType"); + assert(!layout && "expected empty layout with UnrankedMemRefType"); return UnrankedMemRefType::get(getElementTypeOrSelf(type), addressSpace); } @@ -1677,12 +1676,11 @@ ? rankedMemRefType.getMemorySpaceAsInt() : unrankedMemRefType.getMemorySpaceAsInt(); TensorType tensorType = castOp.getResult().getType().cast(); - ArrayRef affineMaps = - rankedMemRefType && tensorType.isa() - ? rankedMemRefType.getAffineMaps() - : ArrayRef{}; + AffineMap affineMap = rankedMemRefType && tensorType.isa() + ? rankedMemRefType.getAffineMap() + : AffineMap(); Type memRefType = getContiguousOrUnrankedMemRefType( - castOp.getResult().getType(), affineMaps, memorySpace); + castOp.getResult().getType(), affineMap, memorySpace); Value res = b.create(castOp.getLoc(), memRefType, newBuffer); aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); map(bvm, castOp.getResult(), res); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -261,7 +261,7 @@ // leave them unchanged. Type actualType = opOperand->get().getType(); if (auto memref = actualType.dyn_cast()) { - if (!memref.getAffineMaps().empty()) + if (!memref.hasIdentityLayout()) return llvm::None; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -83,8 +83,8 @@ "dynamic dimension count"); unsigned numSymbols = 0; - if (!memRefType.getAffineMaps().empty()) - numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); + if (!memRefType.hasIdentityLayout()) + numSymbols = memRefType.getAffineMap().getNumSymbols(); if (op.symbolOperands().size() != numSymbols) return op.emitOpError("symbol operand count does not equal memref symbol " "count: expected ") @@ -490,7 +490,7 @@ if (aT && bT) { if (aT.getElementType() != bT.getElementType()) return false; - if (aT.getAffineMaps() != bT.getAffineMaps()) { + if (aT.getAffineMap() != bT.getAffineMap()) { int64_t aOffset, bOffset; SmallVector aStrides, bStrides; if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || @@ -1402,7 +1402,7 @@ // Match offset and strides in static_offset and static_strides attributes if // result memref type has an affine map specified. - if (!resultType.getAffineMaps().empty()) { + if (!resultType.hasIdentityLayout()) { int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) @@ -1520,8 +1520,8 @@ } // Early-exit: if `type` is contiguous, the result must be contiguous. - if (canonicalizeStridedLayout(type).getAffineMaps().empty()) - return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); + if (canonicalizeStridedLayout(type).hasIdentityLayout()) + return MemRefType::Builder(type).setShape(newSizes).setAffineMap({}); // Convert back to int64_t because we don't have enough information to create // new strided layouts from AffineExpr only. This corresponds to a case where @@ -1540,7 +1540,7 @@ auto layout = makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); return canonicalizeStridedLayout( - MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); + MemRefType::Builder(type).setShape(newSizes).setAffineMap(layout)); } void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, @@ -1656,14 +1656,14 @@ "types should be the same"); if (auto operandMemRefType = operandType.dyn_cast()) - if (!operandMemRefType.getAffineMaps().empty()) + if (!operandMemRefType.hasIdentityLayout()) return op.emitOpError( "source memref type should have identity affine map"); int64_t shapeSize = op.shape().getType().cast().getDimSize(0); auto resultMemRefType = resultType.dyn_cast(); if (resultMemRefType) { - if (!resultMemRefType.getAffineMaps().empty()) + if (!resultMemRefType.hasIdentityLayout()) return op.emitOpError( "result memref type should have identity affine map"); if (shapeSize == ShapedType::kDynamicSize) @@ -1818,10 +1818,9 @@ if (!dimsToProject.contains(pos)) projectedShape.push_back(shape[pos]); - AffineMap map; - auto maps = inferredType.getAffineMaps(); - if (!maps.empty() && maps.front()) - map = getProjectedMap(maps.front(), dimsToProject); + AffineMap map = inferredType.getAffineMap(); + if (map) + map = getProjectedMap(map, dimsToProject); inferredType = MemRefType::get(projectedShape, inferredType.getElementType(), map, inferredType.getMemorySpace()); @@ -2272,7 +2271,7 @@ auto map = makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); map = permutationMap ? map.compose(permutationMap) : map; - return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); + return MemRefType::Builder(memRefType).setShape(sizes).setAffineMap(map); } void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, @@ -2380,15 +2379,11 @@ auto viewType = op.getType(); // The base memref should have identity layout map (or none). - if (baseType.getAffineMaps().size() > 1 || - (baseType.getAffineMaps().size() == 1 && - !baseType.getAffineMaps()[0].isIdentity())) + if (!baseType.hasIdentityLayout()) return op.emitError("unsupported map for base memref type ") << baseType; // The result memref should have identity layout map (or none). - if (viewType.getAffineMaps().size() > 1 || - (viewType.getAffineMaps().size() == 1 && - !viewType.getAffineMaps()[0].isIdentity())) + if (!viewType.hasIdentityLayout()) return op.emitError("unsupported map for result memref type ") << viewType; // The base memref and the view memref should be in the same memory space. diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -3656,15 +3656,15 @@ VectorType vectorType = VectorType::get(extractShape(memRefType), getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); - result.addTypes( - MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace())); + result.addTypes(MemRefType::get({}, vectorType, Attribute(), + memRefType.getMemorySpace())); } static LogicalResult verify(TypeCastOp op) { MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType()); - if (!canonicalType.getAffineMaps().empty()) + if (!canonicalType.hasIdentityLayout()) return op.emitOpError("expects operand to be a memref with no layout"); - if (!op.getResultMemRefType().getAffineMaps().empty()) + if (!op.getResultMemRefType().hasIdentityLayout()) return op.emitOpError("expects result to be a memref with no layout"); if (op.getResultMemRefType().getMemorySpace() != op.getMemRefType().getMemorySpace()) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1966,14 +1966,13 @@ os << 'x'; } printType(memrefTy.getElementType()); - for (auto map : memrefTy.getAffineMaps()) { + if (Attribute layout = memrefTy.getLayout()) { os << ", "; - printAttribute(AffineMapAttr::get(map)); + printAttribute(layout, AttrTypeElision::May); } - // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) { + if (Attribute memSpace = memrefTy.getMemorySpace()) { os << ", "; - printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); + printAttribute(memSpace, AttrTypeElision::May); } os << '>'; }) diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/MemRefLayoutAttrInterfaces.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/BitVector.h" @@ -635,6 +636,11 @@ return static_cast(memorySpace.cast().getInt()); } +MemRefType::Builder &MemRefType::Builder::setAffineMap(AffineMap newAffineMap) { + setLayout(newAffineMap ? AffineMapAttr::get(newAffineMap) : nullptr); + return *this; +} + MemRefType::Builder & MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { memorySpace = @@ -642,14 +648,33 @@ return *this; } +ArrayRef MemRefType::getShape() const { return getImpl()->shape; } + +Type MemRefType::getElementType() const { return getImpl()->elementType; } + +Attribute MemRefType::getLayout() const { return getImpl()->layout; } + +Attribute MemRefType::getMemorySpace() const { return getImpl()->memorySpace; } + +AffineMap MemRefType::getAffineMap() const { + if (auto layout = getLayout().dyn_cast_or_null()) + return layout.getAffineMap(); + return {}; +} + +bool MemRefType::hasIdentityLayout() const { + if (AffineMap m = getAffineMap()) + return m.isIdentity(); + return true; +} + unsigned MemRefType::getMemorySpaceAsInt() const { return detail::getMemorySpaceAsInt(getMemorySpace()); } LogicalResult MemRefType::verify(function_ref emitError, ArrayRef shape, Type elementType, - ArrayRef affineMapComposition, - Attribute memorySpace) { + Attribute layout, Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError() << "invalid memref element type"; @@ -658,21 +683,14 @@ if (s < -1) return emitError() << "invalid memref size"; - // Check that the structure of the composition is valid, i.e. that each - // subsequent affine map has as many inputs as the previous map has results. - // Take the dimensionality of the MemRef for the first map. - size_t dim = shape.size(); - for (auto it : llvm::enumerate(affineMapComposition)) { - AffineMap map = it.value(); - if (map.getNumDims() == dim) { - dim = map.getNumResults(); - continue; + if (layout) { + auto layoutIface = layout.dyn_cast(); + if (!layoutIface) { + return emitError() + << "Layout attribute doesn't implement MemRefLayoutAttrInterface"; } - return emitError() << "memref affine map dimension mismatch between " - << (it.index() == 0 ? Twine("memref rank") - : "affine map " + Twine(it.index())) - << " and affine map" << it.index() + 1 << ": " << dim - << " != " << map.getNumDims(); + if (failed(layoutIface.verifyLayout(shape, emitError))) + return failure(); } if (!isSupportedMemorySpace(memorySpace)) { @@ -686,9 +704,8 @@ function_ref walkAttrsFn, function_ref walkTypesFn) const { walkTypesFn(getElementType()); + walkAttrsFn(getLayout()); walkAttrsFn(getMemorySpace()); - for (AffineMap map : getAffineMaps()) - walkAttrsFn(AffineMapAttr::get(map)); } //===----------------------------------------------------------------------===// @@ -775,16 +792,11 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t, SmallVectorImpl &strides, AffineExpr &offset) { - auto affineMaps = t.getAffineMaps(); + AffineMap m = t.getAffineMap(); - if (affineMaps.size() > 1) + if (m && m.getNumResults() != 1) return failure(); - if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1) - return failure(); - - AffineMap m = affineMaps.empty() ? AffineMap() : affineMaps.back(); - auto zero = getAffineConstantExpr(0, t.getContext()); auto one = getAffineConstantExpr(1, t.getContext()); offset = zero; @@ -938,21 +950,21 @@ /// `t` with simplified layout. /// If `t` has multiple layout maps or a multi-result layout, just return `t`. MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { - auto affineMaps = t.getAffineMaps(); // Already in canonical form. - if (affineMaps.empty()) + if (t.hasIdentityLayout()) return t; + AffineMap m = t.getAffineMap(); + // Can't reduce to canonical identity form, return in canonical form. - if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) + if (m.getNumResults() > 1) return t; // Corner-case for 0-D affine maps. - auto m = affineMaps[0]; if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { if (auto cst = m.getResult(0).dyn_cast()) if (cst.getValue() == 0) - return MemRefType::Builder(t).setAffineMaps({}); + return MemRefType::Builder(t).setAffineMap({}); return t; } @@ -970,9 +982,9 @@ auto simplifiedLayoutExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (expr != simplifiedLayoutExpr) - return MemRefType::Builder(t).setAffineMaps({AffineMap::get( - m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); - return MemRefType::Builder(t).setAffineMaps({}); + return MemRefType::Builder(t).setAffineMap(AffineMap::get( + m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)); + return MemRefType::Builder(t).setAffineMap({}); } AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, @@ -1016,7 +1028,7 @@ /// strides. This is used to erase the static layout. MemRefType mlir::eraseStridedLayout(MemRefType t) { auto val = ShapedType::kDynamicStrideOrOffset; - return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( + return MemRefType::Builder(t).setAffineMap(makeStridedLinearLayoutMap( SmallVector(t.getRank(), val), val, t.getContext())); } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -16,6 +16,7 @@ FunctionSupport.cpp IntegerSet.cpp Location.cpp + MemRefLayoutAttrInterfaces.cpp MLIRContext.cpp Operation.cpp OperationSupport.cpp @@ -25,8 +26,8 @@ SubElementInterfaces.cpp SymbolTable.cpp TensorEncoding.cpp - Types.cpp TypeRange.cpp + Types.cpp TypeUtilities.cpp Value.cpp Verifier.cpp @@ -52,6 +53,7 @@ MLIRSubElementInterfacesIncGen MLIRSymbolInterfacesIncGen MLIRTensorEncodingIncGen + MLIRMemRefLayoutAttrInterfacesIncGen LINK_LIBS PUBLIC MLIRSupport diff --git a/mlir/lib/IR/MemRefLayoutAttrInterfaces.cpp b/mlir/lib/IR/MemRefLayoutAttrInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/MemRefLayoutAttrInterfaces.cpp @@ -0,0 +1,29 @@ +//===- MemRefLayoutAttrInterfaces.cpp - MemRef layout interfaces ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/MemRefLayoutAttrInterfaces.h" +#include "mlir/IR/Diagnostics.h" + +using namespace mlir; + +LogicalResult mlir::detail::verifyAffineMapAsLayout( + AffineMap m, ArrayRef shape, + function_ref emitError) { + if (m.getNumDims() != shape.size()) { + return emitError() << "memref layout mismatch between rank and affine map: " + << shape.size() << " != " << m.getNumDims(); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// MemRefLayoutAttrInterfaces Tablegen definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/MemRefLayoutAttrInterfaces.cpp.inc" diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -13,6 +13,7 @@ #define TYPEDETAIL_H_ #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/MLIRContext.h" @@ -132,6 +133,62 @@ unsigned numElements; }; +struct MemRefTypeStorage final : public TypeStorage { + using KeyTy = std::tuple, Type, Attribute, Attribute>; + + static Attribute skipDefaultLayout(Attribute layout) { + if (AffineMapAttr attr = layout.dyn_cast_or_null()) { + if (attr.getValue().isIdentity()) { + return nullptr; + } + } + return layout; + } + + MemRefTypeStorage(ArrayRef shape, Type elementType, Attribute layout, + Attribute memorySpace) + : shape(shape), elementType(elementType), layout(layout), + memorySpace(memorySpace) {} + + bool operator==(const KeyTy &key) const { + if (shape != std::get<0>(key)) + return false; + if (elementType != std::get<1>(key)) + return false; + if (layout != skipDefaultLayout(std::get<2>(key))) + return false; + if (memorySpace != std::get<3>(key)) + return false; + return true; + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return hash_combine(std::get<0>(key), std::get<1>(key), + skipDefaultLayout(std::get<2>(key)), std::get<3>(key)); + } + + /// Define a construction method for creating a new instance of this + /// storage. + static MemRefTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + auto shape = std::get<0>(key); + auto elementType = std::get<1>(key); + auto layout = std::get<2>(key); + auto memorySpace = std::get<3>(key); + + shape = allocator.copyInto(shape); + layout = skipDefaultLayout(layout); + + return new (allocator.allocate()) + MemRefTypeStorage(shape, elementType, layout, memorySpace); + } + + ArrayRef shape; + Type elementType; + Attribute layout; + Attribute memorySpace; +}; + /// Checks if the memorySpace has supported Attribute type. bool isSupportedMemorySpace(Attribute memorySpace); diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -185,9 +185,8 @@ /// /// stride-list ::= `[` (dimension (`,` dimension)*)? `]` /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list -/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map -/// layout-specification ::= semi-affine-map-composition | strided-layout -/// memory-space ::= integer-literal /* | TODO: address-space-id */ +/// layout-specification ::= semi-affine-map | strided-layout | attribute +/// memory-space ::= integer-literal | attribute /// Type Parser::parseMemRefType() { llvm::SMLoc loc = getToken().getLoc(); @@ -221,15 +220,10 @@ if (!BaseMemRefType::isValidElementType(elementType)) return emitError(typeLoc, "invalid memref element type"), nullptr; - // Parse semi-affine-map-composition. - SmallVector affineMapComposition; + Attribute layout; Attribute memorySpace; - unsigned numDims = dimensions.size(); auto parseElt = [&]() -> ParseResult { - AffineMap map; - llvm::SMLoc mapLoc = getToken().getLoc(); - // Check for AffineMap as offset/strides. if (getToken().is(Token::kw_offset)) { int64_t offset; @@ -237,15 +231,17 @@ if (failed(parseStridedLayout(offset, strides))) return failure(); // Construct strided affine map. - map = makeStridedLinearLayoutMap(strides, offset, state.context); + AffineMap map = + makeStridedLinearLayoutMap(strides, offset, state.context); + layout = AffineMapAttr::get(map); } else { - // Either it is AffineMapAttr or memory space attribute. + // Either it is MemRefLayoutAttrInterface or memory space attribute. Attribute attr = parseAttribute(); if (!attr) return failure(); - if (AffineMapAttr affineMapAttr = attr.dyn_cast()) { - map = affineMapAttr.getValue(); + if (attr.isa()) { + layout = attr; } else if (memorySpace) { return emitError("multiple memory spaces specified in memref type"); } else { @@ -259,15 +255,6 @@ if (memorySpace) return emitError("expected memory space to be last in memref type"); - if (map.getNumDims() != numDims) { - size_t i = affineMapComposition.size(); - return emitError(mapLoc, "memref affine map dimension mismatch between ") - << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) - << " and affine map" << i + 1 << ": " << numDims - << " != " << map.getNumDims(); - } - numDims = map.getNumResults(); - affineMapComposition.push_back(map); return success(); }; @@ -284,8 +271,8 @@ if (isUnranked) return getChecked(loc, elementType, memorySpace); - return getChecked(loc, dimensions, elementType, - affineMapComposition, memorySpace); + return getChecked(loc, dimensions, elementType, layout, + memorySpace); } /// Parse any type except the function type. diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -225,7 +225,7 @@ // memref type is normalized. // TODO: When selective normalization is implemented, handle multiple // results case where some are normalized, some aren't. - if (memrefType.getAffineMaps().empty()) + if (memrefType.hasIdentityLayout()) resultTypes[operandEn.index()] = memrefType; } }); @@ -269,7 +269,7 @@ if (oldResult.getType() == newResult.getType()) continue; AffineMap layoutMap = - oldResult.getType().dyn_cast().getAffineMaps().front(); + oldResult.getType().dyn_cast().getAffineMap(); if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, /*extraIndices=*/{}, /*indexRemap=*/layoutMap, @@ -363,7 +363,7 @@ BlockArgument newMemRef = funcOp.front().insertArgument(argIndex, newMemRefType); BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); - AffineMap layoutMap = memrefType.getAffineMaps().front(); + AffineMap layoutMap = memrefType.getAffineMap(); // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, /*extraIndices=*/{}, @@ -412,7 +412,7 @@ if (oldMemRefType == newMemRefType) continue; // TODO: Assume single layout map. Multiple maps not supported. - AffineMap layoutMap = oldMemRefType.getAffineMaps().front(); + AffineMap layoutMap = oldMemRefType.getAffineMap(); if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, /*extraIndices=*/{}, diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -76,7 +76,7 @@ std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); return MemRefType::Builder(oldMemRefType) .setShape(newShape) - .setAffineMaps({}); + .setAffineMap({}); }; auto oldMemRefType = oldMemRef.getType().cast(); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -2639,9 +2639,7 @@ auto memref = region.memref; auto memRefType = memref.getType().cast(); - auto layoutMaps = memRefType.getAffineMaps(); - if (layoutMaps.size() > 1 || - (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + if (!memRefType.hasIdentityLayout()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); return failure(); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -646,7 +646,7 @@ Value oldMemRef = allocOp->getResult(); SmallVector symbolOperands(allocOp->symbolOperands()); - AffineMap layoutMap = memrefType.getAffineMaps().front(); + AffineMap layoutMap = memrefType.getAffineMap(); memref::AllocOp newAlloc; // Check if `layoutMap` is a tiled layout. Only single layout map is // supported for normalizing dynamic memrefs. @@ -694,13 +694,12 @@ if (rank == 0) return memrefType; - ArrayRef layoutMaps = memrefType.getAffineMaps(); - if (layoutMaps.empty() || - layoutMaps.front() == b.getMultiDimIdentityMap(rank)) { + if (memrefType.hasIdentityLayout()) { // Either no maps is associated with this memref or this memref has // a trivial (identity) map. return memrefType; } + AffineMap layoutMap = memrefType.getAffineMap(); // We don't do any checks for one-to-one'ness; we assume that it is // one-to-one. @@ -709,7 +708,7 @@ // for now. // TODO: Normalize the other types of dynamic memrefs. SmallVector> tileSizePos; - (void)getTileSizePos(layoutMaps.front(), tileSizePos); + (void)getTileSizePos(layoutMap, tileSizePos); if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty()) return memrefType; @@ -730,7 +729,6 @@ } // We compose this map with the original index (logical) space to derive // the upper bounds for the new index space. - AffineMap layoutMap = layoutMaps.front(); unsigned newRank = layoutMap.getNumResults(); if (failed(fac.composeMatchingMap(layoutMap))) return memrefType; @@ -762,7 +760,7 @@ MemRefType newMemRefType = MemRefType::Builder(memrefType) .setShape(newShape) - .setAffineMaps(b.getMultiDimIdentityMap(newRank)); + .setAffineMap(b.getMultiDimIdentityMap(newRank)); return newMemRefType; } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -726,7 +726,6 @@ MlirType memRef = mlirMemRefTypeContiguousGet( f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2); if (!mlirTypeIsAMemRef(memRef) || - mlirMemRefTypeGetNumAffineMaps(memRef) != 0 || !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2)) return 18; mlirTypeDump(memRef); diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -128,7 +128,8 @@ func @test_alloc_memref_map_rank_mismatch() { ^bb0: - %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> // expected-error {{memref affine map dimension mismatch}} + // expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}} + %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> return } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -61,13 +61,7 @@ // The error must be emitted even for the trivial identity layout maps that are // dropped in type creation. #map0 = affine_map<(d0, d1) -> (d0, d1)> -func @memrefs(memref<42xi8, #map0>) // expected-error {{memref affine map dimension mismatch}} - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0) -> (d0)> -func @memrefs(memref<42x42xi8, #map0, #map1>) // expected-error {{memref affine map dimension mismatch}} +func @memrefs(memref<42xi8, #map0>) // expected-error {{memref layout mismatch between rank and affine map: 1 != 2}} // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -12,9 +12,6 @@ // CHECK-DAG: #map{{[0-9]+}} = affine_map<(d0, d1, d2) -> (d1, d0, d2)> #map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -// CHECK-DAG: #map{{[0-9]+}} = affine_map<(d0, d1, d2) -> (d2, d1, d0)> -#map4 = affine_map<(d0, d1, d2) -> (d2, d1, d0)> - // CHECK-DAG: #map{{[0-9]+}} = affine_map<()[s0] -> (0, s0 - 1)> #inline_map_minmax_loop1 = affine_map<()[s0] -> (0, s0 - 1)> @@ -80,28 +77,19 @@ // CHECK: func private @tensor_encoding(tensor<16x32xf64, "sparse">) func private @tensor_encoding(tensor<16x32xf64, "sparse">) -// CHECK: func private @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>) -func private @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>) - // Test memref affine map compositions. +// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ()) +func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->()) + // CHECK: func private @memrefs2(memref<2x4x8xi8, 1>) func private @memrefs2(memref<2x4x8xi8, #map2, 1>) -// CHECK: func private @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}>) -func private @memrefs23(memref<2x4x8xi8, #map2, #map3, 0>) - -// CHECK: func private @memrefs234(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 3>) -func private @memrefs234(memref<2x4x8xi8, #map2, #map3, #map4, 3>) - // Test memref inline affine map compositions, minding that identity maps are removed. // CHECK: func private @memrefs3(memref<2x4x8xi8>) func private @memrefs3(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>>) -// CHECK: func private @memrefs33(memref<2x4x8xi8, #map{{[0-9]+}}, 1>) -func private @memrefs33(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 1>) - // CHECK: func private @memrefs_drop_triv_id_inline(memref<2xi8>) func private @memrefs_drop_triv_id_inline(memref<2xi8, affine_map<(d0) -> (d0)>>) @@ -111,35 +99,6 @@ // CHECK: func private @memrefs_drop_triv_id_inline1(memref<2xi8, 1>) func private @memrefs_drop_triv_id_inline1(memref<2xi8, affine_map<(d0) -> (d0)>, 1>) -// Identity maps should be dropped from the composition, but not the pair of -// "interchange" maps that, if composed, would be also an identity. -// CHECK: func private @memrefs_drop_triv_id_composition(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>) -func private @memrefs_drop_triv_id_composition(memref<2x2xi8, - affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>>) - -// CHECK: func private @memrefs_drop_triv_id_trailing(memref<2x2xi8, #map{{[0-9]+}}>) -func private @memrefs_drop_triv_id_trailing(memref<2x2xi8, affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0, d1)>>) - -// CHECK: func private @memrefs_drop_triv_id_middle(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>) -func private @memrefs_drop_triv_id_middle(memref<2x2xi8, - affine_map<(d0, d1) -> (d0, d1 + 1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0 + 1, d1)>>) - -// CHECK: func private @memrefs_drop_triv_id_multiple(memref<2xi8>) -func private @memrefs_drop_triv_id_multiple(memref<2xi8, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>>) - -// These maps appeared before, so they must be uniqued and hoisted to the beginning. -// Identity map should be removed. -// CHECK: func private @memrefs_compose_with_id(memref<2x2xi8, #map{{[0-9]+}}>) -func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1, d0)>>) - // Test memref with custom memory space // CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>) @@ -202,9 +161,6 @@ // CHECK: func private @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>) func private @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>) -// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ()) -func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->()) - // CHECK-LABEL: func @simpleCFG(%{{.*}}: i32, %{{.*}}: f32) -> i1 { func @simpleCFG(%arg0: i32, %f: f32) -> i1 { // CHECK: %{{.*}} = "foo"() : () -> i64 diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -372,18 +372,19 @@ memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref) - # CHECK: number of affine layout maps: 0 - print("number of affine layout maps:", len(memref.layout)) + # CHECK: memref layout: <> + print("memref layout:", memref.layout) # CHECK: memory space: 2 print("memory space:", memref.memory_space) - layout = AffineMap.get_permutation([1, 0]) - memref_layout = MemRefType.get(shape, f32, [layout]) + layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) + memref_layout = MemRefType.get(shape, f32, layout=layout) # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> print("memref type:", memref_layout) - assert len(memref_layout.layout) == 1 - # CHECK: memref layout: (d0, d1) -> (d1, d0) - print("memref layout:", memref_layout.layout[0]) + # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)> + print("memref layout:", memref_layout.layout) + # CHECK: memref affine map: (d0, d1) -> (d1, d0) + print("memref affine map:", memref_layout.affine_map) # CHECK: memory space: <> print("memory space:", memref_layout.memory_space) diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -32,26 +32,26 @@ ShapedType memrefType = MemRefType::Builder(memrefOriginalShape, memrefOriginalType) .setMemorySpace(memSpace) - .setAffineMaps(map); + .setAffineMap(map); // Update shape. llvm::SmallVector memrefNewShape({30, 40}); ASSERT_NE(memrefOriginalShape, memrefNewShape); ASSERT_EQ(memrefType.clone(memrefNewShape), (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) .setMemorySpace(memSpace) - .setAffineMaps(map)); + .setAffineMap(map)); // Update type. Type memrefNewType = f32; ASSERT_NE(memrefOriginalType, memrefNewType); ASSERT_EQ(memrefType.clone(memrefNewType), (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) .setMemorySpace(memSpace) - .setAffineMaps(map)); + .setAffineMap(map)); // Update both. ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) .setMemorySpace(memSpace) - .setAffineMaps(map)); + .setAffineMap(map)); // Test unranked memref cloning. ShapedType unrankedTensorType =