diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1468,7 +1468,10 @@ Value x = b.create(2); Value channel = b.create(3); - Value zeroI32 = b.create(b.getI32IntegerAttr(0)); + Value zeroI32 = + b.create(b.getZeroAttr(b.getI32Type())); + Value zeroFp32 = + b.create(b.getZeroAttr(b.getF32Type())); Value hMax = b.create(b.getI32IntegerAttr(imageH - 1)); Value wMax = b.create(b.getI32IntegerAttr(imageW - 1)); @@ -1498,6 +1501,11 @@ auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, Value scaleN, Value scaleD, Value offset, int size, ImplicitLocOpBuilder &b) { + if (size == 1) { + index = zeroI32; + delta = zeroFp32; + return; + } // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x / scale_n - ix @@ -1517,6 +1525,11 @@ auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in, Value scaleN, Value scaleD, Value offset, int size, ImplicitLocOpBuilder &b) { + if (size == 1) { + index = zeroI32; + delta = zeroI32; + return; + } // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x - ix * scale_n; @@ -1606,7 +1619,10 @@ if (floatingPointMode) { auto oneVal = b.create(b.getF32FloatAttr(1.0f)); auto interpolate = [&](Value val0, Value val1, Value delta, + int inputSize, ImplicitLocOpBuilder &b) -> Value { + if (inputSize == 1) + return val0; Value oneMinusDelta = b.create(oneVal, delta); Value mul0 = b.create(val0, oneMinusDelta); Value mul1 = b.create(val1, delta); @@ -1616,16 +1632,16 @@ // Linalg equivalent to the section below: // topAcc = v00 * (unit_x - dx); // topAcc += v01 * dx; - Value topAcc = interpolate(y0x0, y0x1, dx, b); + Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b); // Linalg equivalent to the section below: // bottomAcc = v10 * (unit_x - dx); // bottomAcc += v11 * dx; - Value bottomAcc = interpolate(y1x0, y1x1, dx, b); + Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b); // Linalg equivalent to the section below: // result = topAcc * (unit_y - dy) + bottomAcc * dy - Value result = interpolate(topAcc, bottomAcc, dy, b); + Value result = interpolate(topAcc, bottomAcc, dy, imageH, b); b.create(result); } else { // Perform in quantized space. @@ -1650,22 +1666,22 @@ xScaleNExt = b.create(resultETy, xScaleN); } - auto interpolate = [](Value val0, Value val1, Value weight0, - Value weight1, + auto interpolate = [](Value val0, Value val1, Value weight1, + Value scale, int inputSize, ImplicitLocOpBuilder &b) -> Value { + if (inputSize == 1) { + return b.create(val0, scale); + } + Value weight0 = b.create(scale, weight1); Value mul0 = b.create(val0, weight0); Value mul1 = b.create(val1, weight1); return b.create(mul0, mul1); }; - Value weight0 = b.create(xScaleNExt, dx); - Value weight1 = dx; - Value topAcc = interpolate(y0x0, y0x1, weight0, weight1, b); - Value bottomAcc = interpolate(y1x0, y1x1, weight0, weight1, b); - - weight0 = b.create(yScaleNExt, dy); - weight1 = dy; - Value result = interpolate(topAcc, bottomAcc, weight0, weight1, b); + Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b); + Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b); + Value result = + interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b); b.create(result); } } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir @@ -278,6 +278,7 @@ // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]] // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X_EXT]] // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]] + // CHECK: %[[NDX:.+]] = arith.subi %[[X_N_EXT]], %[[D_X_EXT]] // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]] // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X_EXT]] // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]] @@ -492,3 +493,47 @@ %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi16>) -> tensor<1x289x289x1xi48> return } + +// ----- + +// CHECK-LABEL: skip_interpolate_bilinear_i8 +func.func @skip_interpolate_bilinear_i8(%arg0 : tensor<3x1x2x7xi8>) -> tensor<3x1x5x7xi32> { + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK: %[[BATCH:.+]] = linalg.index 0 + // CHECK: %[[CHANNEL:.+]] = linalg.index 3 + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 + // CHECK: %[[EXTRACT0:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xi8> + // CHECK: %[[EXTRACT1:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xi8> + // CHECK: %[[EXT0:.+]] = arith.extsi %[[EXTRACT0]] : i8 to i32 + // CHECK: %[[EXT1:.+]] = arith.extsi %[[EXTRACT1]] : i8 to i32 + // CHECK: %[[SUB:.+]] = arith.subi %[[C3]], %[[DX:.+]] + // CHECK: %[[MUL0:.+]] = arith.muli %[[EXT0]], %[[SUB]] + // CHECK: %[[MUL1:.+]] = arith.muli %[[EXT1]], %[[DX]] + // CHECK: %[[ADD:.+]] = arith.addi %[[MUL0]], %[[MUL1]] + // CHECK: %[[RES:.+]] = arith.muli %[[ADD]], %[[C2]] + // CHECK: linalg.yield %[[RES]] + %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x2x7xi8>) -> tensor<3x1x5x7xi32> + + // CHECK: return %[[GENERIC]] + return %resize : tensor<3x1x5x7xi32> +} + +// CHECK-LABEL: skip_interpolate_bilinear_f32 +func.func @skip_interpolate_bilinear_f32(%arg0 : tensor<3x1x2x7xf32>) -> tensor<3x1x5x7xf32> { + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK: %[[BATCH:.+]] = linalg.index 0 : index + // CHECK: %[[CHANNEL:.+]] = linalg.index 3 : index + // CHECK: %[[EXTRACT0:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xf32> + // CHECK: %[[EXTRACT1:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xf32> + // CHECK: %[[C1:.+]] = arith.constant 1.000000e+00 + // CHECK: %[[SUB:.+]] = arith.subf %[[C1]], %[[DX:.+]] + // CHECK: %[[MUL0:.+]] = arith.mulf %[[EXTRACT0]], %[[SUB]] + // CHECK: %[[MUL1:.+]] = arith.mulf %[[EXTRACT1]], %[[DX]] + // CHECK: %[[ADD:.+]] = arith.addf %[[MUL0]], %[[MUL1]] + // CHECK: linalg.yield %[[ADD]] + %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x2x7xf32>) -> tensor<3x1x5x7xf32> + + // CHECK: return %[[GENERIC]] + return %resize : tensor<3x1x5x7xf32> +}