diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -76,12 +76,12 @@ TypeConverter *converter; }; -/// This conversion adds an extra argument for each function result which makes -/// the converted function a void function. A type converter must be provided -/// for this conversion to convert a non-shaped type to memref. -/// BufferAssignmentTypeConverter is an helper TypeConverter for this -/// purpose. All the non-shaped type of the input function will be converted to -/// memref. +/// This conversion converts the signature of the function using the type +/// converter. It adds an extra argument for each illegally-typed function +/// result to the function arguments. A type converter must be provided for this +/// conversion. BufferAssignmentTypeConverter is an helper TypeConverter for +/// this purpose. All the non-shaped type of the input function will be +/// converted to memref. class FunctionAndBlockSignatureConverter : public BufferAssignmentOpConversionPattern { public: @@ -94,12 +94,12 @@ ConversionPatternRewriter &rewriter) const final; }; -/// This pattern converter transforms a non-void ReturnOpSourceTy into a void -/// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy -/// to copy the results to the output buffer. +/// This converter converts the source ReturnOp to target ReturnOp, removes all +/// the buffer operands from the operands list, and inserts CopyOps for all +/// buffer operands instead. template -class NonVoidToVoidReturnOpConverter +class NoBufferOperandsReturnOpConverter : public BufferAssignmentOpConversionPattern { public: using BufferAssignmentOpConversionPattern< @@ -109,29 +109,38 @@ LogicalResult matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - unsigned numReturnValues = returnOp.getNumOperands(); Block &entryBlock = returnOp.getParentRegion()->front(); unsigned numFuncArgs = entryBlock.getNumArguments(); Location loc = returnOp.getLoc(); // Find the corresponding output buffer for each operand. - assert(numReturnValues <= numFuncArgs && + assert(returnOp.getNumOperands() <= numFuncArgs && "The number of operands of return operation is more than the " "number of function argument."); - unsigned firstReturnParameter = numFuncArgs - numReturnValues; - for (auto operand : llvm::enumerate(operands)) { - unsigned returnArgNumber = firstReturnParameter + operand.index(); - BlockArgument dstBuffer = entryBlock.getArgument(returnArgNumber); - if (dstBuffer == operand.value()) - continue; - - // Insert the copy operation to copy before the return. - rewriter.setInsertionPoint(returnOp); - rewriter.create(loc, operand.value(), - entryBlock.getArgument(returnArgNumber)); - } - // Insert the new target return operation. - rewriter.replaceOpWithNewOp(returnOp); + + // The target Return operation should not contain any memref operands. + SmallVector newOperands; + newOperands.reserve(operands.size()); + for (Value operand : operands) + if (!operand.getType().isa()) + newOperands.push_back(operand); + + // For each memref type operand of the source ReturnOp, a new CopyOp is + // inserted that copies the buffer contents from the operand to the target + // buffer. + unsigned numBufferOperands = operands.size() - newOperands.size(); + unsigned destArgNum = numFuncArgs - numBufferOperands; + for (Value operand : operands) + if (operand.getType().isa()) { + // Insert the copy operation before the target Return operation. + rewriter.setInsertionPoint(returnOp); + rewriter.create(loc, operand, + entryBlock.getArgument(destArgNum)); + ++destArgNum; + } + + // Insert the new target Return operation. + rewriter.replaceOpWithNewOp(returnOp, newOperands); return success(); } }; diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -411,13 +411,26 @@ for (auto argType : llvm::enumerate(funcType.getInputs())) conversion.addInputs(argType.index(), converter->convertType(argType.value())); - // Adding function results to the arguments of the converted function as - // memref type. The converted function will be a void function. - for (Type resType : funcType.getResults()) - conversion.addInputs(converter->convertType((resType))); + + // Adding a function argument for each memref-type function result. + SmallVector newResultTypes; + newResultTypes.reserve(funcOp.getNumResults()); + for (Type resType : funcType.getResults()) { + Type convertedType = converter->convertType(resType); + + // If the result type is memref after the type conversion, a new argument + // should be appended to the function arguments list for this result. + // Otherwise, it is remained as the function result type. + if (convertedType.isa()) + conversion.addInputs(convertedType); + else + newResultTypes.push_back(convertedType); + } + + // Update the signature of the function. rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType( - rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); + funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); rewriter.applySignatureConversion(&funcOp.getBody(), conversion); }); return success(); diff --git a/mlir/test/Transforms/buffer-placement-prepration.mlir b/mlir/test/Transforms/buffer-placement-prepration.mlir --- a/mlir/test/Transforms/buffer-placement-prepration.mlir +++ b/mlir/test/Transforms/buffer-placement-prepration.mlir @@ -8,6 +8,21 @@ // ----- +// CHECK-LABEL: func @func_signature_conversion_complex +func @func_signature_conversion_complex(%arg0: tensor<4x8xf32>, + %arg1: i1, + %arg2: tensor<5x5xf64>, + %arg3: f16) -> (i1, tensor<5x5xf64>, f16, tensor<4x8xf32>) { + return %arg1, %arg2, %arg3, %arg0 : i1, tensor<5x5xf64>, f16, tensor<4x8xf32> +} +// CHECK: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5x5xf64>, %[[ARG3:.*]]: f16, +// CHECK-SAME: %[[RESULT1:.*]]: memref<5x5xf64>, %[[RESULT2:.*]]: memref<4x8xf32>) -> (i1, f16) { +// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]]) +// CHECK-NEXT: linalg.copy(%[[ARG0]], %[[RESULT2]]) +// CHECK-NEXT: return %[[ARG1]], %[[ARG3]] + +// ----- + // CHECK-LABEL: func @non_void_to_void_return_op_converter func @non_void_to_void_return_op_converter(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { return %arg0 : tensor<4x8xf32> diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -23,7 +23,7 @@ namespace { /// This pass tests the computeAllocPosition helper method and two provided /// operation converters, FunctionAndBlockSignatureConverter and -/// NonVoidToVoidReturnOpConverter. Furthermore, this pass converts linalg +/// NoBufferOperandsReturnOpConverter. Furthermore, this pass converts linalg /// operations on tensors to linalg operations on buffers to prepare them for /// the BufferPlacement pass that can be applied afterwards. struct TestBufferPlacementPreparationPass @@ -82,7 +82,6 @@ auto type = result.getType().cast(); entryBlock.addArgument(type.getElementType()); } - rewriter.eraseOp(op); return success(); } @@ -95,7 +94,7 @@ patterns->insert< FunctionAndBlockSignatureConverter, GenericOpConverter, - NonVoidToVoidReturnOpConverter< + NoBufferOperandsReturnOpConverter< ReturnOp, ReturnOp, linalg::CopyOp> >(context, placer, converter); // clang-format on @@ -105,8 +104,9 @@ auto &context = getContext(); ConversionTarget target(context); BufferAssignmentTypeConverter converter; - // Make all linalg operations illegal as long as they work on tensors. target.addLegalDialect(); + + // Make all linalg operations illegal as long as they work on tensors. target.addDynamicallyLegalDialect( Optional( [&](Operation *op) { @@ -117,9 +117,12 @@ llvm::none_of(op->getResultTypes(), isIllegalType); })); - // Mark return operations illegal as long as they return values. - target.addDynamicallyLegalOp( - [](mlir::ReturnOp returnOp) { return returnOp.getNumOperands() == 0; }); + // Mark std.ReturnOp illegal as long as an operand is tensor or buffer. + target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { + return llvm::none_of(returnOp.getOperandTypes(), [&](Type type) { + return type.isa() || !converter.isLegal(type); + }); + }); // Mark the function whose arguments are in tensor-type illegal. target.addDynamicallyLegalOp([&](FuncOp funcOp) {