diff --git a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h @@ -0,0 +1,26 @@ +//===- BufferizationToMemRef.h - Bufferization to MemRef conversion -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H +#define MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H + +#include + +namespace mlir { +class Pass; +class RewritePatternSet; + +/// Collect a set of patterns to convert memory-related operations from the +/// Bufferization dialect to the MemRef dialect. +void populateBufferizationToMemRefConversionPatterns( + RewritePatternSet &patterns); + +std::unique_ptr createBufferizationToMemRefPass(); +} // namespace mlir + +#endif // MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -14,6 +14,7 @@ #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" +#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -126,6 +126,17 @@ ]; } +//===----------------------------------------------------------------------===// +// BufferizationToMemRef +//===----------------------------------------------------------------------===// + +def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> { + let summary = "Convert operations from the Bufferization dialect to the " + "MemRef dialect"; + let constructor = "mlir::createBufferizationToMemRefPass()"; + let dependentDialects = ["memref::MemRefDialect", "arith::ArithmeticDialect"]; +} + //===----------------------------------------------------------------------===// // ComplexToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -0,0 +1,89 @@ +//===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert Bufferization dialect to MemRef +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" +#include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +// The CloneOpConversion transforms all bufferization clone operations into +// memref alloc and memref copy operations. +struct CloneOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Check for unranked memref types which are currently not supported. + Type type = op.getType(); + if (type.isa()) { + return rewriter.notifyMatchFailure( + op, "UnrankedMemRefType is not supported."); + } + + // Transform a clone operation into alloc + copy operation and pay + // attention to the shape dimensions. + MemRefType memrefType = type.cast(); + Location loc = op->getLoc(); + SmallVector dynamicOperands; + for (int i = 0; i < memrefType.getRank(); ++i) { + if (!memrefType.isDynamicDim(i)) + continue; + Value size = rewriter.createOrFold(loc, i); + Value dim = rewriter.createOrFold(loc, op.input(), size); + dynamicOperands.push_back(dim); + } + Value alloc = rewriter.replaceOpWithNewOp(op, memrefType, + dynamicOperands); + rewriter.create(loc, op.input(), alloc); + return success(); + } +}; +} // namespace + +void mlir::populateBufferizationToMemRefConversionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +namespace { +struct BufferizationToMemRefPass + : public ConvertBufferizationToMemRefBase { + BufferizationToMemRefPass() = default; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateBufferizationToMemRefConversionPatterns(patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalDialect(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createBufferizationToMemRefPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_conversion_library(MLIRBufferizationToMemRef + BufferizationToMemRef.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/BufferizationToMemRef + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRBufferization + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(ArithmeticToSPIRV) add_subdirectory(ArmNeon2dToIntr) add_subdirectory(AsyncToLLVM) +add_subdirectory(BufferizationToMemRef) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToStandard) add_subdirectory(GPUCommon) diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt -verify-diagnostics -convert-bufferization-to-memref -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @conversion_static +func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> { + %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32> + memref.dealloc %arg0 : memref<2xf32> + return %0 : memref<2xf32> +} + +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK-NEXT: memref.copy %[[ARG:.*]], %[[ALLOC]] +// CHECK-NEXT: memref.dealloc %[[ARG]] +// CHECK-NEXT: return %[[ALLOC]] + +// ----- + +// CHECK-LABEL: @conversion_dynamic +func @conversion_dynamic(%arg0 : memref) -> memref { + %1 = bufferization.clone %arg0 : memref to memref + memref.dealloc %arg0 : memref + return %1 : memref +} + +// CHECK: %[[CONST:.*]] = arith.constant +// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]] +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) +// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]] +// CHECK-NEXT: memref.dealloc %[[ARG]] +// CHECK-NEXT: return %[[ALLOC]] + +// ----- + +func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> { +// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}} + %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32> + memref.dealloc %arg0 : memref<*xf32> + return %1 : memref<*xf32> +}