diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h rename from mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h rename to mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -29,13 +29,13 @@ // Takes the parameters for a clamp and turns it into a series of ops for float // inputs. -Value clampFloatHelper(Location loc, Value arg, arith::ConstantOp min, - arith::ConstantOp max, OpBuilder &rewriter); +Value clampFloatHelper(Location loc, Value arg, Value min, Value max, + OpBuilder &rewriter); // Takes the parameters for a clamp and turns it into a series of ops for // integer inputs. -Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min, - arith::ConstantOp max, OpBuilder &rewriter); +Value clampIntHelper(Location loc, Value arg, Value min, Value max, + OpBuilder &rewriter); // Returns the values in an attribute as an array of values. template diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -18,7 +18,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" @@ -177,10 +177,10 @@ auto sub = rewriter.create(loc, zpAddValue, ext); // Clamp to the negation range. - auto min = rewriter.create( + Value min = rewriter.create( loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), intermediateType); - auto max = rewriter.create( + Value max = rewriter.create( loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), intermediateType); auto clamp = clampIntHelper(loc, sub, min, max, rewriter); @@ -1426,6 +1426,7 @@ LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); auto input = op.getInput(); auto inputTy = input.getType().cast(); auto resultTy = op.getType().cast(); @@ -1440,276 +1441,283 @@ return failure(); SmallVector dynamicDims = dynamicDimsOr.value(); + llvm::SmallVector resizeShape(resultTy.getShape()); + if (imageH == 1) + resizeShape[1] = 1; + if (imageW == 1) + resizeShape[2] = 1; + if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") return failure(); - auto emptyTensor = rewriter.create( - loc, resultTy.getShape(), resultElementTy, dynamicDims); + RankedTensorType resizeTy = + RankedTensorType::get(resizeShape, resultTy.getElementType()); + auto emptyTensor = + b.create(resizeShape, resultElementTy, dynamicDims); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - Value resize = input; - auto genericOp = rewriter.create( - loc, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, + // TODO(suderman): Override resultTy. + auto genericOp = b.create( + resizeTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); - resize = genericOp.getResult(0); - - OpBuilder::InsertionGuard regionGuard(rewriter); - rewriter.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), - TypeRange({resultElementTy}), loc); - Value batch = rewriter.create(loc, 0); - Value y = rewriter.create(loc, 1); - Value x = rewriter.create(loc, 2); - Value channel = rewriter.create(loc, 3); - - auto hwMin = - rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto hMax = rewriter.create( - loc, rewriter.getI32IntegerAttr(imageH - 1)); - auto wMax = rewriter.create( - loc, rewriter.getI32IntegerAttr(imageW - 1)); - - Value inY = - rewriter.create(loc, rewriter.getI32Type(), y); - Value inX = - rewriter.create(loc, rewriter.getI32Type(), x); - - bool floatingPointMode = resultElementTy.isF32(); - - Value yScaleN, yScaleD, xScaleN, xScaleD, yOffset, xOffset, yBorder, - xBorder; - SmallVector scale, offset, border; - getValuesFromIntArrayAttribute(op.getScale(), scale); - getValuesFromIntArrayAttribute(op.getOffset(), offset); - getValuesFromIntArrayAttribute(op.getBorder(), border); - - yScaleN = rewriter.create( - loc, rewriter.getI32IntegerAttr(scale[0])); - yScaleD = rewriter.create( - loc, rewriter.getI32IntegerAttr(scale[1])); - xScaleN = rewriter.create( - loc, rewriter.getI32IntegerAttr(scale[2])); - xScaleD = rewriter.create( - loc, rewriter.getI32IntegerAttr(scale[3])); - yOffset = rewriter.create( - loc, rewriter.getI32IntegerAttr(offset[0])); - xOffset = rewriter.create( - loc, rewriter.getI32IntegerAttr(offset[1])); - yBorder = rewriter.create( - loc, rewriter.getI32IntegerAttr(border[0])); - xBorder = rewriter.create( - loc, rewriter.getI32IntegerAttr(border[1])); - - // Compute the the integer index and partial offset. - Value ix, iy, dx, dy; - // x = x * scale_d + offset; - // ix = floor(x / scale_n) - if (floatingPointMode) { - // dx = x / scale_n - ix - Value y = - rewriter.create(loc, rewriter.getF32Type(), inY); - Value x = - rewriter.create(loc, rewriter.getF32Type(), inX); - - yScaleN = - rewriter.create(loc, rewriter.getF32Type(), yScaleN); - yScaleD = - rewriter.create(loc, rewriter.getF32Type(), yScaleD); - xScaleN = - rewriter.create(loc, rewriter.getF32Type(), xScaleN); - xScaleD = - rewriter.create(loc, rewriter.getF32Type(), xScaleD); - yOffset = - rewriter.create(loc, rewriter.getF32Type(), yOffset); - xOffset = - rewriter.create(loc, rewriter.getF32Type(), xOffset); - - y = rewriter.create(loc, y, yScaleD); - x = rewriter.create(loc, x, xScaleD); - - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); - - y = rewriter.create(loc, y, yScaleN); - x = rewriter.create(loc, x, xScaleN); - - iy = rewriter.create(loc, y); - ix = rewriter.create(loc, x); - - dy = rewriter.create(loc, y, iy); - dx = rewriter.create(loc, x, ix); - - iy = rewriter.create(loc, rewriter.getI32Type(), iy); - ix = rewriter.create(loc, rewriter.getI32Type(), ix); - } else { - // dx = x - ix * scale_n; - Value y = rewriter.create(loc, inY, yScaleD); - Value x = rewriter.create(loc, inX, xScaleD); - - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); + Value resize = genericOp.getResult(0); - iy = rewriter.create(loc, y, yScaleN); - ix = rewriter.create(loc, x, xScaleN); - - Value tempY = rewriter.create(loc, iy, yScaleN); - Value tempX = rewriter.create(loc, ix, xScaleN); - - dy = rewriter.create(loc, y, tempY); - dx = rewriter.create(loc, x, tempX); - } - - if (op.getMode() == "NEAREST_NEIGHBOR") { - Value yPred, xPred; - auto zeroVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(0)); - auto oneVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(1)); - - // Round the index position towards the closest pixel location. + { + OpBuilder::InsertionGuard regionGuard(b); + b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), + TypeRange({resultElementTy}), loc); + Value batch = b.create(0); + Value y = b.create(1); + Value x = b.create(2); + Value channel = b.create(3); + + Value zeroI32 = b.create(b.getI32IntegerAttr(0)); + Value zeroF32 = + b.create(b.getZeroAttr(rewriter.getF32Type())); + Value hMax = b.create(b.getI32IntegerAttr(imageH - 1)); + Value wMax = b.create(b.getI32IntegerAttr(imageW - 1)); + + Value inY = b.create(b.getI32Type(), y); + Value inX = b.create(b.getI32Type(), x); + + bool floatingPointMode = resultElementTy.isF32(); + + SmallVector scale, offset, border; + getValuesFromIntArrayAttribute(op.getScale(), scale); + getValuesFromIntArrayAttribute(op.getOffset(), offset); + getValuesFromIntArrayAttribute(op.getBorder(), border); + + Value yScaleN, yScaleD, xScaleN, xScaleD; + yScaleN = b.create(b.getI32IntegerAttr(scale[0])); + yScaleD = b.create(b.getI32IntegerAttr(scale[1])); + xScaleN = b.create(b.getI32IntegerAttr(scale[2])); + xScaleD = b.create(b.getI32IntegerAttr(scale[3])); + + Value yOffset, xOffset, yBorder, xBorder; + yOffset = b.create(b.getI32IntegerAttr(offset[0])); + xOffset = b.create(b.getI32IntegerAttr(offset[1])); + yBorder = b.create(b.getI32IntegerAttr(border[0])); + xBorder = b.create(b.getI32IntegerAttr(border[1])); + + // Compute the index and delta values for the float case. + auto floatIndices = [&](Value &index, Value &delta, Value in, + Value scaleN, Value scaleD, Value offset, + int size, ImplicitLocOpBuilder &b) { + if (size == 1) { + index = zeroI32; + delta = zeroF32; + return; + } + // x = x * scale_d + offset; + // ix = floor(x / scale_n) + // dx = x / scale_n - ix + Value val = b.create(b.getF32Type(), in); + scaleN = b.create(b.getF32Type(), scaleN); + scaleD = b.create(b.getF32Type(), scaleD); + offset = b.create(b.getF32Type(), offset); + val = b.create(val, scaleD); + val = b.create(val, offset); + val = b.create(val, scaleN); + index = b.create(val); + delta = b.create(val, index); + index = b.create(b.getI32Type(), index); + }; + + // Compute the index and delta values for the integer case. + auto intIndices = [&](Value &index, Value &delta, Value in, Value scaleN, + Value scaleD, Value offset, int size, + ImplicitLocOpBuilder &b) { + if (size == 1) { + index = zeroI32; + delta = zeroI32; + return; + } + // x = x * scale_d + offset; + // ix = floor(x / scale_n) + // dx = x - ix * scale_n; + Value val = b.create(in, scaleD); + val = b.create(val, offset); + index = b.create(val, scaleN); + delta = b.create(index, scaleN); + delta = b.create(val, delta); + }; + + // Compute the the integer index and partial offset. + Value ix, iy, dx, dy; if (floatingPointMode) { - auto halfVal = rewriter.create( - loc, rewriter.getF32FloatAttr(0.5f)); - yPred = rewriter.create(loc, arith::CmpFPredicate::OGE, - dy, halfVal); - xPred = rewriter.create(loc, arith::CmpFPredicate::OGE, - dx, halfVal); + floatIndices(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); + floatIndices(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } else { - Value yScaleNHalfVal = - rewriter.create(loc, yScaleN, oneVal); - Value xScaleNHalfVal = - rewriter.create(loc, xScaleN, oneVal); - yPred = rewriter.create(loc, arith::CmpIPredicate::sge, - dy, yScaleNHalfVal); - xPred = rewriter.create(loc, arith::CmpIPredicate::sge, - dx, xScaleNHalfVal); + intIndices(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); + intIndices(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } - auto yOffset = - rewriter.create(loc, yPred, oneVal, zeroVal); - auto xOffset = - rewriter.create(loc, xPred, oneVal, zeroVal); - - iy = rewriter.create(loc, iy, yOffset); - ix = rewriter.create(loc, ix, xOffset); + if (op.getMode() == "NEAREST_NEIGHBOR") { + auto one = b.create(b.getI32IntegerAttr(1)); - // Clamp the to be within the bounds of the input image. - iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter); - ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter); + auto roundIndex = [&](Value val, Value dval, Value scale, Value max, + int size, ImplicitLocOpBuilder &b) -> Value { + if (size == 1) { + return b.create(0); + } - // Read the value from the input array. - iy = - rewriter.create(loc, rewriter.getIndexType(), iy); - ix = - rewriter.create(loc, rewriter.getIndexType(), ix); + Value pred; + if (floatingPointMode) { + auto h = b.create(b.getF32FloatAttr(0.5f)); + pred = b.create(arith::CmpFPredicate::OGE, dval, h); + } else { + Value scaleH = b.create(scale, one); + pred = b.create(arith::CmpIPredicate::sge, dval, + scaleH); + } - Value result = rewriter.create( - loc, input, ValueRange{batch, iy, ix, channel}); + auto offset = b.create(pred, one, zeroI32); + val = b.create(val, offset); + val = clampIntHelper(loc, val, zeroI32, max, b); + return b.create(b.getIndexType(), val); + }; - rewriter.create(loc, result); - } else { - // The mode here must be BILINEAR. - assert(op.getMode() == "BILINEAR"); - Value y0 = iy; - Value x0 = ix; - - auto oneVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(1)); - Value y1 = rewriter.create(loc, y0, oneVal); - Value x1 = rewriter.create(loc, x0, oneVal); - - y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter); - y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter); - - x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter); - x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter); - - y0 = - rewriter.create(loc, rewriter.getIndexType(), y0); - y1 = - rewriter.create(loc, rewriter.getIndexType(), y1); - x0 = - rewriter.create(loc, rewriter.getIndexType(), x0); - x1 = - rewriter.create(loc, rewriter.getIndexType(), x1); - - Value y0x0 = rewriter.create( - loc, input, ValueRange{batch, y0, x0, channel}); - Value y0x1 = rewriter.create( - loc, input, ValueRange{batch, y0, x1, channel}); - Value y1x0 = rewriter.create( - loc, input, ValueRange{batch, y1, x0, channel}); - Value y1x1 = rewriter.create( - loc, input, ValueRange{batch, y1, x1, channel}); + iy = roundIndex(iy, dy, yScaleN, hMax, imageH, b); + ix = roundIndex(ix, dx, xScaleN, wMax, imageW, b); - if (floatingPointMode) { - Value rightPart = dx; - auto oneVal = rewriter.create( - loc, rewriter.getF32FloatAttr(1.0f)); - Value leftPart = rewriter.create(loc, oneVal, dx); + Value result = b.create( + input, ValueRange{batch, iy, ix, channel}); - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + b.create(result); + } else { + // The mode here must be BILINEAR. + assert(op.getMode() == "BILINEAR"); - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + auto oneVal = b.create(b.getI32IntegerAttr(1)); - Value bottomPart = dy; - Value topPart = rewriter.create(loc, oneVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value x0, x1, y0, y1; - rewriter.create(loc, result); - } else { - // Perform in quantized space. - y0x0 = rewriter.create(loc, resultElementTy, y0x0); - y0x1 = rewriter.create(loc, resultElementTy, y0x1); - y1x0 = rewriter.create(loc, resultElementTy, y1x0); - y1x1 = rewriter.create(loc, resultElementTy, y1x1); - - if (resultElementTy.getIntOrFloatBitWidth() > 32) { - dx = rewriter.create(loc, resultElementTy, dx); - dy = rewriter.create(loc, resultElementTy, dy); - } - - Value topAcc, bottomAcc; - if (imageW == 1) { - topAcc = rewriter.create(loc, y0x0, xScaleN); - bottomAcc = rewriter.create(loc, y1x0, xScaleN); + auto clampEdges = [&](Value &val0, Value &val1, int size, Value in, + Value max, ImplicitLocOpBuilder &b) { + if (size == 1) { + val0 = b.create(0); + val1 = val0; + return; + } + val0 = in; + val1 = b.create(val0, oneVal); + val0 = clampIntHelper(loc, val0, zeroI32, max, b); + val1 = clampIntHelper(loc, val1, zeroI32, max, b); + val0 = b.create(b.getIndexType(), val0); + val1 = b.create(b.getIndexType(), val1); + }; + + clampEdges(y0, y1, imageH, iy, hMax, b); + clampEdges(x0, x1, imageW, ix, wMax, b); + + Value y0x0 = b.create( + input, ValueRange{batch, y0, x0, channel}); + Value y0x1 = b.create( + input, ValueRange{batch, y0, x1, channel}); + Value y1x0 = b.create( + input, ValueRange{batch, y1, x0, channel}); + Value y1x1 = b.create( + input, ValueRange{batch, y1, x1, channel}); + + if (floatingPointMode) { + auto oneVal = b.create(b.getF32FloatAttr(1.0f)); + Value w0 = b.create(oneVal, dx); + Value w1 = dx; + + auto interpolate = [](Value val0, Value val1, Value d0, Value d1, + int64_t size, + ImplicitLocOpBuilder &b) -> Value { + if (size == 1) + return val0; + Value mul0 = b.create(val0, d0); + Value mul1 = b.create(val1, d1); + return b.create(mul0, mul1); + }; + + Value topAcc = interpolate(y0x0, y0x1, w0, w1, imageW, b); + Value bottomAcc = interpolate(y1x0, y1x1, w0, w1, imageW, b); + + w0 = b.create(oneVal, dy); + w1 = dy; + Value result = interpolate(topAcc, bottomAcc, w0, w1, imageH, b); + b.create(result); } else { - Value rightPart = dx; - Value leftPart = rewriter.create(loc, xScaleN, dx); - - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - topAcc = rewriter.create(loc, y0x0, y0x1); + // Perform in quantized space. + y0x0 = b.create(resultElementTy, y0x0); + y0x1 = b.create(resultElementTy, y0x1); + y1x0 = b.create(resultElementTy, y1x0); + y1x1 = b.create(resultElementTy, y1x1); + + if (resultElementTy.getIntOrFloatBitWidth() > 32) { + dx = b.create(resultElementTy, dx); + dy = b.create(resultElementTy, dy); + } - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - bottomAcc = rewriter.create(loc, y1x0, y1x1); + auto interpolate = [](Value val0, Value val1, Value d0, Value d1, + int64_t size, + ImplicitLocOpBuilder &b) -> Value { + if (size == 1) + return val0; + + Value mul0 = b.create(val0, d0); + Value mul1 = b.create(val1, d1); + return b.create(mul0, mul1); + }; + + Value w0 = b.create(xScaleN, dx); + Value w1 = dx; + Value topAcc = interpolate(y0x0, y0x1, w0, w1, imageW, b); + Value bottomAcc = interpolate(y1x0, y1x1, w0, w1, imageW, b); + + w0 = b.create(yScaleN, dy); + w1 = dy; + Value result = interpolate(topAcc, bottomAcc, w0, w1, imageH, b); + b.create(result); } + } + } - Value result; - if (imageH == 1) { - result = rewriter.create(loc, topAcc, yScaleN); - } else { - Value bottomPart = dy; - Value topPart = rewriter.create(loc, yScaleN, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = - rewriter.create(loc, bottomAcc, bottomPart); - result = rewriter.create(loc, topAcc, bottomAcc); - } + if (resizeTy == resultTy) { + rewriter.replaceOp(op, resize); + return success(); + } - rewriter.create(loc, result); + // Collapse the length-1 width/height values before the broadcast. + SmallVector collapseShape; + SmallVector collapseMap; + SmallVector broadcastExprs; + + for (int i = 0, s = resizeTy.getRank(); i < s; i++) { + if (resizeTy.getDimSize(i) == resultTy.getDimSize(i)) { + collapseShape.push_back(resizeTy.getDimSize(i)); + collapseMap.push_back({rewriter.getAffineDimExpr(i)}); + broadcastExprs.push_back(rewriter.getAffineDimExpr(i)); + continue; } + + collapseMap.back().push_back(rewriter.getAffineDimExpr(i)); } - rewriter.replaceOp(op, resize); + resizeTy = RankedTensorType::get(collapseShape, resizeTy.getElementType()); + resize = b.create(resizeTy, resize, collapseMap); + + auto broadcastEmpty = b.create( + resultTy.getShape(), resultElementTy, dynamicDims); + + SmallVector broadcastMaps = { + AffineMap::get(resultTy.getRank(), 0, broadcastExprs, + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + rewriter.replaceOpWithNewOp( + op, resultTy, ArrayRef({resize}), ValueRange{broadcastEmpty}, + broadcastMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(op.getLoc(), args[0]); + }); + return success(); } }; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -18,7 +18,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -14,7 +14,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" using namespace mlir; using namespace mlir::tosa; @@ -29,15 +29,14 @@ return condensedValues; } -Value mlir::tosa::clampFloatHelper(Location loc, Value arg, - arith::ConstantOp min, arith::ConstantOp max, - OpBuilder &rewriter) { +Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, + Value max, OpBuilder &rewriter) { Value minValue = rewriter.create(loc, arg, max); return rewriter.create(loc, minValue, min); } -Value mlir::tosa::clampIntHelper(Location loc, Value arg, arith::ConstantOp min, - arith::ConstantOp max, OpBuilder &rewriter) { +Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, + OpBuilder &rewriter) { auto smallerThanMin = rewriter.create(loc, arith::CmpIPredicate::slt, arg, min); auto minOrArg = diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir @@ -124,7 +124,7 @@ // CHECK: %[[IDX_1:.+]] = linalg.index 1 // CHECK: %[[IDX_2:.+]] = linalg.index 2 // CHECK: %[[IDX_3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XY_MIN:.+]] = arith.constant 0 + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 // CHECK-DAG: %[[Y_MAX:.+]] = arith.constant 14 // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 12 @@ -142,66 +142,62 @@ // find the remainder and integer component of the target index. // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]] - // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]] - // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]] - // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]] // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]] - // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]] - // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]] - // Round to the nearest neighor. + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] + // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] + // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]] + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] + // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]] - // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // Compute the offset and bound for the Y position. // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 // CHECK: %[[SCALE_Y_N_HALF:.*]] = arith.shrsi %[[SCALE_Y_N]], %[[ONE]] - // CHECK: %[[SCALE_X_N_HALF:.*]] = arith.shrsi %[[SCALE_X_N]], %[[ONE]] // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y]], %[[SCALE_Y_N_HALF]] - // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X]], %[[SCALE_X_N_HALF]] // CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]] - // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] // CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]] - // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]] - - // This section applies bound checking to be within the input image. - - // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[XY_MIN]] - // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[XY_MIN]], %[[VAL_39]] + // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[ZERO]] + // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[ZERO]], %[[VAL_39]] // CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[VAL_39]] // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[Y_MAX]], %[[VAL_42]] - // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[XY_MIN]] - // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[XY_MIN]], %[[VAL_40]] + // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]] + + // Compute the offset and bound for the X position. + // CHECK: %[[SCALE_X_N_HALF:.*]] = arith.shrsi %[[SCALE_X_N]], %[[ONE]] + // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X]], %[[SCALE_X_N_HALF]] + // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] + // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]] + // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[ZERO]] + // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[ZERO]], %[[VAL_40]] // CHECK: %[[VAL_47:.*]] = arith.cmpi slt, %[[X_MAX]], %[[VAL_40]] // CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[X_MAX]], %[[VAL_46]] - - // Extract the nearest value using the computed indices. - - // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]] // CHECK: %[[IDX:.+]] = arith.index_cast %[[VAL_48]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[IDY]], %[[IDX]], %[[IDX_3]]] // CHECK: linalg.yield %[[EXTRACT]] // Round to the nearest index. %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x23x179x1xi8> - return + return } // ----- // CHECK-LABEL: @resize_bilinear_int // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) { - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x289x289x1xi32> +func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) { + // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x304x320x1xi32> // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX_0:.+]] = linalg.index 0 // CHECK: %[[IDX_1:.+]] = linalg.index 1 // CHECK: %[[IDX_2:.+]] = linalg.index 2 // CHECK: %[[IDX_3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XY_MIN:.+]] = arith.constant 0 + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 // CHECK-DAG: %[[Y_MAX:.+]] = arith.constant 18 - // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 18 + // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 19 // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]] // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]] // CHECK-DAG: %[[SCALE_Y_N:.*]] = arith.constant 16 @@ -214,51 +210,53 @@ // CHECK-DAG: %[[BORDER_X:.*]] = arith.constant 0 // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]] - // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]] - // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]] - // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]] // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]] - // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]] + + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]] + // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]] + // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]] + // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]] // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]] // Compute the left, right, and top indices for the bilinear interpolation. // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]] - // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] // Bound check each dimension. - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_Y]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]] // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[Y1]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]] // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]] + // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] + // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] + + // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_X]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]] // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[X1]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]] // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // Extract each corner of the bilinear interpolation. - - // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] - // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]] // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]] + // Extract each corner of the bilinear interpolation. + // CHECK: %[[LOLO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]] // CHECK: %[[LOHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]] // CHECK: %[[HILO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]] @@ -285,8 +283,8 @@ // CHECK: linalg.yield %[[RESULT]] // Round to the nearest index. - %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi8>) -> tensor<1x289x289x1xi32> - return + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x20x1xi8>) -> tensor<1x304x320x1xi32> + return } // ----- @@ -299,7 +297,7 @@ // CHECK: %[[IDX1:.+]] = linalg.index 1 // CHECK: %[[IDX2:.+]] = linalg.index 2 // CHECK: %[[IDX3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XYMIN:.*]] = arith.constant 0 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 // CHECK-DAG: %[[YMAX:.*]] = arith.constant 49 // CHECK-DAG: %[[XMAX:.*]] = arith.constant 47 // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX1]] @@ -314,72 +312,68 @@ // CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 31 // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]] - // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]] // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]] + // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] + // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] + // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] + // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] + // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] + // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] + // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]] + + // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]] // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]] - // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]] - - // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]] - // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]] - // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]] - - // Find the remainder and integer component of the target index. - - // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]] - // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]] - // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]] // CHECK: %[[VAL_40:.*]] = arith.fptosi %[[VAL_36]] - // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 // CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]] - // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]] // CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]] - // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]] - // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]] - - // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[XYMIN]] - // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[XYMIN]], %[[VAL_48]] + // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[ZERO]] + // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[ZERO]], %[[VAL_48]] // CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[YMAX]], %[[VAL_48]] // CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[YMAX]], %[[VAL_51]] - // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[XYMIN]] - // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[XYMIN]], %[[VAL_49]] + // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]] + + // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 + // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]] + // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]] + // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]] + // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[ZERO]] + // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[ZERO]], %[[VAL_49]] // CHECK: %[[VAL_56:.*]] = arith.cmpi slt, %[[XMAX]], %[[VAL_49]] // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[XMAX]], %[[VAL_55]] - - // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]] // CHECK: %[[IDX:.*]] = arith.index_cast %[[VAL_57]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]] // CHECK: linalg.yield %[[EXTRACT]] %output = "tosa.resize"(%input) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor<1x1600x1536x1xf32> - return } // ----- // CHECK-LABEL: @resize_bilinear_fp -func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () { - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x89x89x1xf32> +func.func @resize_bilinear_fp(%input: tensor<1x23x24x1xf32>) -> () { + // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x92x96x1xf32> // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX_0:.+]] = linalg.index 0 // CHECK: %[[IDX_1:.+]] = linalg.index 1 // CHECK: %[[IDX_2:.+]] = linalg.index 2 // CHECK: %[[IDX_3:.+]] = linalg.index 3 - // CHECK-DAG: %[[XY_MIN:.*]] = arith.constant 0 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 // CHECK-DAG: %[[Y_MAX:.*]] = arith.constant 22 - // CHECK-DAG: %[[X_MAX:.*]] = arith.constant 22 + // CHECK-DAG: %[[X_MAX:.*]] = arith.constant 23 // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]] // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]] // CHECK-DAG: %[[ISCALE_Y_N:.*]] = arith.constant 4 @@ -392,58 +386,58 @@ // CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 0 // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]] - // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]] // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]] + // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] + // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] + // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] + // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] + // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] + // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] + // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]] + + // CHECK: %[[X0:.+]] = arith.uitofp %[[X]] // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]] // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]] - // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]] // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]] - - // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]] // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]] - // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]] // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]] - // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]] // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]] - - // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]] // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]] - // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]] // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]] - // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]] // CHECK: %[[I_X:.*]] = arith.fptosi %[[VAL_36]] // Compute the left, right, and top indices for the bilinear interpolation. - // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 - // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]] - // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] + // CHECK: %[[ONE:.*]] = arith.constant 1 // Bound check each dimension. - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]] + // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]] + + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_Y]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]] // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[Y1]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]] // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]] + // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] + // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]] + // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_X]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]] // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]] - // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]] + // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[ZERO]] + // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[X1]] // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]] // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]] - // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]] - // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]] // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]] // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]] @@ -467,7 +461,7 @@ // CHECK: linalg.yield %[[RESULT]] // Round by bilinear interpolation - %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor<1x89x89x1xf32> + %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x24x1xf32>) -> tensor<1x92x96x1xf32> return } @@ -484,3 +478,102 @@ %output = "tosa.resize"(%input) { scale = [4, 2, 4, 2], offset = [-1, -1], border = [1, 1], mode = "BILINEAR" } : (tensor) -> (tensor) return } + +// ----- + +// CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> + +// CHECK-LABEL: @resize_nearest_w_one +func.func @resize_nearest_w_one(%arg0: tensor<1x1x13x1xf32>) -> (tensor<1x23x179x1xf32>) { + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x179x1xf32> + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK-SAME: indexing_maps = [#map0] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x179x1xf32>) + // CHECK: %[[I0:.+]] = arith.constant 0 : index + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%{{.*}}, %[[I0]], %{{.*}}, %{{.*}}] : tensor<1x1x13x1xf32> + // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] + // CHECK-SAME{literal}:[[0, 1], [2], [3]] + + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x23x179x1xf32> + // CHECK: %[[BROADCAST:.+]] = linalg.generic + // CHECK-SAME: indexing_maps = [#map1, #map0] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<1x179x1xf32>) + // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x23x179x1xf32>) + // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32): + // CHECK: linalg.yield %[[IN]] + // CHECK: return %[[BROADCAST]] + %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x1x13x1xf32>) -> tensor<1x23x179x1xf32> + return %0 : tensor<1x23x179x1xf32> +} + +// ----- + +// CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> + +// CHECK-LABEL: @resize_nearest_h_one +func.func @resize_nearest_h_one(%arg0: tensor<1x13x1x1xf32>) -> (tensor<1x179x23x1xf32>) { + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x179x1x1xf32> + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK-SAME: indexing_maps = [#map0] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x179x1x1xf32>) + // CHECK: %[[I0:.+]] = arith.constant 0 : index + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%{{.*}}, %{{.*}}, %[[I0]], %{{.*}}] : tensor<1x13x1x1xf32> + // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] + // CHECK-SAME{literal}:[[0], [1, 2], [3]] + + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x179x23x1xf32> + // CHECK: %[[BROADCAST:.+]] = linalg.generic + // CHECK-SAME: indexing_maps = [#map1, #map0] + // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<1x179x1xf32>) + // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x179x23x1xf32>) + // CHECK: return %[[BROADCAST]] + %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x13x1x1xf32>) -> tensor<1x179x23x1xf32> + return %0 : tensor<1x179x23x1xf32> +} + +// ----- + +// CHECK-LABEL: @resize_bilinear_h_one +func.func @resize_bilinear_h_one(%arg0: tensor<1x13x1x1xf32>) -> (tensor<1x179x23x1xf32>) { + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x13x1x1xf32>) -> tensor<1x179x23x1xf32> + return %0 : tensor<1x179x23x1xf32> +} + +// ----- + +// CHECK-LABEL: @resize_bilinear_w_one +func.func @resize_bilinear_w_one(%arg0: tensor<1x1x13x1xf32>) -> (tensor<1x23x179x1xf32>) { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[EXTRACT0:.+]] = tensor.extract %arg0[%{{.*}}, %[[C0]], %{{.*}}, %{{.*}}] : tensor<1x1x13x1xf32> + // CHECK: %[[EXTRACT1:.+]] = tensor.extract %arg0[%{{.*}}, %[[C0]], %{{.*}}, %{{.*}}] : tensor<1x1x13x1xf32> + // CHECK: %[[EXTRACT2:.+]] = tensor.extract %arg0[%{{.*}}, %[[C0]], %{{.*}}, %{{.*}}] : tensor<1x1x13x1xf32> + // CHECK: %[[EXTRACT3:.+]] = tensor.extract %arg0[%{{.*}}, %[[C0]], %{{.*}}, %{{.*}}] : tensor<1x1x13x1xf32> + // CHECK: %[[MUL1:.+]] = arith.mulf %[[EXTRACT0]], %{{.*}} + // CHECK: %[[MUL2:.+]] = arith.mulf %[[EXTRACT1]], %{{.*}} + // CHECK: %[[ADD:.+]] = arith.addf %[[MUL1]], %[[MUL2]] : f32 + // CHECK: linalg.yield %[[ADD]] : f32 + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x1x13x1xf32>) -> tensor<1x23x179x1xf32> + return %0 : tensor<1x23x179x1xf32> +} + +// ----- + +// CHECK-LABEL: @resize_bilinear_h_one +func.func @resize_bilinear_h_one(%arg0: tensor<1x13x1x1xf32>) -> (tensor<1x179x23x1xf32>) { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[EXTRACT0:.+]] = tensor.extract %arg0[%{{.*}}, %{{.*}}, %[[C0]], %{{.*}}] : tensor<1x13x1x1xf32> + // CHECK: %[[EXTRACT1:.+]] = tensor.extract %arg0[%{{.*}}, %{{.*}}, %[[C0]], %{{.*}}] : tensor<1x13x1x1xf32> + // CHECK: %[[EXTRACT2:.+]] = tensor.extract %arg0[%{{.*}}, %{{.*}}, %[[C0]], %{{.*}}] : tensor<1x13x1x1xf32> + // CHECK: %[[EXTRACT3:.+]] = tensor.extract %arg0[%{{.*}}, %{{.*}}, %[[C0]], %{{.*}}] : tensor<1x13x1x1xf32> + // CHECK: %[[MUL1:.+]] = arith.mulf %[[EXTRACT0]], %{{.*}} + // CHECK: %[[MUL2:.+]] = arith.mulf %[[EXTRACT2]], %{{.*}} + // CHECK: %[[ADD:.+]] = arith.addf %[[MUL1]], %[[MUL2]] : f32 + // CHECK: linalg.yield %[[ADD]] : f32 + %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x13x1x1xf32>) -> tensor<1x179x23x1xf32> + return %0 : tensor<1x179x23x1xf32> +}