diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,21 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace memref { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -45,6 +45,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" @@ -131,6 +132,7 @@ registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); + memref::registerBufferizableOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,63 @@ +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace { +/// Bufferization of memref.tensor_store. Replace with memref.copy. +struct TensorStoreOpInterface + : public BufferizableOpInterface::ExternalModel { + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(opOperand.getOperandNumber() == 0 && "expected src operand"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // The memref operand is written but not the tensor operand. + assert(opOperand.getOperandNumber() == 0 && "expected src operand"); + return false; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto tensorStoreOp = cast(op); + auto srcBuffer = getBuffer(rewriter, tensorStoreOp.getTensor(), options); + if (failed(srcBuffer)) + return failure(); + if (failed(options.createMemCpy(rewriter, op->getLoc(), *srcBuffer, + tensorStoreOp.getMemref()))) + return failure(); + rewriter.eraseOp(tensorStoreOp); + return success(); + } +}; + +} // namespace + +void mlir::memref::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, MemRefDialect *dialect) { + TensorStoreOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRMemRefTransforms + BufferizableOpInterfaceImpl.cpp ComposeSubView.cpp ExpandOps.cpp ExpandStridedMetadata.cpp @@ -20,6 +21,7 @@ MLIRAffineUtils MLIRArithDialect MLIRArithTransforms + MLIRBufferizationDialect MLIRFuncDialect MLIRInferTypeOpInterface MLIRLoopLikeInterface diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -38,6 +38,34 @@ // ----- +// CHECK-LABEL: func @tensor_pad_constant( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[src:.*]] = bufferization.to_memref %[[t]] +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: %[[subview:.*]] = memref.subview %[[alloc]] +// CHECK: memref.copy %[[src]], %[[subview]] +// CHECK: bufferization.to_tensor %[[alloc]] restrict writable +func.func @tensor_pad_constant(%t: tensor, %l2: index, %h1: index, + %h2: index) -> tensor { + %0 = tensor.pad %t low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %c = arith.constant 50 : index + tensor.yield %c : index + } : tensor to tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.get_result %0[0] : (!pdl.operation) -> !transform.any_value + %2 = transform.structured.bufferize_to_allocation %1 + // Make sure that One-Shot Bufferize can bufferize the rest. + transform.bufferization.one_shot_bufferize %arg1 +} + +// ----- + // CHECK-LABEL: func @materialization_of_bbarg( // CHECK-SAME: %[[t:.*]]: tensor // CHECK: %[[c0:.*]] = arith.constant 0 : index @@ -59,3 +87,26 @@ %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} } + +// ----- + +// CHECK-LABEL: func @materialization_of_bbarg( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] +// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref +// CHECK: memref.copy %[[m]], %[[alloc]] +// CHECK: %[[r:.*]] = memref.load %[[alloc]] +// CHECK: return %[[r]] +func.func @materialization_of_bbarg(%t: tensor, %idx: index) -> index { + %r = tensor.extract %t[%idx, %idx] : tensor + return %r : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value + %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} + // Make sure that One-Shot Bufferize can bufferize the rest. + transform.bufferization.one_shot_bufferize %arg1 +} diff --git a/mlir/test/Dialect/MemRef/bufferize.mlir b/mlir/test/Dialect/MemRef/bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/bufferize.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt -one-shot-bufferize %s | FileCheck %s + +// CHECK-LABEL: func @tensor_store( +// CHECK-SAME: %[[t:.*]]: tensor, %[[m:.*]]: memref +// CHECK: %[[src:.*]] = bufferization.to_memref %[[t]] +// CHECK: memref.copy %[[src]], %[[m]] +// CHECK: return +func.func @tensor_store(%t: tensor, %m: memref) { + memref.tensor_store %t, %m : memref + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9904,6 +9904,7 @@ ":ArithDialect", ":ArithTransforms", ":ArithUtils", + ":BufferizationDialect", ":ControlFlowDialect", ":DialectUtils", ":FuncDialect",