diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -615,16 +615,27 @@ SmallVector opResultTypes; SmallVector initTensors; + + SmallVector dynDims; + dynDims.resize(results.front().getType().cast().getRank()); + + for (auto arg : operation->getOperands()) { + auto operandTy = arg.getType().cast(); + for (int i = 0; i < operandTy.getRank(); i++) { + if (operandTy.isDynamicDim(i) && !dynDims[i]) + dynDims[i] = rewriter.create(loc, arg, i); + } + } + + SmallVector filteredDims; + for (auto dim : dynDims) + if (dim) + filteredDims.push_back(dim); + for (auto result : results) { auto resultTy = result.getType().template cast(); - if (!resultTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - operation, - "tosa to linalg conversion expects statically shaped tensors"); - initTensors.push_back(rewriter.create( - loc, ArrayRef({}), resultTy.getShape(), - resultTy.getElementType())); + loc, filteredDims, resultTy.getShape(), resultTy.getElementType())); opResultTypes.push_back(result.getType()); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -55,6 +55,34 @@ // ----- +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor) -> tensor { + // CHECK: %[[C0:.+]] = constant 0 + // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]]] + // CHECK: linalg.generic + // CHECK: absf + %0 = "tosa.abs"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @test_abs_dyn +func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { + // CHECK: %[[C1:.+]] = constant 1 + // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DIM]]] + // CHECK: linalg.generic + // CHECK: absf + %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> + return %0 : tensor<2x?xf32> +} +// ----- + + // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()> // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> @@ -111,14 +139,6 @@ // ----- -func @test_abs(%arg0: tensor) -> tensor { - // expected-error @+1 {{failed to legalize operation 'tosa.abs'}} - %0 = "tosa.abs"(%arg0) : (tensor) -> tensor - return %0 : tensor -} - -// ----- - // CHECK-LABEL: @test_simple_f32 func @test_simple_f32(%arg0: tensor<1xf32>) -> () { // CHECK: linalg.generic