diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -261,7 +261,7 @@ /// Given an operation, return whether this op is guaranteed to /// allocate an AutomaticAllocationScopeResource -static bool isGuaranteedAutomaticAllocationScope(Operation *op) { +static bool isGuaranteedAutomaticAllocation(Operation *op) { MemoryEffectOpInterface interface = dyn_cast(op); if (!interface) return false; @@ -276,9 +276,15 @@ return false; } -/// Given an operation, return whether this op could to -/// allocate an AutomaticAllocationScopeResource -static bool isPotentialAutomaticAllocationScope(Operation *op) { +/// Given an operation, return whether this op itself could +/// allocate an AutomaticAllocationScopeResource. Note that +/// this will not check whether an operation contained within +/// the op can allocate. +static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { + // This op itself doesn't create a stack allocation, + // the inner allocation should be handled separately. + if (op->hasTrait()) + return false; MemoryEffectOpInterface interface = dyn_cast(op); if (!interface) return true; @@ -312,9 +318,11 @@ if (!op->getParentOp()->hasTrait()) { bool hasPotentialAlloca = op->walk([&](Operation *alloc) { - if (isPotentialAutomaticAllocationScope(alloc)) + if (alloc == op) + return WalkResult::advance(); + if (isOpItselfPotentialAutomaticAllocation(alloc)) return WalkResult::interrupt(); - return WalkResult::skip(); + return WalkResult::advance(); }).wasInterrupted(); if (hasPotentialAlloca) return failure(); @@ -383,7 +391,7 @@ SmallVector toHoist; op->walk([&](Operation *alloc) { - if (!isGuaranteedAutomaticAllocationScope(alloc)) + if (!isGuaranteedAutomaticAllocation(alloc)) return WalkResult::skip(); // If any operand is not defined before the location of diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -643,3 +643,17 @@ // CHECK: }) : () -> () // CHECK: return // CHECK: } + +func @scopeInline(%arg : memref) { + %cnt = "test.count"() : () -> index + "test.region"() ({ + memref.alloca_scope { + memref.store %cnt, %arg[] : memref + } + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK: func @scopeInline +// CHECK-NOT: memref.alloca_scope