Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Show First 20 Lines • Show All 1,112 Lines • ▼ Show 20 Lines | |||||
void mlir::populateVectorToLLVMMatrixConversionPatterns( | void mlir::populateVectorToLLVMMatrixConversionPatterns( | ||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { | LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { | ||||
MLIRContext *ctx = converter.getDialect()->getContext(); | MLIRContext *ctx = converter.getDialect()->getContext(); | ||||
patterns.insert<VectorMatmulOpConversion>(ctx, converter); | patterns.insert<VectorMatmulOpConversion>(ctx, converter); | ||||
} | } | ||||
namespace { | namespace { | ||||
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { | struct LowerVectorToLLVMPass | ||||
: public OperationPass<LowerVectorToLLVMPass, ModuleOp> { | |||||
/// Include the generated pass utilities. | /// Include the generated pass utilities. | ||||
#define GEN_PASS_ConvertVectorToLLVM | #define GEN_PASS_ConvertVectorToLLVM | ||||
#include "mlir/Conversion/Passes.h.inc" | #include "mlir/Conversion/Passes.h.inc" | ||||
void runOnModule() override; | void runOnOperation() override; | ||||
}; | }; | ||||
} // namespace | } // namespace | ||||
void LowerVectorToLLVMPass::runOnModule() { | void LowerVectorToLLVMPass::runOnOperation() { | ||||
// Perform progressive lowering of operations on slices and | // Perform progressive lowering of operations on slices and | ||||
// all contraction operations. Also applies folding and DCE. | // all contraction operations. Also applies folding and DCE. | ||||
{ | { | ||||
OwningRewritePatternList patterns; | OwningRewritePatternList patterns; | ||||
populateVectorSlicesLoweringPatterns(patterns, &getContext()); | populateVectorSlicesLoweringPatterns(patterns, &getContext()); | ||||
populateVectorContractLoweringPatterns(patterns, &getContext()); | populateVectorContractLoweringPatterns(patterns, &getContext()); | ||||
applyPatternsGreedily(getModule(), patterns); | applyPatternsGreedily(getOperation(), patterns); | ||||
} | } | ||||
// Convert to the LLVM IR dialect. | // Convert to the LLVM IR dialect. | ||||
LLVMTypeConverter converter(&getContext()); | LLVMTypeConverter converter(&getContext()); | ||||
OwningRewritePatternList patterns; | OwningRewritePatternList patterns; | ||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns); | populateVectorToLLVMMatrixConversionPatterns(converter, patterns); | ||||
populateVectorToLLVMConversionPatterns(converter, patterns); | populateVectorToLLVMConversionPatterns(converter, patterns); | ||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns); | populateVectorToLLVMMatrixConversionPatterns(converter, patterns); | ||||
populateStdToLLVMConversionPatterns(converter, patterns); | populateStdToLLVMConversionPatterns(converter, patterns); | ||||
LLVMConversionTarget target(getContext()); | LLVMConversionTarget target(getContext()); | ||||
target.addDynamicallyLegalOp<FuncOp>( | target.addDynamicallyLegalOp<FuncOp>( | ||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); | [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); | ||||
if (failed( | if (failed(applyPartialConversion(getOperation(), target, patterns, | ||||
applyPartialConversion(getModule(), target, patterns, &converter))) { | &converter))) { | ||||
signalPassFailure(); | signalPassFailure(); | ||||
} | } | ||||
} | } | ||||
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertVectorToLLVMPass() { | std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertVectorToLLVMPass() { | ||||
return std::make_unique<LowerVectorToLLVMPass>(); | return std::make_unique<LowerVectorToLLVMPass>(); | ||||
} | } |