diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1829,19 +1829,20 @@ auto input = adaptor.getOperands()[0]; auto indices = adaptor.getOperands()[1]; + auto valuesTy = + op.getValues().getType().dyn_cast_or_null(); auto resultTy = op.getType().cast(); - auto dynamicDimsOr = checkHasDynamicBatchDims( - rewriter, op, {input, indices, op.getOutput()}); - if (!dynamicDimsOr.has_value()) - return rewriter.notifyMatchFailure( - op, "tosa.gather currently only supports dynamic batch dimensions"); - SmallVector dynamicDims = *dynamicDimsOr; + if (!valuesTy) + return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); + + auto dynamicDims = + outputDynamicSizes(rewriter, op.getLoc(), adaptor.getValues(), valuesTy, + adaptor.getIndices(), op.getIndices().getType()); auto resultElementTy = resultTy.getElementType(); auto loc = op.getLoc(); - auto emptyTensor = rewriter .create(loc, resultTy.getShape(), resultElementTy, @@ -1872,6 +1873,27 @@ rewriter.replaceOp(op, genericOp.getResult(0)); return success(); } + + static void addDynamicDimension(OpBuilder &builder, Location loc, + Value source, ShapedType type, + int64_t dimension, + llvm::SmallVectorImpl &dynamicDims) { + if (type.isDynamicDim(dimension)) { + auto dim = builder.create(loc, source, dimension); + dynamicDims.push_back(dim); + } + } + + static llvm::SmallVector + outputDynamicSizes(OpBuilder &builder, Location loc, Value values, + RankedTensorType valueType, Value indices, + RankedTensorType indicesType) { + llvm::SmallVector results; + addDynamicDimension(builder, loc, values, valueType, 0, results); + addDynamicDimension(builder, loc, indices, indicesType, 1, results); + addDynamicDimension(builder, loc, values, valueType, 2, results); + return results; + } }; // Lowerings the TableOp to a series of gathers and numerica operations. This diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1267,6 +1267,30 @@ // ----- +// CHECK-LABEL: @gather_float_all_dynamic +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]] +func.func @gather_float_all_dynamic(%arg0: tensor, %arg1: tensor) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]] + // CHECK: %[[C1:.+]] = arith.constant 1 + // CHECK: %[[INDEX:.+]] = tensor.dim %[[ARG1]], %[[C1]] + // CHECK: %[[C2:.+]] = arith.constant 2 + // CHECK: %[[CHANNEL:.+]] = tensor.dim %[[ARG0]], %[[C2]] + // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[INDEX]], %[[CHANNEL]]) + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor) outs(%[[INIT]] : tensor) + // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: f32) + // CHECK: %[[IDX0:.+]] = linalg.index 0 + // CHECK: %[[CAST:.+]] = arith.index_cast %[[BBARG0]] + // CHECK: %[[IDX2:.+]] = linalg.index 2 + // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.gather"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + return +} + +// ----- + // CHECK-LABEL: @gather_int // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]