diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1126,6 +1126,277 @@ } }; +class ResizeConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ResizeOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + auto input = op.input(); + auto inputTy = input.getType().cast(); + auto resultTy = op.getType().cast(); + auto resultElementTy = resultTy.getElementType(); + + auto imageH = inputTy.getShape()[1]; + auto imageW = inputTy.getShape()[2]; + + if (!resultTy.hasStaticShape()) + return failure(); + if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR") + return failure(); + + auto initTensor = + rewriter + .create(loc, ArrayRef{}, + resultTy.getShape(), resultElementTy) + .result(); + + SmallVector affineMaps = { + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + auto genericOp = rewriter.create( + loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank())); + rewriter.replaceOp(op, genericOp.getResult(0)); + + { + OpBuilder::InsertionGuard regionGuard(rewriter); + Block *block = rewriter.createBlock( + &genericOp.region(), genericOp.region().end(), + TypeRange({rewriter.getIndexType(), rewriter.getIndexType(), + rewriter.getIndexType(), rewriter.getIndexType(), + resultElementTy})); + Value batch = block->getArgument(0); + Value channel = block->getArgument(3); + + auto hwMin = + rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + auto hMax = rewriter.create( + loc, rewriter.getI32IntegerAttr(imageH - 1)); + auto wMax = rewriter.create( + loc, rewriter.getI32IntegerAttr(imageW - 1)); + + Value inY = rewriter.create(loc, rewriter.getI32Type(), + block->getArgument(1)); + Value inX = rewriter.create(loc, rewriter.getI32Type(), + block->getArgument(2)); + + int32_t shift = op.shift(); + bool floatingPointMode = shift == 0; + + Value yStride, xStride, yOffset, xOffset; + if (floatingPointMode) { + yStride = rewriter.create(loc, op.stride_fp()[0]); + xStride = rewriter.create(loc, op.stride_fp()[1]); + yOffset = rewriter.create(loc, op.offset_fp()[0]); + xOffset = rewriter.create(loc, op.offset_fp()[1]); + } else { + SmallVector stride, offset; + getValuesFromIntArrayAttribute(op.stride(), stride); + getValuesFromIntArrayAttribute(op.offset(), offset); + + yStride = rewriter.create( + loc, rewriter.getI32IntegerAttr(stride[0])); + xStride = rewriter.create( + loc, rewriter.getI32IntegerAttr(stride[1])); + yOffset = rewriter.create( + loc, rewriter.getI32IntegerAttr(offset[0])); + xOffset = rewriter.create( + loc, rewriter.getI32IntegerAttr(offset[1])); + } + + // Compute the the integer index and partial offset. + // x = x * stride + offset; + // ix = floor(x) + // dx = x - ix + Value ix, iy, dx, dy; + if (floatingPointMode) { + Value y = rewriter.create(loc, rewriter.getF32Type(), inY); + Value x = rewriter.create(loc, rewriter.getF32Type(), inX); + + y = rewriter.create(loc, y, yStride); + x = rewriter.create(loc, x, xStride); + + y = rewriter.create(loc, y, yOffset); + x = rewriter.create(loc, x, xOffset); + + iy = rewriter.create(loc, y); + ix = rewriter.create(loc, x); + + dy = rewriter.create(loc, y, iy); + dx = rewriter.create(loc, x, ix); + + iy = rewriter.create(loc, rewriter.getI32Type(), iy); + ix = rewriter.create(loc, rewriter.getI32Type(), ix); + } else { + Value shiftVal = + rewriter.create(loc, rewriter.getI32IntegerAttr(shift)); + + Value y = rewriter.create(loc, inY, yStride); + Value x = rewriter.create(loc, inX, xStride); + + y = rewriter.create(loc, y, yOffset); + x = rewriter.create(loc, x, xOffset); + + iy = rewriter.create(loc, y, shiftVal); + ix = rewriter.create(loc, x, shiftVal); + + Value yTrunc = rewriter.create(loc, iy, shiftVal); + Value xTrunc = rewriter.create(loc, ix, shiftVal); + + dy = rewriter.create(loc, y, yTrunc); + dx = rewriter.create(loc, x, xTrunc); + } + + if (op.mode() == "NEAREST_NEIGHBOR") { + Value yPred, xPred; + // Round the index position towards the closest pixel location. + if (floatingPointMode) { + auto halfVal = + rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); + yPred = rewriter.create(loc, CmpFPredicate::OGE, dy, + halfVal); + xPred = rewriter.create(loc, CmpFPredicate::OGE, dx, + halfVal); + } else { + auto halfVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(1 << (shift - 1))); + yPred = rewriter.create(loc, CmpIPredicate::sge, dy, + halfVal); + xPred = rewriter.create(loc, CmpIPredicate::sge, dx, + halfVal); + } + + auto zeroVal = + rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + auto oneVal = + rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + + auto yOffset = + rewriter.create(loc, yPred, oneVal, zeroVal); + auto xOffset = + rewriter.create(loc, xPred, oneVal, zeroVal); + + iy = rewriter.create(loc, iy, yOffset); + ix = rewriter.create(loc, ix, xOffset); + + // Clamp the to be within the bounds of the input image. + + iy = clampHelper(loc, iy, hwMin, hMax, CmpIPredicate::slt, + rewriter); + ix = clampHelper(loc, ix, hwMin, wMax, CmpIPredicate::slt, + rewriter); + + // Read the value from the input array. + iy = rewriter.create(loc, rewriter.getIndexType(), iy); + ix = rewriter.create(loc, rewriter.getIndexType(), ix); + + Value result = rewriter.create( + loc, input, ValueRange{batch, iy, ix, channel}); + + rewriter.create(loc, result); + + return success(); + } + + if (op.mode() == "BILINEAR") { + Value y0 = iy; + Value x0 = ix; + + auto oneVal = + rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + Value y1 = rewriter.create(loc, y0, oneVal); + Value x1 = rewriter.create(loc, x0, oneVal); + + y0 = clampHelper(loc, y0, hwMin, hMax, CmpIPredicate::slt, + rewriter); + y1 = clampHelper(loc, y1, hwMin, hMax, CmpIPredicate::slt, + rewriter); + + x0 = clampHelper(loc, x0, hwMin, wMax, CmpIPredicate::slt, + rewriter); + x1 = clampHelper(loc, x1, hwMin, wMax, CmpIPredicate::slt, + rewriter); + + y0 = rewriter.create(loc, rewriter.getIndexType(), y0); + y1 = rewriter.create(loc, rewriter.getIndexType(), y1); + x0 = rewriter.create(loc, rewriter.getIndexType(), x0); + x1 = rewriter.create(loc, rewriter.getIndexType(), x1); + + Value y0x0 = rewriter.create( + loc, input, ValueRange{batch, y0, x0, channel}); + Value y0x1 = rewriter.create( + loc, input, ValueRange{batch, y0, x1, channel}); + Value y1x0 = rewriter.create( + loc, input, ValueRange{batch, y1, x0, channel}); + Value y1x1 = rewriter.create( + loc, input, ValueRange{batch, y1, x1, channel}); + + if (floatingPointMode) { + auto oneVal = + rewriter.create(loc, rewriter.getF32FloatAttr(1.f)); + Value rightPart = dx; + Value leftPart = rewriter.create(loc, oneVal, dx); + + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); + + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + + Value bottomPart = dy; + Value topPart = rewriter.create(loc, oneVal, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); + + rewriter.create(loc, result); + return success(); + } else { + y0x0 = rewriter.create(loc, resultElementTy, y0x0); + y0x1 = rewriter.create(loc, resultElementTy, y0x1); + y1x0 = rewriter.create(loc, resultElementTy, y1x0); + y1x1 = rewriter.create(loc, resultElementTy, y1x1); + + if (resultElementTy.getIntOrFloatBitWidth() > 32) { + dx = rewriter.create(loc, resultElementTy, dx); + dy = rewriter.create(loc, resultElementTy, dy); + } + + auto unitVal = rewriter.create( + loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift)); + Value rightPart = dx; + Value leftPart = rewriter.create(loc, unitVal, dx); + + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); + + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + + Value bottomPart = dy; + Value topPart = rewriter.create(loc, unitVal, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); + + rewriter.create(loc, result); + return success(); + } + } + + return failure(); + } + + return success(); + } +}; + // At the codegen level any identity operations should be removed. Any cases // where identity is load-bearing (e.g. cross device computation) should be // handled before lowering to codegen. @@ -1817,6 +2088,7 @@ PadConverter, ReshapeConverter, RescaleConverter, + ResizeConverter, ReverseConverter, TableConverter, TileConverter, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -963,3 +963,292 @@ %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>) return } + +// ----- + +// CHECK-LABEL: @resize_nearest +func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] + // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic + // CHECK-DAG: %[[XYMIN:.+]] = constant 0 + // CHECK-DAG: %[[YMAX:.+]] = constant 1 + // CHECK-DAG: %[[XMAX:.+]] = constant 1 + // CHECK-DAG: %[[Y:.+]] = index_cast %arg2 + // CHECK-DAG: %[[X:.+]] = index_cast %arg3 + // CHECK-DAG: %[[STRIDEY:.+]] = constant 5.000000e-01 + // CHECK-DAG: %[[STRIDEX:.+]] = constant 5.000000e-01 + // CHECK-DAG: %[[OFFSETY:.+]] = constant 1.000000e-01 + // CHECK-DAG: %[[OFFSETX:.+]] = constant 2.000000e-01 + // CHECK-DAG: %[[VAL4:.+]] = uitofp %[[Y]] + // CHECK-DAG: %[[VAL5:.+]] = uitofp %[[X]] + // CHECK-DAG: %[[VAL6:.+]] = mulf %[[VAL4]], %[[STRIDEY]] + // CHECK-DAG: %[[VAL7:.+]] = mulf %[[VAL5]], %[[STRIDEX]] + // CHECK-DAG: %[[VAL8:.+]] = addf %[[VAL6]], %[[OFFSETY]] + // CHECK-DAG: %[[VAL9:.+]] = addf %[[VAL7]], %[[OFFSETX]] + + // Find the remainder and integer component of the target index. + + // CHECK-DAG: %[[VAL10:.+]] = floorf %[[VAL8]] + // CHECK-DAG: %[[VAL11:.+]] = floorf %[[VAL9]] + // CHECK-DAG: %[[VAL12:.+]] = subf %[[VAL8]], %[[VAL10]] + // CHECK-DAG: %[[VAL13:.+]] = subf %[[VAL9]], %[[VAL11]] + // CHECK-DAG: %[[VAL14:.+]] = fptosi %[[VAL10]] + // CHECK-DAG: %[[VAL15:.+]] = fptosi %[[VAL11]] + + // Round to the nearest index. + + // CHECK-DAG: %[[ROUND:.+]] = constant 5.000000e-01 + // CHECK-DAG: %[[VAL16:.+]] = cmpf oge, %[[VAL12]], %[[ROUND]] + // CHECK-DAG: %[[VAL17:.+]] = cmpf oge, %[[VAL13]], %[[ROUND]] + // CHECK-DAG: %[[ZERO:.+]] = constant 0 + // CHECK-DAG: %[[ONE:.+]] = constant 1 + // CHECK-DAG: %[[VAL18:.+]] = select %[[VAL16]], %[[ONE]], %[[ZERO]] + // CHECK-DAG: %[[VAL19:.+]] = select %[[VAL17]], %[[ONE]], %[[ZERO]] + // CHECK-DAG: %[[VAL20:.+]] = addi %[[VAL14]], %[[VAL18]] + // CHECK-DAG: %[[VAL21:.+]] = addi %[[VAL15]], %[[VAL19]] + + // This section applies bound checking to be within the input image. + + // CHECK-DAG: %[[VAL22:.+]] = cmpi slt, %[[VAL20]], %[[XYMIN]] + // CHECK-DAG: %[[VAL23:.+]] = select %[[VAL22]], %[[XYMIN]], %[[VAL20]] + // CHECK-DAG: %[[VAL24:.+]] = cmpi slt, %[[YMAX]], %[[VAL20]] + // CHECK-DAG: %[[VAL25:.+]] = select %[[VAL24]], %[[YMAX]], %[[VAL23]] + // CHECK-DAG: %[[VAL26:.+]] = cmpi slt, %[[VAL21]], %[[XYMIN]] + // CHECK-DAG: %[[VAL27:.+]] = select %[[VAL26]], %[[XYMIN]], %[[VAL21]] + // CHECK-DAG: %[[VAL28:.+]] = cmpi slt, %[[XMAX]], %[[VAL21]] + // CHECK-DAG: %[[VAL29:.+]] = select %[[VAL28]], %[[XMAX]], %[[VAL27]] + + // Extract the nearest value using the computed indices. + + // CHECK-DAG: %[[IDY:.+]] = index_cast %[[VAL25]] + // CHECK-DAG: %[[IDX:.+]] = index_cast %[[VAL29]] + // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%arg1, %[[IDY]], %[[IDX]], %arg4] + // CHECK: linalg.yield %[[EXTRACT]] + %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [0, 0], offset = [0, 0], stride_fp = [0.5 : f32, 0.5 : f32], offset_fp = [0.1 : f32, 0.2 : f32], shift = 0 : i32, mode = "NEAREST_NEIGHBOR" } : (tensor<1x2x2x1xf32>) -> (tensor<1x4x4x1xf32>) + + return +} + +// ----- + +// CHECK-LABEL: @resize_bilinear +func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] + // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic + // CHECK: %[[XYMIN:.+]] = constant 0 + // CHECK: %[[YMAX:.+]] = constant 1 + // CHECK: %[[XMAX:.+]] = constant 1 + + // CHECK: %[[VAL10:.+]] = floorf %[[VAL8:.+]] + // CHECK: %[[VAL11:.+]] = floorf %[[VAL9:.+]] + + // CHECK: %[[DY:.+]] = subf %[[VAL8:.+]], %[[VAL10]] + // CHECK: %[[DX:.+]] = subf %[[VAL9:.+]], %[[VAL11]] + + // CHECK: %[[Y0:.+]] = fptosi %[[VAL10]] + // CHECK: %[[X0:.+]] = fptosi %[[VAL11]] + + // Compute the left, right, and top indices for the bilinear interpolation. + + // CHECK: %[[ONE:.+]] = constant 1 + // CHECK: %[[Y1:.+]] = addi %[[Y0]], %[[ONE]] + // CHECK: %[[X1:.+]] = addi %[[X0]], %[[ONE]] + + // Bound check each dimension. + + // CHECK: %[[PRED:.+]] = cmpi slt, %[[Y0]], %[[XYMIN]] + // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y0]] + // CHECK: %[[PRED:.+]] = cmpi slt, %[[YMAX]], %[[Y0]] + // CHECK: %[[YLO:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]] + + // CHECK: %[[PRED:.+]] = cmpi slt, %[[Y1]], %[[XYMIN]] + // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y1]] + // CHECK: %[[PRED:.+]] = cmpi slt, %[[YMAX]], %[[Y1]] + // CHECK: %[[YHI:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]] + + // CHECK: %[[PRED:.+]] = cmpi slt, %[[X0]], %[[XYMIN]] + // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X0]] + // CHECK: %[[PRED:.+]] = cmpi slt, %[[XMAX]], %[[X0]] + // CHECK: %[[XLO:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]] + + // CHECK: %[[PRED:.+]] = cmpi slt, %[[X1]], %[[XYMIN]] + // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X1]] + // CHECK: %[[PRED:.+]] = cmpi slt, %[[XMAX]], %[[X1]] + // CHECK: %[[XHI:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]] + + // Extract each corner of the bilinear interpolation. + + // CHECK: %[[YLOI:.+]] = index_cast %[[YLO]] + // CHECK: %[[YHII:.+]] = index_cast %[[YHI]] + // CHECK: %[[XLOI:.+]] = index_cast %[[XLO]] + // CHECK: %[[XHII:.+]] = index_cast %[[XHI]] + + // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%arg1, %[[YLOI]], %[[XLOI]], %arg4] + // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%arg1, %[[YLOI]], %[[XHII]], %arg4] + // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%arg1, %[[YHII]], %[[XLOI]], %arg4] + // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%arg1, %[[YHII]], %[[XHII]], %arg4] + + // Compute the bilinear interpolation. + + // CHECK: %[[ONE:.+]] = constant 1.000000e+00 + // CHECK: %[[NDX:.+]] = subf %[[ONE]], %[[DX]] + // CHECK: %[[WLOLO:.+]] = mulf %[[LOLO]], %[[NDX]] + // CHECK: %[[WLOHI:.+]] = mulf %[[LOHI]], %[[DX]] + // CHECK: %[[LO:.+]] = addf %[[WLOLO]], %[[WLOHI]] + // CHECK: %[[WHILO:.+]] = mulf %[[HILO]], %[[NDX]] + // CHECK: %[[WHIHI:.+]] = mulf %[[HIHI]], %[[DX]] + // CHECK: %[[HI:.+]] = addf %[[WHILO]], %[[WHIHI]] + // CHECK: %[[NDY:.+]] = subf %[[ONE]], %[[DY]] + // CHECK: %[[WLO:.+]] = mulf %[[LO]], %[[NDY]] + // CHECK: %[[WHI:.+]] = mulf %[[HI]], %[[DY]] + // CHECK: %[[RESULT:.+]] = addf %[[WLO]], %[[WHI]] + // CHECK: linalg.yield %[[RESULT]] + %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [0, 0], offset = [0, 0], stride_fp = [0.5 : f32, 0.5 : f32], offset_fp = [0.1 : f32, 0.2 : f32], shift = 0 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xf32>) -> (tensor<1x4x4x1xf32>) + return +} + +// ----- + +// CHECK-LABEL: @resize_nearest_int +func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] + // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic + // CHECK-DAG: %[[XYMIN:.+]] = constant 0 + // CHECK-DAG: %[[YMAX:.+]] = constant 1 + // CHECK-DAG: %[[XMAX:.+]] = constant 1 + // CHECK-DAG: %[[Y:.+]] = index_cast %arg2 + // CHECK-DAG: %[[X:.+]] = index_cast %arg3 + // CHECK-DAG: %[[STRIDEY:.+]] = constant 128 + // CHECK-DAG: %[[STRIDEX:.+]] = constant 128 + // CHECK-DAG: %[[OFFSETY:.+]] = constant 1 + // CHECK-DAG: %[[OFFSETX:.+]] = constant 2 + // CHECK-DAG: %[[EIGHT:.+]] = constant 8 + // CHECK-DAG: %[[VAL4:.+]] = muli %[[Y]], %[[STRIDEY]] + // CHECK-DAG: %[[VAL5:.+]] = muli %[[X]], %[[STRIDEX]] + // CHECK-DAG: %[[VAL6:.+]] = addi %[[VAL4]], %[[OFFSETY]] + // CHECK-DAG: %[[VAL7:.+]] = addi %[[VAL5]], %[[OFFSETX]] + + // Find the remainder and integer component of the target index. + + + // CHECK-DAG: %[[VAL8:.+]] = shift_right_signed %[[VAL6]], %[[EIGHT]] + // CHECK-DAG: %[[VAL9:.+]] = shift_right_signed %[[VAL7]], %[[EIGHT]] + // CHECK-DAG: %[[VAL10:.+]] = shift_left %[[VAL8]], %[[EIGHT]] + // CHECK-DAG: %[[VAL11:.+]] = shift_left %[[VAL9]], %[[EIGHT]] + // CHECK-DAG: %[[VAL12:.+]] = subi %[[VAL6]], %[[VAL10]] + // CHECK-DAG: %[[VAL13:.+]] = subi %[[VAL7]], %[[VAL11]] + + // Round to the nearest index. + + // CHECK-DAG: %[[ROUND:.+]] = constant 128 + // CHECK-DAG: %[[VAL16:.+]] = cmpi sge, %[[VAL12]], %[[ROUND]] + // CHECK-DAG: %[[VAL17:.+]] = cmpi sge, %[[VAL13]], %[[ROUND]] + // CHECK-DAG: %[[ZERO:.+]] = constant 0 + // CHECK-DAG: %[[ONE:.+]] = constant 1 + // CHECK-DAG: %[[VAL18:.+]] = select %[[VAL16]], %[[ONE]], %[[ZERO]] + // CHECK-DAG: %[[VAL19:.+]] = select %[[VAL17]], %[[ONE]], %[[ZERO]] + // CHECK-DAG: %[[VAL20:.+]] = addi %[[VAL8]], %[[VAL18]] + // CHECK-DAG: %[[VAL21:.+]] = addi %[[VAL9]], %[[VAL19]] + + // This section applies bound checking to be within the input image. + + // CHECK-DAG: %[[VAL22:.+]] = cmpi slt, %[[VAL20]], %[[XYMIN]] + // CHECK-DAG: %[[VAL23:.+]] = select %[[VAL22]], %[[XYMIN]], %[[VAL20]] + // CHECK-DAG: %[[VAL24:.+]] = cmpi slt, %[[YMAX]], %[[VAL20]] + // CHECK-DAG: %[[VAL25:.+]] = select %[[VAL24]], %[[YMAX]], %[[VAL23]] + // CHECK-DAG: %[[VAL26:.+]] = cmpi slt, %[[VAL21]], %[[XYMIN]] + // CHECK-DAG: %[[VAL27:.+]] = select %[[VAL26]], %[[XYMIN]], %[[VAL21]] + // CHECK-DAG: %[[VAL28:.+]] = cmpi slt, %[[XMAX]], %[[VAL21]] + // CHECK-DAG: %[[VAL29:.+]] = select %[[VAL28]], %[[XMAX]], %[[VAL27]] + + // Extract the nearest value using the computed indices. + + // CHECK-DAG: %[[IDY:.+]] = index_cast %[[VAL25]] + // CHECK-DAG: %[[IDX:.+]] = index_cast %[[VAL29]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%arg1, %[[IDY]], %[[IDX]], %arg4] + // CHECK: linalg.yield %[[EXTRACT]] + %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "NEAREST_NEIGHBOR" } : (tensor<1x2x2x1xi32>) -> (tensor<1x4x4x1xi32>) + return +} + +// ----- + +// CHECK-LABEL: @resize_bilinear_int +func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1] + // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic + + // CHECK: %[[XYMIN:.+]] = constant 0 + // CHECK: %[[YMAX:.+]] = constant 1 + // CHECK: %[[XMAX:.+]] = constant 1 + + // CHECK: %[[Y0:.+]] = shift_right_signed + // CHECK: %[[X0:.+]] = shift_right_signed + // CHECK: %[[ROUNDY:.+]] = shift_left %[[Y0]] + // CHECK: %[[ROUNDX:.+]] = shift_left %[[X0]] + // CHECK: %[[DY:.+]] = subi %6, %[[ROUNDY]] + // CHECK: %[[DX:.+]] = subi %7, %[[ROUNDX]] + + // Compute the left, right, and top indices for the bilinear interpolation. + + // CHECK: %[[ONE:.+]] = constant 1 + // CHECK: %[[Y1:.+]] = addi %[[Y0]], %[[ONE]] + // CHECK: %[[X1:.+]] = addi %[[X0]], %[[ONE]] + + // Bound check each dimension. + + // CHECK: %[[PRED:.+]] = cmpi slt, %[[Y0]], %[[XYMIN]] + // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y0]] + // CHECK: %[[PRED:.+]] = cmpi slt, %[[YMAX]], %[[Y0]] + // CHECK: %[[YLO:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]] + + // CHECK: %[[PRED:.+]] = cmpi slt, %[[Y1]], %[[XYMIN]] + // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y1]] + // CHECK: %[[PRED:.+]] = cmpi slt, %[[YMAX]], %[[Y1]] + // CHECK: %[[YHI:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]] + + // CHECK: %[[PRED:.+]] = cmpi slt, %[[X0]], %[[XYMIN]] + // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X0]] + // CHECK: %[[PRED:.+]] = cmpi slt, %[[XMAX]], %[[X0]] + // CHECK: %[[XLO:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]] + + // CHECK: %[[PRED:.+]] = cmpi slt, %[[X1]], %[[XYMIN]] + // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X1]] + // CHECK: %[[PRED:.+]] = cmpi slt, %[[XMAX]], %[[X1]] + // CHECK: %[[XHI:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]] + + // Extract each corner of the bilinear interpolation. + + // CHECK: %[[YLOI:.+]] = index_cast %[[YLO]] + // CHECK: %[[YHII:.+]] = index_cast %[[YHI]] + // CHECK: %[[XLOI:.+]] = index_cast %[[XLO]] + // CHECK: %[[XHII:.+]] = index_cast %[[XHI]] + + // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%arg1, %[[YLOI]], %[[XLOI]], %arg4] + // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%arg1, %[[YLOI]], %[[XHII]], %arg4] + // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%arg1, %[[YHII]], %[[XLOI]], %arg4] + // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%arg1, %[[YHII]], %[[XHII]], %arg4] + + // CHECK: %[[XLOLO:.+]] = sexti %[[LOLO]] + // CHECK: %[[XLOHI:.+]] = sexti %[[LOHI]] + // CHECK: %[[XHILO:.+]] = sexti %[[HILO]] + // CHECK: %[[XHIHI:.+]] = sexti %[[HIHI]] + + // Compute the bilinear interpolation. + + // CHECK: %[[SCALE:.+]] = constant 256 + // CHECK: %[[NDX:.+]] = subi %[[SCALE]], %[[DX]] + // CHECK: %[[WLOLO:.+]] = muli %[[XLOLO]], %[[NDX]] + // CHECK: %[[WLOHI:.+]] = muli %[[XLOHI]], %[[DX]] + // CHECK: %[[LO:.+]] = addi %[[WLOLO]], %[[WLOHI]] + // CHECK: %[[WHILO:.+]] = muli %[[XHILO]], %[[NDX]] + // CHECK: %[[WHIHI:.+]] = muli %[[XHIHI]], %[[DX]] + // CHECK: %[[HI:.+]] = addi %[[WHILO]], %[[WHIHI]] + // CHECK: %[[NDY:.+]] = subi %[[SCALE]], %[[DY]] + // CHECK: %[[WLO:.+]] = muli %[[LO]], %[[NDY]] + // CHECK: %[[WHI:.+]] = muli %[[HI]], %[[DY]] + // CHECK: %[[RESULT:.+]] = addi %[[WLO]], %[[WHI]] + // CHECK: linalg.yield %[[RESULT]] + %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xi8>) -> (tensor<1x4x4x1xi32>) + return +}