diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -10,14 +10,18 @@ #define MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES_H_ #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Bufferize.h" namespace mlir { +namespace bufferization { +class BufferizeTypeConverter; +} // end namespace bufferization + namespace arith { /// Add patterns to bufferize Arithmetic ops. -void populateArithmeticBufferizePatterns(BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateArithmeticBufferizePatterns( + bufferization::BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); /// Create a pass to bufferize Arithmetic ops. std::unique_ptr createArithmeticBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h rename from mlir/include/mlir/Transforms/Bufferize.h rename to mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -20,19 +20,13 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_TRANSFORMS_BUFFERIZE_H -#define MLIR_TRANSFORMS_BUFFERIZE_H - -#include "mlir/Analysis/BufferViewFlowAnalysis.h" -#include "mlir/Analysis/Liveness.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/Operation.h" +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERIZE_H +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERIZE_H + #include "mlir/Transforms/DialectConversion.h" namespace mlir { +namespace bufferization { /// A helper type converter class that automatically populates the relevant /// materializations and type conversions for bufferization. @@ -58,6 +52,7 @@ void populateEliminateBufferizeMaterializationsPatterns( BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns); +} // end namespace bufferization } // end namespace mlir -#endif // MLIR_TRANSFORMS_BUFFERIZE_H +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERIZE_H diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Bufferization) +add_public_tablegen_target(MLIRBufferizationPassIncGen) +add_dependencies(mlir-headers MLIRBufferizationPassIncGen) + +add_mlir_doc(Passes BufferizationPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -0,0 +1,32 @@ +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace bufferization { + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// + +/// Creates an instance of the BufferDeallocation pass to free all allocated +/// buffers. +std::unique_ptr createBufferDeallocationPass(); + +/// Creates a pass that finalizes a partial bufferization by removing remaining +/// bufferization.to_tensor and bufferization.to_memref operations. +std::unique_ptr createFinalizingBufferizePass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" + +} // end namespace bufferization +} // end namespace mlir + +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -0,0 +1,107 @@ +//===-- Passes.td - Bufferization passes definition file ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def BufferDeallocation : FunctionPass<"buffer-deallocation"> { + let summary = "Adds all required dealloc operations for all allocations in " + "the input program"; + let description = [{ + This pass implements an algorithm to automatically introduce all required + deallocation operations for all buffers in the input program. This ensures + that the resulting program does not have any memory leaks. + + + Input + + ```mlir + #map0 = affine_map<(d0) -> (d0)> + module { + func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2 + ^bb1: + br ^bb3(%arg1 : memref<2xf32>) + ^bb2: + %0 = alloc() : memref<2xf32> + linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + br ^bb3(%0 : memref<2xf32>) + ^bb3(%1: memref<2xf32>): + "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return + } + } + + ``` + + Output + + ```mlir + #map0 = affine_map<(d0) -> (d0)> + module { + func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + %0 = alloc() : memref<2xf32> + linalg.copy(%arg1, %0) : memref<2xf32>, memref<2xf32> + br ^bb3(%0 : memref<2xf32>) + ^bb2: // pred: ^bb0 + %1 = alloc() : memref<2xf32> + linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"]} %arg1, %1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %4 = exp %arg3 : f32 + linalg.yield %4 : f32 + }: memref<2xf32>, memref<2xf32> + %2 = alloc() : memref<2xf32> + linalg.copy(%1, %2) : memref<2xf32>, memref<2xf32> + dealloc %1 : memref<2xf32> + br ^bb3(%2 : memref<2xf32>) + ^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2 + linalg.copy(%3, %arg2) : memref<2xf32>, memref<2xf32> + dealloc %3 : memref<2xf32> + return + } + + } + ``` + + }]; + let constructor = "mlir::bufferization::createBufferDeallocationPass()"; +} + +def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> { + let summary = "Finalize a partial bufferization"; + let description = [{ + A bufferize pass that finalizes a partial bufferization by removing + remaining `bufferization.to_tensor` and `bufferization.to_buffer` operations. + + The removal of those operations is only possible if the operations only + exist in pairs, i.e., all uses of `bufferization.to_tensor` operations are + `bufferization.to_buffer` operations. + + This pass will fail if not all operations can be removed or if any operation + with tensor typed operands remains. + }]; + let constructor = "mlir::bufferization::createFinalizingBufferizePass()"; +} + +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -18,12 +18,15 @@ #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/Bufferize.h" +#include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallSet.h" namespace mlir { +namespace bufferization { class BufferizeTypeConverter; +} // namespace bufferization + class FrozenRewritePatternSet; namespace linalg { @@ -90,8 +93,9 @@ RewritePatternSet &patterns); /// Populates the given list with patterns to bufferize linalg ops. -void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter, - RewritePatternSet &patterns); +void populateLinalgBufferizePatterns( + bufferization::BufferizeTypeConverter &converter, + RewritePatternSet &patterns); /// Create linalg op on buffers given the original tensor-based operation and /// the buffers for the outputs. diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -15,16 +15,19 @@ #define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Bufferize.h" namespace mlir { +namespace bufferization { +class BufferizeTypeConverter; +} // end namespace bufferization class GlobalCreator; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; -void populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateStdBufferizePatterns( + bufferization::BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); /// Creates an instance of std bufferization pass. std::unique_ptr createStdBufferizePass(); @@ -35,7 +38,8 @@ /// Add patterns to bufferize tensor constants into global memrefs to the given /// pattern list. void populateTensorConstantBufferizePatterns( - GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + GlobalCreator &globalCreator, + bufferization::BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns); /// Creates an instance of tensor constant bufferization pass. diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h @@ -10,15 +10,18 @@ #define MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_ #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Bufferize.h" namespace mlir { +namespace bufferization { +class BufferizeTypeConverter; +} // end namespace bufferization class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; -void populateTensorBufferizePatterns(BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateTensorBufferizePatterns( + bufferization::BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); /// Creates an instance of `tensor` dialect bufferization pass. std::unique_ptr createTensorBufferizePass(); diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -18,6 +18,7 @@ #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -54,6 +55,7 @@ registerAffinePasses(); registerAsyncPasses(); arith::registerArithmeticPasses(); + bufferization::registerBufferizationPasses(); registerGPUPasses(); registerGpuSerializeToCubinPass(); registerGpuSerializeToHsacoPass(); diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -33,10 +33,6 @@ // Passes //===----------------------------------------------------------------------===// -/// Creates an instance of the BufferDeallocation pass to free all allocated -/// buffers. -std::unique_ptr createBufferDeallocationPass(); - /// Creates a pass that moves allocations upwards to reduce the number of /// required copies that are inserted during the BufferDeallocation pass. std::unique_ptr createBufferHoistingPass(); @@ -58,10 +54,6 @@ std::unique_ptr createPromoteBuffersToStackPass(std::function isSmallAlloc); -/// Creates a pass that finalizes a partial bufferization by removing remaining -/// tensor_load and buffer_cast operations. -std::unique_ptr createFinalizingBufferizePass(); - /// Creates a pass that converts memref function results to out-params. std::unique_ptr createBufferResultsToOutParamsPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -217,83 +217,6 @@ let constructor = "mlir::createPipelineDataTransferPass()"; } -def BufferDeallocation : FunctionPass<"buffer-deallocation"> { - let summary = "Adds all required dealloc operations for all allocations in the " - "input program"; - let description = [{ - This pass implements an algorithm to automatically introduce all required - deallocation operations for all buffers in the input program. This ensures that - the resulting program does not have any memory leaks. - - - Input - - ```mlir - #map0 = affine_map<(d0) -> (d0)> - module { - func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - cond_br %arg0, ^bb1, ^bb2 - ^bb1: - br ^bb3(%arg1 : memref<2xf32>) - ^bb2: - %0 = alloc() : memref<2xf32> - linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, - indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { - ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): - %tmp1 = exp %gen1_arg0 : f32 - linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> - br ^bb3(%0 : memref<2xf32>) - ^bb3(%1: memref<2xf32>): - "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - return - } - } - - ``` - - Output - - ```mlir - #map0 = affine_map<(d0) -> (d0)> - module { - func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - cond_br %arg0, ^bb1, ^bb2 - ^bb1: // pred: ^bb0 - %0 = alloc() : memref<2xf32> - linalg.copy(%arg1, %0) : memref<2xf32>, memref<2xf32> - br ^bb3(%0 : memref<2xf32>) - ^bb2: // pred: ^bb0 - %1 = alloc() : memref<2xf32> - linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, - indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %1 { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %4 = exp %arg3 : f32 - linalg.yield %4 : f32 - }: memref<2xf32>, memref<2xf32> - %2 = alloc() : memref<2xf32> - linalg.copy(%1, %2) : memref<2xf32>, memref<2xf32> - dealloc %1 : memref<2xf32> - br ^bb3(%2 : memref<2xf32>) - ^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2 - linalg.copy(%3, %arg2) : memref<2xf32>, memref<2xf32> - dealloc %3 : memref<2xf32> - return - } - - } - ``` - - }]; - let constructor = "mlir::createBufferDeallocationPass()"; -} - def BufferHoisting : FunctionPass<"buffer-hoisting"> { let summary = "Optimizes placement of allocation operations by moving them " "into common dominators and out of nested regions"; @@ -416,22 +339,6 @@ ]; } -def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> { - let summary = "Finalize a partial bufferization"; - let description = [{ - A bufferize pass that finalizes a partial bufferization by removing - remaining `memref.tensor_load` and `memref.buffer_cast` operations. - - The removal of those operations is only possible if the operations only - exist in pairs, i.e., all uses of `memref.tensor_load` operations are - `memref.buffer_cast` operations. - - This pass will fail if not all operations can be removed or if any operation - with tensor typed operands remains. - }]; - let constructor = "mlir::createFinalizingBufferizePass()"; -} - def LocationSnapshot : Pass<"snapshot-op-locations"> { let summary = "Generate new locations from the current IR"; let description = [{ diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp @@ -6,10 +6,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" + #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" using namespace mlir; @@ -35,7 +36,7 @@ struct ArithmeticBufferizePass : public ArithmeticBufferizeBase { void runOnFunction() override { - BufferizeTypeConverter typeConverter; + bufferization::BufferizeTypeConverter typeConverter; RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); @@ -57,7 +58,8 @@ } // end anonymous namespace void mlir::arith::populateArithmeticBufferizePatterns( - BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { + bufferization::BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ LINK_LIBS PUBLIC MLIRArithmetic + MLIRBufferizationTransforms MLIRIR MLIRMemRef MLIRPass diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -7,8 +7,11 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp rename from mlir/lib/Transforms/BufferDeallocation.cpp rename to mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -54,14 +54,9 @@ #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Operation.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/BufferUtils.h" -#include "mlir/Transforms/Passes.h" #include "llvm/ADT/SetOperations.h" using namespace mlir; @@ -676,6 +671,6 @@ // BufferDeallocationPass construction //===----------------------------------------------------------------------===// -std::unique_ptr mlir::createBufferDeallocationPass() { +std::unique_ptr mlir::bufferization::createBufferDeallocationPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp rename from mlir/lib/Transforms/Bufferize.cpp rename to mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -6,20 +6,22 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" + #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/IR/Operation.h" -#include "mlir/Transforms/Passes.h" using namespace mlir; +using namespace mlir::bufferization; //===----------------------------------------------------------------------===// // BufferizeTypeConverter //===----------------------------------------------------------------------===// -static Value materializeTensorLoad(OpBuilder &builder, TensorType type, - ValueRange inputs, Location loc) { +static Value materializeToTensor(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); return builder.create(loc, type, inputs[0]); @@ -37,8 +39,8 @@ addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); - addArgumentMaterialization(materializeTensorLoad); - addSourceMaterialization(materializeTensorLoad); + addArgumentMaterialization(materializeToTensor); + addSourceMaterialization(materializeToTensor); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -47,14 +49,15 @@ }); } -void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) { +void mlir::bufferization::populateBufferizeMaterializationLegality( + ConversionTarget &target) { target.addLegalOp(); } namespace { // In a finalizing bufferize conversion, we know that all tensors have been // converted to memrefs, thus, this op becomes an identity. -class BufferizeTensorLoadOp +class BufferizeToTensorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -70,7 +73,8 @@ namespace { // In a finalizing bufferize conversion, we know that all tensors have been // converted to memrefs, thus, this op becomes an identity. -class BufferizeCastOp : public OpConversionPattern { +class BufferizeToMemrefOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult @@ -82,10 +86,10 @@ }; } // namespace -void mlir::populateEliminateBufferizeMaterializationsPatterns( +void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } namespace { @@ -121,6 +125,7 @@ }; } // namespace -std::unique_ptr mlir::createFinalizingBufferizePass() { +std::unique_ptr +mlir::bufferization::createFinalizingBufferizePass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MLIRBufferizationTransforms + Bufferize.cpp + BufferDeallocation.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization + + DEPENDS + MLIRBufferizationPassIncGen + + LINK_LIBS PUBLIC + MLIRBufferization + MLIRPass + MLIRTransforms +) diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Dialect/Bufferization/Transforms/PassDetail.h copy from mlir/lib/Transforms/PassDetail.h copy to mlir/lib/Dialect/Bufferization/Transforms/PassDetail.h --- a/mlir/lib/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Bufferization/Transforms/PassDetail.h @@ -1,4 +1,4 @@ -//===- PassDetail.h - Transforms Pass class details -------------*- C++ -*-===// +//===- PassDetail.h - Bufferization Pass details ----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,30 +6,26 @@ // //===----------------------------------------------------------------------===// -#ifndef TRANSFORMS_PASSDETAIL_H_ -#define TRANSFORMS_PASSDETAIL_H_ +#ifndef DIALECT_BUFFERIZATION_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_BUFFERIZATION_TRANSFORMS_PASSDETAIL_H_ #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Passes.h" namespace mlir { -class AffineDialect; -// Forward declaration from Dialect.h -template -void registerDialect(DialectRegistry ®istry); +class StandardOpsDialect; -namespace arith { -class ArithmeticDialect; -} // end namespace arith +namespace bufferization { +class BufferizationDialect; +} // end namespace bufferization namespace memref { class MemRefDialect; } // end namespace memref #define GEN_PASS_CLASSES -#include "mlir/Transforms/Passes.h.inc" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" } // end namespace mlir -#endif // TRANSFORMS_PASSDETAIL_H_ +#endif // DIALECT_BUFFERIZATION_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -6,10 +6,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" + #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -313,7 +314,7 @@ void runOnOperation() override { MLIRContext &context = getContext(); ConversionTarget target(context); - BufferizeTypeConverter typeConverter; + bufferization::BufferizeTypeConverter typeConverter; // Mark all Standard operations legal. target.addLegalDialect(typeConverter, patterns.getContext()); } @@ -50,7 +51,7 @@ struct StdBufferizePass : public StdBufferizeBase { void runOnFunction() override { auto *context = &getContext(); - BufferizeTypeConverter typeConverter; + bufferization::BufferizeTypeConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ LINK_LIBS PUBLIC MLIRArithmeticTransforms + MLIRBufferizationTransforms MLIRIR MLIRMemRef MLIRPass diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -13,13 +13,14 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" + #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -12,10 +12,10 @@ #include "PassDetail.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -27,7 +27,7 @@ auto module = getOperation(); auto *context = &getContext(); - BufferizeTypeConverter typeConverter; + bufferization::BufferizeTypeConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -12,12 +12,12 @@ #include "PassDetail.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/BufferUtils.h" -#include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -25,7 +25,7 @@ memref::GlobalOp GlobalCreator::getGlobalFor(arith::ConstantOp constantOp) { auto type = constantOp.getType().cast(); - BufferizeTypeConverter typeConverter; + bufferization::BufferizeTypeConverter typeConverter; // If we already have a global for this constant value, no need to do // anything else. @@ -91,7 +91,8 @@ } // namespace void mlir::populateTensorConstantBufferizePatterns( - GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + GlobalCreator &globalCreator, + bufferization::BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(globalCreator, typeConverter, patterns.getContext()); @@ -111,7 +112,7 @@ GlobalCreator globals(module, alignment); auto *context = &getContext(); - BufferizeTypeConverter typeConverter; + bufferization::BufferizeTypeConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -153,7 +153,8 @@ } // namespace void mlir::populateTensorBufferizePatterns( - BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { + bufferization::BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { patterns.add( typeConverter, patterns.getContext()); @@ -163,11 +164,11 @@ struct TensorBufferizePass : public TensorBufferizeBase { void runOnFunction() override { auto *context = &getContext(); - BufferizeTypeConverter typeConverter; + bufferization::BufferizeTypeConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); - populateBufferizeMaterializationLegality(target); + bufferization::populateBufferizeMaterializationLegality(target); populateTensorBufferizePatterns(typeConverter, patterns); target.addIllegalOp