diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -76,12 +76,11 @@ TypeConverter *converter; }; -/// This conversion adds an extra argument for each function result which makes -/// the converted function a void function. A type converter must be provided -/// for this conversion to convert a non-shaped type to memref. -/// BufferAssignmentTypeConverter is an helper TypeConverter for this -/// purpose. All the non-shaped type of the input function will be converted to -/// memref. +/// Converts the signature of the function using the type converter. +/// It adds an extra argument for each illegally-typed function +/// result to the function arguments. `BufferAssignmentTypeConverter` +/// is a helper `TypeConverter` for this purpose. All the non-shaped types +/// of the input function will be converted to memref. class FunctionAndBlockSignatureConverter : public BufferAssignmentOpConversionPattern { public: @@ -94,12 +93,12 @@ ConversionPatternRewriter &rewriter) const final; }; -/// This pattern converter transforms a non-void ReturnOpSourceTy into a void -/// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy -/// to copy the results to the output buffer. +/// Converts the source `ReturnOp` to target `ReturnOp`, removes all +/// the buffer operands from the operands list, and inserts `CopyOp`s +/// for all buffer operands instead. template -class NonVoidToVoidReturnOpConverter +class NoBufferOperandsReturnOpConverter : public BufferAssignmentOpConversionPattern { public: using BufferAssignmentOpConversionPattern< @@ -109,29 +108,38 @@ LogicalResult matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - unsigned numReturnValues = returnOp.getNumOperands(); Block &entryBlock = returnOp.getParentRegion()->front(); unsigned numFuncArgs = entryBlock.getNumArguments(); Location loc = returnOp.getLoc(); - // Find the corresponding output buffer for each operand. - assert(numReturnValues <= numFuncArgs && - "The number of operands of return operation is more than the " - "number of function argument."); - unsigned firstReturnParameter = numFuncArgs - numReturnValues; - for (auto operand : llvm::enumerate(operands)) { - unsigned returnArgNumber = firstReturnParameter + operand.index(); - BlockArgument dstBuffer = entryBlock.getArgument(returnArgNumber); - if (dstBuffer == operand.value()) - continue; - - // Insert the copy operation to copy before the return. - rewriter.setInsertionPoint(returnOp); - rewriter.create(loc, operand.value(), - entryBlock.getArgument(returnArgNumber)); - } - // Insert the new target return operation. - rewriter.replaceOpWithNewOp(returnOp); + // The target `ReturnOp` should not contain any memref operands. + SmallVector newOperands(operands.begin(), operands.end()); + llvm::erase_if(newOperands, [](Value operand) { + return operand.getType().isa(); + }); + + // Find the index of the first destination buffer. + unsigned numBufferOperands = operands.size() - newOperands.size(); + unsigned destArgNum = numFuncArgs - numBufferOperands; + + rewriter.setInsertionPoint(returnOp); + // Find the corresponding destination buffer for each memref operand. + for (Value operand : operands) + if (operand.getType().isa()) { + assert(destArgNum < numFuncArgs && + "The number of operands of return operation is more than the " + "number of function argument."); + + // For each memref type operand of the source `ReturnOp`, a new `CopyOp` + // is inserted that copies the buffer content from the operand to the + // target. + rewriter.create(loc, operand, + entryBlock.getArgument(destArgNum)); + ++destArgNum; + } + + // Insert the new target Return operation. + rewriter.replaceOpWithNewOp(returnOp, newOperands); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -21,8 +21,8 @@ using namespace mlir; using ReturnOpConverter = - NonVoidToVoidReturnOpConverter; + NoBufferOperandsReturnOpConverter; namespace { /// A pattern to convert Generic Linalg operations which work on tensors to @@ -132,30 +132,6 @@ Optional( isLegalOperation)); - // TODO: Considering the following dynamic legality checks, the current - // implementation of FunctionAndBlockSignatureConverter of Buffer Assignment - // will convert the function signature incorrectly. This converter moves - // all the return values of the function to the input argument list without - // considering the return value types and creates a void function. However, - // the NonVoidToVoidReturnOpConverter doesn't change the return operation if - // its operands are not tensors. The following example leaves the IR in a - // broken state. - // - // @function(%arg0: f32, %arg1: tensor<4xf32>) -> (f32, f32) { - // %0 = mulf %arg0, %arg0 : f32 - // return %0, %0 : f32, f32 - // } - // - // broken IR after conversion: - // - // func @function(%arg0: f32, %arg1: memref<4xf32>, f32, f32) { - // %0 = mulf %arg0, %arg0 : f32 - // return %0, %0 : f32, f32 - // } - // - // This issue must be fixed in FunctionAndBlockSignatureConverter and - // NonVoidToVoidReturnOpConverter. - // Mark Standard Return operations illegal as long as one operand is tensor. target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { return llvm::none_of(returnOp.getOperandTypes(), isIllegalType); diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -43,7 +43,8 @@ // The current implementation does not support loops and the resulting code will // be invalid with respect to program semantics. The only thing that is // currently missing is a high-level loop analysis that allows us to move allocs -// and deallocs outside of the loop blocks. +// and deallocs outside of the loop blocks. Furthermore, it doesn't also accept +// functions which return buffers already. // //===----------------------------------------------------------------------===// @@ -429,19 +430,39 @@ "FunctionAndBlockSignatureConverter"); return failure(); } - // Converting shaped type arguments to memref type. auto funcType = funcOp.getType(); + TypeRange resultTypes = funcType.getResults(); + if (llvm::any_of(resultTypes, + [](Type type) { return type.isa(); })) + return funcOp.emitError("BufferAssignmentPlacer doesn't currently support " + "functions which return memref typed values"); + + // Convert function arguments using the provided TypeConverter. TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); for (auto argType : llvm::enumerate(funcType.getInputs())) conversion.addInputs(argType.index(), converter->convertType(argType.value())); - // Adding function results to the arguments of the converted function as - // memref type. The converted function will be a void function. - for (Type resType : funcType.getResults()) - conversion.addInputs(converter->convertType((resType))); + + // Adding a function argument for each function result which is going to be a + // memref type after type conversion. + SmallVector newResultTypes; + newResultTypes.reserve(funcOp.getNumResults()); + for (Type resType : resultTypes) { + Type convertedType = converter->convertType(resType); + + // If the result type is memref after the type conversion, a new argument + // should be appended to the function arguments list for this result. + // Otherwise, it remains unchanged as a function result. + if (convertedType.isa()) + conversion.addInputs(convertedType); + else + newResultTypes.push_back(convertedType); + } + + // Update the signature of the function. rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType( - rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); + funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); rewriter.applySignatureConversion(&funcOp.getBody(), conversion); }); return success(); diff --git a/mlir/test/Transforms/buffer-placement-prepration.mlir b/mlir/test/Transforms/buffer-placement-prepration.mlir --- a/mlir/test/Transforms/buffer-placement-prepration.mlir +++ b/mlir/test/Transforms/buffer-placement-prepration.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file %s | FileCheck %s -dump-input-on-failure +// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file -verify-diagnostics %s | FileCheck %s -dump-input-on-failure // CHECK-LABEL: func @func_signature_conversion func @func_signature_conversion(%arg0: tensor<4x8xf32>) { @@ -8,6 +8,44 @@ // ----- +// expected-error @below {{BufferAssignmentPlacer doesn't currently support functions which return memref typed values}} +// expected-error @below {{failed to legalize operation 'func'}} +func @memref_in_function_results(%arg0: tensor<4x8xf32>) -> (tensor<4x8xf32>, memref<5xf32>) { + %0 = alloc() : memref<5xf32> + return %arg0, %0 : tensor<4x8xf32>, memref<5xf32> +} + +// ----- + +// CHECK-LABEL: func @no_signature_conversion_is_needed +func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) { + return +} +// CHECK: ({{.*}}: memref<4x8xf32>) { + +// ----- + +// CHECK-LABEL: func @no_signature_conversion_is_needed +func @no_signature_conversion_is_needed(%arg0: i1, %arg1: f16) -> (i1, f16){ + return %arg0, %arg1 : i1, f16 +} +// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16) -> (i1, f16) +// CHECK: return %[[ARG0]], %[[ARG1]] + +// ----- + +// CHECK-LABEL: func @complex_signature_conversion +func @complex_signature_conversion(%arg0: tensor<4x8xf32>, %arg1: i1, %arg2: tensor<5x5xf64>,%arg3: f16) -> (i1, tensor<5x5xf64>, f16, tensor<4x8xf32>) { + return %arg1, %arg2, %arg3, %arg0 : i1, tensor<5x5xf64>, f16, tensor<4x8xf32> +} +// CHECK: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5x5xf64>, %[[ARG3:.*]]: f16, +// CHECK-SAME: %[[RESULT1:.*]]: memref<5x5xf64>, %[[RESULT2:.*]]: memref<4x8xf32>) -> (i1, f16) { +// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]]) +// CHECK-NEXT: linalg.copy(%[[ARG0]], %[[RESULT2]]) +// CHECK-NEXT: return %[[ARG1]], %[[ARG3]] + +// ----- + // CHECK-LABEL: func @non_void_to_void_return_op_converter func @non_void_to_void_return_op_converter(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { return %arg0 : tensor<4x8xf32> diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -23,7 +23,7 @@ namespace { /// This pass tests the computeAllocPosition helper method and two provided /// operation converters, FunctionAndBlockSignatureConverter and -/// NonVoidToVoidReturnOpConverter. Furthermore, this pass converts linalg +/// NoBufferOperandsReturnOpConverter. Furthermore, this pass converts linalg /// operations on tensors to linalg operations on buffers to prepare them for /// the BufferPlacement pass that can be applied afterwards. struct TestBufferPlacementPreparationPass @@ -82,7 +82,6 @@ auto type = result.getType().cast(); entryBlock.addArgument(type.getElementType()); } - rewriter.eraseOp(op); return success(); } @@ -95,7 +94,7 @@ patterns->insert< FunctionAndBlockSignatureConverter, GenericOpConverter, - NonVoidToVoidReturnOpConverter< + NoBufferOperandsReturnOpConverter< ReturnOp, ReturnOp, linalg::CopyOp> >(context, placer, converter); // clang-format on @@ -105,8 +104,9 @@ auto &context = getContext(); ConversionTarget target(context); BufferAssignmentTypeConverter converter; - // Make all linalg operations illegal as long as they work on tensors. target.addLegalDialect(); + + // Make all linalg operations illegal as long as they work on tensors. target.addDynamicallyLegalDialect( Optional( [&](Operation *op) { @@ -117,9 +117,12 @@ llvm::none_of(op->getResultTypes(), isIllegalType); })); - // Mark return operations illegal as long as they return values. - target.addDynamicallyLegalOp( - [](mlir::ReturnOp returnOp) { return returnOp.getNumOperands() == 0; }); + // Mark std.ReturnOp illegal as long as an operand is tensor or buffer. + target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { + return llvm::none_of(returnOp.getOperandTypes(), [&](Type type) { + return type.isa() || !converter.isLegal(type); + }); + }); // Mark the function whose arguments are in tensor-type illegal. target.addDynamicallyLegalOp([&](FuncOp funcOp) {