diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -353,7 +353,9 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator, RecursivelySpeculatable, - RecursiveMemoryEffects, NoRegionArguments]> { + RecursiveMemoryEffects, NoRegionArguments, + DeclareOpInterfaceMethods + ]> { let summary = "if-then-else operation"; let description = [{ Syntax: diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -911,7 +911,7 @@ static std::enable_if_t(), OpFoldResult> createOrFold(OpBuilder &b, Location loc, ValueRange operands, - Args &&...leadingArguments) { + Args &&... leadingArguments) { // Identify the constant operands and extract their values as attributes. // Note that we cannot use the original values directly because the list of // operands may have changed due to canonicalization and composition. @@ -2513,6 +2513,29 @@ }; } // namespace +/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be +/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp +void AffineIfOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + // If the predecessor is an AffineIfOp, then branching into both `then` and + // `else` region is valid. + if (!index.has_value()) { + regions.reserve(2); + regions.push_back( + RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); + // Don't consider the else region if it is empty. + if (!getElseRegion().empty()) + regions.push_back( + RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); + return; + } + + // If the predecessor is the `else`/`then` region, then branching into parent + // op is valid. + regions.push_back(RegionSuccessor(getResults())); +} + LogicalResult AffineIfOp::verify() { // Verify that we have a condition attribute. // FIXME: This should be specified in the arguments list in ODS. diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -1314,3 +1314,107 @@ test.copy(%2, %arg1) : (memref, memref) return } + +// ----- + +// Memref allocated in `then` region and passed back to the parent if op. +#set = affine_set<() : (0 >= 0)> +// CHECK-LABEL: func @test_affine_if_1 +// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>) -> memref<10xf32> { +func.func @test_affine_if_1(%arg0: memref<10xf32>) -> memref<10xf32> { + %0 = affine.if #set() -> memref<10xf32> { + %alloc = memref.alloc() : memref<10xf32> + affine.yield %alloc : memref<10xf32> + } else { + affine.yield %arg0 : memref<10xf32> + } + return %0 : memref<10xf32> +} +// CHECK-NEXT: %[[IF:.*]] = affine.if +// CHECK-NEXT: %[[MEMREF:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: %[[CLONED:.*]] = bufferization.clone %[[MEMREF]] : memref<10xf32> to memref<10xf32> +// CHECK-NEXT: memref.dealloc %[[MEMREF]] : memref<10xf32> +// CHECK-NEXT: affine.yield %[[CLONED]] : memref<10xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[ARG0_CLONE:.*]] = bufferization.clone %[[ARG0]] : memref<10xf32> to memref<10xf32> +// CHECK-NEXT: affine.yield %[[ARG0_CLONE]] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[IF]] : memref<10xf32> + +// ----- + +// Memref allocated before parent IfOp and used in `then` region. +// Expected result: deallocation should happen after affine.if op. +#set = affine_set<() : (0 >= 0)> +// CHECK-LABEL: func @test_affine_if_2() -> memref<10xf32> { +func.func @test_affine_if_2() -> memref<10xf32> { + %alloc0 = memref.alloc() : memref<10xf32> + %0 = affine.if #set() -> memref<10xf32> { + affine.yield %alloc0 : memref<10xf32> + } else { + %alloc = memref.alloc() : memref<10xf32> + affine.yield %alloc : memref<10xf32> + } + return %0 : memref<10xf32> +} +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: %[[IF_RES:.*]] = affine.if {{.*}} -> memref<10xf32> { +// CHECK-NEXT: %[[ALLOC_CLONE:.*]] = bufferization.clone %[[ALLOC]] : memref<10xf32> to memref<10xf32> +// CHECK-NEXT: affine.yield %[[ALLOC_CLONE]] : memref<10xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: %[[ALLOC2_CLONE:.*]] = bufferization.clone %[[ALLOC2]] : memref<10xf32> to memref<10xf32> +// CHECK-NEXT: memref.dealloc %[[ALLOC2]] : memref<10xf32> +// CHECK-NEXT: affine.yield %[[ALLOC2_CLONE]] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<10xf32> +// CHECK-NEXT: return %[[IF_RES]] : memref<10xf32> + +// ----- + +// Memref allocated before parent IfOp and used in `else` region. +// Expected result: deallocation should happen after affine.if op. +#set = affine_set<() : (0 >= 0)> +// CHECK-LABEL: func @test_affine_if_3() -> memref<10xf32> { +func.func @test_affine_if_3() -> memref<10xf32> { + %alloc0 = memref.alloc() : memref<10xf32> + %0 = affine.if #set() -> memref<10xf32> { + %alloc = memref.alloc() : memref<10xf32> + affine.yield %alloc : memref<10xf32> + } else { + affine.yield %alloc0 : memref<10xf32> + } + return %0 : memref<10xf32> +} +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: %[[IFRES:.*]] = affine.if {{.*}} -> memref<10xf32> { +// CHECK-NEXT: memref.alloc +// CHECK-NEXT: bufferization.clone +// CHECK-NEXT: memref.dealloc +// CHECK-NEXT: affine.yield +// CHECK-NEXT: } else { +// CHECK-NEXT: bufferization.clone +// CHECK-NEXT: affine.yield +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<10xf32> +// CHECK-NEXT: return %[[IFRES]] : memref<10xf32> + +// ----- + +// Memref allocated before parent IfOp and not used later. +// Expected result: deallocation should happen before affine.if op. +#set = affine_set<() : (0 >= 0)> +// CHECK-LABEL: func @test_affine_if_4({{.*}}: memref<10xf32>) -> memref<10xf32> { +func.func @test_affine_if_4(%arg0 : memref<10xf32>) -> memref<10xf32> { + %alloc0 = memref.alloc() : memref<10xf32> + %0 = affine.if #set() -> memref<10xf32> { + affine.yield %arg0 : memref<10xf32> + } else { + %alloc = memref.alloc() : memref<10xf32> + affine.yield %alloc : memref<10xf32> + } + return %0 : memref<10xf32> +} +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<10xf32> +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<10xf32> +// CHECK-NEXT: affine.if