diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -311,15 +311,23 @@ LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override { - if (!op->getParentOp()->hasTrait()) { - bool hasPotentialAlloca = - op->walk([&](Operation *alloc) { - if (alloc == op) - return WalkResult::advance(); - if (isOpItselfPotentialAutomaticAllocation(alloc)) - return WalkResult::interrupt(); + bool hasPotentialAlloca = + op->walk([&](Operation *alloc) { + if (alloc == op) return WalkResult::advance(); - }).wasInterrupted(); + if (alloc + ->getParentWithTrait() != + op) + return WalkResult::skip(); + if (isOpItselfPotentialAutomaticAllocation(alloc)) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }).wasInterrupted(); + + // If the parent isn't an allocation scope, and we contain an allocation + // we cannot inline the scope. + if (!op->getParentOp()->hasTrait()) { if (hasPotentialAlloca) return failure(); } @@ -327,7 +335,7 @@ // Only apply to if this is this last non-terminator // op in the block (lest lifetime be extended) of a one // block region - if (!lastNonTerminatorInRegion(op)) + if (hasPotentialAlloca && !lastNonTerminatorInRegion(op)) return failure(); Block *block = &op.getRegion().front(); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -644,6 +644,32 @@ // CHECK: return // CHECK: } +func @scopeMerge5() { + "test.region"() ({ + memref.alloca_scope { + affine.parallel (%arg) = (0) to (64) { + %a = memref.alloca(%arg) : memref + "test.use"(%a) : (memref) -> () + } + } + "test.op"() : () -> () + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK: func @scopeMerge5() { +// CHECK: "test.region"() ({ +// CHECK: affine.parallel (%[[cnt:.+]]) = (0) to (64) { +// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK: } +// CHECK: "test.op"() : () -> () +// CHECK: "test.terminator"() : () -> () +// CHECK: }) : () -> () +// CHECK: return +// CHECK: } + func @scopeInline(%arg : memref) { %cnt = "test.count"() : () -> index "test.region"() ({