Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Transforms/BufferPlacement.cpp
Show First 20 Lines • Show All 43 Lines • ▼ Show 20 Lines | |||||
// be invalid with respect to program semantics. The only thing that is | // be invalid with respect to program semantics. The only thing that is | ||||
// currently missing is a high-level loop analysis that allows us to move allocs | // currently missing is a high-level loop analysis that allows us to move allocs | ||||
// and deallocs outside of the loop blocks. Furthermore, it doesn't also accept | // and deallocs outside of the loop blocks. Furthermore, it doesn't also accept | ||||
// functions which return buffers already. | // functions which return buffers already. | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
#include "mlir/Transforms/BufferPlacement.h" | #include "mlir/Transforms/BufferPlacement.h" | ||||
#include "mlir/IR/Function.h" | |||||
#include "mlir/IR/Operation.h" | |||||
#include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||
#include "mlir/Transforms/Passes.h" | #include "mlir/Transforms/Passes.h" | ||||
using namespace mlir; | using namespace mlir; | ||||
namespace { | namespace { | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
▲ Show 20 Lines • Show All 357 Lines • ▼ Show 20 Lines | |||||
/// Computes the actual position to place allocs for the given value. | /// Computes the actual position to place allocs for the given value. | ||||
OpBuilder::InsertPoint | OpBuilder::InsertPoint | ||||
BufferAssignmentPlacer::computeAllocPosition(OpResult result) { | BufferAssignmentPlacer::computeAllocPosition(OpResult result) { | ||||
Operation *owner = result.getOwner(); | Operation *owner = result.getOwner(); | ||||
return OpBuilder::InsertPoint(owner->getBlock(), Block::iterator(owner)); | return OpBuilder::InsertPoint(owner->getBlock(), Block::iterator(owner)); | ||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// FunctionAndBlockSignatureConverter | |||||
//===----------------------------------------------------------------------===// | |||||
// Performs the actual signature rewriting step. | |||||
LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite( | |||||
FuncOp funcOp, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter &rewriter) const { | |||||
if (!converter) { | |||||
funcOp.emitError("The type converter has not been defined for " | |||||
"FunctionAndBlockSignatureConverter"); | |||||
return failure(); | |||||
} | |||||
auto funcType = funcOp.getType(); | |||||
// Convert function arguments using the provided TypeConverter. | |||||
TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); | |||||
for (auto argType : llvm::enumerate(funcType.getInputs())) | |||||
conversion.addInputs(argType.index(), | |||||
converter->convertType(argType.value())); | |||||
// If a function result type is not a memref but it would be a memref after | |||||
// type conversion, a new argument should be appended to the function | |||||
// arguments list for this result. Otherwise, it remains unchanged as a | |||||
// function result. | |||||
SmallVector<Type, 2> newResultTypes; | |||||
newResultTypes.reserve(funcOp.getNumResults()); | |||||
for (Type resType : funcType.getResults()) { | |||||
Type convertedType = converter->convertType(resType); | |||||
if (BufferAssignmentTypeConverter::isConvertedMemref(convertedType, | |||||
resType)) | |||||
conversion.addInputs(convertedType); | |||||
else | |||||
newResultTypes.push_back(convertedType); | |||||
} | |||||
// Update the signature of the function. | |||||
rewriter.updateRootInPlace(funcOp, [&] { | |||||
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), | |||||
newResultTypes)); | |||||
rewriter.applySignatureConversion(&funcOp.getBody(), conversion); | |||||
}); | |||||
return success(); | |||||
} | |||||
//===----------------------------------------------------------------------===// | |||||
// BufferAssignmentCallOpConverter | |||||
//===----------------------------------------------------------------------===// | |||||
// Performs `CallOp` conversion to match its operands and results with the | |||||
// signature of the callee after rewriting the callee with | |||||
// FunctionAndBlockSignatureConverter. | |||||
LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( | |||||
CallOp callOp, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter &rewriter) const { | |||||
Location loc = callOp.getLoc(); | |||||
SmallVector<Value, 2> newOperands, replacingValues; | |||||
SmallVector<Type, 2> newResultTypes; | |||||
unsigned numResults = callOp.getNumResults(); | |||||
newOperands.reserve(numResults + operands.size()); | |||||
newOperands.append(operands.begin(), operands.end()); | |||||
newResultTypes.reserve(numResults); | |||||
replacingValues.reserve(numResults); | |||||
// For each memref result of `CallOp` which has not been a memref before type | |||||
// conversion, a new buffer is allocated and passed to the operands list of | |||||
// the new `CallOp`. Otherwise, it remains as a caller result. | |||||
for (Value result : callOp.getResults()) { | |||||
Type currType = result.getType(); | |||||
Type newType = converter->convertType(result.getType()); | |||||
if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) { | |||||
OpBuilder::InsertionGuard guard(rewriter); | |||||
rewriter.restoreInsertionPoint( | |||||
bufferAssignment->computeAllocPosition(result.dyn_cast<OpResult>())); | |||||
Value alloc = | |||||
rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>()); | |||||
newOperands.push_back(alloc); | |||||
replacingValues.push_back(alloc); | |||||
} else { | |||||
newResultTypes.push_back(currType); | |||||
// No replacing is required. | |||||
replacingValues.push_back(nullptr); | |||||
} | |||||
} | |||||
// Creating the new `CallOp`. | |||||
rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes, newOperands); | |||||
// Replacing the results of the old `CallOp`. | |||||
rewriter.replaceOp(callOp, replacingValues); | |||||
return success(); | |||||
} | |||||
//===----------------------------------------------------------------------===// | |||||
// BufferAssignmentTypeConverter | // BufferAssignmentTypeConverter | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
/// Registers conversions into BufferAssignmentTypeConverter | /// Registers conversions into BufferAssignmentTypeConverter | ||||
BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() { | BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() { | ||||
// Keep all types unchanged. | // Keep all types unchanged. | ||||
addConversion([](Type type) { return type; }); | addConversion([](Type type) { return type; }); | ||||
// A type conversion that converts ranked-tensor type to memref type. | // A type conversion that converts ranked-tensor type to memref type. | ||||
Show All 17 Lines |