diff --git a/mlir/include/mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h b/mlir/include/mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h @@ -0,0 +1,25 @@ +//===- MemRefToMemRef.h - MemRef to MemRef dialect conversion ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MEMREFTOMEMREF_MEMREFTOMEMREF_H +#define MLIR_CONVERSION_MEMREFTOMEMREF_MEMREFTOMEMREF_H + +#include + +namespace mlir { +class Pass; +class RewritePatternSet; + +/// Collect a set of patterns to convert memory-related operations from the +/// MemRef dialect to the MemRef dialect. +void populateMemRefToMemRefConversionPatterns(RewritePatternSet &patterns); + +std::unique_ptr createMemRefToMemRefPass(); +} // namespace mlir + +#endif // MLIR_CONVERSION_MEMREFTOMEMREF_MEMREFTOMEMREF_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -28,6 +28,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -337,6 +337,16 @@ ]; } +//===----------------------------------------------------------------------===// +// MemRefToMemRef +//===----------------------------------------------------------------------===// + +def ConvertMemRefToMemRef : Pass<"convert-memref-to-memref", "ModuleOp"> { + let summary = "Convert operations from the MemRef dialect to the MemRef " + "dialect"; + let constructor = "mlir::createMemRefToMemRefPass()"; +} + //===----------------------------------------------------------------------===// // MemRefToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(MathToLLVM) add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToLLVM) +add_subdirectory(MemRefToMemRef) add_subdirectory(MemRefToSPIRV) add_subdirectory(OpenACCToLLVM) add_subdirectory(OpenACCToSCF) diff --git a/mlir/lib/Conversion/MemRefToMemRef/CMakeLists.txt b/mlir/lib/Conversion/MemRefToMemRef/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToMemRef/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_conversion_library(MLIRMemRefToMemRef + MemRefToMemRef.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToMemRef + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRMemRef + ) diff --git a/mlir/lib/Conversion/MemRefToMemRef/MemRefToMemRef.cpp b/mlir/lib/Conversion/MemRefToMemRef/MemRefToMemRef.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToMemRef/MemRefToMemRef.cpp @@ -0,0 +1,60 @@ +//===- MemRefToMemRef.cpp - MemRef to MemRef dialect conversion -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h" +#include "../PassDetail.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct CloneOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CloneOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value alloc = rewriter.replaceOpWithNewOp( + op, op.getType().cast()); + rewriter.create(op->getLoc(), adaptor.input(), alloc); + return success(); + } +}; +} // namespace + +void mlir::populateMemRefToMemRefConversionPatterns( + RewritePatternSet &patterns) { + // clang-format off + patterns.add(patterns.getContext()); + // clang-format on +} + +namespace { +struct MemRefToMemRefPass + : public ConvertMemRefToMemRefBase { + MemRefToMemRefPass() = default; + + void runOnOperation() override { + Operation *op = getOperation(); + // Convert to the Standard dialect using the converter defined above. + RewritePatternSet patterns(&getContext()); + populateMemRefToMemRefConversionPatterns(patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createMemRefToMemRefPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/MemRefToMemRef/memref-to-memref.mlir b/mlir/test/Conversion/MemRefToMemRef/memref-to-memref.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToMemRef/memref-to-memref.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt -convert-memref-to-memref %s -split-input-file | FileCheck %s + +// CHECK-LABEL: @conversion +func @conversion(%arg0 : memref<2xf32>) -> memref<2xf32> { + %0 = memref.clone %arg0 : memref<2xf32> to memref<2xf32> + memref.dealloc %arg0 : memref<2xf32> + return %0 : memref<2xf32> +} + +// CHECK: %0 = memref.alloc +// CHECK-NEXT: memref.copy %arg0, %0 +// CHECK-NEXT: memref.dealloc %arg0 +// CHECK-NEXT: return %0 \ No newline at end of file