diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -129,10 +129,14 @@ // If allowMemrefFunctionResults is false and a function result type is not // a memref but it would be a memref after type conversion, a new argument // should be appended to the function arguments list for this result. - // Otherwise, it remains unchanged as a function result. + // Similarly, we also deal with a tuple result type where the elements of + // the tuple are not memrefs but memrefs after type conversion. Otherwise, + // the result type remains unchanged as a function result. SmallVector newResultTypes; newResultTypes.reserve(funcOp.getNumResults()); - for (Type resType : funcType.getResults()) { + + // Use the converted type if the conversion happened to a memref. + auto processResult = [&](Type resType) { Type convertedType = converter->convertType(resType); if (!allowMemrefFunctionResults && BufferAssignmentTypeConverter::isConvertedMemref(convertedType, @@ -140,7 +144,20 @@ conversion.addInputs(convertedType); else newResultTypes.push_back(convertedType); + }; + + TupleType tupleType; + if (funcType.getNumResults() == 1 && + (tupleType = funcType.getResult(0).dyn_cast())) { + // Deal with a single tuple type output by unpacking contained types. + for (Type tupleEltType : tupleType.getTypes()) + processResult(tupleEltType); + } else { + // Deal with other list of types. + for (Type resType : funcType.getResults()) + processResult(resType); } + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, &conversion))) return failure(); diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -8,9 +8,9 @@ // ----- -// Only tensor typed function result should be converted to memref and move to the -// function arguments list. The other memref function results remain as function -// results. +// Only tensor typed function result (or tuples of such types) should be +// converted to memref and move to the function arguments list. The other memref +// function results remain as function results. #map0 = affine_map<(d0) -> (d0)> @@ -290,3 +290,15 @@ return } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) + +// Test return of tuple tensor types. Use a test op that creates a tuple. + +// CHECK-LABEL: func @main(%arg0: memref, %arg1: memref, %arg2: memref) { +func @main(%input: tensor) -> tuple, tensor> { + %tuple = "test.tensor_tuple"(%input, %input) : (tensor, tensor) -> tuple, tensor> + return %tuple : tuple, tensor> +} +// CHECK-NEXT: linalg.copy(%arg0, %arg1) : memref, memref +// CHECK-NEXT: linalg.copy(%arg0, %arg2) : memref, memref +// CHECK-NEXT: return +// CHECK-NEXT: } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -62,6 +62,11 @@ let results = (outs NestedTupleOf<[I32, F32]>); } +def TensorTupleOp : TEST_Op<"tensor_tuple"> { + let arguments = (ins AnyTensor:$in0, AnyTensor:$in1); + let results = (outs TupleOf<[AnyTensor, AnyTensor]>); +} + def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> { let arguments = (ins AnyStaticShapeMemRef:$x); } diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "TestDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -106,13 +107,64 @@ } }; + /// Converts the test tensor_tuple op by propagating its operands (which are + /// tensor types converted to memref types) to the use of its result whenever + /// there is a single use at a return op. This is to test function signature + /// conversion when the return type is a tuple type of tensors. + class TestTupleOpConverter + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + TensorTupleOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(TensorTupleOp tupleOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + + // We will convert only if the op's result is used in ReturnOp, and the + // return op's only operand is that result. + if (!tupleOp.getOperation()->hasOneUse()) + return failure(); + + auto returnOp = + dyn_cast(*tupleOp.getResult().user_begin()); + if (!returnOp) { + return failure(); + } + + // Iterate over the element types of this tuple to check all have been + // converted to memrefs. + if (!llvm::all_of(TypeRange(operands), + [](Type type) { return type.isa(); })) + return failure(); + + // Replace the return op with one that takes all of the tuple's operands. + // The function signature would also be updated that way. + { + OpBuilder::InsertionGuard regionGuard(rewriter); + rewriter.setInsertionPoint(returnOp); + rewriter.create(tupleOp.getLoc(), + tupleOp.getOperands()); + } + // Erase the tuple op. + rewriter.eraseOp(tupleOp); + + // Erase the return op with the tuple type. This can't be erased + // before erasing tupleOp since dialect conversion would process + // replacement / erasure ops in reverse. + rewriter.eraseOp(returnOp); + return success(); + } + }; + void populateTensorLinalgToBufferLinalgConversionPattern( MLIRContext *context, BufferAssignmentPlacer *placer, TypeConverter *converter, OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, allowMemrefFunctionResults>(context, placer, converter, patterns); - patterns->insert(context, placer, converter); + patterns->insert(context, placer, + converter); } void runOnOperation() override {