diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -352,7 +352,9 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator, RecursiveSideEffects, - NoRegionArguments]> { + NoRegionArguments, + DeclareOpInterfaceMethods + ]> { let summary = "if-then-else operation"; let description = [{ Syntax: diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2363,6 +2363,29 @@ }; } // namespace +/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be +/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp +void AffineIfOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + // If the predecessor is an AffineIfOp, then branching into both `then` and + // `else` region is valid. + if (!index.hasValue()) { + regions.reserve(2); + regions.push_back( + RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); + // Don't consider the else region if it is empty. + if (!getElseRegion().empty()) + regions.push_back( + RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); + return; + } + + // If the predecessor is `else`/`then` region then branching into parent op is + // valid. + regions.push_back(RegionSuccessor(getResults())); +} + LogicalResult AffineIfOp::verify() { // Verify that we have a condition attribute. // FIXME: This should be specified in the arguments list in ODS. diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -1298,3 +1298,107 @@ // CHECK-NEXT: return return } + +// ----- + +// Memref allocated in then region and passed back to the parent if op. +#set = affine_set<() : (0 >= 0)> +// CHECK-LABEL: func @test_affine_if_1 +// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>) -> memref<10xf32> { +func.func @test_affine_if_1(%arg0: memref<10xf32>) -> memref<10xf32> { + %0 = affine.if #set() -> memref<10xf32> { + %alloc = memref.alloc() : memref<10xf32> + affine.yield %alloc : memref<10xf32> + } else { + affine.yield %arg0 : memref<10xf32> + } + return %0 : memref<10xf32> +} +// CHECK-NEXT: %[[IF:.*]] = affine.if +// CHECK-NEXT: %[[MEMREF:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: %[[CLONED:.*]] = bufferization.clone %[[MEMREF]] : memref<10xf32> to memref<10xf32> +// CHECK-NEXT: memref.dealloc %[[MEMREF]] : memref<10xf32> +// CHECK-NEXT: affine.yield %[[CLONED]] : memref<10xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[ARG0_CLONE:.*]] = bufferization.clone %[[ARG0]] : memref<10xf32> to memref<10xf32> +// CHECK-NEXT: affine.yield %[[ARG0_CLONE]] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[IF]] : memref<10xf32> + +// ----- + +// Memref allocated before parent IfOp and used in `then` region. +// Expected result: deallocation should happen after affine.if op. +#set = affine_set<() : (0 >= 0)> +// CHECK-LABEL: func @test_affine_if_2() -> memref<10xf32> { +func.func @test_affine_if_2() -> memref<10xf32> { + %alloc0 = memref.alloc() : memref<10xf32> + %0 = affine.if #set() -> memref<10xf32> { + affine.yield %alloc0 : memref<10xf32> + } else { + %alloc = memref.alloc() : memref<10xf32> + affine.yield %alloc : memref<10xf32> + } + return %0 : memref<10xf32> +} +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: %[[IF_RES:.*]] = affine.if {{.*}} -> memref<10xf32> { +// CHECK-NEXT: %[[ALLOC_CLONE:.*]] = bufferization.clone %[[ALLOC]] : memref<10xf32> to memref<10xf32> +// CHECK-NEXT: affine.yield %[[ALLOC_CLONE]] : memref<10xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: %[[ALLOC2_CLONE:.*]] = bufferization.clone %[[ALLOC2]] : memref<10xf32> to memref<10xf32> +// CHECK-NEXT: memref.dealloc %[[ALLOC2]] : memref<10xf32> +// CHECK-NEXT: affine.yield %[[ALLOC2_CLONE]] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<10xf32> +// CHECK-NEXT: return %[[IF_RES]] : memref<10xf32> + +// ----- + +// Memref allocated before parent IfOp and used in `else` region. +// Expected result: deallocation should happen after affine.if op. +#set = affine_set<() : (0 >= 0)> +// CHECK-LABEL: func @test_affine_if_3() -> memref<10xf32> { +func.func @test_affine_if_3() -> memref<10xf32> { + %alloc0 = memref.alloc() : memref<10xf32> + %0 = affine.if #set() -> memref<10xf32> { + %alloc = memref.alloc() : memref<10xf32> + affine.yield %alloc : memref<10xf32> + } else { + affine.yield %alloc0 : memref<10xf32> + } + return %0 : memref<10xf32> +} +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: %[[IFRES:.*]] = affine.if {{.*}} -> memref<10xf32> { +// CHECK-NEXT: memref.alloc +// CHECK-NEXT: bufferization.clone +// CHECK-NEXT: memref.dealloc +// CHECK-NEXT: affine.yield +// CHECK-NEXT: } else { +// CHECK-NEXT: bufferization.clone +// CHECK-NEXT: affine.yield +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<10xf32> +// CHECK-NEXT: return %[[IFRES]] : memref<10xf32> + +// ----- + +// Memref allocated before parent IfOp and not used later. +// Expected result: deallocation should happen before affine.if op. +#set = affine_set<() : (0 >= 0)> +// CHECK-LABEL: func @test_affine_if_4({{.*}}: memref<10xf32>) -> memref<10xf32> { +func.func @test_affine_if_4(%arg0 : memref<10xf32>) -> memref<10xf32> { + %alloc0 = memref.alloc() : memref<10xf32> + %0 = affine.if #set() -> memref<10xf32> { + affine.yield %arg0 : memref<10xf32> + } else { + %alloc = memref.alloc() : memref<10xf32> + affine.yield %alloc : memref<10xf32> + } + return %0 : memref<10xf32> +} +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<10xf32> +// CHECK-NEXT: affine.if