diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -101,6 +101,9 @@ /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); +/// Tests if types are the same, but ignoring encoding on ranked tensors. +bool isSameTypesWithoutEncoding(Type tp1, Type tp2); + /// Returns true iff the given sparse tensor encoding attribute has a trailing /// COO region starting at the given level. bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -449,8 +449,16 @@ return nullptr; } -/// Returns true iff the given sparse tensor encoding attribute has a trailing -/// COO region starting at the given level. +bool mlir::sparse_tensor::isSameTypesWithoutEncoding(Type tp1, Type tp2) { + if (auto rtp1 = tp1.dyn_cast()) { + if (auto rtp2 = tp2.dyn_cast()) + return rtp1.getShape() == rtp2.getShape() && + rtp1.getElementType() == rtp2.getElementType(); + return false; + } + return tp1 == tp2; // default test +} + bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique) { if (!enc || diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/post_rewriting.mlir old mode 100755 new mode 100644 rename from mlir/test/Dialect/SparseTensor/rewriting.mlir rename to mlir/test/Dialect/SparseTensor/post_rewriting.mlir diff --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +#SortedCOO = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ] +}> + +#Slice = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ], + slice = [ (?, 1, 1), (?, 3, 1) ] +}> + +// CHECK-LABEL: func @sparse_nop_cast( +// CHECK-SAME: %[[A:.*]]: tensor>) +// CHECK: return %[[A]] : tensor> +func.func @sparse_nop_cast(%a : tensor) -> tensor { + %0 = tensor.cast %a : tensor to tensor + %1 = tensor.cast %0 : tensor to tensor + %2 = tensor.cast %1 : tensor to tensor + return %2 : tensor +} + +// CHECK-LABEL: func @sparse_repair_cast( +// CHECK-SAME: %[[A:.*]]: tensor) +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor to tensor +// CHECK: return %[[C]] : tensor> +func.func @sparse_repair_cast(%a : tensor) -> tensor { + %0 = tensor.cast %a : tensor to tensor + return %0 : tensor +} + +// CHECK-LABEL: func @sparse_fuse_slice( +// CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>>) +// CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[C]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> { + %extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64> + %cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice> + %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO> + return %0 : tensor<1x3xi64, #SortedCOO> +}