diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -163,17 +163,9 @@ This operation returns a single ssa value of memref type, which can be used by subsequent load and store operations. - - The optional `alignment` attribute may be specified to ensure that the - region of memory that will be indexed is aligned at the specified byte - boundary. TODO(b/144281289) optional alignment attribute to MemRefType. - - %0 = alloc()[%s] {alignment = 8} : - memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> }]; - let arguments = (ins Variadic:$value, - Confined, [IntMinValue<0>]>:$alignment); + let arguments = (ins Variadic:$value); let results = (outs AnyMemRef); let builders = [OpBuilder< @@ -182,16 +174,12 @@ }]>, OpBuilder< "Builder *builder, OperationState &result, MemRefType memrefType, " # - "ArrayRef operands, IntegerAttr alignment = IntegerAttr()", [{ + "ArrayRef operands", [{ result.addOperands(operands); result.types.push_back(memrefType); - if (alignment) - result.addAttribute(getAlignmentAttrName(), alignment); }]>]; let extraClassDeclaration = [{ - static StringRef getAlignmentAttrName() { return "alignment"; } - MemRefType getType() { return getResult().getType().cast(); } /// Returns the number of symbolic operands (the ones in square brackets), diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -398,7 +398,7 @@ /// construction failures. static MemRefType get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition = {}, - unsigned memorySpace = 0); + unsigned memorySpace = 0, unsigned alignment = 0); /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space declared at the given location. @@ -408,7 +408,8 @@ /// the error stream) and returns nullptr. static MemRefType getChecked(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Location location); + unsigned memorySpace, unsigned alignment, + Location location); ArrayRef getShape() const; @@ -419,6 +420,9 @@ /// Returns the memory space in which data referred to by this memref resides. unsigned getMemorySpace() const; + /// Returns the alignment the underlying buffer is aligned to. + unsigned getAlignment() const; + // TODO(ntv): merge these two special values in a single one used everywhere. // Unfortunately, uses of `-1` have crept deep into the codebase now and are // hard to track. @@ -435,7 +439,8 @@ /// emit detailed error messages. static MemRefType getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Optional location); + unsigned memorySpace, unsigned alignment, + Optional location); using Base::getImpl; }; @@ -444,16 +449,18 @@ Type elementType; ArrayRef affineMaps; unsigned memorySpace; + unsigned alignment; public: // Builder that clones from another MemRefType. explicit MemRefTypeBuilder(MemRefType other) : shape(other.getShape()), elementType(other.getElementType()), - affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()) { - } + affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()), + alignment(other.getAlignment()) {} MemRefTypeBuilder(ArrayRef shape, Type elementType) - : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {} + : shape(shape), elementType(elementType), affineMaps(), memorySpace(0), + alignment(0) {} MemRefTypeBuilder &setShape(ArrayRef newShape) { shape = newShape; @@ -475,8 +482,14 @@ return *this; } + MemRefTypeBuilder &setAlignment(unsigned newAlignment) { + alignment = newAlignment; + return *this; + } + operator MemRefType() { - return MemRefType::get(shape, elementType, affineMaps, memorySpace); + return MemRefType::get(shape, elementType, affineMaps, memorySpace, + alignment); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -961,10 +961,8 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Value allocated = nullptr; - int alignment = 0; + int alignment = type.getAlignment(); Value alignmentValue = nullptr; - if (auto alignAttr = allocOp.alignment()) - alignment = alignAttr.getValue().getSExtValue(); if (useAlloca) { allocated = rewriter.create(loc, getVoidPtrType(), diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -359,8 +359,8 @@ newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. - auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, - newOperands, IntegerAttr()); + auto newAlloc = + rewriter.create(alloc.getLoc(), newMemRefType, newOperands); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1482,8 +1482,10 @@ printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. - if (v.getMemorySpace()) + if (v.getMemorySpace() || v.getAlignment()) os << ", " << v.getMemorySpace(); + if (v.getAlignment()) + os << ", " << v.getAlignment(); os << '>'; return; } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -302,9 +302,10 @@ /// construction failures. MemRefType MemRefType::get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace) { - auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, - /*location=*/llvm::None); + unsigned memorySpace, unsigned alignment) { + auto result = + getImpl(shape, elementType, affineMapComposition, memorySpace, alignment, + /*location=*/llvm::None); assert(result && "Failed to construct instance of MemRefType."); return result; } @@ -317,9 +318,10 @@ /// the error stream) and returns nullptr. MemRefType MemRefType::getChecked(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Location location) { + unsigned memorySpace, unsigned alignment, + Location location) { return getImpl(shape, elementType, affineMapComposition, memorySpace, - location); + alignment, location); } /// Get or create a new MemRefType defined by the arguments. If the resulting @@ -328,7 +330,7 @@ /// pass in an instance of UnknownLoc. MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, + unsigned memorySpace, unsigned alignment, Optional location) { auto *context = elementType.getContext(); @@ -373,8 +375,13 @@ cleanedAffineMapComposition.push_back(map); } + if ((alignment & (alignment - 1)) != 0) { + return emitOptionalError(location, "alignment must be power of 2"), + MemRefType(); + } + return Base::get(context, StandardTypes::MemRef, shape, elementType, - cleanedAffineMapComposition, memorySpace); + cleanedAffineMapComposition, memorySpace, alignment); } ArrayRef MemRefType::getShape() const { return getImpl()->getShape(); } @@ -385,6 +392,8 @@ unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; } +unsigned MemRefType::getAlignment() const { return getImpl()->alignment; } + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -194,18 +194,20 @@ struct MemRefTypeStorage : public ShapedTypeStorage { MemRefTypeStorage(unsigned shapeSize, Type elementType, const int64_t *shapeElements, const unsigned numAffineMaps, - AffineMap const *affineMapList, const unsigned memorySpace) + AffineMap const *affineMapList, const unsigned memorySpace, + const unsigned alignment) : ShapedTypeStorage(elementType, shapeSize), shapeElements(shapeElements), numAffineMaps(numAffineMaps), affineMapList(affineMapList), - memorySpace(memorySpace) {} + memorySpace(memorySpace), alignment(alignment) {} /// The hash key used for uniquing. // MemRefs are uniqued based on their shape, element type, affine map // composition, and memory space. - using KeyTy = - std::tuple, Type, ArrayRef, unsigned>; + using KeyTy = std::tuple, Type, ArrayRef, + unsigned, unsigned>; bool operator==(const KeyTy &key) const { - return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace); + return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace, + alignment); } /// Construction. @@ -219,10 +221,10 @@ allocator.copyInto(std::get<2>(key)); // Initialize the memory using placement new. - return new (allocator.allocate()) - MemRefTypeStorage(shape.size(), std::get<1>(key), shape.data(), - affineMapComposition.size(), - affineMapComposition.data(), std::get<3>(key)); + return new (allocator.allocate()) MemRefTypeStorage( + shape.size(), std::get<1>(key), shape.data(), + affineMapComposition.size(), affineMapComposition.data(), + std::get<3>(key), std::get<4>(key)); } ArrayRef getShape() const { @@ -241,6 +243,8 @@ AffineMap const *affineMapList; /// Memory space in which data referenced by memref resides. const unsigned memorySpace; + /// Alignment of the underlying pointer. + const unsigned alignment; }; /// Unranked MemRef is a MemRef with unknown rank. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1099,19 +1099,32 @@ // Parse semi-affine-map-composition. SmallVector affineMapComposition; unsigned memorySpace = 0; + unsigned alignment = 0; bool parsedMemorySpace = false; + bool parsedAlignment = false; auto parseElt = [&]() -> ParseResult { if (getToken().is(Token::integer)) { // Parse memory space. - if (parsedMemorySpace) - return emitError("multiple memory spaces specified in memref type"); - auto v = getToken().getUnsignedIntegerValue(); - if (!v.hasValue()) - return emitError("invalid memory space in memref type"); - memorySpace = v.getValue(); - consumeToken(Token::integer); - parsedMemorySpace = true; + if (!parsedMemorySpace) { + auto v = getToken().getUnsignedIntegerValue(); + if (!v.hasValue()) + return emitError("invalid memory space in memref type"); + memorySpace = v.getValue(); + consumeToken(Token::integer); + parsedMemorySpace = true; + } else if (!parsedAlignment) { + if (isUnranked) + return emitError("cannot have alignment for unranked memref type"); + auto v = getToken().getUnsignedIntegerValue(); + if (!v.hasValue()) + return emitError("invalid alignment in memref type"); + alignment = v.getValue(); + consumeToken(Token::integer); + parsedAlignment = true; + } else { + return emitError("excessive integers specified in memref type"); + } } else { if (isUnranked) return emitError("cannot have affine map for unranked memref type"); @@ -1158,7 +1171,8 @@ getEncodedSourceLocation(typeLoc)); return MemRefType::getChecked(dimensions, elementType, affineMapComposition, - memorySpace, getEncodedSourceLocation(typeLoc)); + memorySpace, alignment, + getEncodedSourceLocation(typeLoc)); } /// Parse any type except the function type. diff --git a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir @@ -55,7 +55,7 @@ } // CHECK-LABEL: func @aligned_1d_alloc( -func @aligned_1d_alloc() -> memref<42xf32> { +func @aligned_1d_alloc() -> memref<42xf32, 0, 8> { // CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 @@ -78,8 +78,8 @@ // CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> - %0 = alloc() {alignment = 8} : memref<42xf32> - return %0 : memref<42xf32> + %0 = alloc() : memref<42xf32, 0, 8> + return %0 : memref<42xf32, 0, 8> } // CHECK-LABEL: func @mixed_alloc( diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -728,3 +728,13 @@ tensor_store %1, %0 : memref<4x4xi32> return } + +// CHECK-LABEL: func @aligned_memref_load_store +func @aligned_memref_load_store(%arr : memref<4xf16, 0, 8>) { + %c_0 = constant 0 : index + // CHECK: %[[VALUE:.*]] = load %[[MEMREF:.*]] : memref<4xf16, 0, 8> + %v = load %arr[%c_0] : memref<4xf16, 0, 8> + // CHECK: store %[[VALUE]], %[[MEMREF]] : memref<4xf16, 0, 8> + store %v, %arr[%c_0] : memref<4xf16, 0, 8> + return +} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -46,9 +46,18 @@ func @memrefs(memref<2x4xi8, #map0, #map8>) // expected-error {{undefined symbol alias id 'map8'}} // ----- -// Test multiple memory space error. +// Test multiple integers error. #map0 = affine_map<(d0, d1) -> (d0, d1)> -func @memrefs(memref<2x4xi8, #map0, 1, 2>) // expected-error {{multiple memory spaces specified in memref type}} +func @memrefs(memref<2x4xi8, #map0, 1, 2, 3>) // expected-error {{excessive integers specified in memref type}} + +// ----- +// Test invalid alignment. +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @memrefs(memref<2x4xi8, #map0, 1, 3>) // expected-error {{alignment must be power of 2}} + +// ----- +// Test alignment for unranked memref error. +func @memrefs(memref<*xi8, 1, 2>) // expected-error {{cannot have alignment for unranked memref type}} // ----- // Test affine map after memory space.