diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -157,6 +157,21 @@ return success(); } }; + +/// Converts `CallOp` to match its operands and results with the +/// the callee after rewriting the callee with +/// FunctionAndBlockSignatureConverter. +class BufferAssignmentCallOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + CallOp>::BufferAssignmentOpConversionPattern; + + /// Performs the actual `CallOp` conversion step. + LogicalResult + matchAndRewrite(CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final; +}; } // end namespace mlir #endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -469,6 +469,57 @@ } //===----------------------------------------------------------------------===// +// BufferAssignmentCallOpConverter +//===----------------------------------------------------------------------===// + +// Performs `CallOp` conversion to match its operands and results with the +// signature of the callee after rewriting the callee with +// FunctionAndBlockSignatureConverter. +LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( + CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + + Location loc = callOp.getLoc(); + SmallVector newOperands, replacingValues; + SmallVector newResultTypes; + unsigned numResults = callOp.getNumResults(); + newOperands.reserve(numResults + operands.size()); + newOperands.append(operands.begin(), operands.end()); + newResultTypes.reserve(numResults); + replacingValues.reserve(numResults); + + // For each memref result of `CallOp` which has not been a memref before type + // conversion, a new buffer is allocated and passed to the operands list of + // the new `CallOp`. Otherwise, it remains as a caller result. + for (Value result : callOp.getResults()) { + Type currType = result.getType(); + Type newType = converter->convertType(result.getType()); + if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result.dyn_cast())); + Value alloc = + rewriter.create(loc, newType.dyn_cast()); + newOperands.push_back(alloc); + replacingValues.push_back(alloc); + } else { + newResultTypes.push_back(currType); + + // No replacing is required. + replacingValues.push_back(nullptr); + } + } + + // Creating the new `CallOp`. + rewriter.create(loc, callOp.getCallee(), newResultTypes, newOperands); + + // Replacing the results of the old `CallOp`. + rewriter.replaceOp(callOp, replacingValues); + + return success(); +} + +//===----------------------------------------------------------------------===// // BufferAssignmentTypeConverter //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -195,3 +195,92 @@ // CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC6]] // CHECK: %[[ALLOC7:.*]] = alloc() // CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC6]], %[[ALLOC7]] + +// ----- + +// Test case: Checking BufferAssignmentCallOpConverter and +// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all +// together. The signature of `callee` after signature conversion would be: + +// func @callee(%arg0: memref<5xf32>,%arg1: memref<5xf32>) -> () + +// The operands and results of caller and return operations must be matched +// respectively. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @callee +func @callee(%arg1: tensor<5xf32>) -> tensor<5xf32> { + %0 = linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"] + } %arg1 { + ^bb0(%gen1_arg0: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: tensor<5xf32> -> tensor<5xf32> + return %0 : tensor<5xf32> +} +// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>) +// CHECK: %[[ALLOC:.*]] = alloc() +// CHECK: linalg.generic +// CHECK: linalg.copy(%[[ALLOC]], %[[CALLEE_RESULT]]) +// CHECK: return + +// CHECK-LABEL: func @caller +func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> { + %x = call @callee(%arg0) : (tensor<5xf32>) -> tensor<5xf32> + %y = call @callee(%x) : (tensor<5xf32>) -> tensor<5xf32> + return %y : tensor<5xf32> +} +// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>, %[[CALLER_RESULT:.*]]: memref<5xf32>) +// CHECK: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK: call @callee(%[[CALLER_ARG]], %[[FIRST_ALLOC]]) +// CHECK: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK: call @callee(%[[FIRST_ALLOC]], %[[SECOND_ALLOC]]) +// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[CALLER_RESULT]]) +// CHECK: return + +// ----- + +// Test case: Checking BufferAssignmentCallOpConverter and +// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all +// together on functions that also have memref typed results. The signature of +// `callee` after signature conversion would be: + +// func @callee(%arg0: memref<5xf32>,%arg1: memref<5xf32>)-> memref<2xf32> + +// where %arg0 is the input and %arg1 is the output buffer and the original memref +// type result remain as the function result. Then, the rewriter should match the +// caller's signature with the callee. Thus, two buffers will be allocated instead +// of %x0 and %y0 and they are passed to the callers' operands list as the output +// buffers. %x1 and %y1 remain as callers' results. + + +// CHECK-LABEL: func @callee +func @callee(%arg1: tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) { + %buff = alloc() : memref<2xf32> + return %arg1, %buff : tensor<5xf32>, memref<2xf32> +} +// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>) +// CHECK-SAME: memref<2xf32> +// CHECK: %[[ALLOC:.*]] = alloc() +// CHECK: linalg.copy(%[[CALLEE_ARG]], %[[CALLEE_RESULT]]) +// CHECK: return %[[ALLOC]] + + +// CHECK-LABEL: func @caller +func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> { + %x0, %x1 = call @callee(%arg0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) + %y0, %y1 = call @callee(%x0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) + return %y0 : tensor<5xf32> +} +// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>, %[[CALLER_RESULT:.*]]: memref<5xf32>) +// CHECK: %[[X0:.*]] = alloc() +// CHECK: %[[X1:.*]] = call @callee(%[[CALLER_ARG]], %[[X0]]) +// CHECK: %[[Y0:.*]] = alloc() +// CHECK: %[[Y1:.*]] = call @callee(%[[X0]], %[[Y0]]) +// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]]) +// CHECK: return diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -106,6 +106,7 @@ TypeConverter *converter, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< + BufferAssignmentCallOpConverter, FunctionAndBlockSignatureConverter, GenericOpConverter, BufferAssignmentReturnOpConverter< @@ -137,6 +138,12 @@ return llvm::none_of(returnOp.getOperandTypes(), isIllegalType); }); + // Mark Standard Call Operation illegal as long as it operates on tensor. + target.addDynamicallyLegalOp([&](mlir::CallOp callOp) { + return llvm::none_of(callOp.getOperandTypes(), isIllegalType) && + llvm::none_of(callOp.getResultTypes(), isIllegalType); + }); + // Mark the function whose arguments are in tensor-type illegal. target.addDynamicallyLegalOp([&](FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getType());