diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -184,8 +184,8 @@ newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. - auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, - newOperands, IntegerAttr()); + auto newAlloc = rewriter.create( + alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -368,6 +368,17 @@ return %a : memref } +// CHECK-LABEL: func @alloc_alignment_const_fold +func @alloc_alignment_const_fold() -> memref { + // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32> + %c4 = constant 4 : index + %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref + + // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref + // CHECK-NEXT: return %1 : memref + return %a : memref +} + // CHECK-LABEL: func @dead_alloc_fold func @dead_alloc_fold() { // CHECK-NEXT: return