diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -0,0 +1,149 @@ +//===- BufferPlacement.h - Buffer Assignment Utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines buffer assignment helper methods to compute correct +// and valid positions for placing Alloc and Dealloc operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_BUFFERPLACEMENT_H +#define MLIR_TRANSFORMS_BUFFERPLACEMENT_H + +#include "mlir/Analysis/Dominance.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +/// Prepares a buffer placement phase. It can place (user-defined) alloc +/// nodes. This simplifies the integration of the actual buffer-placement +/// pass. Sample usage: +/// BufferAssignmentPlacer baHelper(regionOp); +/// -> determine alloc positions +/// auto allocPosition = baHelper.computeAllocPosition(value); +/// -> place alloc +/// allocBuilder.setInsertionPoint(positions.getAllocPosition()); +/// +/// Note: this class is intended to be used during legalization. In order +/// to move alloc and dealloc nodes into the right places you can use the +/// createBufferPlacementPass() function. +class BufferAssignmentPlacer { +public: + /// Creates a new assignment builder. + explicit BufferAssignmentPlacer(Operation *op); + + /// Returns the operation this analysis was constructed from. + Operation *getOperation() const { return operation; } + + /// Computes the actual position to place allocs for the given result. + OpBuilder::InsertPoint computeAllocPosition(OpResult result); + +private: + /// The operation this analysis was constructed from. + Operation *operation; +}; + +/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer +/// instance. Sample usage: +/// class CustomConversionPattern : public +/// BufferAssignmentOpConversionPattern +/// { +/// ... matchAndRewrite(...) { +/// -> Access stored BufferAssignmentPlacer +/// bufferAssignment->computeAllocPosition(resultOp); +/// } +/// }; +template +class BufferAssignmentOpConversionPattern + : public OpConversionPattern { +public: + explicit BufferAssignmentOpConversionPattern( + MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr, + TypeConverter *converter = nullptr, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), + bufferAssignment(bufferAssignment), converter(converter) {} + +protected: + BufferAssignmentPlacer *bufferAssignment; + TypeConverter *converter; +}; + +/// This conversion adds an extra argument for each function result which makes +/// the converted function a void function. A type converter must be provided +/// for this conversion to convert a non-shaped type to memref. +/// BufferAssignmentTypeConverter is an helper TypeConverter for this +/// purpose. All the non-shaped type of the input function will be converted to +/// memref. +class FunctionAndBlockSignatureConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + FuncOp>::BufferAssignmentOpConversionPattern; + + /// Performs the actual signature rewriting step. + LogicalResult + matchAndRewrite(FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final; +}; + +/// This pattern converter transforms a non-void ReturnOpSourceTy into a void +/// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy +/// to copy the results to the output buffer. +template +class NonVoidToVoidReturnOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + ReturnOpSourceTy>::BufferAssignmentOpConversionPattern; + + /// Performs the actual return-op conversion step. + LogicalResult + matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + unsigned numReturnValues = returnOp.getNumOperands(); + Block &entryBlock = returnOp.getParentRegion()->front(); + unsigned numFuncArgs = entryBlock.getNumArguments(); + Location loc = returnOp.getLoc(); + + // Find the corresponding output buffer for each operand. + assert(numReturnValues <= numFuncArgs && + "The number of operands of return operation is more than the " + "number of function argument."); + unsigned firstReturnParameter = numFuncArgs - numReturnValues; + for (auto operand : llvm::enumerate(operands)) { + unsigned returnArgNumber = firstReturnParameter + operand.index(); + BlockArgument dstBuffer = entryBlock.getArgument(returnArgNumber); + if (dstBuffer == operand.value()) + continue; + + // Insert the copy operation to copy before the return. + rewriter.setInsertionPoint(returnOp); + rewriter.create(loc, operand.value(), + entryBlock.getArgument(returnArgNumber)); + } + // Insert the new target return operation. + rewriter.replaceOpWithNewOp(returnOp); + return success(); + } +}; + +/// A helper type converter class for using inside Buffer Assignment operation +/// conversion patterns. The default constructor keeps all the types intact +/// except for the ranked-tensor types which is converted to memref types. +class BufferAssignmentTypeConverter : public TypeConverter { +public: + BufferAssignmentTypeConverter(); +}; + +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -26,6 +26,9 @@ class Pass; template class OperationPass; +/// Creates an instance of the BufferPlacement pass. +std::unique_ptr createBufferPlacementPass(); + /// Creates an instance of the Canonicalizer pass. std::unique_ptr createCanonicalizerPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -102,6 +102,68 @@ let constructor = "mlir::createPipelineDataTransferPass()"; } +def BufferPlacement : Pass<"buffer-placement"> { + let summary = "Optimizes placement of alloc and dealloc operations"; + let description = [{ + This pass implements an algorithm to optimize the placement of alloc and + dealloc operations. This pass also inserts missing dealloc operations + automatically to reclaim memory. + + + Input + + ```mlir + #map0 = affine_map<(d0) -> (d0)> + module { + func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2 + ^bb1: + br ^bb3(%arg1 : memref<2xf32>) + ^bb2: + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + br ^bb3(%0 : memref<2xf32>) + ^bb3(%1: memref<2xf32>): + "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return + } + } + + ``` + + Output + + ```mlir + #map0 = affine_map<(d0) -> (d0)> + module { + func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + %0 = alloc() : memref<2xf32> + cond_br %arg0, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + br ^bb3(%arg1 : memref<2xf32>) + ^bb2: // pred: ^bb0 + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %2 = exp %arg3 : f32 + linalg.yield %2 : f32 + }: memref<2xf32>, memref<2xf32> + br ^bb3(%0 : memref<2xf32>) + ^bb3(%1: memref<2xf32>): // 2 preds: ^bb1, ^bb2 + linalg.copy(%1, %arg2) : memref<2xf32>, memref<2xf32> + dealloc %0 : memref<2xf32> + return + } + } + ``` + + }]; + let constructor = "mlir::createBufferPlacementPass()"; +} + def Canonicalizer : Pass<"canonicalize"> { let summary = "Canonicalize operations"; let description = [{ diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -0,0 +1,446 @@ +//===- BufferPlacement.cpp - the impl for buffer placement ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for computing correct alloc and dealloc positions. +// The main class is the BufferPlacementPass class that implements the +// underlying algorithm. In order to put allocations and deallocations at safe +// positions, it is significantly important to put them into the correct blocks. +// However, the liveness analysis does not pay attention to aliases, which can +// occur due to branches (and their associated block arguments) in general. For +// this purpose, BufferPlacement firstly finds all possible aliases for a single +// value (using the BufferPlacementAliasAnalysis class). Consider the following +// example: +// +// ^bb0(%arg0): +// cond_br %cond, ^bb1, ^bb2 +// ^bb1: +// br ^exit(%arg0) +// ^bb2: +// %new_value = ... +// br ^exit(%new_value) +// ^exit(%arg1): +// return %arg1; +// +// Using liveness information on its own would cause us to place the allocs and +// deallocs in the wrong block. This is due to the fact that %new_value will not +// be liveOut of its block. Instead, we have to place the alloc for %new_value +// in bb0 and its associated dealloc in exit. Using the class +// BufferPlacementAliasAnalysis, we will find out that %new_value has a +// potential alias %arg1. In order to find the dealloc position we have to find +// all potential aliases, iterate over their uses and find the common +// post-dominator block. In this block we can safely be sure that %new_value +// will die and can use liveness information to determine the exact operation +// after which we have to insert the dealloc. Finding the alloc position is +// highly similar and non- obvious. Again, we have to consider all potential +// aliases and find the common dominator block to place the alloc. +// +// TODO: +// The current implementation does not support loops and the resulting code will +// be invalid with respect to program semantics. The only thing that is +// currently missing is a high-level loop analysis that allows us to move allocs +// and deallocs outside of the loop blocks. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/BufferPlacement.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// BufferPlacementAliasAnalysis +//===----------------------------------------------------------------------===// + +/// A straight-forward alias analysis which ensures that all aliases of all +/// values will be determined. This is a requirement for the BufferPlacement +/// class since you need to determine safe positions to place alloc and +/// deallocs. +class BufferPlacementAliasAnalysis { +public: + using ValueSetT = SmallPtrSet; + +public: + /// Constructs a new alias analysis using the op provided. + BufferPlacementAliasAnalysis(Operation *op) { build(op->getRegions()); } + + /// Finds all immediate and indirect aliases this value could potentially + /// have. Note that the resulting set will also contain the value provided as + /// it is an alias of itself. + ValueSetT resolve(Value value) const { + ValueSetT result; + resolveRecursive(value, result); + return result; + } + +private: + /// Recursively determines alias information for the given value. It stores + /// all newly found potential aliases in the given result set. + void resolveRecursive(Value value, ValueSetT &result) const { + if (!result.insert(value).second) + return; + auto it = aliases.find(value); + if (it == aliases.end()) + return; + for (Value alias : it->second) + resolveRecursive(alias, result); + } + + /// This function constructs a mapping from values to its immediate aliases. + /// It iterates over all blocks, gets their predecessors, determines the + /// values that will be passed to the corresponding block arguments and + /// inserts them into the underlying map. + void build(MutableArrayRef regions) { + for (Region ®ion : regions) { + for (Block &block : region) { + // Iterate over all predecessor and get the mapped values to their + // corresponding block arguments values. + for (auto it = block.pred_begin(), e = block.pred_end(); it != e; + ++it) { + unsigned successorIndex = it.getSuccessorIndex(); + // Get the terminator and the values that will be passed to our block. + auto branchInterface = + dyn_cast((*it)->getTerminator()); + if (!branchInterface) + continue; + // Query the branch op interace to get the successor operands. + auto successorOperands = + branchInterface.getSuccessorOperands(successorIndex); + if (successorOperands.hasValue()) { + // Build the actual mapping of values to their immediate aliases. + for (auto argPair : llvm::zip(block.getArguments(), + successorOperands.getValue())) { + aliases[std::get<1>(argPair)].insert(std::get<0>(argPair)); + } + } + } + } + } + } + + /// Maps values to all immediate aliases this value can have. + llvm::DenseMap aliases; +}; + +//===----------------------------------------------------------------------===// +// BufferPlacementPositions +//===----------------------------------------------------------------------===// + +/// Stores correct alloc and dealloc positions to place dialect-specific alloc +/// and dealloc operations. +struct BufferPlacementPositions { +public: + BufferPlacementPositions() + : allocPosition(nullptr), deallocPosition(nullptr) {} + + /// Creates a new positions tuple including alloc and dealloc positions. + BufferPlacementPositions(Operation *allocPosition, Operation *deallocPosition) + : allocPosition(allocPosition), deallocPosition(deallocPosition) {} + + /// Returns the alloc position before which the alloc operation has to be + /// inserted. + Operation *getAllocPosition() const { return allocPosition; } + + /// Returns the dealloc position after which the dealloc operation has to be + /// inserted. + Operation *getDeallocPosition() const { return deallocPosition; } + +private: + Operation *allocPosition; + Operation *deallocPosition; +}; + +//===----------------------------------------------------------------------===// +// BufferPlacementAnalysis +//===----------------------------------------------------------------------===// + +// The main buffer placement analysis used to place allocs and deallocs. +class BufferPlacementAnalysis { +public: + using DeallocSetT = SmallPtrSet; + +public: + BufferPlacementAnalysis(Operation *op) + : operation(op), liveness(op), dominators(op), postDominators(op), + aliases(op) {} + + /// Computes the actual positions to place allocs and deallocs for the given + /// value. + BufferPlacementPositions + computeAllocAndDeallocPositions(OpResult result) const { + if (result.use_empty()) + return BufferPlacementPositions(result.getOwner(), result.getOwner()); + // Get all possible aliases. + auto possibleValues = aliases.resolve(result); + return BufferPlacementPositions(getAllocPosition(result, possibleValues), + getDeallocPosition(result, possibleValues)); + } + + /// Finds all associated dealloc nodes for the alloc nodes using alias + /// information. + DeallocSetT findAssociatedDeallocs(AllocOp alloc) const { + DeallocSetT result; + auto possibleValues = aliases.resolve(alloc); + for (Value alias : possibleValues) + for (Operation *user : alias.getUsers()) { + if (isa(user)) + result.insert(user); + } + return result; + } + + /// Dumps the buffer placement information to the given stream. + void print(raw_ostream &os) const { + os << "// ---- Buffer Placement -----\n"; + + for (Region ®ion : operation->getRegions()) + for (Block &block : region) + for (Operation &operation : block) + for (OpResult result : operation.getResults()) { + BufferPlacementPositions positions = + computeAllocAndDeallocPositions(result); + os << "Positions for "; + result.print(os); + os << "\n Alloc: "; + positions.getAllocPosition()->print(os); + os << "\n Dealloc: "; + positions.getDeallocPosition()->print(os); + os << "\n"; + } + } + +private: + /// Finds a correct placement block to store alloc/dealloc node according to + /// the algorithm described at the top of the file. It supports dominator and + /// post-dominator analyses via template arguments. + template + Block * + findPlacementBlock(OpResult result, + const BufferPlacementAliasAnalysis::ValueSetT &aliases, + const DominatorT &doms) const { + // Start with the current block the value is defined in. + Block *dom = result.getOwner()->getBlock(); + // Iterate over all aliases and their uses to find a safe placement block + // according to the given dominator information. + for (Value alias : aliases) + for (Operation *user : alias.getUsers()) { + // Move upwards in the dominator tree to find an appropriate + // dominator block that takes the current use into account. + dom = doms.findNearestCommonDominator(dom, user->getBlock()); + } + return dom; + } + + /// Finds a correct alloc position according to the algorithm described at + /// the top of the file. + Operation *getAllocPosition( + OpResult result, + const BufferPlacementAliasAnalysis::ValueSetT &aliases) const { + // Determine the actual block to place the alloc and get liveness + // information. + Block *placementBlock = findPlacementBlock(result, aliases, dominators); + const LivenessBlockInfo *livenessInfo = + liveness.getLiveness(placementBlock); + + // We have to ensure that the alloc will be before the first use of all + // aliases of the given value. We first assume that there are no uses in the + // placementBlock and that we can safely place the alloc before the + // terminator at the end of the block. + Operation *startOperation = placementBlock->getTerminator(); + // Iterate over all aliases and ensure that the startOperation will point to + // the first operation of all potential aliases in the placementBlock. + for (Value alias : aliases) { + Operation *aliasStartOperation = livenessInfo->getStartOperation(alias); + // Check whether the aliasStartOperation lies in the desired block and + // whether it is before the current startOperation. If yes, this will be + // the new startOperation. + if (aliasStartOperation->getBlock() == placementBlock && + aliasStartOperation->isBeforeInBlock(startOperation)) + startOperation = aliasStartOperation; + } + // startOperation is the first operation before which we can safely store + // the alloc taking all potential aliases into account. + return startOperation; + } + + /// Finds a correct dealloc position according to the algorithm described at + /// the top of the file. + Operation *getDeallocPosition( + OpResult result, + const BufferPlacementAliasAnalysis::ValueSetT &aliases) const { + // Determine the actual block to place the dealloc and get liveness + // information. + Block *placementBlock = findPlacementBlock(result, aliases, postDominators); + const LivenessBlockInfo *livenessInfo = + liveness.getLiveness(placementBlock); + + // We have to ensure that the dealloc will be after the last use of all + // aliases of the given value. We first assume that there are no uses in the + // placementBlock and that we can safely place the dealloc at the beginning. + Operation *endOperation = &placementBlock->front(); + // Iterate over all aliases and ensure that the endOperation will point to + // the last operation of all potential aliases in the placementBlock. + for (Value alias : aliases) { + Operation *aliasEndOperation = + livenessInfo->getEndOperation(alias, endOperation); + // Check whether the aliasEndOperation lies in the desired block and + // whether it is behind the current endOperation. If yes, this will be the + // new endOperation. + if (aliasEndOperation->getBlock() == placementBlock && + endOperation->isBeforeInBlock(aliasEndOperation)) + endOperation = aliasEndOperation; + } + // endOperation is the last operation behind which we can safely store the + // dealloc taking all potential aliases into account. + return endOperation; + } + + /// The operation this transformation was constructed from. + Operation *operation; + + /// The underlying liveness analysis to compute fine grained information about + /// alloc and dealloc positions. + Liveness liveness; + + /// The dominator analysis to place allocs in the appropriate blocks. + DominanceInfo dominators; + + /// The post dominator analysis to place deallocs in the appropriate blocks. + PostDominanceInfo postDominators; + + /// The internal alias analysis to ensure that allocs and deallocs take all + /// their potential aliases into account. + BufferPlacementAliasAnalysis aliases; +}; + +//===----------------------------------------------------------------------===// +// BufferPlacementPass +//===----------------------------------------------------------------------===// + +/// The actual buffer placement pass that moves alloc and dealloc nodes into +/// the right positions. It uses the algorithm described at the top of the file. +// TODO: create a templated version that allows to match dialect-specific +// alloc/dealloc nodes and to insert dialect-specific dealloc node. +struct BufferPlacementPass + : mlir::PassWrapper { + void runOnFunction() override { + // Get required analysis information first. + auto &analysis = getAnalysis(); + + // Compute an initial placement of all nodes. + llvm::SmallDenseMap placements; + getFunction().walk([&](AllocOp alloc) { + placements[alloc] = analysis.computeAllocAndDeallocPositions( + alloc.getOperation()->getResult(0)); + return WalkResult::advance(); + }); + + // Move alloc (and dealloc - if any) nodes into the right places + // and insert dealloc nodes if necessary. + getFunction().walk([&](AllocOp alloc) { + // Find already associated dealloc nodes. + auto deallocs = analysis.findAssociatedDeallocs(alloc); + if (deallocs.size() > 1) { + emitError(alloc.getLoc(), + "Not supported number of associated dealloc operations"); + return WalkResult::interrupt(); + } + + // Move alloc node to the right place. + BufferPlacementPositions &positions = placements[alloc]; + Operation *allocOperation = alloc.getOperation(); + allocOperation->moveBefore(positions.getAllocPosition()); + + // If there is an existing dealloc, move it to the right place. + if (deallocs.size()) { + Operation *nextOp = positions.getDeallocPosition()->getNextNode(); + assert(nextOp && "Invalid Dealloc operation position"); + (*deallocs.begin())->moveBefore(nextOp); + } else { + // If there is no dealloc node, insert one in the right place. + OpBuilder builder(alloc); + builder.setInsertionPointAfter(positions.getDeallocPosition()); + builder.create(allocOperation->getLoc(), alloc); + } + return WalkResult::advance(); + }); + }; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// BufferAssignmentPlacer +//===----------------------------------------------------------------------===// + +/// Creates a new assignment placer. +BufferAssignmentPlacer::BufferAssignmentPlacer(Operation *op) : operation(op) {} + +/// Computes the actual position to place allocs for the given value. +OpBuilder::InsertPoint +BufferAssignmentPlacer::computeAllocPosition(OpResult result) { + Operation *owner = result.getOwner(); + return OpBuilder::InsertPoint(owner->getBlock(), Block::iterator(owner)); +} + +//===----------------------------------------------------------------------===// +// FunctionAndBlockSignatureConverter +//===----------------------------------------------------------------------===// + +// Performs the actual signature rewriting step. +LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite( + FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (!converter) { + funcOp.emitError("The type converter has not been defined for " + "FunctionAndBlockSignatureConverter"); + return failure(); + } + // Converting shaped type arguments to memref type. + auto funcType = funcOp.getType(); + TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); + for (auto argType : llvm::enumerate(funcType.getInputs())) + conversion.addInputs(argType.index(), + converter->convertType(argType.value())); + // Adding function results to the arguments of the converted function as + // memref type. The converted function will be a void function. + for (Type resType : funcType.getResults()) + conversion.addInputs(converter->convertType((resType))); + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType( + rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); + rewriter.applySignatureConversion(&funcOp.getBody(), conversion); + }); + return success(); +} + +//===----------------------------------------------------------------------===// +// BufferAssignmentTypeConverter +//===----------------------------------------------------------------------===// + +/// Registers conversions into BufferAssignmentTypeConverter +BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() { + // Keep all types unchanged. + addConversion([](Type type) { return type; }); + // A type conversion that converts ranked-tensor type to memref type. + addConversion([](RankedTensorType type) { + return (Type)MemRefType::get(type.getShape(), type.getElementType()); + }); +} + +//===----------------------------------------------------------------------===// +// BufferPlacementPass construction +//===----------------------------------------------------------------------===// + +std::unique_ptr mlir::createBufferPlacementPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(Utils) add_mlir_library(MLIRTransforms + BufferPlacement.cpp Canonicalizer.cpp CSE.cpp DialectConversion.cpp diff --git a/mlir/test/Transforms/buffer-placement-prepration.mlir b/mlir/test/Transforms/buffer-placement-prepration.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/buffer-placement-prepration.mlir @@ -0,0 +1,143 @@ +// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file %s | FileCheck %s -dump-input-on-failure + +// CHECK-LABEL: func @func_signature_conversion +func @func_signature_conversion(%arg0: tensor<4x8xf32>) { + return +} +// CHECK: ({{.*}}: memref<4x8xf32>) { + +// ----- + +// CHECK-LABEL: func @non_void_to_void_return_op_converter +func @non_void_to_void_return_op_converter(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + return %arg0 : tensor<4x8xf32> +} +// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]<[[RANK:.*]]>, %[[RESULT:.*]]: [[TYPE]]<[[RANK]]>) { +// CHECK-NEXT: linalg.copy(%[[ARG0]], %[[RESULT]]) +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @func_and_block_signature_conversion +func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{ + cond_br %cond, ^bb1, ^bb2 + ^bb1: + br ^exit(%arg0 : tensor<2xf32>) + ^bb2: + br ^exit(%arg0 : tensor<2xf32>) + ^exit(%arg2: tensor<2xf32>): + return %arg1 : tensor<4x4xf32> +} +// CHECK: (%[[ARG0:.*]]: [[ARG0_TYPE:.*]], %[[COND:.*]]: i1, %[[ARG1:.*]]: [[ARG1_TYPE:.*]], %[[RESULT:.*]]: [[RESULT_TYPE:.*]]) { +// CHECK: br ^[[EXIT_BLOCK:.*]](%[[ARG0]] : [[ARG0_TYPE]]) +// CHECK: br ^[[EXIT_BLOCK]](%[[ARG0]] : [[ARG0_TYPE]]) +// CHECK: ^[[EXIT_BLOCK]](%{{.*}}: [[ARG0_TYPE]]) +// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[RESULT]]) +// CHECK-NEXT: return + +// ----- + +// Test Case: Simple case for checking if BufferAssignmentPlacer creates AllocOps right before GenericOps. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @compute_allocs_position_simple +func @compute_allocs_position_simple(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{ + %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + ^bb0(%gen1_arg0: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: tensor<2xf32> -> tensor<2xf32> + %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 { + ^bb0(%gen2_arg0: f32): + %tmp2 = exp %gen2_arg0 : f32 + linalg.yield %tmp2 : f32 + }: tensor<2xf32> -> tensor<2xf32> + return %1 : tensor<2xf32> +} +// CHECK: (%{{.*}}: {{.*}}, %[[ARG0:.*]]: memref<2xf32>, +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[FIRST_ALLOC]] +// CHECK: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[FIRST_ALLOC]], %[[SECOND_ALLOC]] + +// ----- + +// Test Case: if-else case for checking if BufferAssignmentPlacer creates AllocOps right before GenericOps. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @compute_allocs_position +func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{ + %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + ^bb0(%gen1_arg0: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: tensor<2xf32> -> tensor<2xf32> + %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 { + ^bb0(%gen2_arg0: f32): + %tmp2 = exp %gen2_arg0 : f32 + linalg.yield %tmp2 : f32 + }: tensor<2xf32> -> tensor<2xf32> + cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), + ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) + ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): + %2 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + ^bb0(%gen3_arg0: f32): + %tmp3 = exp %gen3_arg0 : f32 + linalg.yield %tmp3 : f32 + }: tensor<2xf32> -> tensor<2xf32> + %3 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %2 { + ^bb0(%gen4_arg0: f32): + %tmp4 = exp %gen4_arg0 : f32 + linalg.yield %tmp4 : f32 + }: tensor<2xf32> -> tensor<2xf32> + br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) + ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): + %4 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + ^bb0(%gen5_arg0: f32): + %tmp5 = exp %gen5_arg0 : f32 + linalg.yield %tmp5 : f32 + }: tensor<2xf32> -> tensor<2xf32> + %5 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %4 { + ^bb0(%gen6_arg0: f32): + %tmp6 = exp %gen6_arg0 : f32 + linalg.yield %tmp6 : f32 + }: tensor<2xf32> -> tensor<2xf32> + br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) + ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): + %6 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + ^bb0(%gen7_arg0: f32): + %tmp7 = exp %gen7_arg0 : f32 + linalg.yield %tmp7 : f32 + }: tensor<2xf32> -> tensor<2xf32> + %7 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %6 { + ^bb0(%gen8_arg0: f32): + %tmp8 = exp %gen8_arg0 : f32 + linalg.yield %tmp8 : f32 + }: tensor<2xf32> -> tensor<2xf32> + return %7 : tensor<2xf32> +} +// CHECK: (%{{.*}}: {{.*}}, %[[ARG0:.*]]: memref<2xf32>, +// CHECK-NEXT: %[[ALLOC0:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC0]] +// CHECK: %[[ALLOC1:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC0]], %[[ALLOC1]] +// CHECK: cond_br %{{.*}}, ^[[BB0:.*]]({{.*}}), ^[[BB1:.*]]( +// CHECK-NEXT: ^[[BB0]] +// CHECK-NEXT: %[[ALLOC2:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC2]] +// CHECK: %[[ALLOC3:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC2]], %[[ALLOC3]] +// CHECK: br ^[[EXIT:.*]]({{.*}}) +// CHECK-NEXT: ^[[BB1]] +// CHECK-NEXT: %[[ALLOC4:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC4]] +// CHECK: %[[ALLOC5:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC4]], %[[ALLOC5]] +// CHECK: br ^[[EXIT]] +// CHECK-NEXT: ^[[EXIT]] +// CHECK-NEXT: %[[ALLOC6:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC6]] +// CHECK: %[[ALLOC7:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC6]], %[[ALLOC7]] diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/buffer-placement.mlir @@ -0,0 +1,412 @@ +// RUN: mlir-opt -buffer-placement -split-input-file %s | FileCheck %s -dump-input-on-failure + +// This file checks the behaviour of BufferPlacement pass for moving Alloc and Dealloc +// operations and inserting the missing the DeallocOps in their correct positions. + +// Test Case: +// bb0 +// / \ +// bb1 bb2 <- Initial position of AllocOp +// \ / +// bb3 +// BufferPlacement Expected Behaviour: It should move the existing AllocOp to the entry block, +// and insert a DeallocOp at the exit block after CopyOp since %1 is an alias for %0 and %arg1. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @condBranch +func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2 +^bb1: + br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK-NEXT: cond_br +// CHECK: linalg.copy +// CHECK-NEXT: dealloc %[[ALLOC]] +// CHECK-NEXT: return + +// ----- + +// Test Case: Existing AllocOp with no users. +// BufferPlacement Expected Behaviour: It should insert a DeallocOp right before ReturnOp. + +// CHECK-LABEL: func @emptyUsesValue +func @emptyUsesValue(%arg0: memref<4xf32>) { + %0 = alloc() : memref<4xf32> + return +} +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK-NEXT: dealloc %[[ALLOC]] +// CHECK-NEXT: return + +// ----- + +// Test Case: +// bb0 +// / \ +// | bb1 <- Initial position of AllocOp +// \ / +// bb2 +// BufferPlacement Expected Behaviour: It should move the existing AllocOp to the entry block +// and insert a DeallocOp at the exit block after CopyOp since %1 is an alias for %0 and %arg1. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @criticalEdge +func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>) +^bb1: + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + br ^bb2(%0 : memref<2xf32>) +^bb2(%1: memref<2xf32>): + "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK-NEXT: cond_br +// CHECK: linalg.copy +// CHECK-NEXT: dealloc %[[ALLOC]] +// CHECK-NEXT: return + +// ----- + +// Test Case: +// bb0 <- Initial position of AllocOp +// / \ +// | bb1 +// \ / +// bb2 +// BufferPlacement Expected Behaviour: It shouldn't move the alloc position. It only inserts +// a DeallocOp at the exit block after CopyOp since %1 is an alias for %0 and %arg1. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @invCriticalEdge +func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>) +^bb1: + br ^bb2(%0 : memref<2xf32>) +^bb2(%1: memref<2xf32>): + "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK: dealloc +// CHECK-NEXT: return + +// ----- + +// Test Case: +// bb0 <- Initial position of the first AllocOp +// / \ +// bb1 bb2 +// \ / +// bb3 <- Initial position of the second AllocOp +// BufferPlacement Expected Behaviour: It shouldn't move the AllocOps. It only inserts two missing DeallocOps in the exit block. +// %5 is an alias for %0. Therefore, the DeallocOp for %0 should occur after the last GenericOp. The Dealloc for %7 should +// happen after the CopyOp. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @ifElse +func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + cond_br %arg0, ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) +^bb1(%1: memref<2xf32>, %2: memref<2xf32>): + br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>) +^bb2(%3: memref<2xf32>, %4: memref<2xf32>): + br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>) +^bb3(%5: memref<2xf32>, %6: memref<2xf32>): + %7 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %5, %7 { + ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): + %tmp2 = exp %gen2_arg0 : f32 + linalg.yield %tmp2 : f32 + }: memref<2xf32>, memref<2xf32> + "linalg.copy"(%7, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: linalg.generic +// CHECK: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: linalg.generic +// CHECK: dealloc %[[FIRST_ALLOC]] +// CHECK-NEXT: linalg.copy +// CHECK-NEXT: dealloc %[[SECOND_ALLOC]] +// CHECK-NEXT: return + +// ----- + +// Test Case: No users for buffer in if-else CFG +// bb0 <- Initial position of AllocOp +// / \ +// bb1 bb2 +// \ / +// bb3 +// BufferPlacement Expected Behaviour: It shouldn't move the AllocOp. It only inserts a missing DeallocOp +// in the exit block since %5 or %6 are the latest aliases of %0. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @ifElseNoUsers +func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + cond_br %arg0, ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) +^bb1(%1: memref<2xf32>, %2: memref<2xf32>): + br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>) +^bb2(%3: memref<2xf32>, %4: memref<2xf32>): + br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>) +^bb3(%5: memref<2xf32>, %6: memref<2xf32>): + "linalg.copy"(%arg1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK: dealloc +// CHECK-NEXT: return + +// ----- + +// Test Case: +// bb0 <- Initial position of the first AllocOp +// / \ +// bb1 bb2 +// | / \ +// | bb3 bb4 +// \ \ / +// \ / +// bb5 <- Initial position of the second AllocOp +// BufferPlacement Expected Behaviour: AllocOps shouldn't be moved. +// Two missing DeallocOps should be inserted in the exit block. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @ifElseNested +func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + cond_br %arg0, ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) +^bb1(%1: memref<2xf32>, %2: memref<2xf32>): + br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>) +^bb2(%3: memref<2xf32>, %4: memref<2xf32>): + cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>) +^bb3(%5: memref<2xf32>): + br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>) +^bb4(%6: memref<2xf32>): + br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>) +^bb5(%7: memref<2xf32>, %8: memref<2xf32>): + %9 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %7, %9 { + ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): + %tmp2 = exp %gen2_arg0 : f32 + linalg.yield %tmp2 : f32 + }: memref<2xf32>, memref<2xf32> + "linalg.copy"(%9, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: linalg.generic +// CHECK: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: linalg.generic +// CHECK: dealloc %[[FIRST_ALLOC]] +// CHECK-NEXT: linalg.copy +// CHECK-NEXT: dealloc %[[SECOND_ALLOC]] +// CHECK-NEXT: return + +// ----- + +// Test Case: Dead operations in a single block. +// BufferPlacement Expected Behaviour: It shouldn't move the AllocOps. It only inserts the two missing DeallocOps +// after the last GenericOp. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @redundantOperations +func @redundantOperations(%arg0: memref<2xf32>) { + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + %1 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0, %1 { + ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): + %tmp2 = exp %gen2_arg0 : f32 + linalg.yield %tmp2 : f32 + }: memref<2xf32>, memref<2xf32> + return +} + +// CHECK: (%[[ARG0:.*]]: {{.*}}) +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[FIRST_ALLOC]] +// CHECK: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: linalg.generic {{.*}} %[[FIRST_ALLOC]], %[[SECOND_ALLOC]] +// CHECK: dealloc +// CHECK-NEXT: dealloc +// CHECK-NEXT: return + +// ----- + +// Test Case: +// bb0 +// / \ +// Initial position of the first AllocOp -> bb1 bb2 <- Initial position of the second AllocOp +// \ / +// bb3 +// BufferPlacement Expected Behaviour: Both AllocOps should be moved to the entry block. Both missing DeallocOps should be moved to +// the exit block after CopyOp since %arg2 is an alias for %0 and %1. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @moving_alloc_and_inserting_missing_dealloc +func @moving_alloc_and_inserting_missing_dealloc(%cond: i1, %arg0: memref<2xf32>, %arg1: memref<2xf32>){ + cond_br %cond, ^bb1, ^bb2 +^bb1: + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + br ^exit(%0 : memref<2xf32>) +^bb2: + %1 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %1 { + ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): + %tmp2 = exp %gen2_arg0 : f32 + linalg.yield %tmp2 : f32 + }: memref<2xf32>, memref<2xf32> + br ^exit(%1 : memref<2xf32>) +^exit(%arg2: memref<2xf32>): + "linalg.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK-NEXT: %{{.*}} = alloc() +// CHECK-NEXT: %{{.*}} = alloc() +// CHECK: linalg.copy +// CHECK-NEXT: dealloc +// CHECK-NEXT: dealloc +// CHECK-NEXT: return + +// ----- + +// Test Case: Invalid position of the DeallocOp. There is a user after deallocation. +// bb0 +// / \ +// bb1 bb2 <- Initial position of AllocOp +// \ / +// bb3 +// BufferPlacement Expected Behaviour: It should move the AllocOp to the entry block. +// The existing DeallocOp should be moved to exit block. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @moving_invalid_dealloc_op_complex +func @moving_invalid_dealloc_op_complex(%cond: i1, %arg0: memref<2xf32>, %arg1: memref<2xf32>){ + cond_br %cond, ^bb1, ^bb2 +^bb1: + br ^exit(%arg0 : memref<2xf32>) +^bb2: + %1 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %1 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + dealloc %1 : memref<2xf32> + br ^exit(%1 : memref<2xf32>) +^exit(%arg2: memref<2xf32>): + "linalg.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK-NEXT: %{{.*}} = alloc() +// CHECK: linalg.copy +// CHECK-NEXT: dealloc +// CHECK-NEXT: return + +// ----- + +// Test Case: Iserting missing DeallocOp in a single block. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @inserting_missing_dealloc_simple +func @inserting_missing_dealloc_simple(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){ + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + "linalg.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK: linalg.copy +// CHECK-NEXT: dealloc + +// ----- + +// Test Case: Moving invalid DeallocOp (there is a user after deallocation) in a single block. + +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @moving_invalid_dealloc_op +func @moving_invalid_dealloc_op(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){ + %0 = alloc() : memref<2xf32> + linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + dealloc %0 : memref<2xf32> + "linalg.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK: linalg.copy +// CHECK-NEXT: dealloc \ No newline at end of file diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(MLIRTestTransforms TestAllReduceLowering.cpp + TestBufferPlacement.cpp TestCallGraph.cpp TestConstantFold.cpp TestConvertGPUKernelToCubin.cpp diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -0,0 +1,153 @@ +//===- TestBufferPlacement.cpp - Test for buffer placement 0----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing buffer placement including its +// utility converters. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/BufferPlacement.h" + +using namespace mlir; + +namespace { +/// This pass tests the computeAllocPosition helper method and two provided +/// operation converters, FunctionAndBlockSignatureConverter and +/// NonVoidToVoidReturnOpConverter. Furthermore, this pass converts linalg +/// operations on tensors to linalg operations on buffers to prepare them for +/// the BufferPlacement pass that can be applied afterwards. +struct TestBufferPlacementPreparationPass + : mlir::PassWrapper> { + + /// Converts tensor-type generic linalg operations to memref ones using buffer + /// assignment. + class GenericOpConverter + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + linalg::GenericOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(linalg::GenericOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + SmallVector args(operands.begin(), operands.end()); + + // Update all types to memref types. + auto results = op.getOperation()->getResults(); + for (auto result : results) { + auto type = result.getType().cast(); + if (!type) + op.emitOpError() + << "tensor to buffer conversion expects ranked results"; + if (!type.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "dynamic shapes not currently supported"); + auto memrefType = + MemRefType::get(type.getShape(), type.getElementType()); + + // Compute alloc position and insert a custom allocation node. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result)); + auto alloc = rewriter.create(loc, memrefType); + result.replaceAllUsesWith(alloc); + args.push_back(alloc); + } + + // Generate a new linalg operation that works on buffers. + auto linalgOp = rewriter.create( + loc, llvm::None, args, rewriter.getI64IntegerAttr(operands.size()), + rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), + op.iterator_types(), op.docAttr(), op.library_callAttr()); + + // Move regions from the old operation to the new one. + auto ®ion = linalgOp.region(); + rewriter.inlineRegionBefore(op.region(), region, region.end()); + + // TODO: verify the internal memref-based linalg functionality. + auto &entryBlock = region.front(); + for (auto result : results) { + auto type = result.getType().cast(); + entryBlock.addArgument(type.getElementType()); + } + + rewriter.eraseOp(op); + return success(); + } + }; + + void populateTensorLinalgToBufferLinalgConversionPattern( + MLIRContext *context, BufferAssignmentPlacer *placer, + TypeConverter *converter, OwningRewritePatternList *patterns) { + // clang-format off + patterns->insert< + FunctionAndBlockSignatureConverter, + GenericOpConverter, + NonVoidToVoidReturnOpConverter< + ReturnOp, ReturnOp, linalg::CopyOp> + >(context, placer, converter); + // clang-format on + } + + void runOnOperation() override { + getOperation().walk([&](FuncOp function) { + OwningRewritePatternList patterns; + auto &context = getContext(); + ConversionTarget target(context); + + // Make all linalg operations illegal as long as they work on tensors. + BufferAssignmentTypeConverter converter; + target.addLegalDialect(); + target.addDynamicallyLegalDialect( + Optional( + [&](Operation *op) { + auto isIllegalType = [&](Type type) { + return !converter.isLegal(type); + }; + return llvm::none_of(op->getOperandTypes(), isIllegalType) && + llvm::none_of(op->getResultTypes(), isIllegalType); + })); + + // Mark return operations illegal as long as they return values. + target.addDynamicallyLegalOp([](mlir::ReturnOp returnOp) { + return returnOp.getNumOperands() == 0; + }); + + // Mark the function whose arguments are in tensor-type illegal. + target.addDynamicallyLegalOp([&](FuncOp funcOp) { + return converter.isSignatureLegal(funcOp.getType()); + }); + + BufferAssignmentPlacer placer(function); + populateTensorLinalgToBufferLinalgConversionPattern( + function.getContext(), &placer, &converter, &patterns); + + // Applying full conversion + return failed(applyFullConversion(function, target, patterns, &converter)) + ? WalkResult::interrupt() + : WalkResult::advance(); + }); + }; +}; +} // end anonymous namespace + +namespace mlir { +void registerTestBufferPlacementPreparationPass() { + PassRegistration( + "test-buffer-placement-preparation", + "Tests buffer placement helper methods including its " + "operation-conversion patterns"); +} +} // end namespace mlir \ No newline at end of file diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -41,6 +41,7 @@ void registerTestAffineDataCopyPass(); void registerTestAllReduceLoweringPass(); void registerTestAffineLoopUnswitchingPass(); +void registerTestBufferPlacementPreparationPass(); void registerTestLinalgMatmulToVectorPass(); void registerTestLoopPermutationPass(); void registerTestCallGraphPass(); @@ -112,6 +113,7 @@ #if MLIR_CUDA_CONVERSIONS_ENABLED registerTestConvertGPUKernelToCubinPass(); #endif + registerTestBufferPlacementPreparationPass(); registerTestDominancePass(); registerTestFunc(); registerTestGpuMemoryPromotionPass();