diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -101,6 +101,9 @@ /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); +/// Tests if types are the same, but ignoring encoding on ranked tensors. +bool isSameTypeWithoutEncoding(Type tp1, Type tp2); + /// Returns true iff the given sparse tensor encoding attribute has a trailing /// COO region starting at the given level. bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -449,8 +449,16 @@ return nullptr; } -/// Returns true iff the given sparse tensor encoding attribute has a trailing -/// COO region starting at the given level. +bool mlir::sparse_tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) { + if (auto rtp1 = tp1.dyn_cast()) { + if (auto rtp2 = tp2.dyn_cast()) + return rtp1.getShape() == rtp2.getShape() && + rtp1.getElementType() == rtp2.getElementType(); + return false; + } + return tp1 == tp2; // default test +} + bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique) { if (!enc || diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -346,6 +346,45 @@ } }; +// Fuse a tensor cast into producing operation. Note that a tensor.cast +// should really not be used to convert between sparse encodings. Since +// the pattern currently appears as a result of some prior rewriting +// we make an attempt to repair very obvious cases. +// TODO: audit the pure tensor dialect rewriting rules +struct FuseTensorCast : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + Type srcType = op.getSource().getType(); + Type dstType = op.getDest().getType(); + // A nop cast simply folds away. + if (srcType == dstType) { + rewriter.replaceOp(op, op->getResults()); + return success(); + } + // See if a sparsity changing cast can be fused into producer. + if (isSameTypeWithoutEncoding(srcType, dstType)) { + if (Operation *def = op.getSource().getDefiningOp()) { + if (def->hasOneUse() && isa(def)) { + def->getResult(0).setType(op->getResultTypes()[0]); + rewriter.replaceOp(op, def->getResult(0)); + return success(); + } + } + } + // Repair tensor casts with at least one sparse operand into the + // the properly supported sparse_tensor.convert. + if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) { + rewriter.replaceOpWithNewOp(op, dstType, op.getSource()); + return success(); + } + // Fail otherwise. + return failure(); + } +}; + /// Sparse rewriting rule for sparse-to-sparse reshape operator. template struct Sparse2SparseReshapeRewriter : public OpRewritePattern { @@ -1125,7 +1164,7 @@ //===---------------------------------------------------------------------===// void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { - patterns.add( + patterns.add( patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/post_rewriting.mlir old mode 100755 new mode 100644 rename from mlir/test/Dialect/SparseTensor/rewriting.mlir rename to mlir/test/Dialect/SparseTensor/post_rewriting.mlir diff --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +#SortedCOO = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ] +}> + +#Slice = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ], + slice = [ (?, 1, 1), (?, 3, 1) ] +}> + +// CHECK-LABEL: func @sparse_nop_cast( +// CHECK-SAME: %[[A:.*]]: tensor>) +// CHECK: return %[[A]] : tensor> +func.func @sparse_nop_cast(%a : tensor) -> tensor { + %0 = tensor.cast %a : tensor to tensor + %1 = tensor.cast %0 : tensor to tensor + %2 = tensor.cast %1 : tensor to tensor + return %2 : tensor +} + +// CHECK-LABEL: func @sparse_repair_cast( +// CHECK-SAME: %[[A:.*]]: tensor) +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor to tensor +// CHECK: return %[[C]] : tensor> +func.func @sparse_repair_cast(%a : tensor) -> tensor { + %0 = tensor.cast %a : tensor to tensor + return %0 : tensor +} + +// CHECK-LABEL: func @sparse_fuse_slice( +// CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>>) +// CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[C]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> { + %extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64> + %cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice> + %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO> + return %0 : tensor<1x3xi64, #SortedCOO> +}