Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Show First 20 Lines • Show All 175 Lines • ▼ Show 20 Lines | struct SparseTensorConversionPass | ||||
} | } | ||||
}; | }; | ||||
struct SparseTensorCodegenPass | struct SparseTensorCodegenPass | ||||
: public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> { | : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> { | ||||
SparseTensorCodegenPass() = default; | SparseTensorCodegenPass() = default; | ||||
SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; | SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; | ||||
SparseTensorCodegenPass(bool enableInit) { | SparseTensorCodegenPass(bool createDeallocs, bool enableInit) { | ||||
createSparseDeallocs = createDeallocs; | |||||
enableBufferInitialization = enableInit; | enableBufferInitialization = enableInit; | ||||
} | } | ||||
void runOnOperation() override { | void runOnOperation() override { | ||||
auto *ctx = &getContext(); | auto *ctx = &getContext(); | ||||
RewritePatternSet patterns(ctx); | RewritePatternSet patterns(ctx); | ||||
SparseTensorTypeToBufferConverter converter; | SparseTensorTypeToBufferConverter converter; | ||||
ConversionTarget target(*ctx); | ConversionTarget target(*ctx); | ||||
Show All 34 Lines | target.addLegalDialect< | ||||
arith::ArithDialect, bufferization::BufferizationDialect, | arith::ArithDialect, bufferization::BufferizationDialect, | ||||
complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); | complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); | ||||
target.addLegalOp<UnrealizedConversionCastOp>(); | target.addLegalOp<UnrealizedConversionCastOp>(); | ||||
// Populate with rules and apply rewriting rules. | // Populate with rules and apply rewriting rules. | ||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, | ||||
converter); | converter); | ||||
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, | scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, | ||||
target); | target); | ||||
populateSparseTensorCodegenPatterns(converter, patterns, | populateSparseTensorCodegenPatterns( | ||||
enableBufferInitialization); | converter, patterns, createSparseDeallocs, enableBufferInitialization); | ||||
if (failed(applyPartialConversion(getOperation(), target, | if (failed(applyPartialConversion(getOperation(), target, | ||||
std::move(patterns)))) | std::move(patterns)))) | ||||
signalPassFailure(); | signalPassFailure(); | ||||
} | } | ||||
}; | }; | ||||
struct SparseBufferRewritePass | struct SparseBufferRewritePass | ||||
: public impl::SparseBufferRewriteBase<SparseBufferRewritePass> { | : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> { | ||||
▲ Show 20 Lines • Show All 128 Lines • ▼ Show 20 Lines | std::unique_ptr<Pass> mlir::createSparseTensorConversionPass( | ||||
return std::make_unique<SparseTensorConversionPass>(options); | return std::make_unique<SparseTensorConversionPass>(options); | ||||
} | } | ||||
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() { | std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() { | ||||
return std::make_unique<SparseTensorCodegenPass>(); | return std::make_unique<SparseTensorCodegenPass>(); | ||||
} | } | ||||
std::unique_ptr<Pass> | std::unique_ptr<Pass> | ||||
mlir::createSparseTensorCodegenPass(bool enableBufferInitialization) { | mlir::createSparseTensorCodegenPass(bool createSparseDeallocs, | ||||
return std::make_unique<SparseTensorCodegenPass>(enableBufferInitialization); | bool enableBufferInitialization) { | ||||
return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs, | |||||
enableBufferInitialization); | |||||
} | } | ||||
std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() { | std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() { | ||||
return std::make_unique<SparseBufferRewritePass>(); | return std::make_unique<SparseBufferRewritePass>(); | ||||
} | } | ||||
std::unique_ptr<Pass> | std::unique_ptr<Pass> | ||||
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { | mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { | ||||
Show All 18 Lines |