diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -825,9 +825,12 @@ int32_t dimToPush = i > axis ? i + 1 : i; reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush)); } - int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; - reassociationMap[expandedDim].push_back( - rewriter.getAffineDimExpr(expandedDim + 1)); + + if (expandInputRank != 0) { + int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; + reassociationMap[expandedDim].push_back( + rewriter.getAffineDimExpr(expandedDim + 1)); + } rewriter.replaceOpWithNewOp( op, resultTy, linalgOp.getResults()[0], reassociationMap); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -777,6 +777,26 @@ // ----- +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()> + +// CHECK-LABEL: @reduce_float_dyn_rank_1 +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor +func.func @reduce_float_dyn_rank_1(%arg0: tensor) -> () { + // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor + // CHECK-DAG: %[[CST0:.+]] = arith.constant 0.0 + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT]] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[ARG0]] : tensor) outs(%[[FILL]] : tensor) + // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) + // CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 + // CHECK: linalg.yield %[[RES]] : f32 + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor into tensor<1xf32> + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor) -> tensor<1xf32> + return +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>