diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1829,19 +1829,20 @@ auto input = adaptor.getOperands()[0]; auto indices = adaptor.getOperands()[1]; + auto valuesTy = op.getValues().getType(); auto resultTy = op.getType().cast(); - auto dynamicDimsOr = checkHasDynamicBatchDims( - rewriter, op, {input, indices, op.getOutput()}); - if (!dynamicDimsOr.has_value()) - return rewriter.notifyMatchFailure( - op, "tosa.gather currently only supports dynamic batch dimensions"); - SmallVector dynamicDims = *dynamicDimsOr; + if (!valuesTy.dyn_cast().hasRank()) + return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); + + auto dynamicDims = gatherOutputDynamicSizes( + rewriter, op.getLoc(), adaptor.getValues(), + valuesTy.dyn_cast(), adaptor.getIndices(), + op.getIndices().getType()); auto resultElementTy = resultTy.getElementType(); auto loc = op.getLoc(); - auto emptyTensor = rewriter .create(loc, resultTy.getShape(), resultElementTy, @@ -1872,6 +1873,34 @@ rewriter.replaceOp(op, genericOp.getResult(0)); return success(); } + + static llvm::SmallVector + gatherOutputDynamicSizes(OpBuilder &builder, Location loc, Value values, + RankedTensorType valueType, Value indices, + RankedTensorType indicesType) { + auto valueShape = valueType.getShape(); + auto indicesShape = indicesType.getShape(); + + // The shape of the gather op result is n x w x c. + auto n = valueShape[0]; + auto w = indicesShape[1]; + auto c = valueShape[2]; + + llvm::SmallVector results; + if (ShapedType::isDynamic(n)) { + results.push_back(builder.create(loc, values, 0)); + } + + if (ShapedType::isDynamic(w)) { + results.push_back(builder.create(loc, indices, 1)); + } + + if (ShapedType::isDynamic(c)) { + results.push_back(builder.create(loc, values, 2)); + } + + return results; + } }; // Lowerings the TableOp to a series of gathers and numerica operations. This diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1267,6 +1267,31 @@ // ----- +// CHECK-LABEL: @gather_float_all_dynamic +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]] +func.func @gather_float_all_dynamic(%arg0: tensor, %arg1: tensor) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]] + // CHECK: %[[C1:.+]] = arith.constant 1 + // CHECK: %[[INDEX:.+]] = tensor.dim %[[ARG1]], %[[C1]] + // CHECK: %[[C2:.+]] = arith.constant 2 + // CHECK: %[[CHANNEL:.+]] = tensor.dim %[[ARG0]], %[[C2]] + + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[INDEX]], %[[CHANNEL]]) + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor) outs(%[[INIT]] : tensor) + // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: f32) + // CHECK: %[[IDX0:.+]] = linalg.index 0 + // CHECK: %[[CAST:.+]] = arith.index_cast %[[BBARG0]] + // CHECK: %[[IDX2:.+]] = linalg.index 2 + // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.gather"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + return +} + +// ----- + // CHECK-LABEL: @gather_int // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]