diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -619,6 +619,7 @@ def Shape_AssumingOp : Shape_Op<"assuming", [SingleBlockImplicitTerminator<"AssumingYieldOp">, + DeclareOpInterfaceMethods, RecursiveSideEffects]> { let summary = "Execute the region"; let description = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -233,6 +233,20 @@ patterns.insert(context); } +void AssumingOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + // AssumingOp has unconditional control flow into the region and back to the + // parent, so return the correct RegionSuccessor purely based on the index + // being None or 0. + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + regions.push_back(RegionSuccessor(&doRegion())); +} + void AssumingOp::inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter) { auto *blockBeforeAssuming = rewriter.getInsertionBlock(); diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir --- a/mlir/test/Transforms/buffer-placement.mlir +++ b/mlir/test/Transforms/buffer-placement.mlir @@ -1417,3 +1417,16 @@ } // expected-error@+1 {{Structured control-flow loops are supported only}} + +// CHECK-LABEL: func @assumingOp +func @assumingOp(%arg0: !shape.witness, %arg1: !shape.witness, %arg2: memref<2xf32>, %arg3: memref<2xf32>) { + %1 = shape.assuming %arg0 -> memref<2xf32> { + %0 = alloc() : memref<2xf32> + shape.assuming_yield %arg2 : memref<2xf32> + } + %3 = shape.assuming %arg0 -> memref<2xf32> { + %2 = alloc() : memref<2xf32> + shape.assuming_yield %2 : memref<2xf32> + } + return +}