diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -418,6 +418,22 @@ } }; +/// Sparse conversion rule for trivial tensor casts. +class SparseCastConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only rewrite identically annotated source/dest. + auto encDst = getSparseTensorEncoding(op.getType()); + auto encSrc = getSparseTensorEncoding(op.source().getType()); + if (!encDst || encDst != encSrc) + return failure(); + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; + /// Sparse conversion rule for the new operator. class SparseTensorNewConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -719,9 +735,10 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + SparseCastConverter, SparseTensorNewConverter, + SparseTensorInitConverter, SparseTensorConvertConverter, + SparseTensorReleaseConverter, SparseTensorToPointersConverter, + SparseTensorToIndicesConverter, SparseTensorToValuesConverter, + SparseTensorToTensorConverter>(typeConverter, + patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -97,11 +97,11 @@ RewritePatternSet patterns(ctx); SparseTensorTypeConverter converter; ConversionTarget target(*ctx); - target.addIllegalOp(); - // All dynamic rules below accept new function, call, return, and dimop - // operations as legal output of the rewriting provided that all sparse - // tensor types have been fully rewritten. + // Everything in the sparse dialect must go! + target.addIllegalDialect(); + // All dynamic rules below accept new function, call, return, and tensor + // dim and cast operations as legal output of the rewriting provided that + // all sparse tensor types have been fully rewritten. target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); target.addDynamicallyLegalOp([&](CallOp op) { @@ -112,10 +112,13 @@ target.addDynamicallyLegalOp([&](tensor::DimOp op) { return converter.isLegal(op.getOperandTypes()); }); + target.addDynamicallyLegalOp([&](tensor::CastOp op) { + return converter.isLegal(op.getOperand().getType()); + }); // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. target.addLegalOp(); + arith::IndexCastOp, tensor::ExtractOp>(); target.addLegalDialect(); // Populate with rules and apply rewriting rules. diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -150,6 +150,14 @@ return %0 : tensor<64xf32, #SparseVector> } +// CHECK-LABEL: func @sparse_nop_cast( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: return %[[A]] : !llvm.ptr +func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { + %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor + return %0 : tensor +} + // CHECK-LABEL: func @sparse_convert_1d( // CHECK-SAME: %[[A:.*]]: tensor) -> !llvm.ptr // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index