diff --git a/mlir/include/mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h b/mlir/include/mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h @@ -0,0 +1,25 @@ +//===- MemRefToMemRef.h - MemRef to MemRef dialect conversion ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MEMREFTOMEMREF_MEMREFTOMEMREF_H +#define MLIR_CONVERSION_MEMREFTOMEMREF_MEMREFTOMEMREF_H + +#include + +namespace mlir { +class Pass; +class RewritePatternSet; + +/// Collect a set of patterns to convert memory-related operations from the +/// MemRef dialect to the MemRef dialect. +void populateMemRefToMemRefConversionPatterns(RewritePatternSet &patterns); + +std::unique_ptr createMemRefToMemRefPass(); +} // namespace mlir + +#endif // MLIR_CONVERSION_MEMREFTOMEMREF_MEMREFTOMEMREF_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -28,6 +28,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -337,6 +337,16 @@ ]; } +//===----------------------------------------------------------------------===// +// MemRefToMemRef +//===----------------------------------------------------------------------===// + +def ConvertMemRefToMemRef : Pass<"convert-memref-to-memref", "ModuleOp"> { + let summary = "Convert operations from the MemRef dialect to the MemRef " + "dialect"; + let constructor = "mlir::createMemRefToMemRefPass()"; +} + //===----------------------------------------------------------------------===// // MemRefToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(MathToLLVM) add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToLLVM) +add_subdirectory(MemRefToMemRef) add_subdirectory(MemRefToSPIRV) add_subdirectory(OpenACCToLLVM) add_subdirectory(OpenACCToSCF) diff --git a/mlir/lib/Conversion/MemRefToMemRef/CMakeLists.txt b/mlir/lib/Conversion/MemRefToMemRef/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToMemRef/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_conversion_library(MLIRMemRefToMemRef + MemRefToMemRef.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToMemRef + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRMemRef + ) diff --git a/mlir/lib/Conversion/MemRefToMemRef/MemRefToMemRef.cpp b/mlir/lib/Conversion/MemRefToMemRef/MemRefToMemRef.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToMemRef/MemRefToMemRef.cpp @@ -0,0 +1,82 @@ +//===- MemRefToMemRef.cpp - MemRef to MemRef dialect conversion -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MemRefToMemRef/MemRefToMemRef.h" +#include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct CloneOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CloneOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Check for unranked memref types which are currently not supported. + Type type = op.getType(); + if (type.isa()) { + op->emitError("UnrankedMemRefType is not supported."); + return failure(); + } + + // Transform a clone operation into alloc + copy operation and pay + // attention to the shape dimensions. + MemRefType memrefType = type.cast(); + Location loc = op->getLoc(); + SmallVector dynamicOperands; + for (int i = 0; i < memrefType.getRank(); ++i) { + if (memrefType.getShape()[i] != ShapedType::kDynamicSize) + continue; + Value size = rewriter.createOrFold(loc, i); + dynamicOperands.push_back(size); + } + Value alloc = rewriter.replaceOpWithNewOp(op, memrefType, + dynamicOperands); + rewriter.create(loc, op.input(), alloc); + return success(); + } +}; +} // namespace + +void mlir::populateMemRefToMemRefConversionPatterns( + RewritePatternSet &patterns) { + // clang-format off + patterns.add(patterns.getContext()); + // clang-format on +} + +namespace { +struct MemRefToMemRefPass + : public ConvertMemRefToMemRefBase { + MemRefToMemRefPass() = default; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateMemRefToMemRefConversionPatterns(patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createMemRefToMemRefPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/MemRefToMemRef/memref-to-memref.mlir b/mlir/test/Conversion/MemRefToMemRef/memref-to-memref.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToMemRef/memref-to-memref.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt -verify-diagnostics -convert-memref-to-memref -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @conversion_static +func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> { + %0 = memref.clone %arg0 : memref<2xf32> to memref<2xf32> + memref.dealloc %arg0 : memref<2xf32> + return %0 : memref<2xf32> +} + +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK-NEXT: memref.copy %[[ARG:.*]], %[[ALLOC]] +// CHECK-NEXT: memref.dealloc %[[ARG]] +// CHECK-NEXT: return %[[ALLOC]] + +// ----- + +// CHECK-LABEL: @conversion_dynamic +func @conversion_dynamic(%arg0 : memref) -> memref { + %1 = memref.clone %arg0 : memref to memref + memref.dealloc %arg0 : memref + return %1 : memref +} + +// CHECK: %[[CONST:.*]] = arith.constant +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[CONST]]) +// CHECK-NEXT: memref.copy %[[ARG:.*]], %[[ALLOC]] +// CHECK-NEXT: memref.dealloc %[[ARG]] +// CHECK-NEXT: return %[[ALLOC]] + +// ----- + +func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> { +// expected-error@+2 {{failed to legalize operation 'memref.clone' that was explicitly marked illegal}} +// expected-error@+1 {{UnrankedMemRefType is not supported.}} + %1 = memref.clone %arg0 : memref<*xf32> to memref<*xf32> + memref.dealloc %arg0 : memref<*xf32> + return %1 : memref<*xf32> +}