diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -839,6 +839,9 @@ // Private variable with an initial value. memref.global "private" @x : memref<2xf32> = dense<0.0,2.0> + // Private variable with an initial value and an alignment (power of 2). + memref.global "private" @x : memref<2xf32> = dense<0.0,2.0> {alignment = 64} + // Declaration of an external variable. memref.global "private" @y : memref<4xi32> @@ -855,7 +858,8 @@ OptionalAttr:$sym_visibility, MemRefTypeAttr:$type, OptionalAttr:$initial_value, - UnitAttr:$constant + UnitAttr:$constant, + OptionalAttr:$alignment ); let assemblyFormat = [{ diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -451,9 +451,11 @@ initialValue = elementsAttr.getValue({}); } + uint64_t alignment = global.alignment().getValueOr(0); + auto newGlobal = rewriter.replaceOpWithNewOp( global, arrayTy, global.constant(), linkage, global.sym_name(), - initialValue, /*alignment=*/0, type.getMemorySpaceAsInt()); + initialValue, alignment, type.getMemorySpaceAsInt()); if (!global.isExternal() && global.isUninitialized()) { Block *blk = new Block(); newGlobal.getInitializerRegion().push_back(blk); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1176,6 +1176,14 @@ } } + if (Optional alignAttr = op.alignment()) { + uint64_t alignment = alignAttr.getValue(); + + if (!llvm::isPowerOf2_64(alignment)) + return op->emitError() << "alignment attribute value " << alignment + << " is not a power of 2"; + } + // TODO: verify visibility for declarations. return success(); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -48,7 +48,8 @@ /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/typeConverter.convertType(type).cast(), /*initial_value=*/constantOp.getValue().cast(), - /*constant=*/true); + /*constant=*/true, + /*alignment=*/IntegerAttr()); symbolTable.insert(global); // The symbol table inserts at the end of the module, but globals are a bit // nicer if they are at the beginning. diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -701,6 +701,10 @@ return } +// Test scalar memref with an alignment. +// CHECK: llvm.mlir.global private @gv4(1.000000e+00 : f32) {alignment = 64 : i64} : f32 +memref.global "private" @gv4 : memref = dense<1.0> {alignment = 64} + // ----- func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -345,6 +345,11 @@ // ----- +// expected-error @+1 {{alignment attribute value 63 is not a power of 2}} +memref.global "private" @gv : memref<4xf32> = dense<1.0> { alignment = 63 } + +// ----- + func @copy_different_shape(%arg0: memref<2xf32>, %arg1: memref<3xf32>) { // expected-error @+1 {{op requires the same shape for all operands}} memref.copy %arg0, %arg1 : memref<2xf32> to memref<3xf32>