diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -614,7 +614,7 @@ operands.push_back(operand); indexingMaps.push_back(AffineMap::get( - /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs, + /*dimCount=*/rank, /*symbolCount=*/0, affineExprs, rewriter.getContext())); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1990,3 +1990,22 @@ %output = "tosa.resize"(%input) { scale = [4, 2, 4, 2], offset = [-1, -1], border = [1, 1], mode = "BILINEAR" } : (tensor) -> (tensor) return } + +// ----- + +// CHECK-LABEL: func.func @select_fp32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x5x5xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x12x5x5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor<1x12x5x5xf32> +func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor) -> tensor<1x12x5x5xf32> { + // CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<1x12x5x5xf32> + // CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3]] : tensor<1x1x5x5xi1> into tensor<1x5x5xi1> + // CHECK: %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_4]], %[[VAL_1]], %[[VAL_2]] : tensor<1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) outs(%[[VAL_3]] : tensor<1x12x5x5xf32>) { + // CHECK: ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32): + // CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : f32 + // CHECK: linalg.yield %[[VAL_10]] : f32 + // CHECK: } -> tensor<1x12x5x5xf32> + // CHECK: return %[[VAL_11:.*]] : tensor<1x12x5x5xf32> + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> + return %0 : tensor<1x12x5x5xf32> +} \ No newline at end of file