diff --git a/mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(TransformOps) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h @@ -0,0 +1,30 @@ +//===- BufferizationTransformOps.h - Buff. transf. ops ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// Bufferization Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace bufferization { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace bufferization +} // namespace mlir + +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -0,0 +1,58 @@ +//===- BufferizationTransformOps.td - Buff. transf. ops ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BUFFERIZATION_TRANSFORM_OPS +#define BUFFERIZATION_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformEffects.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +def OneShotBufferizeOp + : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Indicates that the given `target` op should be bufferized with One-Shot + Bufferize. The bufferization can be configured with various attributes that + corresponding to options in `BufferizationOptions` and the + `one-shot-bufferize` pass. More information can be found in the pass + documentation. + + If `target_is_module` is set, `target` must be a module. In that case the + `target` handle can be reused by other transform ops. When bufferizing other + ops, the `target` handled is freed after bufferization and can no longer be + used. + + Note: Only ops that implement `BufferizableOpInterface` are bufferized. All + other ops are ignored if `allow_unknown_ops`. If `allow_unknown_ops` is + unset, this transform fails when an unknown/non-bufferizable op is found. + Many ops implement `BufferizableOpInterface` via an external model. These + external models must be registered when applying this transform op; + otherwise, said ops would be considered non-bufferizable. + }]; + + let arguments = ( + ins PDL_Operation:$target, + DefaultValuedAttr:$allow_return_allocs, + DefaultValuedAttr:$allow_unknown_ops, + DefaultValuedAttr:$bufferize_function_boundaries, + DefaultValuedAttr:$create_deallocs, + DefaultValuedAttr:$target_is_module, + DefaultValuedAttr:$test_analysis_only, + DefaultValuedAttr:$print_conflicts); + + let results = (outs); + + let assemblyFormat = "$target attr-dict"; +} + +#endif // BUFFERIZATION_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS BufferizationTransformOps.td) +mlir_tablegen(BufferizationTransformOps.h.inc -gen-op-decls) +mlir_tablegen(BufferizationTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRBufferizationTransformOpsIncGen) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" @@ -107,6 +108,7 @@ // clang-format on // Register all dialect extensions. + bufferization::registerTransformDialectExtension(registry); linalg::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(TransformOps) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -0,0 +1,96 @@ +//===- BufferizationTransformOps.h - Bufferization transform ops ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace mlir::transform; + +//===----------------------------------------------------------------------===// +// OneShotBufferizeOp +//===----------------------------------------------------------------------===// + +LogicalResult +transform::OneShotBufferizeOp::apply(TransformResults &transformResults, + TransformState &state) { + OneShotBufferizationOptions options; + options.allowReturnAllocs = getAllowReturnAllocs(); + options.allowUnknownOps = getAllowUnknownOps(); + options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); + options.createDeallocs = getCreateDeallocs(); + options.testAnalysisOnly = getTestAnalysisOnly(); + options.printConflicts = getPrintConflicts(); + + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + for (Operation *target : payloadOps) { + auto moduleOp = dyn_cast(target); + if (getTargetIsModule() && !moduleOp) + return emitError("expected ModuleOp target"); + if (options.bufferizeFunctionBoundaries) { + if (!moduleOp) + return emitError("expected ModuleOp target"); + if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) + return emitError("bufferization failed"); + } else { + if (failed(bufferization::runOneShotBufferize(target, options))) + return emitError("bufferization failed"); + } + } + + return success(); +} + +void transform::OneShotBufferizeOp::getEffects( + SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getTarget(), + TransformMappingResource::get()); + + // Handles that are not modules are not longer usable. + if (!getTargetIsModule()) + effects.emplace_back(MemoryEffects::Free::get(), getTarget(), + TransformMappingResource::get()); +} +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +/// Registers new ops and declares PDL as dependent dialect since the additional +/// ops are using PDL types for operands and results. +class BufferizationTransformDialectExtension + : public transform::TransformDialectExtension< + BufferizationTransformDialectExtension> { +public: + BufferizationTransformDialectExtension() { + declareDependentDialect(); + declareDependentDialect(); + declareDependentDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" + +void mlir::bufferization::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(MLIRBufferizationTransformOps + BufferizationTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization/TransformOps + + DEPENDS + MLIRBufferizationTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRBufferization + MLIRBufferizationTransforms + MLIRParser + MLIRPDL + MLIRSideEffectInterfaces + MLIRTransformDialect + ) diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -0,0 +1,125 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s + +// Test One-Shot Bufferize. + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + transform.bufferization.one_shot_bufferize %0 + {target_is_module = false} + } + + pdl.pattern @pdl_target : benefit(1) { + %0 = operation "func.func" + rewrite %0 with "transform.dialect" + } +} + +// CHECK-LABEL: func @test_function( +// CHECK-SAME: %[[A:.*]]: tensor +func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: memref.copy %[[A_memref]], %[[alloc]] + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + + // CHECK: memref.dealloc %[[alloc]] + // CHECK: return %[[res_tensor]] + return %0 : tensor +} + +// ----- + +// Test analysis of One-Shot Bufferize only. + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + transform.bufferization.one_shot_bufferize %0 + {target_is_module = false, test_analysis_only = true} + } + + pdl.pattern @pdl_target : benefit(1) { + %0 = operation "func.func" + rewrite %0 with "transform.dialect" + } +} + +// CHECK-LABEL: func @test_function_analysis( +// CHECK-SAME: %[[A:.*]]: tensor +func.func @test_function_analysis(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]} + // CHECK-SAME: tensor + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + return %0 : tensor +} + +// ----- + +// Test One-Shot Bufferize transform failure with an unknown op. This would be +// allowed with `allow_unknown_ops`. + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + // expected-error @+1 {{bufferization failed}} + transform.bufferization.one_shot_bufferize %0 {target_is_module = false} + } + + pdl.pattern @pdl_target : benefit(1) { + %0 = operation "func.func" + rewrite %0 with "transform.dialect" + } +} + +func.func @test_unknown_op_failure() -> (tensor) { + // expected-error @+1 {{op was not bufferized}} + %0 = "test.dummy_op"() : () -> (tensor) + return %0 : tensor +} + +// ----- + +// Test One-Shot Bufferize transform failure with a module op. + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + // %arg1 is the module + transform.bufferization.one_shot_bufferize %arg1 + } +} + +module { + // CHECK-LABEL: func @test_function( + // CHECK-SAME: %[[A:.*]]: tensor + func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: memref.copy %[[A_memref]], %[[alloc]] + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + + // CHECK: memref.dealloc %[[alloc]] + // CHECK: return %[[res_tensor]] + return %0 : tensor + } +} \ No newline at end of file diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6295,6 +6295,7 @@ ":AsyncToLLVM", ":AsyncTransforms", ":BufferizationDialect", + ":BufferizationTransformOps", ":BufferizationTransforms", ":ComplexDialect", ":ComplexToLLVM", @@ -8899,6 +8900,61 @@ deps = [":BufferizationOpsTdFiles"], ) +td_library( + name = "BufferizationTransformOpsTdFiles", + srcs = [ + "include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td", + ], + includes = ["include"], + deps = [ + ":PDLDialectTdFiles", + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "BufferizationTransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td", + deps = [ + ":BufferizationTransformOpsTdFiles", + ], +) + +cc_library( + name = "BufferizationTransformOps", + srcs = [ + "lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h", + ], + includes = ["include"], + deps = [ + ":BufferizationDialect", + ":BufferizationTransformOpsIncGen", + ":BufferizationTransforms", + ":IR", + ":MemRefDialect", + ":PDLDialect", + ":Parser", + ":SideEffectInterfaces", + ":TransformDialect", + "//llvm:Support", + ], +) + gentbl_cc_library( name = "BufferizationOpsIncGen", strip_include_prefix = "include",