diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -526,12 +526,12 @@ assert(operation->getNumResults() == 1 && "All TOSA elementwise ops should only return a single result."); - auto results = operation->getResults(); - auto resultTy = dyn_cast(operation->getResult(0).getType()); + auto result = operation->getResult(0); + auto resultTy = dyn_cast(result.getType()); if (!resultTy) - return rewriter.notifyMatchFailure(operation, - "All results must be a shaped type"); + return rewriter.notifyMatchFailure( + operation, "All results must be a ranked tensor type"); unsigned rank = resultTy.getRank(); @@ -545,7 +545,7 @@ SmallVector emptyTensors; SmallVector dynDims; - dynDims.resize(cast(results.front().getType()).getRank()); + dynDims.resize(rank); for (auto arg : operation->getOperands()) { auto operandTy = cast(arg.getType()); @@ -557,12 +557,9 @@ SmallVector filteredDims = condenseValues(dynDims); - for (auto result : results) { - auto resultTy = cast(result.getType()); - emptyTensors.push_back(rewriter.create( - loc, resultTy.getShape(), resultTy.getElementType(), filteredDims)); - opResultTypes.push_back(result.getType()); - } + emptyTensors.push_back( + rewriter.create(loc, resultTy, filteredDims)); + opResultTypes.push_back(result.getType()); auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( emptyTensors, [](Value v) { return getElementTypeOrSelf(v); })); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics +// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics // CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { @@ -6,3 +6,12 @@ %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> return %0 : tensor<1x7x7x9x!quant.uniform> } + +// ----- + +// CHECK-LABEL: @tensor_with_unknown_rank +func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> { + // expected-error@+1 {{failed to legalize operation 'tosa.abs'}} + %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8> + return %0 : tensor<*xi8> +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -85,8 +85,20 @@ %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> return %0 : tensor<2x?xf32> } + // ----- +#SparseVector = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }> + +// CHECK-LABEL: @test_encoding_passthrough +func.func @test_encoding_passthrough(%arg0: tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector> { + // CHECK: linalg.generic + // CHECK: sparse_tensor + %0 = "tosa.abs"(%arg0) : (tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector> + return %0 : tensor<2xi8, #SparseVector> +} + +// ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()> // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>