diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -76,11 +76,23 @@ TypeConverter *converter; }; -/// Converts the signature of the function using the type converter. -/// It adds an extra argument for each illegally-typed function -/// result to the function arguments. `BufferAssignmentTypeConverter` -/// is a helper `TypeConverter` for this purpose. All the non-shaped types -/// of the input function will be converted to memref. +/// A helper type converter class for using inside Buffer Assignment operation +/// conversion patterns. The default constructor keeps all the types intact +/// except for the ranked-tensor types which is converted to memref types. +class BufferAssignmentTypeConverter : public TypeConverter { +public: + BufferAssignmentTypeConverter(); + + /// A helper function to check if `type` has been converted from non-memref + /// type to memref. + static bool isConvertedMemref(Type type, Type before); +}; + +/// Converts the signature of the function using the type converter. It adds an +/// extra argument for each function result type which is going to be a memref +/// type after type conversion. The other function result types remain +/// unchanged. `BufferAssignmentTypeConverter` is a helper `TypeConverter` for +/// this purpose. class FunctionAndBlockSignatureConverter : public BufferAssignmentOpConversionPattern { public: @@ -93,12 +105,14 @@ ConversionPatternRewriter &rewriter) const final; }; -/// Converts the source `ReturnOp` to target `ReturnOp`, removes all -/// the buffer operands from the operands list, and inserts `CopyOp`s -/// for all buffer operands instead. +/// Rewrites the `ReturnOp` to conform with the changed function signature. +/// Operands that correspond to return values that have been rewritten from +/// tensor results to memref arguments are dropped. In their place, a +/// corresponding copy operation from the operand to the new function argument +/// is inserted. template -class NoBufferOperandsReturnOpConverter +class BufferAssignmentReturnOpConverter : public BufferAssignmentOpConversionPattern { public: using BufferAssignmentOpConversionPattern< @@ -108,50 +122,41 @@ LogicalResult matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { + // Split the operands by their kinds whether they are converted memref or + // not. + SmallVector needCopyOperands, newOperands; + unsigned operandsSize = operands.size(); + needCopyOperands.reserve(operandsSize); + newOperands.reserve(operandsSize); + for (auto operand : llvm::enumerate(operands)) + if (BufferAssignmentTypeConverter::isConvertedMemref( + operand.value().getType(), + returnOp.getOperand(operand.index()).getType())) + needCopyOperands.push_back(operand.value()); + else + newOperands.push_back(operand.value()); + Block &entryBlock = returnOp.getParentRegion()->front(); unsigned numFuncArgs = entryBlock.getNumArguments(); - Location loc = returnOp.getLoc(); - - // The target `ReturnOp` should not contain any memref operands. - SmallVector newOperands(operands.begin(), operands.end()); - llvm::erase_if(newOperands, [](Value operand) { - return operand.getType().isa(); - }); // Find the index of the first destination buffer. - unsigned numBufferOperands = operands.size() - newOperands.size(); - unsigned destArgNum = numFuncArgs - numBufferOperands; - + assert(needCopyOperands.size() <= numFuncArgs && + "The number of operands of return operation is more than the " + "number of function arguments."); + unsigned destArgNum = numFuncArgs - needCopyOperands.size(); rewriter.setInsertionPoint(returnOp); - // Find the corresponding destination buffer for each memref operand. - for (Value operand : operands) - if (operand.getType().isa()) { - assert(destArgNum < numFuncArgs && - "The number of operands of return operation is more than the " - "number of function argument."); - - // For each memref type operand of the source `ReturnOp`, a new `CopyOp` - // is inserted that copies the buffer content from the operand to the - // target. - rewriter.create(loc, operand, - entryBlock.getArgument(destArgNum)); - ++destArgNum; - } + for (Value operand : needCopyOperands) { + // Insert a `CopyOp` for each converted memref-type operand. + rewriter.create(returnOp.getLoc(), operand, + entryBlock.getArgument(destArgNum)); + ++destArgNum; + } // Insert the new target Return operation. rewriter.replaceOpWithNewOp(returnOp, newOperands); return success(); } }; - -/// A helper type converter class for using inside Buffer Assignment operation -/// conversion patterns. The default constructor keeps all the types intact -/// except for the ranked-tensor types which is converted to memref types. -class BufferAssignmentTypeConverter : public TypeConverter { -public: - BufferAssignmentTypeConverter(); -}; - } // end namespace mlir #endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -21,7 +21,7 @@ using namespace mlir; using ReturnOpConverter = - NoBufferOperandsReturnOpConverter; namespace { diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -389,7 +389,13 @@ // If there is an existing dealloc, move it to the right place. Operation *nextOp = positions.getDeallocPosition()->getNextNode(); - assert(nextOp && "Invalid Dealloc operation position"); + // If the Dealloc position is at the terminator operation of the block, + // then the value should escape from a deallocation. + if (!nextOp) { + assert(deallocs.size() == 0 && + "There should be no dealloc for the returned buffer"); + continue; + } if (deallocs.size()) { (*deallocs.begin())->moveBefore(nextOp); } else { @@ -431,11 +437,6 @@ return failure(); } auto funcType = funcOp.getType(); - TypeRange resultTypes = funcType.getResults(); - if (llvm::any_of(resultTypes, - [](Type type) { return type.isa(); })) - return funcOp.emitError("BufferAssignmentPlacer doesn't currently support " - "functions which return memref typed values"); // Convert function arguments using the provided TypeConverter. TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); @@ -443,17 +444,16 @@ conversion.addInputs(argType.index(), converter->convertType(argType.value())); - // Adding a function argument for each function result which is going to be a - // memref type after type conversion. + // If a function result type is not a memref but it would be a memref after + // type conversion, a new argument should be appended to the function + // arguments list for this result. Otherwise, it remains unchanged as a + // function result. SmallVector newResultTypes; newResultTypes.reserve(funcOp.getNumResults()); - for (Type resType : resultTypes) { + for (Type resType : funcType.getResults()) { Type convertedType = converter->convertType(resType); - - // If the result type is memref after the type conversion, a new argument - // should be appended to the function arguments list for this result. - // Otherwise, it remains unchanged as a function result. - if (convertedType.isa()) + if (BufferAssignmentTypeConverter::isConvertedMemref(convertedType, + resType)) conversion.addInputs(convertedType); else newResultTypes.push_back(convertedType); @@ -482,6 +482,11 @@ }); } +/// Checks if `type` has been converted from non-memref type to memref. +bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) { + return type.isa() && !before.isa(); +} + //===----------------------------------------------------------------------===// // BufferPlacementPass construction //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file -verify-diagnostics %s | FileCheck %s -dump-input-on-failure +// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file %s | FileCheck %s -dump-input-on-failure // CHECK-LABEL: func @func_signature_conversion func @func_signature_conversion(%arg0: tensor<4x8xf32>) { @@ -8,12 +8,28 @@ // ----- -// expected-error @below {{BufferAssignmentPlacer doesn't currently support functions which return memref typed values}} -// expected-error @below {{failed to legalize operation 'func'}} -func @memref_in_function_results(%arg0: tensor<4x8xf32>) -> (tensor<4x8xf32>, memref<5xf32>) { - %0 = alloc() : memref<5xf32> - return %arg0, %0 : tensor<4x8xf32>, memref<5xf32> +// Only tensor typed function result should be converted to memref and move to the +// function arguments list. The other memref function results remain as function +// results. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @memref_in_function_results +func @memref_in_function_results(%arg0: tensor<5xf32>, %arg1: memref<10xf32>) -> (tensor<5xf32>, memref<10xf32>, memref<15xf32>) { + %0 = alloc() : memref<15xf32> + %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + ^bb0(%gen1_arg0: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: tensor<5xf32> -> tensor<5xf32> + return %1, %arg1, %0 : tensor<5xf32>, memref<10xf32>, memref<15xf32> } +// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>) +// CHECK-SAME: (memref<10xf32>, memref<15xf32>) +// CHECK: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK: %[[LINALG_ALLOC:.*]] = alloc() +// CHECK: linalg.copy(%[[LINALG_ALLOC]], %[[RESULT]]) +// CHECK: return %[[ARG1]], %[[FIRST_ALLOC]] // ----- diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir --- a/mlir/test/Transforms/buffer-placement.mlir +++ b/mlir/test/Transforms/buffer-placement.mlir @@ -457,3 +457,32 @@ // CHECK: ^[[BB3:.*]]({{.*}}): // CHECK: linalg.copy // CHECK-NEXT: dealloc %[[GENERIC1_ALLOC]] + +// ----- + +// Test Case: buffer deallocation escaping +// BufferPlacement Expected Behaviour: It must not dealloc %arg1 and %x +// since they are operands of return operation and should escape from +// deallocating. It should dealloc %y after linalg.copy. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @memref_in_function_results +func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %arg2: memref<5xf32>) -> (memref<10xf32>, memref<15xf32>) { + %x = alloc() : memref<15xf32> + %y = alloc() : memref<5xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %y { + ^bb0(%arg3: f32, %arg4: f32): + %2 = exp %arg3 : f32 + linalg.yield %2 : f32 + }: memref<5xf32>, memref<5xf32> + linalg.copy(%y, %arg2) : memref<5xf32>, memref<5xf32> + return %arg1, %x : memref<10xf32>, memref<15xf32> +} +// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>) +// CHECK: %[[X:.*]] = alloc() +// CHECK: %[[Y:.*]] = alloc() +// CHECK: linalg.copy +// CHECK: dealloc %[[Y]] +// CHECK: return %[[ARG1]], %[[X]] + diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -23,7 +23,7 @@ namespace { /// This pass tests the computeAllocPosition helper method and two provided /// operation converters, FunctionAndBlockSignatureConverter and -/// NoBufferOperandsReturnOpConverter. Furthermore, this pass converts linalg +/// BufferAssignmentReturnOpConverter. Furthermore, this pass converts linalg /// operations on tensors to linalg operations on buffers to prepare them for /// the BufferPlacement pass that can be applied afterwards. struct TestBufferPlacementPreparationPass @@ -41,16 +41,18 @@ LogicalResult matchAndRewrite(linalg::GenericOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto loc = op.getLoc(); - SmallVector args(operands.begin(), operands.end()); + Location loc = op.getLoc(); + ResultRange results = op.getOperation()->getResults(); + SmallVector newArgs, newResults; + newArgs.reserve(operands.size() + results.size()); + newArgs.append(operands.begin(), operands.end()); + newResults.reserve(results.size()); // Update all types to memref types. - auto results = op.getOperation()->getResults(); for (auto result : results) { - auto type = result.getType().cast(); - if (!type) - op.emitOpError() - << "tensor to buffer conversion expects ranked results"; + ShapedType type = result.getType().cast(); + assert(type && "Generic operations with non-shaped typed results are " + "not currently supported."); if (!type.hasStaticShape()) return rewriter.notifyMatchFailure( op, "dynamic shapes not currently supported"); @@ -62,27 +64,39 @@ rewriter.restoreInsertionPoint( bufferAssignment->computeAllocPosition(result)); auto alloc = rewriter.create(loc, memrefType); - result.replaceAllUsesWith(alloc); - args.push_back(alloc); + newArgs.push_back(alloc); + newResults.push_back(alloc); } // Generate a new linalg operation that works on buffers. auto linalgOp = rewriter.create( - loc, llvm::None, args, rewriter.getI64IntegerAttr(operands.size()), + loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()), rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), op.iterator_types(), op.docAttr(), op.library_callAttr()); - // Move regions from the old operation to the new one. - auto ®ion = linalgOp.region(); - rewriter.inlineRegionBefore(op.region(), region, region.end()); - - // TODO: verify the internal memref-based linalg functionality. - auto &entryBlock = region.front(); - for (auto result : results) { - auto type = result.getType().cast(); - entryBlock.addArgument(type.getElementType()); - } - rewriter.eraseOp(op); + // Create a new block in the region of the new Generic Op. + Block &oldBlock = op.getRegion().front(); + Region &newRegion = linalgOp.region(); + Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), + oldBlock.getArgumentTypes()); + + // Map the old block arguments to the new ones. + BlockAndValueMapping mapping; + mapping.map(oldBlock.getArguments(), newBlock->getArguments()); + + // Add the result arguments to the new block. + for (auto result : newResults) + newBlock->addArgument( + result.getType().cast().getElementType()); + + // Clone the body of the old block to the new block. + rewriter.setInsertionPointToEnd(newBlock); + for (auto &op : oldBlock.getOperations()) + rewriter.clone(op, mapping); + + // Replace the results of the old Generic Op with the results of the new + // one. + rewriter.replaceOp(op, newResults); return success(); } }; @@ -94,34 +108,33 @@ patterns->insert< FunctionAndBlockSignatureConverter, GenericOpConverter, - NoBufferOperandsReturnOpConverter< + BufferAssignmentReturnOpConverter< ReturnOp, ReturnOp, linalg::CopyOp> >(context, placer, converter); // clang-format on } void runOnOperation() override { - auto &context = getContext(); + MLIRContext &context = getContext(); ConversionTarget target(context); BufferAssignmentTypeConverter converter; + + // Mark all Standard operations legal. target.addLegalDialect(); - // Make all linalg operations illegal as long as they work on tensors. + // Mark all Linalg operations illegal as long as they work on tensors. + auto isIllegalType = [&](Type type) { return !converter.isLegal(type); }; + auto isLegalOperation = [&](Operation *op) { + return llvm::none_of(op->getOperandTypes(), isIllegalType) && + llvm::none_of(op->getResultTypes(), isIllegalType); + }; target.addDynamicallyLegalDialect( Optional( - [&](Operation *op) { - auto isIllegalType = [&](Type type) { - return !converter.isLegal(type); - }; - return llvm::none_of(op->getOperandTypes(), isIllegalType) && - llvm::none_of(op->getResultTypes(), isIllegalType); - })); - - // Mark std.ReturnOp illegal as long as an operand is tensor or buffer. + isLegalOperation)); + + // Mark Standard Return operations illegal as long as one operand is tensor. target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return llvm::none_of(returnOp.getOperandTypes(), [&](Type type) { - return type.isa() || !converter.isLegal(type); - }); + return llvm::none_of(returnOp.getOperandTypes(), isIllegalType); }); // Mark the function whose arguments are in tensor-type illegal. @@ -130,16 +143,14 @@ }); // Walk over all the functions to apply buffer assignment. - getOperation().walk([&](FuncOp function) { + getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns; BufferAssignmentPlacer placer(function); populateTensorLinalgToBufferLinalgConversionPattern( &context, &placer, &converter, &patterns); // Applying full conversion - return failed(applyFullConversion(function, target, patterns, &converter)) - ? WalkResult::interrupt() - : WalkResult::advance(); + return applyFullConversion(function, target, patterns, &converter); }); }; };