diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -558,9 +558,9 @@ SmallVector filteredDims = condenseValues(dynDims); for (auto result : results) { - auto resultTy = cast(result.getType()); - emptyTensors.push_back(rewriter.create( - loc, resultTy.getShape(), resultTy.getElementType(), filteredDims)); + RankedTensorType rankedType = result.getType().dyn_cast(); + emptyTensors.push_back( + rewriter.create(loc, rankedType, filteredDims)); opResultTypes.push_back(result.getType()); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -85,8 +85,21 @@ %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> return %0 : tensor<2x?xf32> } + // ----- +#SparseVector = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }> + +// CHECK-LABEL: @test_encoding_passthrough +func.func @test_encoding_passthrough() -> tensor<8xi32, #SparseVector> { + %0 = tensor.empty() : tensor<8xi32, #SparseVector> + // CHECK: linalg.generic + // CHECK: sparse_tensor + %1 = "tosa.abs"(%0) : (tensor<8xi32, #SparseVector>) -> tensor<8xi32, #SparseVector> + return %1 : tensor<8xi32, #SparseVector> +} + +// ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()> // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>