Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
Show First 20 Lines • Show All 67 Lines • ▼ Show 20 Lines | public: | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(linalg::SparseTensorToPointersMemRefOp op, | matchAndRewrite(linalg::SparseTensorToPointersMemRefOp op, | ||||
ArrayRef<Value> operands, | ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override { | ConversionPatternRewriter &rewriter) const override { | ||||
Type resType = op.getType(); | Type resType = op.getType(); | ||||
Type eltType = resType.cast<ShapedType>().getElementType(); | Type eltType = resType.cast<ShapedType>().getElementType(); | ||||
StringRef name; | StringRef name; | ||||
if (eltType.isIndex() || eltType.isInteger(64)) | if (eltType.isIndex() || eltType.isInteger(64)) | ||||
name = "sparsePtrsI64"; | name = "sparsePointers64"; | ||||
else if (eltType.isInteger(32)) | |||||
name = "sparsePointers32"; | |||||
else | else | ||||
return failure(); | return failure(); | ||||
rewriter.replaceOpWithNewOp<CallOp>( | rewriter.replaceOpWithNewOp<CallOp>( | ||||
op, resType, getFunc(op, name, resType, operands), operands); | op, resType, getFunc(op, name, resType, operands), operands); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
/// Sparse conversion rule for index accesses. | /// Sparse conversion rule for index accesses. | ||||
class TensorToIndicesConverter | class TensorToIndicesConverter | ||||
: public OpConversionPattern<linalg::SparseTensorToIndicesMemRefOp> { | : public OpConversionPattern<linalg::SparseTensorToIndicesMemRefOp> { | ||||
public: | public: | ||||
using OpConversionPattern::OpConversionPattern; | using OpConversionPattern::OpConversionPattern; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(linalg::SparseTensorToIndicesMemRefOp op, | matchAndRewrite(linalg::SparseTensorToIndicesMemRefOp op, | ||||
ArrayRef<Value> operands, | ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override { | ConversionPatternRewriter &rewriter) const override { | ||||
Type resType = op.getType(); | Type resType = op.getType(); | ||||
Type eltType = resType.cast<ShapedType>().getElementType(); | Type eltType = resType.cast<ShapedType>().getElementType(); | ||||
StringRef name; | StringRef name; | ||||
if (eltType.isIndex() || eltType.isInteger(64)) | if (eltType.isIndex() || eltType.isInteger(64)) | ||||
name = "sparseIndxsI64"; | name = "sparseIndices64"; | ||||
else if (eltType.isInteger(32)) | |||||
name = "sparseIndices32"; | |||||
else | else | ||||
return failure(); | return failure(); | ||||
rewriter.replaceOpWithNewOp<CallOp>( | rewriter.replaceOpWithNewOp<CallOp>( | ||||
op, resType, getFunc(op, name, resType, operands), operands); | op, resType, getFunc(op, name, resType, operands), operands); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
/// Sparse conversion rule for value accesses. | /// Sparse conversion rule for value accesses. | ||||
class TensorToValuesConverter | class TensorToValuesConverter | ||||
: public OpConversionPattern<linalg::SparseTensorToValuesMemRefOp> { | : public OpConversionPattern<linalg::SparseTensorToValuesMemRefOp> { | ||||
public: | public: | ||||
using OpConversionPattern::OpConversionPattern; | using OpConversionPattern::OpConversionPattern; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(linalg::SparseTensorToValuesMemRefOp op, | matchAndRewrite(linalg::SparseTensorToValuesMemRefOp op, | ||||
ArrayRef<Value> operands, | ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override { | ConversionPatternRewriter &rewriter) const override { | ||||
Type resType = op.getType(); | Type resType = op.getType(); | ||||
Type eltType = resType.cast<ShapedType>().getElementType(); | Type eltType = resType.cast<ShapedType>().getElementType(); | ||||
StringRef name; | StringRef name; | ||||
if (eltType.isF64()) | if (eltType.isF64()) | ||||
name = "sparseValsF64"; | name = "sparseValuesF64"; | ||||
else if (eltType.isF32()) | |||||
name = "sparseValuesF32"; | |||||
else | else | ||||
return failure(); | return failure(); | ||||
rewriter.replaceOpWithNewOp<CallOp>( | rewriter.replaceOpWithNewOp<CallOp>( | ||||
op, resType, getFunc(op, name, resType, operands), operands); | op, resType, getFunc(op, name, resType, operands), operands); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
Show All 10 Lines |