Changeset View
Changeset View
Standalone View
Standalone View
mlir/test/lib/Transforms/TestBufferPlacement.cpp
Show All 15 Lines | |||||
#include "mlir/IR/Operation.h" | #include "mlir/IR/Operation.h" | ||||
#include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||
#include "mlir/Pass/PassManager.h" | #include "mlir/Pass/PassManager.h" | ||||
#include "mlir/Transforms/BufferPlacement.h" | #include "mlir/Transforms/BufferPlacement.h" | ||||
using namespace mlir; | using namespace mlir; | ||||
namespace { | namespace { | ||||
/// This pass tests the computeAllocPosition helper method and two provided | /// This pass tests the computeAllocPosition helper method and buffer assignment | ||||
/// operation converters, FunctionAndBlockSignatureConverter and | /// operation converters. Furthermore, this pass converts linalg operations on | ||||
/// BufferAssignmentReturnOpConverter. Furthermore, this pass converts linalg | /// tensors to linalg operations on buffers to prepare them for the | ||||
/// operations on tensors to linalg operations on buffers to prepare them for | /// BufferPlacement pass that can be applied afterwards. | ||||
/// the BufferPlacement pass that can be applied afterwards. | /// `allowMemrefFunctionResults` informs the buffer placement to allow functions | ||||
/// that have memref typed results. Buffer assignment operation converters will | |||||
/// be adapted respectively. It will also allow memref typed results to escape | |||||
/// from the deallocation. | |||||
template <bool allowMemrefFunctionResults> | |||||
struct TestBufferPlacementPreparationPass | struct TestBufferPlacementPreparationPass | ||||
: mlir::PassWrapper<TestBufferPlacementPreparationPass, | : mlir::PassWrapper< | ||||
TestBufferPlacementPreparationPass<allowMemrefFunctionResults>, | |||||
OperationPass<ModuleOp>> { | OperationPass<ModuleOp>> { | ||||
/// Converts tensor-type generic linalg operations to memref ones using buffer | /// Converts tensor-type generic linalg operations to memref ones using | ||||
/// assignment. | /// buffer assignment. | ||||
class GenericOpConverter | class GenericOpConverter | ||||
: public BufferAssignmentOpConversionPattern<linalg::GenericOp> { | : public BufferAssignmentOpConversionPattern<linalg::GenericOp> { | ||||
public: | public: | ||||
pifon2a: Maybe move
```
using BAFuncOpConverter =
… | |||||
Replaced with populateWithBufferAssignmentOpConversionPatterns dfki-ehna: Replaced with `populateWithBufferAssignmentOpConversionPatterns` | |||||
using BufferAssignmentOpConversionPattern< | using BufferAssignmentOpConversionPattern< | ||||
linalg::GenericOp>::BufferAssignmentOpConversionPattern; | linalg::GenericOp>::BufferAssignmentOpConversionPattern; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands, | matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const final { | ConversionPatternRewriter &rewriter) const final { | ||||
Location loc = op.getLoc(); | Location loc = op.getLoc(); | ||||
ResultRange results = op.getOperation()->getResults(); | ResultRange results = op.getOperation()->getResults(); | ||||
▲ Show 20 Lines • Show All 53 Lines • ▼ Show 20 Lines | matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands, | ||||
rewriter.replaceOp(op, newResults); | rewriter.replaceOp(op, newResults); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
void populateTensorLinalgToBufferLinalgConversionPattern( | void populateTensorLinalgToBufferLinalgConversionPattern( | ||||
MLIRContext *context, BufferAssignmentPlacer *placer, | MLIRContext *context, BufferAssignmentPlacer *placer, | ||||
TypeConverter *converter, OwningRewritePatternList *patterns) { | TypeConverter *converter, OwningRewritePatternList *patterns) { | ||||
// clang-format off | populateWithBufferAssignmentOpConversionPatterns< | ||||
patterns->insert< | mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, | ||||
BufferAssignmentCallOpConverter, | allowMemrefFunctionResults>(context, placer, converter, patterns); | ||||
FunctionAndBlockSignatureConverter, | patterns->insert<GenericOpConverter>(context, placer, converter); | ||||
GenericOpConverter, | |||||
BufferAssignmentReturnOpConverter< | |||||
ReturnOp, ReturnOp, linalg::CopyOp> | |||||
>(context, placer, converter); | |||||
// clang-format on | |||||
} | } | ||||
void runOnOperation() override { | void runOnOperation() override { | ||||
MLIRContext &context = getContext(); | MLIRContext &context = this->getContext(); | ||||
ConversionTarget target(context); | ConversionTarget target(context); | ||||
BufferAssignmentTypeConverter converter; | BufferAssignmentTypeConverter converter; | ||||
// Mark all Standard operations legal. | // Mark all Standard operations legal. | ||||
target.addLegalDialect<StandardOpsDialect>(); | target.addLegalDialect<StandardOpsDialect>(); | ||||
// Mark all Linalg operations illegal as long as they work on tensors. | // Mark all Linalg operations illegal as long as they work on tensors. | ||||
auto isIllegalType = [&](Type type) { return !converter.isLegal(type); }; | auto isIllegalType = [&](Type type) { return !converter.isLegal(type); }; | ||||
Show All 17 Lines | void runOnOperation() override { | ||||
}); | }); | ||||
// Mark the function whose arguments are in tensor-type illegal. | // Mark the function whose arguments are in tensor-type illegal. | ||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) { | target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) { | ||||
return converter.isSignatureLegal(funcOp.getType()); | return converter.isSignatureLegal(funcOp.getType()); | ||||
}); | }); | ||||
// Walk over all the functions to apply buffer assignment. | // Walk over all the functions to apply buffer assignment. | ||||
getOperation().walk([&](FuncOp function) -> WalkResult { | this->getOperation().walk([&](FuncOp function) -> WalkResult { | ||||
just out of curiosity: pifon2a: just out of curiosity:
will replacing `Super::` with `this->` work? | |||||
Replaced. dfki-ehna: Replaced. | |||||
OwningRewritePatternList patterns; | OwningRewritePatternList patterns; | ||||
BufferAssignmentPlacer placer(function); | BufferAssignmentPlacer placer(function); | ||||
populateTensorLinalgToBufferLinalgConversionPattern( | populateTensorLinalgToBufferLinalgConversionPattern( | ||||
&context, &placer, &converter, &patterns); | &context, &placer, &converter, &patterns); | ||||
// Applying full conversion | // Applying full conversion | ||||
return applyFullConversion(function, target, patterns, &converter); | return applyFullConversion(function, target, patterns, &converter); | ||||
}); | }); | ||||
}; | }; | ||||
}; | }; | ||||
} // end anonymous namespace | } // end anonymous namespace | ||||
namespace mlir { | namespace mlir { | ||||
void registerTestBufferPlacementPreparationPass() { | void registerTestBufferPlacementPreparationPass() { | ||||
PassRegistration<TestBufferPlacementPreparationPass>( | PassRegistration< | ||||
TestBufferPlacementPreparationPass</*allowMemrefFunctionResults=*/false>>( | |||||
"test-buffer-placement-preparation", | "test-buffer-placement-preparation", | ||||
"Tests buffer placement helper methods including its " | "Tests buffer placement helper methods including its " | ||||
"operation-conversion patterns"); | "operation-conversion patterns"); | ||||
} | } | ||||
} // end namespace mlir | |||||
No newline at end of file | void registerTestPreparationPassWithAllowedMemrefResults() { | ||||
PassRegistration< | |||||
TestBufferPlacementPreparationPass</*allowMemrefFunctionResults=*/true>>( | |||||
"test-buffer-placement-preparation-with-allowed-memref-results"? to make it consistent with the pass above. pifon2a: "test-buffer-placement-preparation-with-allowed-memref-results"? to make it consistent with the… | |||||
"test-buffer-placement-preparation-with-allowed-memref-results", | |||||
"Tests the helper operation converters of buffer placement for allowing " | |||||
"functions to have memref typed results."); | |||||
} | |||||
} // end namespace mlir |
Maybe move
closer to insert<> call. In that case, just remove BA part: