diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -445,7 +445,14 @@ MemRefType::Builder(memrefType) .setShape(newShape) .setAffineMaps(b.getMultiDimIdentityMap(newRank)); - auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); + + // Set alignment attribute when input allocOp has + AllocOp newAlloc; + if (allocOp.alignment()) { + newAlloc = b.create(allocOp.getLoc(), newMemRefType, + allocOp.value(), allocOp.alignmentAttr()); + } else + newAlloc = b.create(allocOp.getLoc(), newMemRefType); // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, diff --git a/mlir/test/Transforms/memref-normalize.mlir b/mlir/test/Transforms/memref-normalize.mlir --- a/mlir/test/Transforms/memref-normalize.mlir +++ b/mlir/test/Transforms/memref-normalize.mlir @@ -143,3 +143,10 @@ } return } + +// CHECK-LABEL: func @alignment +func @alignment() { + %A = alloc() {alignment = 32 : i64}: memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> + // CHECK: %{{.*}} = alloc() {alignment = 32 : i64} : memref<256x64x128xf32> + return +}