diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -44,6 +44,9 @@ std::unique_ptr createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024); +/// Creates a pass that converts memref function results to out-params. +std::unique_ptr createBufferResultsToOutParamsPass(); + /// Creates an instance of the Canonicalizer pass. std::unique_ptr createCanonicalizerPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -217,6 +217,31 @@ ]; } +def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp"> { + let summary = "Converts memref-typed function results to out-params"; + let description = [{ + Some calling conventions prefer to pass output memrefs as "out params". The + conversion to this calling convention must be done as an atomic + transformation of the entire program (hence this is a module pass). + + For example, if a call is rewritten, the callee needs to be rewritten + otherwise the IR will end up invalid. Thus, this transformation + require an atomic change to the entire program (e.g. the whole module). + + This pass is expected to run immediately after bufferization is finished. + At that point, tensor-typed results will have been converted to memref-typed + results, and can be consistently converted to out params. + + All memref-typed results are appended to the function argument list. + + The main issue with this pass (and the out-param calling convention) is that + buffers for results need to be allocated in the caller. This currently only + works for static shaped memrefs. + }]; + let constructor = "mlir::createBufferResultsToOutParamsPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + def Canonicalizer : Pass<"canonicalize"> { let summary = "Canonicalize operations"; let description = [{ diff --git a/mlir/lib/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Transforms/BufferResultsToOutParams.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/BufferResultsToOutParams.cpp @@ -0,0 +1,143 @@ +//===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +// Updates the func op and entry block. +// +// Any args appended to the entry block are added to `appendedEntryArgs`. +static void updateFuncOp(FuncOp func, + SmallVectorImpl &appendedEntryArgs) { + auto functionType = func.getType(); + + // Collect information about the results will become appended arguments. + SmallVector erasedResultTypes; + SmallVector erasedResultIndices; + for (auto resultType : llvm::enumerate(functionType.getResults())) { + if (resultType.value().isa()) { + erasedResultIndices.push_back(resultType.index()); + erasedResultTypes.push_back(resultType.value()); + } + } + + // Add the new arguments to the function type. + auto newArgTypes = llvm::to_vector<6>( + llvm::concat(functionType.getInputs(), erasedResultTypes)); + auto newFunctionType = FunctionType::get( + newArgTypes, functionType.getResults(), func.getContext()); + func.setType(newFunctionType); + + // Transfer the result attributes to arg attributes. + for (int i = 0, e = erasedResultTypes.size(); i < e; i++) + func.setArgAttrs(functionType.getNumInputs() + i, + func.getResultAttrs(erasedResultIndices[i])); + + // Erase the results. + func.eraseResults(erasedResultIndices); + + // Add the new arguments to the entry block if the function is not external. + if (func.isExternal()) + return; + auto newArgs = func.front().addArguments(erasedResultTypes); + appendedEntryArgs.append(newArgs.begin(), newArgs.end()); +} + +// Updates all ReturnOps in the scope of the given FuncOp by either keeping them +// as return values or copying the associated buffer contents into the given +// out-params. +static void updateReturnOps(FuncOp func, + ArrayRef appendedEntryArgs) { + func.walk([&](ReturnOp op) { + SmallVector copyIntoOutParams; + SmallVector keepAsReturnOperands; + for (Value operand : op.getOperands()) { + if (operand.getType().isa()) + copyIntoOutParams.push_back(operand); + else + keepAsReturnOperands.push_back(operand); + } + OpBuilder builder(op); + for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) + builder.create(op.getLoc(), std::get<0>(t), + std::get<1>(t)); + builder.create(op.getLoc(), keepAsReturnOperands); + op.erase(); + }); +} + +// Updates all CallOps in the scope of the given ModuleOp by allocating +// temporary buffers for newly introduced out params. +static LogicalResult updateCalls(ModuleOp module) { + bool didFail = false; + module.walk([&](CallOp op) { + SmallVector replaceWithNewCallResults; + SmallVector replaceWithOutParams; + for (OpResult result : op.getResults()) { + if (result.getType().isa()) + replaceWithOutParams.push_back(result); + else + replaceWithNewCallResults.push_back(result); + } + SmallVector outParams; + OpBuilder builder(op); + for (Value memref : replaceWithOutParams) { + if (!memref.getType().cast().hasStaticShape()) { + op.emitError() + << "cannot create out param for dynamically shaped result"; + didFail = true; + return; + } + Value outParam = builder.create( + op.getLoc(), memref.getType().cast()); + memref.replaceAllUsesWith(outParam); + outParams.push_back(outParam); + } + + auto newOperands = llvm::to_vector<6>(op.getOperands()); + newOperands.append(outParams.begin(), outParams.end()); + auto newResultTypes = llvm::to_vector<6>(llvm::map_range( + replaceWithNewCallResults, [](Value v) { return v.getType(); })); + auto newCall = builder.create(op.getLoc(), op.calleeAttr(), + newResultTypes, newOperands); + for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) + std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); + op.erase(); + }); + + return failure(didFail); +} + +namespace { +struct BufferResultsToOutParamsPass + : BufferResultsToOutParamsBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + + for (auto func : module.getOps()) { + SmallVector appendedEntryArgs; + updateFuncOp(func, appendedEntryArgs); + if (func.isExternal()) + continue; + updateReturnOps(func, appendedEntryArgs); + } + if (failed(updateCalls(module))) + return signalPassFailure(); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::createBufferResultsToOutParamsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_library(MLIRTransforms BufferDeallocation.cpp BufferOptimizations.cpp + BufferResultsToOutParams.cpp Bufferize.cpp Canonicalizer.cpp CopyRemoval.cpp diff --git a/mlir/test/Transforms/buffer-results-to-out-params.mlir b/mlir/test/Transforms/buffer-results-to-out-params.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/buffer-results-to-out-params.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt -buffer-results-to-out-params -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func @basic( +// CHECK-SAME: %[[ARG:.*]]: memref) { +// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref +// CHECK: linalg.copy(%[[RESULT]], %[[ARG]]) : memref, memref +// CHECK: return +// CHECK: } +func @basic() -> (memref) { + %0 = "test.source"() : () -> (memref) + return %0 : memref +} + +// CHECK-LABEL: func @presence_of_existing_arguments( +// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) { +// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<2xf32> +// CHECK: linalg.copy(%[[RESULT]], %[[ARG1]]) : memref<2xf32>, memref<2xf32> +// CHECK: return +// CHECK: } +func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) { + %0 = "test.source"() : () -> (memref<2xf32>) + return %0 : memref<2xf32> +} + +// CHECK-LABEL: func @multiple_results( +// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) { +// CHECK: %[[RESULTS:.*]]:2 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>) +// CHECK: linalg.copy(%[[RESULTS]]#0, %[[ARG0]]) : memref<1xf32>, memref<1xf32> +// CHECK: linalg.copy(%[[RESULTS]]#1, %[[ARG1]]) : memref<2xf32>, memref<2xf32> +// CHECK: return +// CHECK: } +func @multiple_results() -> (memref<1xf32>, memref<2xf32>) { + %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>) + return %0, %1 : memref<1xf32>, memref<2xf32> +} + +// CHECK-LABEL: func @non_memref_types( +// CHECK-SAME: %[[OUTPARAM:.*]]: memref) -> (i1, i32) { +// CHECK: %[[RESULT1:.*]]:3 = "test.source"() : () -> (i1, memref, i32) +// CHECK: linalg.copy(%[[RESULT1]]#1, %[[OUTPARAM]]) : memref, memref +// CHECK: return %[[RESULT1]]#0, %[[RESULT1]]#2 : i1, i32 +// CHECK: } +func @non_memref_types() -> (i1, memref, i32) { + %0, %1, %2 = "test.source"() : () -> (i1, memref, i32) + return %0, %1, %2 : i1, memref, i32 +} + +// CHECK: func @external_function(memref) +func @external_function() -> (memref) +// CHECK: func @result_attrs(memref {test.some_attr}) +func @result_attrs() -> (memref {test.some_attr}) +// CHECK: func @mixed_result_attrs(memref<1xf32>, memref<2xf32> {test.some_attr}, memref<3xf32>) +func @mixed_result_attrs() -> (memref<1xf32>, memref<2xf32> {test.some_attr}, memref<3xf32>) + +// ----- + +// CHECK-LABEL: func @callee(memref<1xf32>) +func @callee() -> memref<1xf32> + +// CHECK-LABEL: func @call_basic() { +// CHECK: %[[OUTPARAM:.*]] = alloc() : memref<1xf32> +// CHECK: call @callee(%[[OUTPARAM]]) : (memref<1xf32>) -> () +// CHECK: "test.sink"(%[[OUTPARAM]]) : (memref<1xf32>) -> () +// CHECK: return +// CHECK: } +func @call_basic() { + %0 = call @callee() : () -> memref<1xf32> + "test.sink"(%0) : (memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @callee(memref<1xf32>, memref<2xf32>) +func @callee() -> (memref<1xf32>, memref<2xf32>) + +// CHECK-LABEL: func @call_multiple_result() { +// CHECK: %[[RESULT0:.*]] = alloc() : memref<1xf32> +// CHECK: %[[RESULT1:.*]] = alloc() : memref<2xf32> +// CHECK: call @callee(%[[RESULT0]], %[[RESULT1]]) : (memref<1xf32>, memref<2xf32>) -> () +// CHECK: "test.sink"(%[[RESULT0]], %[[RESULT1]]) : (memref<1xf32>, memref<2xf32>) -> () +// CHECK: } +func @call_multiple_result() { + %0, %1 = call @callee() : () -> (memref<1xf32>, memref<2xf32>) + "test.sink"(%0, %1) : (memref<1xf32>, memref<2xf32>) -> () +} + +// ----- + +// CHECK-LABEL: func @callee(memref<1xf32>) -> (i1, i32) +func @callee() -> (i1, memref<1xf32>, i32) + +// CHECK-LABEL: func @call_non_memref_result() { +// CHECK: %[[RESULT0:.*]] = alloc() : memref<1xf32> +// CHECK: %[[NON_MEMREF_RESULTS:.*]]:2 = call @callee(%[[RESULT0]]) : (memref<1xf32>) -> (i1, i32) +// CHECK: "test.sink"(%[[NON_MEMREF_RESULTS]]#0, %[[RESULT0]], %[[NON_MEMREF_RESULTS]]#1) : (i1, memref<1xf32>, i32) -> () +// CHECK: } +func @call_non_memref_result() { + %0, %1, %2 = call @callee() : () -> (i1, memref<1xf32>, i32) + "test.sink"(%0, %1, %2) : (i1, memref<1xf32>, i32) -> () +} + +// ----- + +func @callee() -> (memref) + +func @call_non_memref_result() { + // expected-error @+1 {{cannot create out param for dynamically shaped result}} + %0 = call @callee() : () -> (memref) + "test.sink"(%0) : (memref) -> () +}