diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1126,11 +1126,12 @@ return rewriter.notifyMatchFailure( op, "tosa.rescale requires scale32 for double_round to be true"); - auto dynamicDimsOr = - checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); - if (!dynamicDimsOr.has_value()) - return failure(); - SmallVector dynamicDims = dynamicDimsOr.value(); + SmallVector dynDims; + for (int i = 0; i < outputTy.getRank(); i++) { + if (outputTy.isDynamicDim(i)) { + dynDims.push_back(rewriter.create(loc, input, i)); + } + } // The shift and multiplier values. SmallVector multiplierValues; @@ -1206,7 +1207,8 @@ // Construct the indexing maps needed for linalg.generic ops. Value initTensor = rewriter.create( - loc, dynamicDims, outputTy.getShape(), outputTy.getElementType()); + loc, ArrayRef({dynDims}), outputTy.getShape(), + outputTy.getElementType()); auto linalgOp = rewriter.create( loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1004,8 +1004,8 @@ // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: @rescale_i8_dyn -func.func @rescale_i8_dyn(%arg0 : tensor) -> () { +// CHECK-LABEL: @rescale_i8_dyn_batch +func.func @rescale_i8_dyn_batch(%arg0 : tensor) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2] @@ -1020,6 +1020,23 @@ return } + +// ----- + +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @rescale_dyn +func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () { + // CHECK: %[[C1:.+]] = arith.constant 1 + // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]] + // CHECK: %[[C2:.+]] = arith.constant 2 + // CHECK: %[[DIM2:.+]] = tensor.dim %arg0, %[[C2]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, %[[DIM1]], %[[DIM2]], 32] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>) + %0 = "tosa.rescale"(%arg0) {double_round = true, input_zp = 0 : i32, multiplier = [1376784203 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [38 : i32]} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8> + return +} + // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>