diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -378,7 +378,6 @@ let arguments = (ins Arg:$memref); - let hasCanonicalizer = 1; let hasFolder = 1; let assemblyFormat = "$memref attr-dict `:` type($memref)"; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -195,30 +195,36 @@ } }; -/// Fold alloc operations with no uses. Alloc has side effects on the heap, -/// but can still be deleted if it has zero uses. -struct SimplifyDeadAlloc : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Fold alloc operations with no users or only store and dealloc uses. +template +struct SimplifyDeadAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AllocOp alloc, + LogicalResult matchAndRewrite(T alloc, PatternRewriter &rewriter) const override { - if (alloc.use_empty()) { - rewriter.eraseOp(alloc); - return success(); - } - return failure(); + if (llvm::any_of(alloc->getUsers(), [](Operation *op) { + return !isa(op); + })) + return failure(); + + for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) + rewriter.eraseOp(user); + + rewriter.eraseOp(alloc); + return success(); } }; } // end anonymous namespace. void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, SimplifyDeadAlloc>(context); + results.add, SimplifyDeadAlloc>(context); } void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add>(context); + results.add, SimplifyDeadAlloc>( + context); } //===----------------------------------------------------------------------===// @@ -537,30 +543,6 @@ //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// -namespace { -/// Fold Dealloc operations that are deallocating an AllocOp that is only used -/// by other Dealloc operations. -struct SimplifyDeadDealloc : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DeallocOp dealloc, - PatternRewriter &rewriter) const override { - // Check that the memref operand's defining operation is an AllocOp. - Value memref = dealloc.memref(); - if (!isa_and_nonnull(memref.getDefiningOp())) - return failure(); - - // Check that all of the uses of the AllocOp are other DeallocOps. - for (auto *user : memref.getUsers()) - if (!isa(user)) - return failure(); - - // Erase the dealloc operation. - rewriter.eraseOp(dealloc); - return success(); - } -}; -} // end anonymous namespace. static LogicalResult verify(DeallocOp op) { if (!op.memref().getType().isa()) @@ -568,11 +550,6 @@ return success(); } -void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - LogicalResult DeallocOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dealloc(memrefcast) -> dealloc diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -93,11 +93,8 @@ // CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0, d1) -> (d1 + (d0 ceildiv 4) * 4 - (d1 floordiv 4) * 4)> // CHECK-DAG: #[[$MAP8a:.*]] = affine_map<(d0, d1) -> (d1 + (d0 ceildiv 8) * 8 - (d1 floordiv 8) * 8)> -// CHECK-LABEL: func @compose_affine_maps_2d_tile() { -func @compose_affine_maps_2d_tile() { - %0 = memref.alloc() : memref<16x32xf32> - %1 = memref.alloc() : memref<16x32xf32> - +// CHECK-LABEL: func @compose_affine_maps_2d_tile +func @compose_affine_maps_2d_tile(%0: memref<16x32xf32>, %1: memref<16x32xf32>) { %c4 = constant 4 : index %c8 = constant 8 : index @@ -221,7 +218,7 @@ // ----- // CHECK-LABEL: func @arg_used_as_dim_and_symbol -func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index, %arg2: f32) { +func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index, %arg2: f32) -> (memref<100x100xf32, 1>, memref<1xi32>) { %c9 = constant 9 : index %1 = memref.alloc() : memref<100x100xf32, 1> %2 = memref.alloc() : memref<1xi32> @@ -235,7 +232,7 @@ memref.store %arg2, %1[%4, %arg1] : memref<100x100xf32, 1> } } - return + return %1, %2 : memref<100x100xf32, 1>, memref<1xi32> } // ----- diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -207,9 +207,8 @@ // CHECK-LABEL: func @extract_from_tensor.generate_sideeffects // CHECK-SAME: %[[IDX:.*]]: index -func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index { +func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref) -> index { %size = rank %tensor : tensor<*xf32> - %mem = memref.alloc(%size) : memref // CHECK: %[[DTENSOR:.*]] = tensor.generate %0 = tensor.generate %size { ^bb0(%arg0: index): diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -410,6 +410,27 @@ return } +// CHECK-LABEL: func @write_only_alloc_fold +func @write_only_alloc_fold(%v: f32) { + // CHECK-NEXT: return + %c0 = constant 0 : index + %c4 = constant 4 : index + %a = memref.alloc(%c4) : memref + memref.store %v, %a[%c0] : memref + memref.dealloc %a: memref + return +} + +// CHECK-LABEL: func @write_only_alloca_fold +func @write_only_alloca_fold(%v: f32) { + // CHECK-NEXT: return + %c0 = constant 0 : index + %c4 = constant 4 : index + %a = memref.alloca(%c4) : memref + memref.store %v, %a[%c0] : memref + return +} + // CHECK-LABEL: func @dead_block_elim func @dead_block_elim() { // CHECK-NOT: ^bb @@ -426,7 +447,7 @@ } // CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index) -func @dyn_shape_fold(%L : index, %M : index) -> (memref, memref, memref<4 x ? x 8 x ? x ? x f32>) { +func @dyn_shape_fold(%L : index, %M : index) -> (memref<4 x ? x 8 x ? x ? x f32>, memref, memref, memref<4 x ? x 8 x ? x ? x f32>) { // CHECK: %c0 = constant 0 : index %zero = constant 0 : index // The constants below disappear after they propagate into shapes. @@ -434,13 +455,13 @@ %N = constant 1024 : index %K = constant 512 : index - // CHECK-NEXT: memref.alloc(%arg0) : memref + // CHECK: memref.alloc(%arg0) : memref %a = memref.alloc(%L, %N) : memref - // CHECK-NEXT: memref.alloc(%arg1) : memref<4x1024x8x512x?xf32> + // CHECK: memref.alloc(%arg1) : memref<4x1024x8x512x?xf32> %b = memref.alloc(%N, %K, %M) : memref<4 x ? x 8 x ? x ? x f32> - // CHECK-NEXT: memref.alloc() : memref<512x1024xi32> + // CHECK: memref.alloc() : memref<512x1024xi32> %c = memref.alloc(%K, %N) : memref // CHECK: memref.alloc() : memref<9x9xf32> @@ -460,7 +481,7 @@ } } - return %c, %d, %e : memref, memref, memref<4 x ? x 8 x ? x ? x f32> + return %b, %c, %d, %e : memref<4 x ? x 8 x ? x ? x f32>, memref, memref, memref<4 x ? x 8 x ? x ? x f32> } #map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>