Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Show First 20 Lines • Show All 774 Lines • ▼ Show 20 Lines | private: | ||||
bool enableBufferInitialization; | bool enableBufferInitialization; | ||||
}; | }; | ||||
/// Sparse codegen rule for the dealloc operator. | /// Sparse codegen rule for the dealloc operator. | ||||
class SparseTensorDeallocConverter | class SparseTensorDeallocConverter | ||||
: public OpConversionPattern<bufferization::DeallocTensorOp> { | : public OpConversionPattern<bufferization::DeallocTensorOp> { | ||||
public: | public: | ||||
using OpConversionPattern::OpConversionPattern; | using OpConversionPattern::OpConversionPattern; | ||||
SparseTensorDeallocConverter(TypeConverter &typeConverter, | |||||
MLIRContext *context, bool createDeallocs) | |||||
: OpConversionPattern(typeConverter, context), | |||||
createDeallocs(createDeallocs) {} | |||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, | matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, | ||||
ConversionPatternRewriter &rewriter) const override { | ConversionPatternRewriter &rewriter) const override { | ||||
auto enc = getSparseTensorEncoding(op.getTensor().getType()); | auto enc = getSparseTensorEncoding(op.getTensor().getType()); | ||||
if (!enc) | if (!enc) | ||||
return failure(); | return failure(); | ||||
// If user requests not to deallocate sparse tensors, simply erase the | |||||
// operation. | |||||
if (createDeallocs) { | |||||
// Replace the sparse tensor deallocation with field deallocations. | // Replace the sparse tensor deallocation with field deallocations. | ||||
Location loc = op.getLoc(); | Location loc = op.getLoc(); | ||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); | ||||
for (auto input : desc.getMemRefFields()) | for (auto input : desc.getMemRefFields()) | ||||
// Deallocate every buffer used to store the sparse tensor handler. | // Deallocate every buffer used to store the sparse tensor handler. | ||||
rewriter.create<memref::DeallocOp>(loc, input); | rewriter.create<memref::DeallocOp>(loc, input); | ||||
} | |||||
rewriter.eraseOp(op); | rewriter.eraseOp(op); | ||||
return success(); | return success(); | ||||
} | } | ||||
private: | |||||
bool createDeallocs; | |||||
}; | }; | ||||
/// Sparse codegen rule for tensor rematerialization. | /// Sparse codegen rule for tensor rematerialization. | ||||
class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { | class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { | ||||
public: | public: | ||||
using OpConversionPattern::OpConversionPattern; | using OpConversionPattern::OpConversionPattern; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(LoadOp op, OpAdaptor adaptor, | matchAndRewrite(LoadOp op, OpAdaptor adaptor, | ||||
▲ Show 20 Lines • Show All 679 Lines • ▼ Show 20 Lines | |||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Public method for populating conversion rules. | // Public method for populating conversion rules. | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
/// Populates the given patterns list with conversion rules required for | /// Populates the given patterns list with conversion rules required for | ||||
/// the sparsification of linear algebra operations. | /// the sparsification of linear algebra operations. | ||||
void mlir::populateSparseTensorCodegenPatterns( | void mlir::populateSparseTensorCodegenPatterns( | ||||
TypeConverter &typeConverter, RewritePatternSet &patterns, | TypeConverter &typeConverter, RewritePatternSet &patterns, | ||||
bool enableBufferInitialization) { | bool createSparseDeallocs, bool enableBufferInitialization) { | ||||
patterns.add<SparsePackOpConverter, SparseUnpackOpConverter, | patterns.add<SparsePackOpConverter, SparseUnpackOpConverter, | ||||
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter, | SparseReturnConverter, SparseCallConverter, SparseDimOpConverter, | ||||
SparseCastConverter, SparseTensorDeallocConverter, | SparseCastConverter, SparseExtractSliceConverter, | ||||
SparseExtractSliceConverter, SparseTensorLoadConverter, | SparseTensorLoadConverter, SparseExpandConverter, | ||||
SparseExpandConverter, SparseCompressConverter, | SparseCompressConverter, SparseInsertConverter, | ||||
SparseInsertConverter, | |||||
SparseSliceGetterOpConverter<ToSliceOffsetOp, | SparseSliceGetterOpConverter<ToSliceOffsetOp, | ||||
StorageSpecifierKind::DimOffset>, | StorageSpecifierKind::DimOffset>, | ||||
SparseSliceGetterOpConverter<ToSliceStrideOp, | SparseSliceGetterOpConverter<ToSliceStrideOp, | ||||
StorageSpecifierKind::DimStride>, | StorageSpecifierKind::DimStride>, | ||||
SparseToPositionsConverter, SparseToCoordinatesConverter, | SparseToPositionsConverter, SparseToCoordinatesConverter, | ||||
SparseToCoordinatesBufferConverter, SparseToValuesConverter, | SparseToCoordinatesBufferConverter, SparseToValuesConverter, | ||||
SparseConvertConverter, SparseNewOpConverter, | SparseConvertConverter, SparseNewOpConverter, | ||||
SparseNumberOfEntriesConverter>(typeConverter, | SparseNumberOfEntriesConverter>(typeConverter, | ||||
patterns.getContext()); | patterns.getContext()); | ||||
patterns.add<SparseTensorDeallocConverter>( | |||||
aartbik: how about not adding rule when disabled? | |||||
my mistake, we still erase... aartbik: my mistake, we still erase... | |||||
typeConverter, patterns.getContext(), createSparseDeallocs); | |||||
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(), | patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(), | ||||
enableBufferInitialization); | enableBufferInitialization); | ||||
} | } |
how about not adding rule when disabled?