diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h @@ -0,0 +1,27 @@ +//===- BufferizationInterfaceImpl.h - Bufferization Impl. of Op Interface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H +#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace linalg { +namespace comprehensive_bufferize { +namespace bufferization_ext { + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +} // namespace bufferization_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -416,10 +416,6 @@ BufferizationState &state) { OpBuilder b(op->getContext()); - // Skip ToMemrefOp and ToTensorOp. - if (isa(op)) - return success(); - // Check if op has tensor results or operands. auto isaTensor = [](Type t) { return t.isa(); }; bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -0,0 +1,101 @@ +//===- BufferizationInterfaceImpl.cpp - Bufferization Impl. of Interface --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace linalg; +using namespace comprehensive_bufferize; + +namespace mlir { +namespace linalg { +namespace comprehensive_bufferize { +namespace bufferization_ext { + +// TODO: These ops should implement BufferizableOpInterface directly when moved +// to the Bufferization dialect. + +// TODO: These implementations are conservative and will likely have to be +// loosened for partial bufferization. + +/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory +/// location of the incoming tensor once it will be bufferized. In the anlysis, +/// the incoming tensor is assumed to bufferize to a memory read and to an +/// inplace memory write, since it is unknown what will happen to the resulting +/// memref. +struct ToMemrefOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + // It is unknown whether the resulting MemRef will be read or not. + return true; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {}; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + return success(); + } +}; + +/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do +/// not lower any further, and they should have disappeared by the time the +/// input is fully bufferized. +/// +/// The analysis has no information about the memref that is loaded from by the +/// ToTensorOp. We have to assume that the loaded tensor may after bufferization +/// potentially alias with any other bufferized tensor. Since ToTensorOp and +/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded +/// directly in the analysis. However, declaring ToTensorOp results as not +/// writable also enforces a buffer copy and has the same effect. +struct ToTensorOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {}; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto tensorLoadOp = cast(op); + state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref()); + return success(); + } + + bool isWritable(Operation *op, Value value) const { + // It is unknown whether the MemRef operand is writable or not. + return false; + } +}; + +} // namespace bufferization_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +void mlir::linalg::comprehensive_bufferize::bufferization_ext:: + registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -2,6 +2,7 @@ AffineInterfaceImpl.cpp ArithInterfaceImpl.cpp BufferizableOpInterface.cpp + BufferizationInterfaceImpl.cpp ComprehensiveBufferize.cpp LinalgInterfaceImpl.cpp ModuleBufferization.cpp @@ -80,6 +81,7 @@ ) add_mlir_dialect_library(MLIRComprehensiveBufferize + BufferizationInterfaceImpl.cpp ComprehensiveBufferize.cpp ModuleBufferization.cpp diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -239,6 +239,12 @@ /// Return true if opOperand has been decided to bufferize in-place. static bool isInplaceMemoryWrite(OpOperand &opOperand, const BufferizationAliasInfo &aliasInfo) { + // The analysis does not know what happens to the result of a ToMemrefOp, so + // we assume that it is written to. + // TODO: This is a conservative implementation. This rule will have to be + // relaxed for partial bufferization. + if (isa(opOperand.getOwner())) + return true; // OpOperands without an aliasing OpResult do not write. OpResult opResult = getAliasingOpResult(opOperand); if (!opResult) @@ -453,14 +459,23 @@ /// If `checkConsistencyOnly` is true, this function checks if there is a /// read-after-write conflict without bufferizing `operand` inplace. This would /// indicate a problem with the current inplace bufferization decisions. +/// +/// Note: If `checkConsistencyOnly`, this function may be called with a null +/// OpResult. In that case, only the consistency of bufferization decisions +/// involving aliases of the given OpOperand are checked. bool wouldCreateReadAfterWriteInterference( OpOperand &operand, OpResult result, const DominanceInfo &domInfo, const BufferizationAliasInfo &aliasInfo, bool checkConsistencyOnly = false) { #ifndef NDEBUG - SmallVector opOperands = getAliasingOpOperand(result); - assert(llvm::find(opOperands, &operand) != opOperands.end() && - "operand and result do not match"); + if (result) { + SmallVector opOperands = getAliasingOpOperand(result); + assert(llvm::find(opOperands, &operand) != opOperands.end() && + "operand and result do not match"); + } else { + assert(checkConsistencyOnly && + "result not provided, can only check consistency"); + } #endif // NDEBUG // Helper function to iterate on aliases of `root` and capture the reads. @@ -486,9 +501,11 @@ // Collect reads and writes of all aliases of OpOperand and OpResult. DenseSet usesRead, usesWrite; getAliasingReads(usesRead, operand.get()); - getAliasingReads(usesRead, result); + if (result) + getAliasingReads(usesRead, result); getAliasingInplaceWrites(usesWrite, operand.get()); - getAliasingInplaceWrites(usesWrite, result); + if (result) + getAliasingInplaceWrites(usesWrite, result); if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); @@ -673,25 +690,38 @@ return res; } -#ifndef NDEBUG /// Assert that the current bufferization decisions are consistent. -static void checkAliasInfoConsistency(FuncOp funcOp, - const DominanceInfo &domInfo, - const BufferizationAliasInfo &aliasInfo) { - funcOp.walk([&](Operation *op) { +static LogicalResult +checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo, + const BufferizationAliasInfo &aliasInfo) { + Operation *inconsistentOp = nullptr; + WalkResult walkResult = funcOp.walk([&](Operation *op) { if (auto bufferizableOp = dyn_cast(op)) for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) - if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand)) - // If this assertion fails, there is probably an inconsistent - // combination of "mustBufferizeInPlace" decisions. - assert(!wouldCreateReadAfterWriteInterference( - opOperand, opResult, domInfo, aliasInfo, - /*checkConsistencyOnly=*/true) && - "found read after write conflict before running analysis"); + if (opOperand.get().getType().isa()) { + OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand); + if (wouldCreateReadAfterWriteInterference( + opOperand, opResult, domInfo, aliasInfo, + /*checkConsistencyOnly=*/true)) { + // This error can happen for two reasons. Either the input IR + // already has a read-after-write conflict. Or certain + // "mustBufferizeInPlace" interface methods are implemented + // incorrectly. + inconsistentOp = op; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); }); + + if (walkResult.wasInterrupted()) + // This can currently happen in one situation: When a tensor is passed into + // a ToMemrefOp and read by another op consecutively. ToMemrefOps are + // currently handled conservatively. Once a tensor is passed into a + // ToMemrefOp, it may longer be read. + return inconsistentOp->emitError("input IR has RaW conflict"); + return success(); } -#endif /// Annotate the IR with the result of the analysis. For testing/debugging only. static void @@ -720,9 +750,8 @@ if (funcOp.body().empty()) return success(); -#ifndef NDEBUG - checkAliasInfoConsistency(funcOp, domInfo, aliasInfo); -#endif // NDEBUG + if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo))) + return failure(); // If the analysis fails, just return. if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" @@ -47,6 +48,7 @@ arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>(); affine_ext::registerBufferizableOpInterfaceExternalModels(registry); arith_ext::registerBufferizableOpInterfaceExternalModels(registry); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -1492,3 +1492,44 @@ %0 = call @some_use(%A, %v) : (tensor, vector<5xf32>) -> (tensor) return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @to_tensor_op_not_writable +func @to_tensor_op_not_writable(%m: memref, %v: vector<5xf32>, + %idx1: index, %idx2: index) + -> vector<10xf32> { + %0 = bufferization.to_tensor %m : memref + + // Write to the tensor. Cannot be inplace due to tensor_load. + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %w = vector.transfer_write %v, %0[%idx1] : vector<5xf32>, tensor + + // Read from the tensor and return result. + %cst = arith.constant 0.0 : f32 + %r = vector.transfer_read %w[%idx2], %cst : tensor, vector<10xf32> + return %r : vector<10xf32> +} + +// ----- + +// CHECK-LABEL: func @to_memref_op_is_reading +func @to_memref_op_is_reading(%t1: tensor {linalg.inplaceable = true}, + %idx1: index, %idx2: index, %idx3: index, + %v1: vector<5xf32>) + -> (vector<5xf32>, vector<5xf32>) { + // Write + read to/from tensor. + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %1 = vector.transfer_write %v1, %t1[%idx2] : vector<5xf32>, tensor + %cst = arith.constant 0.0 : f32 + %r1 = vector.transfer_read %1[%idx3], %cst : tensor, vector<5xf32> + + // Write + read to/from same memref. + %0 = bufferization.to_memref %t1 : memref + vector.transfer_write %v1, %0[%idx1] : vector<5xf32>, memref + %r2 = vector.transfer_read %0[%idx3], %cst : memref, vector<5xf32> + + return %r1, %r2 : vector<5xf32>, vector<5xf32> +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -167,3 +167,23 @@ } return %r: tensor<4xi32> } + +// ----- + +func @to_memref_op_is_writing( + %t1: tensor {linalg.inplaceable = true}, %idx1: index, + %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) { + // This is a RaW conflict because to_memref is an inplace write and %t1 is + // read further down. This will likely have to change with partial + // bufferization. + + // expected-error @+1 {{input IR has RaW conflict}} + %0 = bufferization.to_memref %t1 : memref + + // Read from both. + %cst = arith.constant 0.0 : f32 + %r1 = vector.transfer_read %t1[%idx3], %cst : tensor, vector<5xf32> + %r2 = vector.transfer_read %0[%idx3], %cst : memref, vector<5xf32> + + return %r1, %r2 : vector<5xf32>, vector<5xf32> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6673,10 +6673,12 @@ cc_library( name = "ComprehensiveBufferize", srcs = [ + "lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp", "lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp", "lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp", ], hdrs = [ + "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h", "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h", "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h", ],