diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1616,19 +1616,18 @@ let description = [{ Resizes a tensor. Resize is only allowed in the H and W dimensions. In - expected use, stride_y is approximately (IH<(loc, rewriter.getI32Type(), x); - int32_t shift = op.getShift(); - bool floatingPointMode = shift == 0; - - Value yStride, xStride, yOffset, xOffset; - if (floatingPointMode) { - yStride = rewriter.create(loc, op.getStrideFp()[0]); - xStride = rewriter.create(loc, op.getStrideFp()[1]); - yOffset = rewriter.create(loc, op.getOffsetFp()[0]); - xOffset = rewriter.create(loc, op.getOffsetFp()[1]); - } else { - SmallVector stride, offset; - getValuesFromIntArrayAttribute(op.getStride(), stride); - getValuesFromIntArrayAttribute(op.getOffset(), offset); - - yStride = rewriter.create( - loc, rewriter.getI32IntegerAttr(stride[0])); - xStride = rewriter.create( - loc, rewriter.getI32IntegerAttr(stride[1])); - yOffset = rewriter.create( - loc, rewriter.getI32IntegerAttr(offset[0])); - xOffset = rewriter.create( - loc, rewriter.getI32IntegerAttr(offset[1])); - } + bool floatingPointMode = resultElementTy.isF32(); + + Value yScaleN, yScaleD, xScaleN, xScaleD, yOffset, xOffset, yBorder, + xBorder; + SmallVector scale, offset, border; + getValuesFromIntArrayAttribute(op.getScale(), scale); + getValuesFromIntArrayAttribute(op.getOffset(), offset); + getValuesFromIntArrayAttribute(op.getBorder(), border); + + yScaleN = rewriter.create( + loc, rewriter.getI32IntegerAttr(scale[0])); + yScaleD = rewriter.create( + loc, rewriter.getI32IntegerAttr(scale[1])); + xScaleN = rewriter.create( + loc, rewriter.getI32IntegerAttr(scale[2])); + xScaleD = rewriter.create( + loc, rewriter.getI32IntegerAttr(scale[3])); + yOffset = rewriter.create( + loc, rewriter.getI32IntegerAttr(offset[0])); + xOffset = rewriter.create( + loc, rewriter.getI32IntegerAttr(offset[1])); + yBorder = rewriter.create( + loc, rewriter.getI32IntegerAttr(border[0])); + xBorder = rewriter.create( + loc, rewriter.getI32IntegerAttr(border[1])); // Compute the the integer index and partial offset. - // x = x * stride + offset; - // ix = floor(x) - // dx = x - ix Value ix, iy, dx, dy; + // x = x * scale_d + offset; + // ix = floor(x / scale_n) if (floatingPointMode) { + // dx = x / scale_n - ix Value y = rewriter.create(loc, rewriter.getF32Type(), inY); Value x = rewriter.create(loc, rewriter.getF32Type(), inX); - y = rewriter.create(loc, y, yStride); - x = rewriter.create(loc, x, xStride); + yScaleN = + rewriter.create(loc, rewriter.getF32Type(), yScaleN); + yScaleD = + rewriter.create(loc, rewriter.getF32Type(), yScaleD); + xScaleN = + rewriter.create(loc, rewriter.getF32Type(), xScaleN); + xScaleD = + rewriter.create(loc, rewriter.getF32Type(), xScaleD); + yOffset = + rewriter.create(loc, rewriter.getF32Type(), yOffset); + xOffset = + rewriter.create(loc, rewriter.getF32Type(), xOffset); + + y = rewriter.create(loc, y, yScaleD); + x = rewriter.create(loc, x, xScaleD); y = rewriter.create(loc, y, yOffset); x = rewriter.create(loc, x, xOffset); + y = rewriter.create(loc, y, yScaleN); + x = rewriter.create(loc, x, xScaleN); + iy = rewriter.create(loc, y); ix = rewriter.create(loc, x); @@ -1409,27 +1427,30 @@ iy = rewriter.create(loc, rewriter.getI32Type(), iy); ix = rewriter.create(loc, rewriter.getI32Type(), ix); } else { - Value shiftVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(shift)); - - Value y = rewriter.create(loc, inY, yStride); - Value x = rewriter.create(loc, inX, xStride); + // dx = x - ix * scale_n; + Value y = rewriter.create(loc, inY, yScaleD); + Value x = rewriter.create(loc, inX, xScaleD); y = rewriter.create(loc, y, yOffset); x = rewriter.create(loc, x, xOffset); - iy = rewriter.create(loc, y, shiftVal); - ix = rewriter.create(loc, x, shiftVal); + iy = rewriter.create(loc, y, yScaleN); + ix = rewriter.create(loc, x, xScaleN); - Value yTrunc = rewriter.create(loc, iy, shiftVal); - Value xTrunc = rewriter.create(loc, ix, shiftVal); + Value temp_y = rewriter.create(loc, iy, yScaleN); + Value temp_x = rewriter.create(loc, ix, xScaleN); - dy = rewriter.create(loc, y, yTrunc); - dx = rewriter.create(loc, x, xTrunc); + dy = rewriter.create(loc, y, temp_y); + dx = rewriter.create(loc, x, temp_x); } if (op.getMode() == "NEAREST_NEIGHBOR") { Value yPred, xPred; + auto zeroVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(0)); + auto oneVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); + // Round the index position towards the closest pixel location. if (floatingPointMode) { auto halfVal = rewriter.create( @@ -1439,19 +1460,16 @@ xPred = rewriter.create(loc, arith::CmpFPredicate::OGE, dx, halfVal); } else { - auto halfVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(1 << (shift - 1))); + Value yScaleNHalfVal = + rewriter.create(loc, yScaleN, oneVal); + Value xScaleNHalfVal = + rewriter.create(loc, xScaleN, oneVal); yPred = rewriter.create(loc, arith::CmpIPredicate::sge, - dy, halfVal); + dy, yScaleNHalfVal); xPred = rewriter.create(loc, arith::CmpIPredicate::sge, - dx, halfVal); + dx, xScaleNHalfVal); } - auto zeroVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(0)); - auto oneVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(1)); - auto yOffset = rewriter.create(loc, yPred, oneVal, zeroVal); auto xOffset = @@ -1477,9 +1495,8 @@ rewriter.create(loc, result); return success(); - } - - if (op.getMode() == "BILINEAR") { + } else { + // The mode here must be BILINEAR. This has been checked above. Value y0 = iy; Value x0 = ix; @@ -1513,10 +1530,8 @@ loc, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { - auto oneVal = rewriter.create( - loc, rewriter.getF32FloatAttr(1.f)); Value rightPart = dx; - Value leftPart = rewriter.create(loc, oneVal, dx); + Value leftPart = rewriter.create(loc, xScaleN, dx); y0x0 = rewriter.create(loc, y0x0, leftPart); y0x1 = rewriter.create(loc, y0x1, rightPart); @@ -1527,46 +1542,46 @@ Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; - Value topPart = rewriter.create(loc, oneVal, dy); + Value topPart = rewriter.create(loc, yScaleN, dy); topAcc = rewriter.create(loc, topAcc, topPart); bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); - } - y0x0 = rewriter.create(loc, resultElementTy, y0x0); - y0x1 = rewriter.create(loc, resultElementTy, y0x1); - y1x0 = rewriter.create(loc, resultElementTy, y1x0); - y1x1 = rewriter.create(loc, resultElementTy, y1x1); - - if (resultElementTy.getIntOrFloatBitWidth() > 32) { - dx = rewriter.create(loc, resultElementTy, dx); - dy = rewriter.create(loc, resultElementTy, dy); - } + } else { + y0x0 = rewriter.create(loc, resultElementTy, y0x0); + y0x1 = rewriter.create(loc, resultElementTy, y0x1); + y1x0 = rewriter.create(loc, resultElementTy, y1x0); + y1x1 = rewriter.create(loc, resultElementTy, y1x1); + + if (resultElementTy.getIntOrFloatBitWidth() > 32) { + dx = rewriter.create(loc, resultElementTy, dx); + dy = rewriter.create(loc, resultElementTy, dy); + } - auto unitVal = rewriter.create( - loc, rewriter.getIntegerAttr(resultElementTy, 1LL << shift)); - Value rightPart = dx; - Value leftPart = rewriter.create(loc, unitVal, dx); + Value rightPart = dx; + Value leftPart = rewriter.create(loc, xScaleN, dx); - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); - Value bottomPart = dy; - Value topPart = rewriter.create(loc, unitVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value bottomPart = dy; + Value topPart = rewriter.create(loc, yScaleN, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); - rewriter.create(loc, result); - return success(); + rewriter.create(loc, result); + return success(); + } } + return failure(); } }; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -800,64 +800,36 @@ llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamicSize); - int32_t inHeight = ShapedType::kDynamicSize; - int32_t inWidth = ShapedType::kDynamicSize; - ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); - if (inputShape.hasRank()) { - outputShape[0] = inputShape.getDimSize(0); - outputShape[3] = inputShape.getDimSize(3); + if (!inputShape.hasRank()) + return failure(); - inHeight = inputShape.getDimSize(1); - inWidth = inputShape.getDimSize(2); - } + outputShape[0] = inputShape.getDimSize(0); + outputShape[3] = inputShape.getDimSize(3); + int32_t inputHeight = inputShape.getDimSize(1); + int32_t inputWidth = inputShape.getDimSize(2); - int32_t shift = adaptor.getShift(); - llvm::SmallVector newShape; - getI64Values(adaptor.getOutputSize(), newShape); - outputShape[1] = newShape[0]; - outputShape[2] = newShape[1]; + if ((inputHeight == ShapedType::kDynamicSize) || + (inputWidth == ShapedType::kDynamicSize)) + return failure(); - llvm::SmallVector strideInt; + llvm::SmallVector scaleInt; llvm::SmallVector offsetInt; - llvm::SmallVector strideFp; - llvm::SmallVector offsetFp; + llvm::SmallVector borderInt; + getI64Values(adaptor.getScale(), scaleInt); getI64Values(adaptor.getOffset(), offsetInt); - getF64Values(adaptor.getOffsetFp(), offsetFp); - getI64Values(adaptor.getStride(), strideInt); - getF64Values(adaptor.getStrideFp(), strideFp); - - // If we have a 0 zero in integers we know that the resize indexing needs to - // be performed in floating point. Use the floating point varient to compute - // the resize shape. - bool fpMode = strideInt[0] == 0; - - // We can compute the output shape if attribute specifies unknown dimensions - // based on the offset and stride. If we perfectly line up to the last index - // we need to round up the size to include it. - if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) { - float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0]; - float round = std::floor(sizeFp) == sizeFp ? 1 : 0; - outputShape[1] = std::ceil(sizeFp) + round; - } - - if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) { - float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1]; - float round = std::floor(sizeFp) == sizeFp ? 1 : 0; - outputShape[2] = std::ceil(sizeFp) + round; - } - - if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) { - int64_t size = (inHeight - 1); - size = ((size << shift) - offsetInt[0]) / strideInt[0]; - outputShape[1] = size + 1; - } - - if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) { - int64_t size = (inWidth - 1); - size = ((size << shift) - offsetInt[1]) / strideInt[1]; - outputShape[2] = size + 1; - } + getI64Values(adaptor.getBorder(), borderInt); + + // Compute the output shape based on attributes: scale, offset, and border. + outputShape[1] = + (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) / + scaleInt[1]) + + 1; + + outputShape[2] = + (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) / + scaleInt[3]) + + 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1559,123 +1559,141 @@ // ----- -// CHECK-LABEL: @resize_nearest -func.func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] +// CHECK-LABEL: @resize_nearest_int +func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 23, 179, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic - // CHECK: %[[IDX0:.+]] = linalg.index 0 - // CHECK: %[[IDX1:.+]] = linalg.index 1 - // CHECK: %[[IDX2:.+]] = linalg.index 2 - // CHECK: %[[IDX3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XYMIN:.+]] = arith.constant 0 - // CHECK-DAG: %[[YMAX:.+]] = arith.constant 1 - // CHECK-DAG: %[[XMAX:.+]] = arith.constant 1 - // CHECK-DAG: %[[Y:.+]] = arith.index_cast %[[IDX1]] - // CHECK-DAG: %[[X:.+]] = arith.index_cast %[[IDX2]] - // CHECK-DAG: %[[STRIDEY:.+]] = arith.constant 5.000000e-01 - // CHECK-DAG: %[[STRIDEX:.+]] = arith.constant 5.000000e-01 - // CHECK-DAG: %[[OFFSETY:.+]] = arith.constant 1.000000e-01 - // CHECK-DAG: %[[OFFSETX:.+]] = arith.constant 2.000000e-01 - // CHECK-DAG: %[[VAL4:.+]] = arith.uitofp %[[Y]] - // CHECK-DAG: %[[VAL5:.+]] = arith.uitofp %[[X]] - // CHECK-DAG: %[[VAL6:.+]] = arith.mulf %[[VAL4]], %[[STRIDEY]] - // CHECK-DAG: %[[VAL7:.+]] = arith.mulf %[[VAL5]], %[[STRIDEX]] - // CHECK-DAG: %[[VAL8:.+]] = arith.addf %[[VAL6]], %[[OFFSETY]] - // CHECK-DAG: %[[VAL9:.+]] = arith.addf %[[VAL7]], %[[OFFSETX]] - - // Find the remainder and integer component of the target index. - - // CHECK-DAG: %[[VAL10:.+]] = math.floor %[[VAL8]] - // CHECK-DAG: %[[VAL11:.+]] = math.floor %[[VAL9]] - // CHECK-DAG: %[[VAL12:.+]] = arith.subf %[[VAL8]], %[[VAL10]] - // CHECK-DAG: %[[VAL13:.+]] = arith.subf %[[VAL9]], %[[VAL11]] - // CHECK-DAG: %[[VAL14:.+]] = arith.fptosi %[[VAL10]] - // CHECK-DAG: %[[VAL15:.+]] = arith.fptosi %[[VAL11]] - - // Round to the nearest index. - - // CHECK-DAG: %[[ROUND:.+]] = arith.constant 5.000000e-01 - // CHECK-DAG: %[[VAL16:.+]] = arith.cmpf oge, %[[VAL12]], %[[ROUND]] - // CHECK-DAG: %[[VAL17:.+]] = arith.cmpf oge, %[[VAL13]], %[[ROUND]] - // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 - // CHECK-DAG: %[[ONE:.+]] = arith.constant 1 - // CHECK-DAG: %[[VAL18:.+]] = arith.select %[[VAL16]], %[[ONE]], %[[ZERO]] - // CHECK-DAG: %[[VAL19:.+]] = arith.select %[[VAL17]], %[[ONE]], %[[ZERO]] - // CHECK-DAG: %[[VAL20:.+]] = arith.addi %[[VAL14]], %[[VAL18]] - // CHECK-DAG: %[[VAL21:.+]] = arith.addi %[[VAL15]], %[[VAL19]] + // CHECK: %[[IDX_0:.+]] = linalg.index 0 + // CHECK: %[[IDX_1:.+]] = linalg.index 1 + // CHECK: %[[IDX_2:.+]] = linalg.index 2 + // CHECK: %[[IDX_3:.+]] = linalg.index 3 + // CHECK: %[[XY_MIN:.+]] = arith.constant 0 + // CHECK: %[[Y_MAX:.+]] = arith.constant 14 + // CHECK: %[[X_MAX:.+]] = arith.constant 12 + + // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]] + // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]] + // CHECK: %[[SCALE_Y_N:.*]] = arith.constant 11 + // CHECK: %[[SCALE_Y_D:.*]] = arith.constant 7 + // CHECK: %[[SCALE_X_N:.*]] = arith.constant 89 + // CHECK: %[[SCALE_X_D:.*]] = arith.constant 6 + // CHECK: %[[OFFSET_Y:.*]] = arith.constant 0 + // CHECK: %[[OFFSET_X:.*]] = arith.constant 0 + // CHECK: %[[BORDER_Y:.*]] = arith.constant 0 + // CHECK: %[[BORDER_X:.*]] = arith.constant 0 + + // find the remainder and integer component of the target index. + + // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]] + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] + // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]] + // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] + // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]] + // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]] + // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]] + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] + // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]] + // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]] + + // Round to the nearest neighor. + + // CHECK: %[[ZERO:.*]] = arith.constant 0 + // CHECK: %[[ONE:.*]] = arith.constant 1 + // CHECK: %[[SCALE_Y_N_HALF:.*]] = arith.shrsi %[[SCALE_Y_N]], %[[ONE]] + // CHECK: %[[SCALE_X_N_HALF:.*]] = arith.shrsi %[[SCALE_X_N]], %[[ONE]] + // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y]], %[[SCALE_Y_N_HALF]] + // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X]], %[[SCALE_X_N_HALF]] + // CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]] + // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] + // CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]] + // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]] // This section applies bound checking to be within the input image. - // CHECK-DAG: %[[VAL22:.+]] = arith.cmpi slt, %[[VAL20]], %[[XYMIN]] - // CHECK-DAG: %[[VAL23:.+]] = arith.select %[[VAL22]], %[[XYMIN]], %[[VAL20]] - // CHECK-DAG: %[[VAL24:.+]] = arith.cmpi slt, %[[YMAX]], %[[VAL20]] - // CHECK-DAG: %[[VAL25:.+]] = arith.select %[[VAL24]], %[[YMAX]], %[[VAL23]] - // CHECK-DAG: %[[VAL26:.+]] = arith.cmpi slt, %[[VAL21]], %[[XYMIN]] - // CHECK-DAG: %[[VAL27:.+]] = arith.select %[[VAL26]], %[[XYMIN]], %[[VAL21]] - // CHECK-DAG: %[[VAL28:.+]] = arith.cmpi slt, %[[XMAX]], %[[VAL21]] - // CHECK-DAG: %[[VAL29:.+]] = arith.select %[[VAL28]], %[[XMAX]], %[[VAL27]] + // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[XY_MIN]] + // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[XY_MIN]], %[[VAL_39]] + // CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[VAL_39]] + // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[Y_MAX]], %[[VAL_42]] + // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[XY_MIN]] + // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[XY_MIN]], %[[VAL_40]] + // CHECK: %[[VAL_47:.*]] = arith.cmpi slt, %[[X_MAX]], %[[VAL_40]] + // CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[X_MAX]], %[[VAL_46]] // Extract the nearest value using the computed indices. - // CHECK-DAG: %[[IDY:.+]] = arith.index_cast %[[VAL25]] - // CHECK-DAG: %[[IDX:.+]] = arith.index_cast %[[VAL29]] - // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]] + // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]] + // CHECK: %[[IDX:.+]] = arith.index_cast %[[VAL_48]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[IDY]], %[[IDX]], %[[IDX_3]]] // CHECK: linalg.yield %[[EXTRACT]] - %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [0, 0], offset = [0, 0], stride_fp = [0.5 : f32, 0.5 : f32], offset_fp = [0.1 : f32, 0.2 : f32], shift = 0 : i32, mode = "NEAREST_NEIGHBOR" } : (tensor<1x2x2x1xf32>) -> (tensor<1x4x4x1xf32>) - return + // Round to the nearest index. + %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x23x179x1xi8> + return } // ----- -// CHECK-LABEL: @resize_bilinear -func.func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] +// CHECK-LABEL: @resize_bilinear_int +func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 289, 289, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic - // CHECK: %[[IDX0:.+]] = linalg.index 0 - // CHECK: %[[IDX1:.+]] = linalg.index 1 - // CHECK: %[[IDX2:.+]] = linalg.index 2 - // CHECK: %[[IDX3:.+]] = linalg.index 3 - // CHECK: %[[XYMIN:.+]] = arith.constant 0 - // CHECK: %[[YMAX:.+]] = arith.constant 1 - // CHECK: %[[XMAX:.+]] = arith.constant 1 - - // CHECK: %[[VAL10:.+]] = math.floor %[[VAL8:.+]] - // CHECK: %[[VAL11:.+]] = math.floor %[[VAL9:.+]] - - // CHECK: %[[DY:.+]] = arith.subf %[[VAL8:.+]], %[[VAL10]] - // CHECK: %[[DX:.+]] = arith.subf %[[VAL9:.+]], %[[VAL11]] - - // CHECK: %[[Y0:.+]] = arith.fptosi %[[VAL10]] - // CHECK: %[[X0:.+]] = arith.fptosi %[[VAL11]] + // CHECK: ^bb0(%arg1: i32): + // CHECK: %[[IDX_0:.+]] = linalg.index 0 + // CHECK: %[[IDX_1:.+]] = linalg.index 1 + // CHECK: %[[IDX_2:.+]] = linalg.index 2 + // CHECK: %[[IDX_3:.+]] = linalg.index 3 + // CHECK: %[[XY_MIN:.+]] = arith.constant 0 + // CHECK: %[[Y_MAX:.+]] = arith.constant 18 + // CHECK: %[[X_MAX:.+]] = arith.constant 18 + // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]] + // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]] + // CHECK: %[[SCALE_Y_N:.*]] = arith.constant 16 + // CHECK: %[[SCALE_Y_D:.*]] = arith.constant 1 + // CHECK: %[[SCALE_X_N:.*]] = arith.constant 16 + // CHECK: %[[SCALE_X_D:.*]] = arith.constant 1 + // CHECK: %[[OFFSET_Y:.*]] = arith.constant 0 + // CHECK: %[[OFFSET_X:.*]] = arith.constant 0 + // CHECK: %[[BORDER_Y:.*]] = arith.constant 0 + // CHECK: %[[BORDER_X:.*]] = arith.constant 0 + + // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]] + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] + // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]] + // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] + // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]] + // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]] + // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]] + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] + // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]] + // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]] // Compute the left, right, and top indices for the bilinear interpolation. - // CHECK: %[[ONE:.+]] = arith.constant 1 - // CHECK: %[[Y1:.+]] = arith.addi %[[Y0]], %[[ONE]] - // CHECK: %[[X1:.+]] = arith.addi %[[X0]], %[[ONE]] + // CHECK: %[[ONE:.*]] = arith.constant 1 + // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]] + // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] // Bound check each dimension. - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y0]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y0]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y0]] - // CHECK: %[[YLO:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]] + // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y1]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y1]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y1]] - // CHECK: %[[YHI:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]] + // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X0]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X0]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X0]] - // CHECK: %[[XLO:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]] + // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X1]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X1]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X1]] - // CHECK: %[[XHI:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]] + // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] // Extract each corner of the bilinear interpolation. @@ -1684,180 +1702,217 @@ // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]] // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]] - // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%[[IDX0]], %[[YLOI]], %[[XLOI]], %[[IDX3]]] - // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%[[IDX0]], %[[YLOI]], %[[XHII]], %[[IDX3]]] - // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%[[IDX0]], %[[YHII]], %[[XLOI]], %[[IDX3]]] - // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%[[IDX0]], %[[YHII]], %[[XHII]], %[[IDX3]]] + // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]] + // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]] + // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]] + // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XHII]], %[[IDX_3]]] + + // CHECK: %[[XLOLO:.+]] = arith.extsi %[[LOLO]] + // CHECK: %[[XLOHI:.+]] = arith.extsi %[[LOHI]] + // CHECK: %[[XHILO:.+]] = arith.extsi %[[HILO]] + // CHECK: %[[XHIHI:.+]] = arith.extsi %[[HIHI]] // Compute the bilinear interpolation. - // CHECK: %[[ONE:.+]] = arith.constant 1.000000e+00 - // CHECK: %[[NDX:.+]] = arith.subf %[[ONE]], %[[DX]] - // CHECK: %[[WLOLO:.+]] = arith.mulf %[[LOLO]], %[[NDX]] - // CHECK: %[[WLOHI:.+]] = arith.mulf %[[LOHI]], %[[DX]] - // CHECK: %[[LO:.+]] = arith.addf %[[WLOLO]], %[[WLOHI]] - // CHECK: %[[WHILO:.+]] = arith.mulf %[[HILO]], %[[NDX]] - // CHECK: %[[WHIHI:.+]] = arith.mulf %[[HIHI]], %[[DX]] - // CHECK: %[[HI:.+]] = arith.addf %[[WHILO]], %[[WHIHI]] - // CHECK: %[[NDY:.+]] = arith.subf %[[ONE]], %[[DY]] - // CHECK: %[[WLO:.+]] = arith.mulf %[[LO]], %[[NDY]] - // CHECK: %[[WHI:.+]] = arith.mulf %[[HI]], %[[DY]] - // CHECK: %[[RESULT:.+]] = arith.addf %[[WLO]], %[[WHI]] + // CHECK: %[[NDX:.+]] = arith.subi %[[SCALE_X_N]], %[[D_X]] + // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]] + // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X]] + // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]] + // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]] + // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X]] + // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]] + // CHECK: %[[NDY:.+]] = arith.subi %[[SCALE_Y_N]], %[[D_Y]] + // CHECK: %[[WLO:.+]] = arith.muli %[[LO]], %[[NDY]] + // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[D_Y]] + // CHECK: %[[RESULT:.+]] = arith.addi %[[WLO]], %[[WHI]] // CHECK: linalg.yield %[[RESULT]] - %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [0, 0], offset = [0, 0], stride_fp = [0.5 : f32, 0.5 : f32], offset_fp = [0.1 : f32, 0.2 : f32], shift = 0 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xf32>) -> (tensor<1x4x4x1xf32>) - return + + // Round to the nearest index. + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi8>) -> tensor<1x289x289x1xi32> + return } // ----- -// CHECK-LABEL: @resize_nearest_int -func.func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] +// CHECK-LABEL: @resize_nearest_fp +func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 1600, 1536, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK: ^bb0(%arg1: f32): // CHECK: %[[IDX0:.+]] = linalg.index 0 // CHECK: %[[IDX1:.+]] = linalg.index 1 // CHECK: %[[IDX2:.+]] = linalg.index 2 // CHECK: %[[IDX3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XYMIN:.+]] = arith.constant 0 - // CHECK-DAG: %[[YMAX:.+]] = arith.constant 1 - // CHECK-DAG: %[[XMAX:.+]] = arith.constant 1 - // CHECK-DAG: %[[Y:.+]] = arith.index_cast %[[IDX1]] - // CHECK-DAG: %[[X:.+]] = arith.index_cast %[[IDX2]] - // CHECK-DAG: %[[STRIDEY:.+]] = arith.constant 128 - // CHECK-DAG: %[[STRIDEX:.+]] = arith.constant 128 - // CHECK-DAG: %[[OFFSETY:.+]] = arith.constant 1 - // CHECK-DAG: %[[OFFSETX:.+]] = arith.constant 2 - // CHECK-DAG: %[[EIGHT:.+]] = arith.constant 8 - // CHECK-DAG: %[[VAL4:.+]] = arith.muli %[[Y]], %[[STRIDEY]] - // CHECK-DAG: %[[VAL5:.+]] = arith.muli %[[X]], %[[STRIDEX]] - // CHECK-DAG: %[[VAL6:.+]] = arith.addi %[[VAL4]], %[[OFFSETY]] - // CHECK-DAG: %[[VAL7:.+]] = arith.addi %[[VAL5]], %[[OFFSETX]] + // CHECK: %[[XYMIN:.*]] = arith.constant 0 + // CHECK: %[[YMAX:.*]] = arith.constant 49 + // CHECK: %[[XMAX:.*]] = arith.constant 47 + // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX1]] + // CHECK: %[[X:.+]] = arith.index_cast %[[IDX2]] + // CHECK: %[[ISCALE_Y_N:.*]] = arith.constant 64 + // CHECK: %[[ISCALE_Y_D:.*]] = arith.constant 2 + // CHECK: %[[ISCALE_X_N:.*]] = arith.constant 64 + // CHECK: %[[ISCALE_X_D:.*]] = arith.constant 2 + // CHECK: %[[IOFFSET_Y:.*]] = arith.constant -31 + // CHECK: %[[IOFFSET_X:.*]] = arith.constant -31 + // CHECK: %[[IBORDER_Y:.*]] = arith.constant 31 + // CHECK: %[[IBORDER_X:.*]] = arith.constant 31 + + // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]] + // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] + // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]] + // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]] + // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]] + // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]] + // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] + // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]] + + // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] + // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]] + // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] + // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]] + // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] + // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]] // Find the remainder and integer component of the target index. - - // CHECK-DAG: %[[VAL8:.+]] = arith.shrsi %[[VAL6]], %[[EIGHT]] - // CHECK-DAG: %[[VAL9:.+]] = arith.shrsi %[[VAL7]], %[[EIGHT]] - // CHECK-DAG: %[[VAL10:.+]] = arith.shli %[[VAL8]], %[[EIGHT]] - // CHECK-DAG: %[[VAL11:.+]] = arith.shli %[[VAL9]], %[[EIGHT]] - // CHECK-DAG: %[[VAL12:.+]] = arith.subi %[[VAL6]], %[[VAL10]] - // CHECK-DAG: %[[VAL13:.+]] = arith.subi %[[VAL7]], %[[VAL11]] - - // Round to the nearest index. - - // CHECK-DAG: %[[ROUND:.+]] = arith.constant 128 - // CHECK-DAG: %[[VAL16:.+]] = arith.cmpi sge, %[[VAL12]], %[[ROUND]] - // CHECK-DAG: %[[VAL17:.+]] = arith.cmpi sge, %[[VAL13]], %[[ROUND]] - // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 - // CHECK-DAG: %[[ONE:.+]] = arith.constant 1 - // CHECK-DAG: %[[VAL18:.+]] = arith.select %[[VAL16]], %[[ONE]], %[[ZERO]] - // CHECK-DAG: %[[VAL19:.+]] = arith.select %[[VAL17]], %[[ONE]], %[[ZERO]] - // CHECK-DAG: %[[VAL20:.+]] = arith.addi %[[VAL8]], %[[VAL18]] - // CHECK-DAG: %[[VAL21:.+]] = arith.addi %[[VAL9]], %[[VAL19]] - - // This section applies bound checking to be within the input image. - - // CHECK-DAG: %[[VAL22:.+]] = arith.cmpi slt, %[[VAL20]], %[[XYMIN]] - // CHECK-DAG: %[[VAL23:.+]] = arith.select %[[VAL22]], %[[XYMIN]], %[[VAL20]] - // CHECK-DAG: %[[VAL24:.+]] = arith.cmpi slt, %[[YMAX]], %[[VAL20]] - // CHECK-DAG: %[[VAL25:.+]] = arith.select %[[VAL24]], %[[YMAX]], %[[VAL23]] - // CHECK-DAG: %[[VAL26:.+]] = arith.cmpi slt, %[[VAL21]], %[[XYMIN]] - // CHECK-DAG: %[[VAL27:.+]] = arith.select %[[VAL26]], %[[XYMIN]], %[[VAL21]] - // CHECK-DAG: %[[VAL28:.+]] = arith.cmpi slt, %[[XMAX]], %[[VAL21]] - // CHECK-DAG: %[[VAL29:.+]] = arith.select %[[VAL28]], %[[XMAX]], %[[VAL27]] - - // Extract the nearest value using the computed indices. - - // CHECK-DAG: %[[IDY:.+]] = arith.index_cast %[[VAL25]] - // CHECK-DAG: %[[IDX:.+]] = arith.index_cast %[[VAL29]] + // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] + // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]] + // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] + // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]] + // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]] + // CHECK: %[[VAL_40:.*]] = arith.fptosi %[[VAL_36]] + + // CHECK: %[[ZERO:.*]] = arith.constant 0 + // CHECK: %[[ONE:.*]] = arith.constant 1 + // CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 + // CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]] + // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]] + // CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]] + // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] + // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]] + // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]] + + // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[XYMIN]] + // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[XYMIN]], %[[VAL_48]] + // CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[YMAX]], %[[VAL_48]] + // CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[YMAX]], %[[VAL_51]] + // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[XYMIN]] + // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[XYMIN]], %[[VAL_49]] + // CHECK: %[[VAL_56:.*]] = arith.cmpi slt, %[[XMAX]], %[[VAL_49]] + // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[XMAX]], %[[VAL_55]] + + // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]] + // CHECK: %[[IDX:.*]] = arith.index_cast %[[VAL_57]] // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]] // CHECK: linalg.yield %[[EXTRACT]] - %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "NEAREST_NEIGHBOR" } : (tensor<1x2x2x1xi32>) -> (tensor<1x4x4x1xi32>) + + %output = "tosa.resize"(%input) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor<1x1600x1536x1xf32> + return } // ----- -// CHECK-LABEL: @resize_bilinear_int -func.func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () { - // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] +// CHECK-LABEL: @resize_bilinear_fp +func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 89, 89, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic - - // CHECK: %[[IDX0:.+]] = linalg.index 0 - // CHECK: %[[IDX3:.+]] = linalg.index 3 - - // CHECK: %[[XYMIN:.+]] = arith.constant 0 - // CHECK: %[[YMAX:.+]] = arith.constant 1 - // CHECK: %[[XMAX:.+]] = arith.constant 1 - - // CHECK: %[[Y0:.+]] = arith.shrsi - // CHECK: %[[X0:.+]] = arith.shrsi - // CHECK: %[[ROUNDY:.+]] = arith.shli %[[Y0]] - // CHECK: %[[ROUNDX:.+]] = arith.shli %[[X0]] - // CHECK: %[[DY:.+]] = arith.subi %10, %[[ROUNDY]] - // CHECK: %[[DX:.+]] = arith.subi %11, %[[ROUNDX]] + // CHECK: ^bb0(%arg1: f32): + // CHECK: %[[IDX_0:.+]] = linalg.index 0 + // CHECK: %[[IDX_1:.+]] = linalg.index 1 + // CHECK: %[[IDX_2:.+]] = linalg.index 2 + // CHECK: %[[IDX_3:.+]] = linalg.index 3 + // CHECK: %[[XY_MIN:.*]] = arith.constant 0 + // CHECK: %[[Y_MAX:.*]] = arith.constant 22 + // CHECK: %[[X_MAX:.*]] = arith.constant 22 + // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]] + // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]] + // CHECK: %[[ISCALE_Y_N:.*]] = arith.constant 4 + // CHECK: %[[ISCALE_Y_D:.*]] = arith.constant 1 + // CHECK: %[[ISCALE_X_N:.*]] = arith.constant 4 + // CHECK: %[[ISCALE_X_D:.*]] = arith.constant 1 + // CHECK: %[[IOFFSET_Y:.*]] = arith.constant 0 + // CHECK: %[[IOFFSET_X:.*]] = arith.constant 0 + // CHECK: %[[IBORDER_Y:.*]] = arith.constant 0 + // CHECK: %[[IBORDER_X:.*]] = arith.constant 0 + + // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]] + // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] + // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]] + // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]] + // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]] + // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]] + // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] + // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]] + + // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] + // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]] + // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] + // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]] + // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] + // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]] + + // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] + // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]] + // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] + // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]] + // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]] + // CHECK: %[[I_X:.*]] = arith.fptosi %[[VAL_36]] // Compute the left, right, and top indices for the bilinear interpolation. - // CHECK: %[[ONE:.+]] = arith.constant 1 - // CHECK: %[[Y1:.+]] = arith.addi %[[Y0]], %[[ONE]] - // CHECK: %[[X1:.+]] = arith.addi %[[X0]], %[[ONE]] + // CHECK: %[[ONE:.*]] = arith.constant 1 + // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]] + // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] // Bound check each dimension. - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y0]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y0]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y0]] - // CHECK: %[[YLO:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]] + // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y1]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y1]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y1]] - // CHECK: %[[YHI:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]] + // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X0]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X0]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X0]] - // CHECK: %[[XLO:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]] + // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X1]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X1]] - // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X1]] - // CHECK: %[[XHI:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]] - - // Extract each corner of the bilinear interpolation. + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]] + // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]] // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]] - // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%[[IDX0]], %[[YLOI]], %[[XLOI]], %[[IDX3]]] - // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%[[IDX0]], %[[YLOI]], %[[XHII]], %[[IDX3]]] - // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%[[IDX0]], %[[YHII]], %[[XLOI]], %[[IDX3]]] - // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%[[IDX0]], %[[YHII]], %[[XHII]], %[[IDX3]]] + // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]] + // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]] + // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]] + // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XHII]], %[[IDX_3]]] - // CHECK: %[[XLOLO:.+]] = arith.extsi %[[LOLO]] - // CHECK: %[[XLOHI:.+]] = arith.extsi %[[LOHI]] - // CHECK: %[[XHILO:.+]] = arith.extsi %[[HILO]] - // CHECK: %[[XHIHI:.+]] = arith.extsi %[[HIHI]] + // CHECK: %[[NDX:.+]] = arith.subf %[[SCALE_X_N]], %[[D_X]] + // CHECK: %[[WLOLO:.+]] = arith.mulf %[[LOLO]], %[[NDX]] + // CHECK: %[[WLOHI:.+]] = arith.mulf %[[LOHI]], %[[D_X]] + // CHECK: %[[LO:.+]] = arith.addf %[[WLOLO]], %[[WLOHI]] + // CHECK: %[[WHILO:.+]] = arith.mulf %[[HILO]], %[[NDX]] + // CHECK: %[[WHIHI:.+]] = arith.mulf %[[HIHI]], %[[D_X]] + // CHECK: %[[HI:.+]] = arith.addf %[[WHILO]], %[[WHIHI]] + // CHECK: %[[NDY:.+]] = arith.subf %[[SCALE_Y_N]], %[[D_Y]] + // CHECK: %[[WLO:.+]] = arith.mulf %[[LO]], %[[NDY]] + // CHECK: %[[WHI:.+]] = arith.mulf %[[HI]], %[[D_Y]] + // CHECK: %[[RESULT:.+]] = arith.addf %[[WLO]], %[[WHI]] + // CHECK: linalg.yield %[[RESULT]] - // Compute the bilinear interpolation. + // Round by bilinear interpolation + %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor<1x89x89x1xf32> - // CHECK: %[[SCALE:.+]] = arith.constant 256 - // CHECK: %[[NDX:.+]] = arith.subi %[[SCALE]], %[[DX]] - // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]] - // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[DX]] - // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]] - // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]] - // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[DX]] - // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]] - // CHECK: %[[NDY:.+]] = arith.subi %[[SCALE]], %[[DY]] - // CHECK: %[[WLO:.+]] = arith.muli %[[LO]], %[[NDY]] - // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[DY]] - // CHECK: %[[RESULT:.+]] = arith.addi %[[WLO]], %[[WHI]] - // CHECK: linalg.yield %[[RESULT]] - %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xi8>) -> (tensor<1x4x4x1xi32>) return } @@ -1865,10 +1920,10 @@ // CHECK-LABEL: @resize_dyn func.func @resize_dyn(%input: tensor) -> () { - // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 4, 1] // CHECK: %[[GENERIC:.+]] = linalg.generic - %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor) -> (tensor) + %output = "tosa.resize"(%input) { scale = [4, 2, 4, 2], offset = [-1, -1], border = [1, 1], mode = "BILINEAR" } : (tensor) -> (tensor) return } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -448,7 +448,7 @@ // ----- // CHECK-LABEL: resize func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { - %1 = "tosa.resize"(%arg0) {output_size = [64, 64], stride = [1024, 1024], offset = [0, 0], shift = 10 : i32, stride_fp = [0.0 : f32, 0.0 : f32], offset_fp = [0.0 : f32, 0.0 : f32], mode = "BILINEAR"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> + %1 = "tosa.resize"(%arg0) { scale = [4, 2, 4, 2], offset = [-1, -1], border = [1, 1], mode = "BILINEAR"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> return %1 : tensor<1x64x64x8xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -964,61 +964,71 @@ // ----- -// CHECK-LABEL: @resize_output_size -func.func @resize_output_size(%arg0: tensor<2x?x?x3xi32>) { - // CHECK: -> tensor<2x4x5x3xi32> - %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 1], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [4, 5], shift = 8 : i32, stride = [1, 1], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<2x?x?x3xi32>) -> tensor +// CHECK-LABEL: @resize_int_horizontal +func.func @resize_int_horizontal(%arg0: tensor<1x15x13x1xi8>) { + // CHECK: -> tensor<1x23x179x1xi8> + %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor return } // ----- -// CHECK-LABEL: @resize_int_horizontal -func.func @resize_int_horizontal(%arg0: tensor<1x2x4x1xi32>) { - // CHECK: -> tensor<1x2x7x1xi32> - %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [256, 128], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor +// CHECK-LABEL: @resize_int_vertical +func.func @resize_int_vertical(%arg0: tensor<1x49x42x1xi16>) { + // CHECK: -> tensor<1x112x220x1xi16> + %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [37, 16, 219, 41], offset = [0, 0], border = [0, 0]} : (tensor<1x49x42x1xi16>) -> tensor return } // ----- -// CHECK-LABEL: @resize_int_vertical -func.func @resize_int_vertical(%arg0: tensor<1x2x4x1xi32>) { - // CHECK: -> tensor<1x3x4x1xi32> - %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [128, 256], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor +// CHECK-LABEL: @resize_int_power_of_two_upscale +func.func @resize_int_power_of_two_upscale(%arg0: tensor<1x23x19x1xi8>) { + // CHECK: -> tensor<1x353x289x1xi32> + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x19x1xi8>) -> tensor return } // ----- -// CHECK-LABEL: @resize_int_offsetted -func.func @resize_int_offsetted(%arg0: tensor<1x2x4x1xi32>) { - // CHECK: -> tensor<1x4x6x1xi32> - %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [64, 64], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [64, 128], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor +// CHECK-LABEL: @resize_int_power_of_two_upscale_offsetted +func.func @resize_int_power_of_two_upscale_offsetted(%arg0: tensor<1x41x26x1xi16>) { + // CHECK: -> tensor<1x328x208x1xi48> + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 2, 16, 2], offset = [-7, -7], border = [7, 7]} : (tensor<1x41x26x1xi16>) -> tensor return } // ----- - // CHECK-LABEL: @resize_fp_horizontal -func.func @resize_fp_horizontal(%arg0: tensor<1x2x4x1xi32>) { - // CHECK: -> tensor<1x2x7x1xi32> - %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [1.000000e+00 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor +func.func @resize_fp_horizontal(%arg0: tensor<1x50x48x1xf32>) { + // CHECK: -> tensor<1x106x85x1xf32> + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [15, 7, 84, 47], offset = [0, 0], border = [0, 0]} : (tensor<1x50x48x1xf32>) -> tensor return } // ----- - // CHECK-LABEL: @resize_fp_vertical -func.func @resize_fp_vertical(%arg0: tensor<1x2x4x1xi32>) { - // CHECK: -> tensor<1x3x4x1xi32> - %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor +func.func @resize_fp_vertical(%arg0: tensor<1x50x48x1xf32>) { + // CHECK: -> tensor<1x128x13x1xf32> + %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [127, 49, 12, 47], offset = [0, 0], border = [0, 0]} : (tensor<1x50x48x1xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @resize_fp_power_of_two_upscale +func.func @resize_fp_power_of_two_upscale(%arg0: tensor<1x23x23x1xf32>) { + // CHECK: -> tensor<1x89x89x1xf32> + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor return } -// CHECK-LABEL: @resize_fp_offsetted -func.func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) { - // CHECK: -> tensor<1x4x6x1xi32> - %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor + +// ----- + +// CHECK-LABEL: @resize_fp_power_of_two_upscale_offsetted +func.func @resize_fp_power_of_two_upscale_offsetted(%arg0: tensor<1x50x48x1xf32>) { + // CHECK: -> tensor<1x1600x1536x1xf32> + %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor return }