diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -76,12 +76,11 @@ TypeConverter *converter; }; -/// This conversion adds an extra argument for each function result which makes -/// the converted function a void function. A type converter must be provided -/// for this conversion to convert a non-shaped type to memref. -/// BufferAssignmentTypeConverter is an helper TypeConverter for this -/// purpose. All the non-shaped type of the input function will be converted to -/// memref. +/// Converts the signature of the function using the type converter. +/// It adds an extra argument for each illegally-typed function +/// result to the function arguments. `BufferAssignmentTypeConverter` +/// is a helper `TypeConverter` for this purpose. All the non-shaped types +/// of the input function will be converted to memref. class FunctionAndBlockSignatureConverter : public BufferAssignmentOpConversionPattern { public: @@ -94,12 +93,12 @@ ConversionPatternRewriter &rewriter) const final; }; -/// This pattern converter transforms a non-void ReturnOpSourceTy into a void -/// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy -/// to copy the results to the output buffer. +/// Converts the source `ReturnOp` to target `ReturnOp`, removes all +/// the buffer operands from the operands list, and inserts `CopyOp`s +/// for all buffer operands instead. template -class NonVoidToVoidReturnOpConverter +class NoBufferOperandsReturnOpConverter : public BufferAssignmentOpConversionPattern { public: using BufferAssignmentOpConversionPattern< @@ -109,29 +108,38 @@ LogicalResult matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - unsigned numReturnValues = returnOp.getNumOperands(); Block &entryBlock = returnOp.getParentRegion()->front(); unsigned numFuncArgs = entryBlock.getNumArguments(); Location loc = returnOp.getLoc(); - // Find the corresponding output buffer for each operand. - assert(numReturnValues <= numFuncArgs && - "The number of operands of return operation is more than the " - "number of function argument."); - unsigned firstReturnParameter = numFuncArgs - numReturnValues; - for (auto operand : llvm::enumerate(operands)) { - unsigned returnArgNumber = firstReturnParameter + operand.index(); - BlockArgument dstBuffer = entryBlock.getArgument(returnArgNumber); - if (dstBuffer == operand.value()) - continue; - - // Insert the copy operation to copy before the return. - rewriter.setInsertionPoint(returnOp); - rewriter.create(loc, operand.value(), - entryBlock.getArgument(returnArgNumber)); - } - // Insert the new target return operation. - rewriter.replaceOpWithNewOp(returnOp); + // The target `ReturnOp` should not contain any memref operands. + SmallVector newOperands(operands.begin(), operands.end()); + llvm::erase_if(newOperands, [](Value operand) { + return operand.getType().isa(); + }); + + // Find the index of the first destination buffer. + unsigned numBufferOperands = operands.size() - newOperands.size(); + unsigned destArgNum = numFuncArgs - numBufferOperands; + + rewriter.setInsertionPoint(returnOp); + // Find the corresponding destination buffer for each memref operand. + for (Value operand : operands) + if (operand.getType().isa()) { + assert(destArgNum < numFuncArgs && + "The number of operands of return operation is more than the " + "number of function argument."); + + // For each memref type operand of the source `ReturnOp`, a new `CopyOp` + // is inserted that copies the buffer content from the operand to the + // target. + rewriter.create(loc, operand, + entryBlock.getArgument(destArgNum)); + ++destArgNum; + } + + // Insert the new target Return operation. + rewriter.replaceOpWithNewOp(returnOp, newOperands); return success(); } }; diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -411,13 +411,26 @@ for (auto argType : llvm::enumerate(funcType.getInputs())) conversion.addInputs(argType.index(), converter->convertType(argType.value())); - // Adding function results to the arguments of the converted function as - // memref type. The converted function will be a void function. - for (Type resType : funcType.getResults()) - conversion.addInputs(converter->convertType((resType))); + + // Adding a function argument for each memref-type function result. + SmallVector newResultTypes; + newResultTypes.reserve(funcOp.getNumResults()); + for (Type resType : funcType.getResults()) { + Type convertedType = converter->convertType(resType); + + // If the result type is memref after the type conversion, a new argument + // should be appended to the function arguments list for this result. + // Otherwise, it is remained as the function result type. + if (convertedType.isa()) + conversion.addInputs(convertedType); + else + newResultTypes.push_back(convertedType); + } + + // Update the signature of the function. rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType( - rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); + funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); rewriter.applySignatureConversion(&funcOp.getBody(), conversion); }); return success(); diff --git a/mlir/test/Transforms/buffer-placement-prepration.mlir b/mlir/test/Transforms/buffer-placement-prepration.mlir --- a/mlir/test/Transforms/buffer-placement-prepration.mlir +++ b/mlir/test/Transforms/buffer-placement-prepration.mlir @@ -8,6 +8,59 @@ // ----- +// CHECK-LABEL: func @no_signature_conversion_is_needed +func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) { + return +} +// CHECK: ({{.*}}: memref<4x8xf32>) { + +// ----- + +// CHECK-LABEL: func @no_signature_conversion_is_needed +func @no_signature_conversion_is_needed(%arg0: i1, %arg1: f16) -> (i1, f16){ + return %arg0, %arg1 : i1, f16 +} +// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16) -> (i1, f16) +// CHECK: return %[[ARG0]], %[[ARG1]] + +// ----- + +// CHECK-LABEL: func @memref_in_function_results +func @memref_in_function_results(%arg0: i1, %arg1: f16, %arg2:memref<2xf32>) -> (i1, memref<2xf32>, f16){ + return %arg0, %arg2, %arg1 : i1, memref<2xf32>, f16 +} +// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16, %[[ARG2:.*]]: memref<2xf32>, %[[RESULT:.*]]: memref<2xf32>) +// CHECK-SAME: (i1, f16) +// CHECK: linalg.copy(%[[ARG2]], %[[RESULT]]) +// CHECK: return %[[ARG0]], %[[ARG1]] + +// ----- + +// CHECK-LABEL: func @memref_in_function_results +func @memref_in_function_results(%arg0: i1, %arg1: f16) -> (i1, memref<2xf32>, f16){ + %0 = alloc() : memref<2xf32> + return %arg0, %0, %arg1 : i1, memref<2xf32>, f16 +} +// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16, %[[RESULT:.*]]: memref<2xf32>) +// CHECK-SAME: (i1, f16) +// CHECK: %[[ALLOC:.*]] = alloc() +// CHECK: linalg.copy(%[[ALLOC]], %[[RESULT]]) +// CHECK: return %[[ARG0]], %[[ARG1]] + +// ----- + +// CHECK-LABEL: func @complex_signature_conversion +func @complex_signature_conversion(%arg0: tensor<4x8xf32>, %arg1: i1, %arg2: tensor<5x5xf64>,%arg3: f16) -> (i1, tensor<5x5xf64>, f16, tensor<4x8xf32>) { + return %arg1, %arg2, %arg3, %arg0 : i1, tensor<5x5xf64>, f16, tensor<4x8xf32> +} +// CHECK: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5x5xf64>, %[[ARG3:.*]]: f16, +// CHECK-SAME: %[[RESULT1:.*]]: memref<5x5xf64>, %[[RESULT2:.*]]: memref<4x8xf32>) -> (i1, f16) { +// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]]) +// CHECK-NEXT: linalg.copy(%[[ARG0]], %[[RESULT2]]) +// CHECK-NEXT: return %[[ARG1]], %[[ARG3]] + +// ----- + // CHECK-LABEL: func @non_void_to_void_return_op_converter func @non_void_to_void_return_op_converter(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { return %arg0 : tensor<4x8xf32> diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -23,7 +23,7 @@ namespace { /// This pass tests the computeAllocPosition helper method and two provided /// operation converters, FunctionAndBlockSignatureConverter and -/// NonVoidToVoidReturnOpConverter. Furthermore, this pass converts linalg +/// NoBufferOperandsReturnOpConverter. Furthermore, this pass converts linalg /// operations on tensors to linalg operations on buffers to prepare them for /// the BufferPlacement pass that can be applied afterwards. struct TestBufferPlacementPreparationPass @@ -82,7 +82,6 @@ auto type = result.getType().cast(); entryBlock.addArgument(type.getElementType()); } - rewriter.eraseOp(op); return success(); } @@ -95,7 +94,7 @@ patterns->insert< FunctionAndBlockSignatureConverter, GenericOpConverter, - NonVoidToVoidReturnOpConverter< + NoBufferOperandsReturnOpConverter< ReturnOp, ReturnOp, linalg::CopyOp> >(context, placer, converter); // clang-format on @@ -105,8 +104,9 @@ auto &context = getContext(); ConversionTarget target(context); BufferAssignmentTypeConverter converter; - // Make all linalg operations illegal as long as they work on tensors. target.addLegalDialect(); + + // Make all linalg operations illegal as long as they work on tensors. target.addDynamicallyLegalDialect( Optional( [&](Operation *op) { @@ -117,13 +117,18 @@ llvm::none_of(op->getResultTypes(), isIllegalType); })); - // Mark return operations illegal as long as they return values. - target.addDynamicallyLegalOp( - [](mlir::ReturnOp returnOp) { return returnOp.getNumOperands() == 0; }); + // Mark std.ReturnOp illegal as long as an operand is tensor or buffer. + target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { + return llvm::none_of(returnOp.getOperandTypes(), [&](Type type) { + return type.isa() || !converter.isLegal(type); + }); + }); // Mark the function whose arguments are in tensor-type illegal. target.addDynamicallyLegalOp([&](FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getType()); + return converter.isSignatureLegal(funcOp.getType()) && + llvm::none_of(funcOp.getType().getResults(), + [&](Type type) { return type.isa(); }); }); // Walk over all the functions to apply buffer assignment.