diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -38,6 +38,9 @@ /// Creates an instance of std bufferization pass. std::unique_ptr createStdBufferizePass(); +/// Creates an instance of func bufferization pass. +std::unique_ptr createFuncBufferizePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -22,4 +22,33 @@ let dependentDialects = ["scf::SCFDialect"]; } +def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { + let summary = "Bufferize func/call/return ops"; + let description = [{ + A finalizing bufferize pass that bufferizes std.func and std.call ops. + + Because this pass updates std.func ops, it must be a module pass. It is + useful to keep this pass separate from other bufferizations so that the + other ones can be run at function-level in parallel. + + This pass must be done atomically for two reasons: + 1. This pass changes func op signatures, which requires atomically updating + calls as well throughout the entire module. + 2. This pass changes the type of block arguments, which requires that all + successor arguments of predecessors be converted. Terminators are not + a closed universe (and need not implement BranchOpInterface), and so we + cannot in general rewrite them. + + Note, because this is a "finalizing" bufferize step, it can create + invalid IR because it will not create materializations. To avoid this + situation, the pass must only be run when the only SSA values of + tensor type are: + - block arguments + - the result of tensor_load + Other values of tensor type should be eliminated by earlier + bufferization passes. + }]; + let constructor = "mlir::createFuncBufferizePass()"; +} + #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h --- a/mlir/include/mlir/Transforms/Bufferize.h +++ b/mlir/include/mlir/Transforms/Bufferize.h @@ -150,9 +150,18 @@ /// This function should be called by all bufferization passes using /// BufferizeTypeConverter so that materializations work proprely. One exception /// is bufferization passes doing "full" conversions, where it can be desirable -/// for even the materializations to remain illegal so that they are eliminated. +/// for even the materializations to remain illegal so that they are eliminated, +/// such as via the patterns in +/// populateEliminateBufferizeMaterializationsPatterns. void populateBufferizeMaterializationLegality(ConversionTarget &target); +/// Populate patterns to eliminate bufferize materializations. +/// +/// In particular, these are the tensor_load/tensor_to_memref ops. +void populateEliminateBufferizeMaterializationsPatterns( + MLIRContext *context, BufferizeTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + /// Helper conversion pattern that encapsulates a BufferizeTypeConverter /// instance. template diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ ExpandAtomic.cpp ExpandMemRefReshape.cpp ExpandTanh.cpp + FuncBufferize.cpp FuncConversions.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -0,0 +1,56 @@ +//===- Bufferize.cpp - Bufferization for std ops --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements bufferization of std.func's and std.call's. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Transforms/Bufferize.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct FuncBufferizePass : public FuncBufferizeBase { + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + + BufferizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + populateFuncOpTypeConversionPattern(patterns, context, typeConverter); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); + }); + populateCallOpTypeConversionPattern(patterns, context, typeConverter); + populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, + patterns); + target.addIllegalOp(); + + // If all result types are legal, and all block arguments are legal (ensured + // by func conversion above), then all types in the program are legal. + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()); + }); + + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createFuncBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -76,6 +76,45 @@ target.addLegalOp(); } +namespace { +// In a finalizing bufferize conversion, we know that all tensors have been +// converted to memrefs, thus, this op becomes an identity. +class BufferizeTensorLoadOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(TensorLoadOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TensorLoadOp::Adaptor adaptor(operands); + rewriter.replaceOp(op, adaptor.memref()); + return success(); + } +}; +} // namespace + +namespace { +// In a finalizing bufferize conversion, we know that all tensors have been +// converted to memrefs, thus, this op becomes an identity. +class BufferizeTensorToMemrefOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(TensorToMemrefOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TensorToMemrefOp::Adaptor adaptor(operands); + rewriter.replaceOp(op, adaptor.tensor()); + return success(); + } +}; +} // namespace + +void mlir::populateEliminateBufferizeMaterializationsPatterns( + MLIRContext *context, BufferizeTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert( + typeConverter, context); +} + //===----------------------------------------------------------------------===// // BufferizeFuncOpConverter //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/func-bufferize.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @identity( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: return %[[ARG]] : memref +func @identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +// CHECK-LABEL: func @block_arguments( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: br ^bb1(%[[ARG]] : memref) +// CHECK: ^bb1(%[[BBARG:.*]]: memref): +// CHECK: return %[[BBARG]] : memref +func @block_arguments(%arg0: tensor) -> tensor { + br ^bb1(%arg0: tensor) +^bb1(%bbarg: tensor): + return %bbarg : tensor +} + +// CHECK-LABEL: func @eliminate_target_materialization( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: return %[[ARG]] : memref +func @eliminate_target_materialization(%arg0: tensor) -> memref { + %0 = tensor_to_memref %arg0 : memref + return %0 : memref +} + +// CHECK-LABEL: func @eliminate_source_materialization( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: return %[[ARG]] : memref +func @eliminate_source_materialization(%arg0: memref) -> tensor { + %0 = tensor_load %arg0 : memref + return %0 : tensor +} + +// CHECK-LABEL: func @source() -> memref +// CHECK-LABEL: func @call_source() -> memref { +// CHECK: %[[RET:.*]] = call @source() : () -> memref +// CHECK: return %[[RET]] : memref +func @source() -> tensor +func @call_source() -> tensor { + %0 = call @source() : () -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @sink(memref) +// CHECK-LABEL: func @call_sink( +// CHECK-SAME: %[[ARG:.*]]: memref) { +// CHECK: call @sink(%[[ARG]]) : (memref) -> () +// CHECK: return +func @sink(tensor) +func @call_sink(%arg0: tensor) { + call @sink(%arg0) : (tensor) -> () + return +} + +// ----- + +func @failed_to_legalize() -> tensor { + // expected-error @+1 {{failed to legalize operation 'test.source'}} + %0 = "test.source"() : () -> (tensor) + return %0 : tensor +}