diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h rename from mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h rename to mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -30,13 +30,13 @@ // Takes the parameters for a clamp and turns it into a series of ops for float // inputs. -Value clampFloatHelper(Location loc, Value arg, arith::ConstantOp min, - arith::ConstantOp max, OpBuilder &rewriter); +Value clampFloatHelper(Location loc, Value arg, Value min, Value max, + OpBuilder &rewriter); // Takes the parameters for a clamp and turns it into a series of ops for // integer inputs. -Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min, - arith::ConstantOp max, OpBuilder &rewriter); +Value clampIntHelper(Location loc, Value arg, Value min, Value max, + OpBuilder &rewriter); // Returns the values in an attribute as an array of values. template diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -18,7 +18,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" @@ -177,10 +177,10 @@ auto sub = rewriter.create(loc, zpAddValue, ext); // Clamp to the negation range. - auto min = rewriter.create( + Value min = rewriter.create( loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), intermediateType); - auto max = rewriter.create( + Value max = rewriter.create( loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), intermediateType); auto clamp = clampIntHelper(loc, sub, min, max, rewriter); @@ -1431,10 +1431,11 @@ LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); auto input = op.getInput(); auto inputTy = input.getType().cast(); auto resultTy = op.getType().cast(); - auto resultElementTy = resultTy.getElementType(); + auto resultETy = resultTy.getElementType(); auto imageH = inputTy.getShape()[1]; auto imageW = inputTy.getShape()[2]; @@ -1444,273 +1445,235 @@ if (!dynamicDimsOr.has_value()) return rewriter.notifyMatchFailure( op, "unable to get dynamic dimensions of tosa.resize"); - SmallVector dynamicDims = dynamicDimsOr.value(); if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); - auto emptyTensor = rewriter.create( - loc, resultTy.getShape(), resultElementTy, dynamicDims); - SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - - Value resize = input; - auto genericOp = rewriter.create( - loc, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, + auto emptyTensor = b.create(resultTy.getShape(), resultETy, + dynamicDimsOr.value()); + auto genericOp = b.create( + resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); - resize = genericOp.getResult(0); - - OpBuilder::InsertionGuard regionGuard(rewriter); - rewriter.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), - TypeRange({resultElementTy}), loc); - Value batch = rewriter.create(loc, 0); - Value y = rewriter.create(loc, 1); - Value x = rewriter.create(loc, 2); - Value channel = rewriter.create(loc, 3); - - auto hwMin = - rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto hMax = rewriter.create( - loc, rewriter.getI32IntegerAttr(imageH - 1)); - auto wMax = rewriter.create( - loc, rewriter.getI32IntegerAttr(imageW - 1)); - - Value inY = - rewriter.create(loc, rewriter.getI32Type(), y); - Value inX = - rewriter.create(loc, rewriter.getI32Type(), x); - - bool floatingPointMode = resultElementTy.isF32(); - - Value yScaleN, yScaleD, xScaleN, xScaleD, yOffset, xOffset, yBorder, - xBorder; - SmallVector scale, offset, border; - getValuesFromIntArrayAttribute(op.getScale(), scale); - getValuesFromIntArrayAttribute(op.getOffset(), offset); - getValuesFromIntArrayAttribute(op.getBorder(), border); - - yScaleN = rewriter.create( - loc, rewriter.getI32IntegerAttr(scale[0])); - yScaleD = rewriter.create( - loc, rewriter.getI32IntegerAttr(scale[1])); - xScaleN = rewriter.create( - loc, rewriter.getI32IntegerAttr(scale[2])); - xScaleD = rewriter.create( - loc, rewriter.getI32IntegerAttr(scale[3])); - yOffset = rewriter.create( - loc, rewriter.getI32IntegerAttr(offset[0])); - xOffset = rewriter.create( - loc, rewriter.getI32IntegerAttr(offset[1])); - yBorder = rewriter.create( - loc, rewriter.getI32IntegerAttr(border[0])); - xBorder = rewriter.create( - loc, rewriter.getI32IntegerAttr(border[1])); - - // Compute the the integer index and partial offset. - Value ix, iy, dx, dy; - // x = x * scale_d + offset; - // ix = floor(x / scale_n) - if (floatingPointMode) { - // dx = x / scale_n - ix - Value y = - rewriter.create(loc, rewriter.getF32Type(), inY); - Value x = - rewriter.create(loc, rewriter.getF32Type(), inX); - - yScaleN = - rewriter.create(loc, rewriter.getF32Type(), yScaleN); - yScaleD = - rewriter.create(loc, rewriter.getF32Type(), yScaleD); - xScaleN = - rewriter.create(loc, rewriter.getF32Type(), xScaleN); - xScaleD = - rewriter.create(loc, rewriter.getF32Type(), xScaleD); - yOffset = - rewriter.create(loc, rewriter.getF32Type(), yOffset); - xOffset = - rewriter.create(loc, rewriter.getF32Type(), xOffset); - - y = rewriter.create(loc, y, yScaleD); - x = rewriter.create(loc, x, xScaleD); - - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); - - y = rewriter.create(loc, y, yScaleN); - x = rewriter.create(loc, x, xScaleN); - - iy = rewriter.create(loc, y); - ix = rewriter.create(loc, x); - - dy = rewriter.create(loc, y, iy); - dx = rewriter.create(loc, x, ix); - - iy = rewriter.create(loc, rewriter.getI32Type(), iy); - ix = rewriter.create(loc, rewriter.getI32Type(), ix); - } else { - // dx = x - ix * scale_n; - Value y = rewriter.create(loc, inY, yScaleD); - Value x = rewriter.create(loc, inX, xScaleD); - - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); + Value resize = genericOp.getResult(0); - iy = rewriter.create(loc, y, yScaleN); - ix = rewriter.create(loc, x, xScaleN); - - Value tempY = rewriter.create(loc, iy, yScaleN); - Value tempX = rewriter.create(loc, ix, xScaleN); - - dy = rewriter.create(loc, y, tempY); - dx = rewriter.create(loc, x, tempX); - } - - if (op.getMode() == "NEAREST_NEIGHBOR") { - Value yPred, xPred; - auto zeroVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(0)); - auto oneVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(1)); - - // Round the index position towards the closest pixel location. + { + OpBuilder::InsertionGuard regionGuard(b); + b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), + TypeRange({resultETy}), loc); + Value batch = b.create(0); + Value y = b.create(1); + Value x = b.create(2); + Value channel = b.create(3); + + Value zeroI32 = b.create(b.getI32IntegerAttr(0)); + Value hMax = b.create(b.getI32IntegerAttr(imageH - 1)); + Value wMax = b.create(b.getI32IntegerAttr(imageW - 1)); + + Value inY = b.create(b.getI32Type(), y); + Value inX = b.create(b.getI32Type(), x); + + bool floatingPointMode = resultETy.isF32(); + + SmallVector scale, offset, border; + getValuesFromIntArrayAttribute(op.getScale(), scale); + getValuesFromIntArrayAttribute(op.getOffset(), offset); + getValuesFromIntArrayAttribute(op.getBorder(), border); + + Value yScaleN, yScaleD, xScaleN, xScaleD; + yScaleN = b.create(b.getI32IntegerAttr(scale[0])); + yScaleD = b.create(b.getI32IntegerAttr(scale[1])); + xScaleN = b.create(b.getI32IntegerAttr(scale[2])); + xScaleD = b.create(b.getI32IntegerAttr(scale[3])); + + Value yOffset, xOffset, yBorder, xBorder; + yOffset = b.create(b.getI32IntegerAttr(offset[0])); + xOffset = b.create(b.getI32IntegerAttr(offset[1])); + yBorder = b.create(b.getI32IntegerAttr(border[0])); + xBorder = b.create(b.getI32IntegerAttr(border[1])); + + // Compute the ix and dx values for both the X and Y dimensions. + auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, + Value scaleN, Value scaleD, Value offset, + int size, ImplicitLocOpBuilder &b) { + // x = x * scale_d + offset; + // ix = floor(x / scale_n) + // dx = x / scale_n - ix + Value val = b.create(b.getF32Type(), in); + scaleN = b.create(b.getF32Type(), scaleN); + scaleD = b.create(b.getF32Type(), scaleD); + offset = b.create(b.getF32Type(), offset); + val = b.create(val, scaleD); + val = b.create(val, offset); + val = b.create(val, scaleN); + index = b.create(val); + delta = b.create(val, index); + index = b.create(b.getI32Type(), index); + }; + + // Compute the ix and dx values for the X and Y dimensions - int case. + auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in, + Value scaleN, Value scaleD, Value offset, + int size, ImplicitLocOpBuilder &b) { + // x = x * scale_d + offset; + // ix = floor(x / scale_n) + // dx = x - ix * scale_n; + Value val = b.create(in, scaleD); + val = b.create(val, offset); + index = b.create(val, scaleN); + delta = b.create(index, scaleN); + delta = b.create(val, delta); + }; + + // Compute the ix and dx values for the X and Y dimensions - fp case. + Value ix, iy, dx, dy; if (floatingPointMode) { - auto halfVal = rewriter.create( - loc, rewriter.getF32FloatAttr(0.5f)); - yPred = rewriter.create(loc, arith::CmpFPredicate::OGE, - dy, halfVal); - xPred = rewriter.create(loc, arith::CmpFPredicate::OGE, - dx, halfVal); + getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); + getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } else { - Value dyDoubled = rewriter.create(loc, dy, oneVal); - Value dxDoubled = rewriter.create(loc, dx, oneVal); - yPred = rewriter.create(loc, arith::CmpIPredicate::sge, - dyDoubled, yScaleN); - xPred = rewriter.create(loc, arith::CmpIPredicate::sge, - dxDoubled, xScaleN); + getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); + getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } - auto yOffset = - rewriter.create(loc, yPred, oneVal, zeroVal); - auto xOffset = - rewriter.create(loc, xPred, oneVal, zeroVal); - - iy = rewriter.create(loc, iy, yOffset); - ix = rewriter.create(loc, ix, xOffset); - - // Clamp the to be within the bounds of the input image. - iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter); - ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter); - - // Read the value from the input array. - iy = - rewriter.create(loc, rewriter.getIndexType(), iy); - ix = - rewriter.create(loc, rewriter.getIndexType(), ix); + if (op.getMode() == "NEAREST_NEIGHBOR") { + auto one = b.create(b.getI32IntegerAttr(1)); - Value result = rewriter.create( - loc, input, ValueRange{batch, iy, ix, channel}); - - rewriter.create(loc, result); - } else { - // The mode here must be BILINEAR. - assert(op.getMode() == "BILINEAR"); - Value y0 = iy; - Value x0 = ix; - - auto oneVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(1)); - Value y1 = rewriter.create(loc, y0, oneVal); - Value x1 = rewriter.create(loc, x0, oneVal); - - y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter); - y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter); - - x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter); - x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter); - - y0 = - rewriter.create(loc, rewriter.getIndexType(), y0); - y1 = - rewriter.create(loc, rewriter.getIndexType(), y1); - x0 = - rewriter.create(loc, rewriter.getIndexType(), x0); - x1 = - rewriter.create(loc, rewriter.getIndexType(), x1); - - Value y0x0 = rewriter.create( - loc, input, ValueRange{batch, y0, x0, channel}); - Value y0x1 = rewriter.create( - loc, input, ValueRange{batch, y0, x1, channel}); - Value y1x0 = rewriter.create( - loc, input, ValueRange{batch, y1, x0, channel}); - Value y1x1 = rewriter.create( - loc, input, ValueRange{batch, y1, x1, channel}); + auto roundIndex = [&](Value val, Value dval, Value scale, Value max, + int size, ImplicitLocOpBuilder &b) -> Value { + if (size == 1) { + return b.create(0); + } - if (floatingPointMode) { - Value rightPart = dx; - auto oneVal = rewriter.create( - loc, rewriter.getF32FloatAttr(1.0f)); - Value leftPart = rewriter.create(loc, oneVal, dx); + Value pred; + if (floatingPointMode) { + auto h = b.create(b.getF32FloatAttr(0.5f)); + pred = b.create(arith::CmpFPredicate::OGE, dval, h); + } else { + Value dvalDouble = b.create(dval, one); + pred = b.create(arith::CmpIPredicate::sge, + dvalDouble, scale); + } - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + auto offset = b.create(pred, one, zeroI32); + val = b.create(val, offset); + val = clampIntHelper(loc, val, zeroI32, max, b); + return b.create(b.getIndexType(), val); + }; - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + iy = roundIndex(iy, dy, yScaleN, hMax, imageH, b); + ix = roundIndex(ix, dx, xScaleN, wMax, imageW, b); - Value bottomPart = dy; - Value topPart = rewriter.create(loc, oneVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value result = b.create( + input, ValueRange{batch, iy, ix, channel}); - rewriter.create(loc, result); + b.create(result); } else { - // Perform in quantized space. - y0x0 = rewriter.create(loc, resultElementTy, y0x0); - y0x1 = rewriter.create(loc, resultElementTy, y0x1); - y1x0 = rewriter.create(loc, resultElementTy, y1x0); - y1x1 = rewriter.create(loc, resultElementTy, y1x1); - - if (resultElementTy.getIntOrFloatBitWidth() > 32) { - dx = rewriter.create(loc, resultElementTy, dx); - dy = rewriter.create(loc, resultElementTy, dy); - } + // The mode here must be BILINEAR. + assert(op.getMode() == "BILINEAR"); + + auto oneVal = b.create(b.getI32IntegerAttr(1)); - Value topAcc, bottomAcc; - if (imageW == 1) { - topAcc = rewriter.create(loc, y0x0, xScaleN); - bottomAcc = rewriter.create(loc, y1x0, xScaleN); + auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in, + Value max, ImplicitLocOpBuilder &b) { + if (size == 1) { + val0 = b.create(0); + val1 = val0; + return; + } + val0 = in; + val1 = b.create(val0, oneVal); + val0 = clampIntHelper(loc, val0, zeroI32, max, b); + val1 = clampIntHelper(loc, val1, zeroI32, max, b); + val0 = b.create(b.getIndexType(), val0); + val1 = b.create(b.getIndexType(), val1); + }; + + // Linalg equivalent to the section below: + // int16_t iy0 = apply_max(iy, 0); + // int16_t iy1 = apply_min(iy + 1, IH - 1); + // int16_t ix0 = apply_max(ix, 0); + // int16_t ix1 = apply_min(ix + 1, IW - 1); + Value x0, x1, y0, y1; + getClampedIdxs(y0, y1, imageH, iy, hMax, b); + getClampedIdxs(x0, x1, imageW, ix, wMax, b); + + Value y0x0 = b.create( + input, ValueRange{batch, y0, x0, channel}); + Value y0x1 = b.create( + input, ValueRange{batch, y0, x1, channel}); + Value y1x0 = b.create( + input, ValueRange{batch, y1, x0, channel}); + Value y1x1 = b.create( + input, ValueRange{batch, y1, x1, channel}); + + if (floatingPointMode) { + auto oneVal = b.create(b.getF32FloatAttr(1.0f)); + auto interpolate = [&](Value val0, Value val1, Value delta, + ImplicitLocOpBuilder &b) -> Value { + Value oneMinusDelta = b.create(oneVal, delta); + Value mul0 = b.create(val0, oneMinusDelta); + Value mul1 = b.create(val1, delta); + return b.create(mul0, mul1); + }; + + // Linalg equivalent to the section below: + // topAcc = v00 * (unit_x - dx); + // topAcc += v01 * dx; + Value topAcc = interpolate(y0x0, y0x1, dx, b); + + // Linalg equivalent to the section below: + // bottomAcc = v10 * (unit_x - dx); + // bottomAcc += v11 * dx; + Value bottomAcc = interpolate(y1x0, y1x1, dx, b); + + // Linalg equivalent to the section below: + // result = topAcc * (unit_y - dy) + bottomAcc * dy + Value result = interpolate(topAcc, bottomAcc, dy, b); + b.create(result); } else { - Value rightPart = dx; - Value leftPart = rewriter.create(loc, xScaleN, dx); + // Perform in quantized space. + y0x0 = b.create(resultETy, y0x0); + y0x1 = b.create(resultETy, y0x1); + y1x0 = b.create(resultETy, y1x0); + y1x1 = b.create(resultETy, y1x1); + + if (resultETy.getIntOrFloatBitWidth() > 32) { + dx = b.create(resultETy, dx); + dy = b.create(resultETy, dy); + } - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - topAcc = rewriter.create(loc, y0x0, y0x1); + Value yScaleNExt = yScaleN; + Value xScaleNExt = xScaleN; - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - bottomAcc = rewriter.create(loc, y1x0, y1x1); - } + if (resultETy.getIntOrFloatBitWidth() > 32) { + yScaleNExt = b.create(resultETy, yScaleN); + xScaleNExt = b.create(resultETy, xScaleN); + } - Value result; - if (imageH == 1) { - result = rewriter.create(loc, topAcc, yScaleN); - } else { - Value bottomPart = dy; - Value topPart = rewriter.create(loc, yScaleN, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = - rewriter.create(loc, bottomAcc, bottomPart); - result = rewriter.create(loc, topAcc, bottomAcc); + auto interpolate = [](Value val0, Value val1, Value weight0, + Value weight1, int64_t size, + ImplicitLocOpBuilder &b) -> Value { + if (size == 1) + return val0; + Value mul0 = b.create(val0, weight0); + Value mul1 = b.create(val1, weight1); + return b.create(mul0, mul1); + }; + + Value weight0 = b.create(xScaleNExt, dx); + Value weight1 = dx; + Value topAcc = interpolate(y0x0, y0x1, weight0, weight1, imageW, b); + Value bottomAcc = + interpolate(y1x0, y1x1, weight0, weight1, imageW, b); + + weight0 = b.create(yScaleNExt, dy); + weight1 = dy; + Value result = + interpolate(topAcc, bottomAcc, weight0, weight1, imageH, b); + b.create(result); } - - rewriter.create(loc, result); } } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -18,7 +18,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -14,7 +14,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" using namespace mlir; using namespace mlir::tosa; @@ -30,15 +30,14 @@ return condensedValues; } -Value mlir::tosa::clampFloatHelper(Location loc, Value arg, - arith::ConstantOp min, arith::ConstantOp max, - OpBuilder &rewriter) { +Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, + Value max, OpBuilder &rewriter) { Value minValue = rewriter.create(loc, arg, max); return rewriter.create(loc, minValue, min); } -Value mlir::tosa::clampIntHelper(Location loc, Value arg, arith::ConstantOp min, - arith::ConstantOp max, OpBuilder &rewriter) { +Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, + OpBuilder &rewriter) { auto smallerThanMin = rewriter.create(loc, arith::CmpIPredicate::slt, arg, min); auto minOrArg = diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir @@ -124,7 +124,7 @@ // CHECK: %[[IDX_1:.+]] = linalg.index 1 // CHECK: %[[IDX_2:.+]] = linalg.index 2 // CHECK: %[[IDX_3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XY_MIN:.+]] = arith.constant 0 + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 // CHECK-DAG: %[[Y_MAX:.+]] = arith.constant 14 // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 12 @@ -142,66 +142,62 @@ // find the remainder and integer component of the target index. // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]] - // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]] - // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] // CHECK: %[[I_Y:.*]] = arith.divsi %[[Y]], %[[SCALE_Y_N]] - // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]] // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]] - // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]] - // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]] - // Round to the nearest neighor. + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] + // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] + // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]] + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] + // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]] - // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // Compute the offset and bound for the Y position. // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 // CHECK: %[[D_Y_DOUBLE:.*]] = arith.shli %[[D_Y]], %[[ONE]] - // CHECK: %[[D_X_DOUBLE:.*]] = arith.shli %[[D_X]], %[[ONE]] // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y_DOUBLE]], %[[SCALE_Y_N]] - // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X_DOUBLE]], %[[SCALE_X_N]] // CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]] - // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] // CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]] - // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]] - - // This section applies bound checking to be within the input image. - - // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[XY_MIN]] - // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[XY_MIN]], %[[VAL_39]] + // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[ZERO]] + // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[ZERO]], %[[VAL_39]] // CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[VAL_39]] // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[Y_MAX]], %[[VAL_42]] - // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[XY_MIN]] - // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[XY_MIN]], %[[VAL_40]] + // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]] + + // Compute the offset and bound for the X position. + // CHECK: %[[D_X_DOUBLE:.*]] = arith.shli %[[D_X]], %[[ONE]] + // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X_DOUBLE]], %[[SCALE_X_N]] + // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] + // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]] + // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[ZERO]] + // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[ZERO]], %[[VAL_40]] // CHECK: %[[VAL_47:.*]] = arith.cmpi slt, %[[X_MAX]], %[[VAL_40]] // CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[X_MAX]], %[[VAL_46]] - - // Extract the nearest value using the computed indices. - - // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]] // CHECK: %[[IDX:.+]] = arith.index_cast %[[VAL_48]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[IDY]], %[[IDX]], %[[IDX_3]]] // CHECK: linalg.yield %[[EXTRACT]] // Round to the nearest index. %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x23x179x1xi8> - return + return } // ----- // CHECK-LABEL: @resize_bilinear_int // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) { - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x289x289x1xi32> +func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) { + // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x304x320x1xi48> // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX_0:.+]] = linalg.index 0 // CHECK: %[[IDX_1:.+]] = linalg.index 1 // CHECK: %[[IDX_2:.+]] = linalg.index 2 // CHECK: %[[IDX_3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XY_MIN:.+]] = arith.constant 0 + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 // CHECK-DAG: %[[Y_MAX:.+]] = arith.constant 18 - // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 18 + // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 19 // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]] // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]] // CHECK-DAG: %[[SCALE_Y_N:.*]] = arith.constant 16 @@ -214,51 +210,53 @@ // CHECK-DAG: %[[BORDER_X:.*]] = arith.constant 0 // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]] - // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]] - // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] // CHECK: %[[I_Y:.*]] = arith.divsi %[[Y]], %[[SCALE_Y_N]] - // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]] // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]] - // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]] + + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] + // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] + // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]] + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]] // Compute the left, right, and top indices for the bilinear interpolation. // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]] - // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] // Bound check each dimension. - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_Y]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]] // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[Y1]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]] // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]] + // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] + // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] + + // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_X]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]] // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[X1]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]] // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // Extract each corner of the bilinear interpolation. - - // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] - // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]] // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]] + // Extract each corner of the bilinear interpolation. + // CHECK: %[[LOLO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]] // CHECK: %[[LOHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]] // CHECK: %[[HILO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]] @@ -269,24 +267,29 @@ // CHECK: %[[XHILO:.+]] = arith.extsi %[[HILO]] // CHECK: %[[XHIHI:.+]] = arith.extsi %[[HIHI]] + // CHECK-NEXT: %[[D_X_EXT:.+]] = arith.extsi %[[D_X]] + // CHECK-NEXT: %[[D_Y_EXT:.+]] = arith.extsi %[[D_Y]] + // CHECK-NEXT: %[[Y_N_EXT:.+]] = arith.extsi %[[SCALE_Y_N]] + // CHECK-NEXT: %[[X_N_EXT:.+]] = arith.extsi %[[SCALE_X_N]] + // Compute the bilinear interpolation. - // CHECK: %[[NDX:.+]] = arith.subi %[[SCALE_X_N]], %[[D_X]] + // CHECK: %[[NDX:.+]] = arith.subi %[[X_N_EXT]], %[[D_X_EXT]] // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]] - // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X]] + // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X_EXT]] // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]] // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]] - // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X]] + // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X_EXT]] // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]] - // CHECK: %[[NDY:.+]] = arith.subi %[[SCALE_Y_N]], %[[D_Y]] + // CHECK: %[[NDY:.+]] = arith.subi %[[Y_N_EXT]], %[[D_Y_EXT]] // CHECK: %[[WLO:.+]] = arith.muli %[[LO]], %[[NDY]] - // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[D_Y]] + // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[D_Y_EXT]] // CHECK: %[[RESULT:.+]] = arith.addi %[[WLO]], %[[WHI]] // CHECK: linalg.yield %[[RESULT]] // Round to the nearest index. - %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi8>) -> tensor<1x289x289x1xi32> - return + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x20x1xi8>) -> tensor<1x304x320x1xi48> + return } // ----- @@ -299,7 +302,7 @@ // CHECK: %[[IDX1:.+]] = linalg.index 1 // CHECK: %[[IDX2:.+]] = linalg.index 2 // CHECK: %[[IDX3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XYMIN:.*]] = arith.constant 0 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 // CHECK-DAG: %[[YMAX:.*]] = arith.constant 49 // CHECK-DAG: %[[XMAX:.*]] = arith.constant 47 // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX1]] @@ -314,72 +317,68 @@ // CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 31 // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]] - // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]] // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]] + // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] + // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] + // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] + // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] + // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] + // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] + // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]] + + // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]] // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]] - // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]] - - // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]] - // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]] - // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]] - - // Find the remainder and integer component of the target index. - - // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]] - // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]] - // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]] // CHECK: %[[VAL_40:.*]] = arith.fptosi %[[VAL_36]] - // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 // CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]] - // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]] // CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]] - // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]] - // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]] - - // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[XYMIN]] - // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[XYMIN]], %[[VAL_48]] + // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[ZERO]] + // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[ZERO]], %[[VAL_48]] // CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[YMAX]], %[[VAL_48]] // CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[YMAX]], %[[VAL_51]] - // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[XYMIN]] - // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[XYMIN]], %[[VAL_49]] + // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]] + + // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 + // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]] + // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] + // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]] + // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[ZERO]] + // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[ZERO]], %[[VAL_49]] // CHECK: %[[VAL_56:.*]] = arith.cmpi slt, %[[XMAX]], %[[VAL_49]] // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[XMAX]], %[[VAL_55]] - - // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]] // CHECK: %[[IDX:.*]] = arith.index_cast %[[VAL_57]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]] // CHECK: linalg.yield %[[EXTRACT]] %output = "tosa.resize"(%input) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor<1x1600x1536x1xf32> - return } // ----- // CHECK-LABEL: @resize_bilinear_fp -func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () { - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x89x89x1xf32> +func.func @resize_bilinear_fp(%input: tensor<1x23x24x1xf32>) -> () { + // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x92x96x1xf32> // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX_0:.+]] = linalg.index 0 // CHECK: %[[IDX_1:.+]] = linalg.index 1 // CHECK: %[[IDX_2:.+]] = linalg.index 2 // CHECK: %[[IDX_3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XY_MIN:.*]] = arith.constant 0 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 // CHECK-DAG: %[[Y_MAX:.*]] = arith.constant 22 - // CHECK-DAG: %[[X_MAX:.*]] = arith.constant 22 + // CHECK-DAG: %[[X_MAX:.*]] = arith.constant 23 // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]] // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]] // CHECK-DAG: %[[ISCALE_Y_N:.*]] = arith.constant 4 @@ -392,58 +391,58 @@ // CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 0 // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]] - // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]] // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]] + // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] + // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] + // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] + // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] + // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] + // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] + // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]] + + // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]] // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]] - // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]] - - // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]] - // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]] - // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]] - - // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]] - // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]] - // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]] // CHECK: %[[I_X:.*]] = arith.fptosi %[[VAL_36]] // Compute the left, right, and top indices for the bilinear interpolation. - // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 - // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]] - // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] + // CHECK: %[[ONE:.*]] = arith.constant 1 // Bound check each dimension. - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]] + // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]] + + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_Y]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]] // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[Y1]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]] // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] + // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] + // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]] + // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_X]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]] // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[X1]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]] // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] - // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]] // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]] @@ -457,6 +456,7 @@ // CHECK: %[[WLOLO:.+]] = arith.mulf %[[LOLO]], %[[NDX]] // CHECK: %[[WLOHI:.+]] = arith.mulf %[[LOHI]], %[[D_X]] // CHECK: %[[LO:.+]] = arith.addf %[[WLOLO]], %[[WLOHI]] + // CHECK: %[[NDX:.+]] = arith.subf %[[ONE]], %[[D_X]] // CHECK: %[[WHILO:.+]] = arith.mulf %[[HILO]], %[[NDX]] // CHECK: %[[WHIHI:.+]] = arith.mulf %[[HIHI]], %[[D_X]] // CHECK: %[[HI:.+]] = arith.addf %[[WHILO]], %[[WHIHI]] @@ -467,7 +467,7 @@ // CHECK: linalg.yield %[[RESULT]] // Round by bilinear interpolation - %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor<1x89x89x1xf32> + %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x24x1xf32>) -> tensor<1x92x96x1xf32> return }