diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,21 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace memref { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -45,6 +45,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" @@ -131,6 +132,7 @@ registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); + memref::registerBufferizableOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,62 @@ +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace { +/// Bufferization of memref.tensor_store. Replace with memref.copy. +struct TensorStoreOpInterface + : public BufferizableOpInterface::ExternalModel { + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(opOperand.getOperandNumber() == 0 && "expected src operand"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // The memref operand is written but not the tensor operand. + assert(opOperand.getOperandNumber() == 0 && "expected src operand"); + return false; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto tensorStoreOp = cast(op); + auto srcBuffer = getBuffer(rewriter, tensorStoreOp.getTensor(), options); + if (failed(srcBuffer)) + return failure(); + rewriter.create(op->getLoc(), *srcBuffer, + tensorStoreOp.getMemref()); + rewriter.eraseOp(tensorStoreOp); + return success(); + } +}; + +} // namespace + +void mlir::memref::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, MemRefDialect *dialect) { + TensorStoreOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRMemRefTransforms + BufferizableOpInterfaceImpl.cpp ComposeSubView.cpp ExpandOps.cpp ExpandStridedMetadata.cpp @@ -20,6 +21,7 @@ MLIRAffineUtils MLIRArithDialect MLIRArithTransforms + MLIRBufferizationDialect MLIRFuncDialect MLIRInferTypeOpInterface MLIRLoopLikeInterface diff --git a/mlir/test/Dialect/Linalg/bufferize-non-dps-ops.mlir b/mlir/test/Dialect/Linalg/bufferize-non-dps-ops.mlir --- a/mlir/test/Dialect/Linalg/bufferize-non-dps-ops.mlir +++ b/mlir/test/Dialect/Linalg/bufferize-non-dps-ops.mlir @@ -2,6 +2,11 @@ // RUN: -test-linalg-transform-patterns=test-bufferize-non-dps-ops-patterns \ // RUN: -canonicalize %s | FileCheck %s +// RUN: mlir-opt -split-input-file \ +// RUN: -test-linalg-transform-patterns=test-bufferize-non-dps-ops-patterns \ +// RUN: -canonicalize -one-shot-bufferize %s | \ +// RUN: FileCheck %s --check-prefix=CHECK-BUFFERIZE-ALL + // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> // CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> // CHECK-LABEL: func @tensor_pad_constant( @@ -18,6 +23,14 @@ // CHECK: memref.tensor_store %[[t]], %[[subview]] // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable : memref // CHECK: return %[[r]] + +// CHECK-BUFFERIZE-ALL-LABEL: func @tensor_pad_constant( +// CHECK-BUFFERIZE-ALL-SAME: %[[t:.*]]: tensor +// CHECK-BUFFERIZE-ALL: %[[src:.*]] = bufferization.to_memref %[[t]] +// CHECK-BUFFERIZE-ALL: %[[alloc:.*]] = memref.alloc +// CHECK-BUFFERIZE-ALL: %[[subview:.*]] = memref.subview %[[alloc]] +// CHECK-BUFFERIZE-ALL: memref.copy %[[src]], %[[subview]] +// CHECK-BUFFERIZE-ALL: bufferization.to_tensor %[[alloc]] restrict writable func.func @tensor_pad_constant(%t: tensor, %l2: index, %h1: index, %h2: index) -> tensor { %0 = tensor.pad %t low[5, %l2] high[%h1, %h2] { diff --git a/mlir/test/Dialect/MemRef/bufferize.mlir b/mlir/test/Dialect/MemRef/bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/bufferize.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt -one-shot-bufferize %s | FileCheck %s + +// CHECK-LABEL: func @tensor_store( +// CHECK-SAME: %[[t:.*]]: tensor, %[[m:.*]]: memref +// CHECK: %[[src:.*]] = bufferization.to_memref %[[t]] +// CHECK: memref.copy %[[src]], %[[m]] +// CHECK: return +func.func @tensor_store(%t: tensor, %m: memref) { + memref.tensor_store %t, %m : memref + return +} \ No newline at end of file diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9898,6 +9898,7 @@ ":ArithDialect", ":ArithTransforms", ":ArithUtils", + ":BufferizationDialect", ":ControlFlowDialect", ":DialectUtils", ":FuncDialect",