diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -174,8 +174,10 @@ LogicalResult matchAndRewrite(T alloc, PatternRewriter &rewriter) const override { - if (llvm::any_of(alloc->getUsers(), [](Operation *op) { - return !isa(op); + if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { + if (auto storeOp = dyn_cast(op)) + return storeOp.value() == alloc; + return !isa(op); })) return failure(); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -420,3 +420,13 @@ %0 = memref.alloc(%c1)[%c1, %c1] : memref return %0 : memref } + +// ----- +// CHECK-LABEL: func @allocator +// CHECK: %[[alloc:.+]] = memref.alloc +// CHECK: memref.store %[[alloc:.+]], %arg0 +func @allocator(%arg0 : memref>, %arg1 : index) { + %0 = memref.alloc(%arg1) : memref + memref.store %0, %arg0[] : memref> + return +}