diff --git a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp --- a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/SetOperations.h" @@ -51,9 +52,9 @@ /// successor regions and branch-like return operations from nested regions. void BufferViewFlowAnalysis::build(Operation *op) { // Registers all dependencies of the given values. - auto registerDependencies = [&](auto values, auto dependencies) { - for (auto entry : llvm::zip(values, dependencies)) - this->dependencies[std::get<0>(entry)].insert(std::get<1>(entry)); + auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { + for (auto [value, dep] : llvm::zip(values, dependencies)) + this->dependencies[value].insert(dep); }; // Add additional dependencies created by view changes to the alias list. @@ -119,4 +120,10 @@ } } }); + + // TODO: This should be an interface. + op->walk([&](arith::SelectOp selectOp) { + registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()}); + registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()}); + }); } diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -40,6 +40,7 @@ mlir-headers LINK_LIBS PUBLIC + MLIRArithmeticDialect MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -1298,3 +1298,19 @@ // CHECK-NEXT: return return } + +// ----- + +func.func @select_aliases(%arg0: index, %arg1: memref, %arg2: i1) { + // CHECK: memref.alloc + // CHECK: memref.alloc + // CHECK: arith.select + // CHECK: test.copy + // CHECK: memref.dealloc + // CHECK: memref.dealloc + %0 = memref.alloc(%arg0) : memref + %1 = memref.alloc(%arg0) : memref + %2 = arith.select %arg2, %0, %1 : memref + test.copy(%2, %arg1) : (memref, memref) + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5879,6 +5879,7 @@ ), includes = ["include"], deps = [ + ":ArithmeticDialect", ":CallOpInterfaces", ":ControlFlowInterfaces", ":DataLayoutInterfaces",