diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -844,6 +844,18 @@ getReassociationIndicesAttribute(b, reassociation)); } +// Checks if types are the same, but ignoring encoding on ranked tensors. +static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) { + if (auto rtp1 = tp1.dyn_cast()) { + if (auto rtp2 = tp2.dyn_cast()) + return rtp1.getShape() == rtp2.getShape() && + rtp1.getElementType() == rtp2.getElementType(); + return false; + } + // Default implementation. + return tp1 == tp2; +} + template ::value> static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, @@ -856,7 +868,7 @@ auto maps = op.getReassociationMaps(); RankedTensorType expectedType = computeTensorReshapeCollapsedType(expandedType, maps); - if (collapsedType != expectedType) + if (!isSameTypesWithoutEncoding(collapsedType, expectedType)) return op.emitOpError("expected collapsed type to be ") << expectedType << ", but got " << collapsedType; return success(); diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// TODO: check lowering to an actual implementation + +#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> +#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +// CHECK-LABEL: func.func @sparse_expand( +// CHECK-SAME: %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> +func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : + tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix> + return %0 : tensor<10x10xf64, #SparseMatrix> +} + +// CHECK-LABEL: func.func @sparse_collapse( +// CHECK-SAME: %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : + tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector> + return %0 : tensor<100xf64, #SparseVector> +}