diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -311,25 +311,28 @@ LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override { - if (!op->getParentOp()->hasTrait()) { - bool hasPotentialAlloca = - op->walk([&](Operation *alloc) { - if (alloc == op) - return WalkResult::advance(); - if (isOpItselfPotentialAutomaticAllocation(alloc)) - return WalkResult::interrupt(); + bool hasPotentialAlloca = + op->walk([&](Operation *alloc) { + if (alloc == op) return WalkResult::advance(); - }).wasInterrupted(); - if (hasPotentialAlloca) + if (isOpItselfPotentialAutomaticAllocation(alloc)) + return WalkResult::interrupt(); + if (alloc->hasTrait()) + return WalkResult::skip(); + return WalkResult::advance(); + }).wasInterrupted(); + + // If this contains no potential allocation, it is always legal to + // inline. Otherwise, consider two conditions: + if (hasPotentialAlloca) { + // If the parent isn't an allocation scope, or we are not the last + // non-terminator op in the parent, we will extend the lifetime. + if (!op->getParentOp()->hasTrait()) + return failure(); + if (!lastNonTerminatorInRegion(op)) return failure(); } - // Only apply to if this is this last non-terminator - // op in the block (lest lifetime be extended) of a one - // block region - if (!lastNonTerminatorInRegion(op)) - return failure(); - Block *block = &op.getRegion().front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -644,6 +644,32 @@ // CHECK: return // CHECK: } +func @scopeMerge5() { + "test.region"() ({ + memref.alloca_scope { + affine.parallel (%arg) = (0) to (64) { + %a = memref.alloca(%arg) : memref + "test.use"(%a) : (memref) -> () + } + } + "test.op"() : () -> () + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK: func @scopeMerge5() { +// CHECK: "test.region"() ({ +// CHECK: affine.parallel (%[[cnt:.+]]) = (0) to (64) { +// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK: } +// CHECK: "test.op"() : () -> () +// CHECK: "test.terminator"() : () -> () +// CHECK: }) : () -> () +// CHECK: return +// CHECK: } + func @scopeInline(%arg : memref) { %cnt = "test.count"() : () -> index "test.region"() ({