diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h @@ -1,4 +1,4 @@ -//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===// +//===- AffineInterfaceImpl.h - Affine Impl. of BufferizableOpInterface ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -346,10 +346,10 @@ /// In the above example, Values with a star satisfy the condition. When /// starting the traversal from Value 1, the resulting SetVector is: /// { 2, 7, 8, 5 } - llvm::SetVector findValueInReverseUseDefChain( + SetVector findValueInReverseUseDefChain( Value value, llvm::function_ref condition) const; - /// Find the Value of the last preceding write of a given Value. + /// Find the Values of the last preceding write of a given Value. /// /// Note: Unknown ops are handled conservatively and assumed to be writes. /// Furthermore, BlockArguments are also assumed to be writes. There is no @@ -357,7 +357,7 @@ /// /// Note: When reaching an end of the reverse SSA use-def chain, that value /// is returned regardless of whether it is a memory write or not. - Value findLastPrecedingWrite(Value value) const; + SetVector findLastPrecedingWrite(Value value) const; /// Creates a memref allocation. FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h @@ -31,7 +31,7 @@ namespace std_ext { -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +void registerModuleBufferizationExternalModels(DialectRegistry ®istry); } // namespace std_ext } // namespace comprehensive_bufferize diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h copy from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h copy to mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h @@ -1,4 +1,4 @@ -//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===// +//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_AFFINE_INTERFACE_IMPL_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_AFFINE_INTERFACE_IMPL_H +#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H +#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H namespace mlir { @@ -15,13 +15,13 @@ namespace linalg { namespace comprehensive_bufferize { -namespace affine_ext { +namespace std_ext { void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); -} // namespace affine_ext +} // namespace std_ext } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_AFFINE_INTERFACE_IMPL_H +#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -305,26 +305,18 @@ return result; } -// Find the Value of the last preceding write of a given Value. -Value mlir::linalg::comprehensive_bufferize::BufferizationState:: - findLastPrecedingWrite(Value value) const { - SetVector result = - findValueInReverseUseDefChain(value, [&](Value value) { - Operation *op = value.getDefiningOp(); - if (!op) - return true; - auto bufferizableOp = options.dynCastBufferizableOp(op); - if (!bufferizableOp) - return true; - return bufferizableOp.isMemoryWrite(value.cast(), *this); - }); - - // To simplify the analysis, `scf.if` ops are considered memory writes. There - // are currently no other ops where one OpResult may alias with multiple - // OpOperands. Therefore, this function should return exactly one result at - // the moment. - assert(result.size() == 1 && "expected exactly one result"); - return result.front(); +// Find the Values of the last preceding write of a given Value. +llvm::SetVector mlir::linalg::comprehensive_bufferize:: + BufferizationState::findLastPrecedingWrite(Value value) const { + return findValueInReverseUseDefChain(value, [&](Value value) { + Operation *op = value.getDefiningOp(); + if (!op) + return true; + auto bufferizableOp = options.dynCastBufferizableOp(op); + if (!bufferizableOp) + return true; + return bufferizableOp.isMemoryWrite(value.cast(), *this); + }); } mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState( @@ -404,15 +396,19 @@ createAlloc(rewriter, loc, operandBuffer, options.createDeallocs); if (failed(resultBuffer)) return failure(); - // Do not copy if the last preceding write of `operand` is an op that does + // Do not copy if the last preceding writes of `operand` are ops that do // not write (skipping ops that merely create aliases). E.g., InitTensorOp. // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA // use-def chain, it returns that value, regardless of whether it is a // memory write or not. - Value lastWrite = findLastPrecedingWrite(operand); - if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) - if (!bufferizableOp.isMemoryWrite(lastWrite.cast(), *this)) - return resultBuffer; + SetVector lastWrites = findLastPrecedingWrite(operand); + if (llvm::none_of(lastWrites, [&](Value lastWrite) { + if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) + return bufferizableOp.isMemoryWrite(lastWrite.cast(), + *this); + return true; + })) + return resultBuffer; // Do not copy if the copied data is never read. OpResult aliasingOpResult = getAliasingOpResult(opOperand); if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) && diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -7,6 +7,7 @@ LinalgInterfaceImpl.cpp ModuleBufferization.cpp SCFInterfaceImpl.cpp + StdInterfaceImpl.cpp TensorInterfaceImpl.cpp VectorInterfaceImpl.cpp ) @@ -61,6 +62,14 @@ MLIRSCF ) +add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl + StdInterfaceImpl.cpp + + LINK_LIBS PUBLIC + MLIRBufferizableOpInterface + MLIRStandard +) + add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl TensorInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -219,7 +219,8 @@ for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); - // Find most recent write of uRead by following the SSA use-def chain. E.g.: + // Find most recent writes of uRead by following the SSA use-def chain. + // E.g.: // // %0 = "writing_op"(%t) : tensor -> tensor // %1 = "aliasing_op"(%0) : tensor -> tensor @@ -228,7 +229,7 @@ // In the above example, if uRead is the OpOperand of reading_op, lastWrite // is %0. Note that operations that create an alias but do not write (such // as ExtractSliceOp) are skipped. - Value lastWrite = state.findLastPrecedingWrite(uRead->get()); + SetVector lastWrites = state.findLastPrecedingWrite(uRead->get()); // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. @@ -265,35 +266,38 @@ if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) continue; - // No conflict if the conflicting write happens before the last - // write. - if (Operation *writingOp = lastWrite.getDefiningOp()) { - if (happensBefore(conflictingWritingOp, writingOp, domInfo)) - // conflictingWritingOp happens before writingOp. No conflict. - continue; - // No conflict if conflictingWritingOp is contained in writingOp. - if (writingOp->isProperAncestor(conflictingWritingOp)) - continue; - } else { - auto bbArg = lastWrite.cast(); - Block *block = bbArg.getOwner(); - if (!block->findAncestorOpInBlock(*conflictingWritingOp)) - // conflictingWritingOp happens outside of the block. No - // conflict. - continue; - } + // Check all possible last writes. + for (Value lastWrite : lastWrites) { + // No conflict if the conflicting write happens before the last + // write. + if (Operation *writingOp = lastWrite.getDefiningOp()) { + if (happensBefore(conflictingWritingOp, writingOp, domInfo)) + // conflictingWritingOp happens before writingOp. No conflict. + continue; + // No conflict if conflictingWritingOp is contained in writingOp. + if (writingOp->isProperAncestor(conflictingWritingOp)) + continue; + } else { + auto bbArg = lastWrite.cast(); + Block *block = bbArg.getOwner(); + if (!block->findAncestorOpInBlock(*conflictingWritingOp)) + // conflictingWritingOp happens outside of the block. No + // conflict. + continue; + } - // No conflict if the conflicting write and the last write are the same - // use. - if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite) - continue; + // No conflict if the conflicting write and the last write are the same + // use. + if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite) + continue; - // All requirements are met. Conflict found! + // All requirements are met. Conflict found! - if (options.printConflicts) - annotateConflict(uRead, uConflictingWrite, lastWrite); + if (options.printConflicts) + annotateConflict(uRead, uConflictingWrite, lastWrite); - return true; + return true; + } } } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -938,7 +938,7 @@ } // namespace mlir void mlir::linalg::comprehensive_bufferize::std_ext:: - registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registerModuleBufferizationExternalModels(DialectRegistry ®istry) { registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp @@ -0,0 +1,79 @@ +//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" + +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace linalg { +namespace comprehensive_bufferize { +namespace std_ext { + +/// Bufferization of std.select. Just replace the operands. +struct SelectOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return op->getOpResult(0) /*result*/; + } + + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + const BufferizationState &state) const { + return {&op->getOpOperand(1) /*true_value*/, + &op->getOpOperand(2) /*false_value*/}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto selectOp = cast(op); + // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. + // TODO: It would be more efficient to copy the result of the `select` op + // instead of its OpOperands. In the worst case, 2 copies are inserted at + // the moment (one for each tensor). When copying the op result, only one + // copy would be needed. + Value trueBuffer = + *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); + Value falseBuffer = + *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); + replaceOpWithNewBufferizedOp( + rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); + return success(); + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo, + const BufferizationState &state) const { + return BufferRelation::None; + } +}; + +} // namespace std_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +void mlir::linalg::comprehensive_bufferize::std_ext:: + registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addOpInterface(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -49,6 +49,7 @@ MLIRSCF MLIRSCFBufferizableOpInterfaceImpl MLIRSCFTransforms + MLIRStdBufferizableOpInterfaceImpl MLIRPass MLIRStandard MLIRStandardOpsTransforms diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -51,6 +52,7 @@ bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + std_ext::registerModuleBufferizationExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry); tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -1710,3 +1710,84 @@ } return %1: tensor } + +// ----- + +// CHECK-LABEL: func @write_after_select_read_one +// CHECK-SAME: %[[t1:.*]]: tensor {{.*}}, %[[t2:.*]]: tensor +func @write_after_select_read_one( + %t1 : tensor {linalg.inplaceable = true}, + %t2 : tensor {linalg.inplaceable = true}, + %c : i1) + -> (f32, tensor) +{ + %cst = arith.constant 0.0 : f32 + %idx = arith.constant 0 : index + + // CHECK: select %{{.*}}, %[[t1]], %[[t2]] + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "true"]} + %s = std.select %c, %t1, %t2 : tensor + // CHECK: tensor.insert + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]} + %w = tensor.insert %cst into %s[%idx] : tensor + // CHECK: tensor.extract + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]} + %f = tensor.extract %t1[%idx] : tensor + + return %f, %w : f32, tensor +} + +// ----- + +// CHECK-LABEL: func @write_after_select_read_both +// CHECK-SAME: %[[t1:.*]]: tensor {{.*}}, %[[t2:.*]]: tensor +func @write_after_select_read_both( + %t1 : tensor {linalg.inplaceable = true}, + %t2 : tensor {linalg.inplaceable = true}, + %c : i1) + -> (f32, f32, tensor) +{ + %cst = arith.constant 0.0 : f32 + %idx = arith.constant 0 : index + + // CHECK: select %{{.*}}, %[[t1]], %[[t2]] + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "false"]} + %s = std.select %c, %t1, %t2 : tensor + // CHECK: tensor.insert + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]} + %w = tensor.insert %cst into %s[%idx] : tensor + // CHECK: tensor.extract + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]} + %f = tensor.extract %t1[%idx] : tensor + // CHECK: tensor.extract + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]} + %f2 = tensor.extract %t2[%idx] : tensor + + return %f, %f2, %w : f32, f32, tensor +} + +// ----- + +// CHECK-LABEL: func @write_after_select_no_conflict +// CHECK-SAME: %[[t1:.*]]: tensor {{.*}}, %[[t2:.*]]: tensor +func @write_after_select_no_conflict( + %t1 : tensor {linalg.inplaceable = true}, + %t2 : tensor {linalg.inplaceable = true}, + %c : i1) + -> (f32, tensor) +{ + %cst = arith.constant 0.0 : f32 + %idx = arith.constant 0 : index + + // CHECK: select %{{.*}}, %[[t1]], %[[t2]] + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} + %s = std.select %c, %t1, %t2 : tensor + // CHECK: tensor.insert + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]} + %w = tensor.insert %cst into %s[%idx] : tensor + // CHECK: tensor.extract + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]} + %f = tensor.extract %w[%idx] : tensor + + return %f, %w : f32, tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1227,7 +1227,7 @@ // InitTensorOp elimination would produce SSA violations for the example below. //===----------------------------------------------------------------------===// -func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>) +func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>) -> tensor { %c0 = arith.constant 0 : index %c32 = arith.constant 32 : index @@ -1243,3 +1243,54 @@ } return %3 : tensor } + +// ----- + +// CHECK-LABEL: func @write_to_select_op_source +// CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref +func @write_to_select_op_source( + %t1 : tensor {linalg.inplaceable = true}, + %t2 : tensor {linalg.inplaceable = true}, + %c : i1) + -> (tensor, tensor) +{ + %cst = arith.constant 0.0 : f32 + %idx = arith.constant 0 : index + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: linalg.copy(%[[t1]], %[[alloc]]) + // CHECK: memref.store %{{.*}}, %[[alloc]] + %w = tensor.insert %cst into %t1[%idx] : tensor + // CHECK: %[[select:.*]] = select %{{.*}}, %[[t1]], %[[t2]] + %s = std.select %c, %t1, %t2 : tensor + // CHECK: return %[[select]], %[[alloc]] + return %s, %w : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @write_after_select_read_one +// CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref +func @write_after_select_read_one( + %t1 : tensor {linalg.inplaceable = true}, + %t2 : tensor {linalg.inplaceable = true}, + %c : i1) + -> (f32, tensor) +{ + %cst = arith.constant 0.0 : f32 + %idx = arith.constant 0 : index + + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: linalg.copy(%[[t1]], %[[alloc]]) + // CHECK: %[[select:.*]] = select %{{.*}}, %[[casted]], %[[t2]] + %s = std.select %c, %t1, %t2 : tensor + + // CHECK: memref.store %{{.*}}, %[[select]] + %w = tensor.insert %cst into %s[%idx] : tensor + + // CHECK: %[[f:.*]] = memref.load %[[t1]] + %f = tensor.extract %t1[%idx] : tensor + + // CHECK: return %[[f]], %[[select]] + return %f, %w : f32, tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -28,6 +28,7 @@ MLIRPass MLIRSCF MLIRSCFBufferizableOpInterfaceImpl + MLIRStdBufferizableOpInterfaceImpl MLIRStandard MLIRTensor MLIRTensorBufferizableOpInterfaceImpl diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -55,13 +56,14 @@ void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); affine_ext::registerBufferizableOpInterfaceExternalModels(registry); arith_ext::registerBufferizableOpInterfaceExternalModels(registry); bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + std_ext::registerBufferizableOpInterfaceExternalModels(registry); tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6484,6 +6484,24 @@ ], ) +cc_library( + name = "StdBufferizableOpInterfaceImpl", + srcs = [ + "lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h", + ], + includes = ["include"], + deps = [ + ":BufferizableOpInterface", + ":IR", + ":StandardOps", + ":Support", + "//llvm:Support", + ], +) + cc_library( name = "TensorBufferizableOpInterfaceImpl", srcs = [ @@ -6743,6 +6761,7 @@ ":SCFTransforms", ":StandardOps", ":StandardOpsTransforms", + ":StdBufferizableOpInterfaceImpl", ":Support", ":TensorBufferizableOpInterfaceImpl", ":TensorDialect", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -399,6 +399,7 @@ "//mlir:SCFDialect", "//mlir:SCFTransforms", "//mlir:StandardOps", + "//mlir:StdBufferizableOpInterfaceImpl", "//mlir:TensorBufferizableOpInterfaceImpl", "//mlir:TensorDialect", "//mlir:TransformUtils",