diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -619,6 +619,7 @@ def Shape_AssumingOp : Shape_Op<"assuming", [SingleBlockImplicitTerminator<"AssumingYieldOp">, + DeclareOpInterfaceMethods, RecursiveSideEffects]> { let summary = "Execute the region"; let description = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -233,6 +233,21 @@ patterns.insert(context); } +// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td +void AssumingOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + // AssumingOp has unconditional control flow into the region and back to the + // parent, so return the correct RegionSuccessor purely based on the index + // being None or 0. + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + regions.push_back(RegionSuccessor(&doRegion())); +} + void AssumingOp::inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter) { auto *blockBeforeAssuming = rewriter.getInsertionBlock(); diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir --- a/mlir/test/Transforms/buffer-placement.mlir +++ b/mlir/test/Transforms/buffer-placement.mlir @@ -1417,3 +1417,42 @@ } // expected-error@+1 {{Structured control-flow loops are supported only}} + +// ----- + +func @assumingOp(%arg0: !shape.witness, %arg2: memref<2xf32>, %arg3: memref<2xf32>) { + // Confirm the alloc will be dealloc'ed in the block. + %1 = shape.assuming %arg0 -> memref<2xf32> { + %0 = alloc() : memref<2xf32> + shape.assuming_yield %arg2 : memref<2xf32> + } + // Confirm the alloc will be returned and dealloc'ed after its use. + %3 = shape.assuming %arg0 -> memref<2xf32> { + %2 = alloc() : memref<2xf32> + shape.assuming_yield %2 : memref<2xf32> + } + "linalg.copy"(%3, %arg3) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK-LABEL: func @assumingOp( +// CHECK-SAME: %[[ARG0:.*]]: !shape.witness, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { +// CHECK: %[[UNUSED_RESULT:.*]] = shape.assuming %[[ARG0]] -> (memref<2xf32>) { +// CHECK: %[[ALLOC0:.*]] = alloc() : memref<2xf32> +// CHECK: dealloc %[[ALLOC0]] : memref<2xf32> +// CHECK: shape.assuming_yield %[[ARG1]] : memref<2xf32> +// CHECK: } +// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (memref<2xf32>) { +// CHECK: %[[TMP_ALLOC:.*]] = alloc() : memref<2xf32> +// CHECK: %[[RETURNING_ALLOC:.*]] = alloc() : memref<2xf32> +// CHECK: linalg.copy(%[[TMP_ALLOC]], %[[RETURNING_ALLOC]]) : memref<2xf32>, memref<2xf32> +// CHECK: dealloc %[[TMP_ALLOC]] : memref<2xf32> +// CHECK: shape.assuming_yield %[[RETURNING_ALLOC]] : memref<2xf32> +// CHECK: } +// CHECK: linalg.copy(%[[ASSUMING_RESULT:.*]], %[[ARG2]]) : memref<2xf32>, memref<2xf32> +// CHECK: dealloc %[[ASSUMING_RESULT]] : memref<2xf32> +// CHECK: return +// CHECK: } +