diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h @@ -38,6 +38,9 @@ Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min, arith::ConstantOp max, OpBuilder &rewriter); +// Determines whether the integer value falls witin the range of integer type. +bool validIntegerRange(IntegerType ty, int64_t value); + // Returns the values in an attribute as an array of values. template void getValuesFromIntArrayAttribute(ArrayAttr attr, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" using namespace mlir; using namespace mlir::tosa; @@ -56,6 +57,49 @@ if (weightShape[1] != 1 || weightShape[2] != 1) return failure(); + auto padAttr = op.getPad(); + llvm::SmallVector pad(8, 0); + for (auto it : llvm::enumerate(padAttr.getValue())) + pad[it.index() + 2] = + it.value().cast().getValue().getSExtValue(); + + if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { + Type inputETy = inputType.getElementType(); + Attribute zeroAttr = rewriter.getZeroAttr(inputETy); + if (op.getQuantizationInfo()) { + auto quantizationInfo = op.getQuantizationInfo(); + int64_t iZp = quantizationInfo->getInputZp(); + + if (!validIntegerRange(inputETy.cast(), iZp)) + return rewriter.notifyMatchFailure( + op, "tosa.conv op quantization has zp outside of input range"); + + zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); + } + + llvm::SmallVector newShape(inputType.getShape()); + + for (int i = 0, s = newShape.size(); i < s; ++i) { + if (newShape[i] != ShapedType::kDynamic) { + newShape[i] += pad[i * 2] + pad[i * 2 + 1]; + } + } + + auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type()); + auto padSize = + DenseIntElementsAttr::get(padSizeTy, ArrayRef(pad)); + Value padSizeVal = + rewriter.create(op->getLoc(), padSizeTy, padSize); + + auto padTy = RankedTensorType::get({}, inputETy); + auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); + Value padVal = + rewriter.create(op->getLoc(), padTy, padAttr); + inputType = RankedTensorType::get(newShape, inputETy); + input = rewriter.create(op->getLoc(), inputType, input, + padSizeVal, padVal); + } + // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. ArrayRef inputShape = inputType.getShape(); int64_t combined = ShapedType::kDynamic; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -31,18 +31,12 @@ ShapedType inputType = input.getType().cast(); ShapedType weightType = weight.getType().cast(); ShapedType resultType = op.getOutput().getType().cast(); - Type inputEType = inputType.getElementType(); if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && resultType.hasStaticShape())) { return failure(); } - // Quantization information needs to still be performed. - if (op.getQuantizationInfo() || !inputEType.isa()) { - return failure(); - } - // Stride must be 1 for this optimization. for (Attribute stride : op.getStride().getValue()) { if (!stride.cast().getValue().isOne()) { @@ -60,39 +54,88 @@ ArrayRef inputShape = inputType.getShape(); llvm::SmallVector revisedInputShape{ inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; - auto revisedInputShapeType = RankedTensorType::get( + inputType = RankedTensorType::get( revisedInputShape, input.getType().dyn_cast().getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getI64ArrayAttr(revisedInputShape)) - .getResult(); - - // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. - llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getI64ArrayAttr(revisedWeightShape)) - .getResult(); + input = rewriter + .create( + op.getLoc(), inputType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + if (inputType.getElementType() != resultType.getElementType()) { + inputType = inputType.clone(resultType.getElementType()); + input = rewriter.create(op.getLoc(), inputType, input); + } + + if (weightType.getElementType() != resultType.getElementType()) { + weightType = weightType.clone(resultType.getElementType()); + weight = rewriter.create(op.getLoc(), weightType, weight); + } + + if (auto quantizationInfo = op.getQuantizationInfo()) { + auto iZp = quantizationInfo->getInputZp(); + auto wZp = quantizationInfo->getWeightZp(); + + auto applyZp = [&](Value val, int64_t zp) -> Value { + if (zp == 0) + return val; + auto ety = val.getType().cast().getElementType(); + auto zpTy = RankedTensorType::get({}, ety); + auto zpAttr = + DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp)); + auto zpVal = rewriter.create(op.getLoc(), zpTy, zpAttr); + return rewriter.create(op.getLoc(), val.getType(), val, + zpVal); + }; + + input = applyZp(input, iZp); + weight = applyZp(weight, wZp); + } + + auto padAttr = op.getPad(); + llvm::SmallVector pad(10, 0); + for (auto it : llvm::enumerate(padAttr.getValue())) + pad[it.index() + 2] = + it.value().cast().getValue().getSExtValue(); + + if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { + Type inputETy = inputType.getElementType(); + Attribute zeroAttr = rewriter.getZeroAttr(inputETy); + + llvm::SmallVector newShape(inputType.getShape()); + for (int i = 0, s = pad.size(); i < s; ++i) { + if (newShape[i / 2] != ShapedType::kDynamic) { + newShape[i / 2] += pad[i]; + } + } + + auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type()); + auto padSize = + DenseIntElementsAttr::get(padSizeTy, ArrayRef(pad)); + Value padSizeVal = + rewriter.create(op->getLoc(), padSizeTy, padSize); + + auto padTy = RankedTensorType::get({}, inputETy); + auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); + Value padVal = + rewriter.create(op->getLoc(), padTy, padAttr); + inputType = RankedTensorType::get(newShape, inputETy); + input = rewriter.create(op->getLoc(), inputType, input, + padSizeVal, padVal); + } // Perform an elementwise mul over the reshaped input and weight. - llvm::SmallVector mulShape{inputShape[0], inputShape[1], - inputShape[2], inputShape[3], - weightShape[3]}; + llvm::SmallVector mulShape{ + inputType.getDimSize(0), inputType.getDimSize(1), + inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]}; auto mulShapeType = RankedTensorType::get( mulShape, weight.getType().dyn_cast().getElementType()); - Value mulValue = - rewriter - .create(op.getLoc(), mulShapeType, reshapedInput, - reshapedWeight, /*shift=*/0) - .getResult(); + Value mulValue = rewriter + .create(op.getLoc(), mulShapeType, input, + weight, /*shift=*/0) + .getResult(); // Reshape output to [N, H, W, C * M]. auto outputShape = op.getOutput().getType().cast().getShape(); diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -47,3 +47,17 @@ rewriter.create(loc, arith::CmpIPredicate::slt, max, arg); return rewriter.create(loc, largerThanMax, max, minOrArg); } + +bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { + uint64_t bitwidth = ty.getIntOrFloatBitWidth(); + if (ty.getSignedness() == IntegerType::Unsigned) { + uint64_t uvalue = value; + APInt intMin = APInt::getMinValue(bitwidth); + APInt intMax = APInt::getMaxValue(bitwidth); + return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue(); + } + + APInt intMin = APInt::getSignedMinValue(bitwidth); + APInt intMax = APInt::getSignedMaxValue(bitwidth); + return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); +} diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -54,3 +54,17 @@ return %0 : tensor } +// ----- + +// CHECK-LABEL: @conv2d_as_fully_connected_padded +func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> { + // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>} + // CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() {value = dense<42> : tensor} + // CHECK-DAG: %[[PAD:.+]] = "tosa.pad"(%arg0, %[[PAD_SHAPE]], %[[PAD_VAL]]) : (tensor<4x10x10x2xi8>, tensor<4x2xi64>, tensor) -> tensor<4x12x12x2xi8> + // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = "tosa.reshape"(%[[PAD]]) {new_shape = [576, 2]} + // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} + // CHECK-DAG: %[[FULLY:.+]] = "tosa.fully_connected"(%[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2) {quantization_info = #tosa.conv_quant} + // CHECK: %[[RESHAPE:.+]] = "tosa.reshape"(%[[FULLY]]) {new_shape = [4, 12, 12, 3]} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1], quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32> + return %0 : tensor<4x12x12x3xi32> +} \ No newline at end of file diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -7,9 +7,7 @@ // CHECK-NOT: "tosa.depthwise_conv2d" // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} // CHECK-SAME: -> tensor<4x10x10x2x1xf32> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]} - // CHECK-SAME: -> tensor<1x1x1x2x3xf32> - // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) + // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %arg1) // CHECK-SAME: -> tensor<4x10x10x2x3xf32> // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]} // CHECK-SAME: -> tensor<4x10x10x6xf32> @@ -24,9 +22,31 @@ // CHECK-LABEL: @depthwise_conv2d_as_mul_q func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { - // CHECK: "tosa.depthwise_conv2d" - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + // CHECK: %[[iZp:.+]] = "tosa.const"() {value = dense<7> : tensor} + // CHECK: %[[wZp:.+]] = "tosa.const"() {value = dense<11> : tensor} + // CHECK: %[[rIn:.+]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} + // CHECK: %[[cIn:.+]] = "tosa.cast"(%[[rIn]]) : (tensor<4x10x10x2x1xi8>) -> tensor<4x10x10x2x1xi32> + // CHECK: %[[cWe:.+]] = "tosa.cast"(%arg1) : (tensor<1x1x2x3xi8>) -> tensor<1x1x2x3xi32> + // CHECK: %[[sIn:.+]] = "tosa.sub"(%[[cIn]], %[[iZp]]) + // CHECK: %[[sWe:.+]] = "tosa.sub"(%[[cWe]], %[[wZp]]) + // CHECK: %[[mul:.+]] = "tosa.mul"(%[[sIn]], %[[sWe]]) {shift = 0 : i32} + // CHECK: %[[reO:.+]] = "tosa.reshape"(%[[mul]]) {new_shape = [4, 10, 10, 6]} + // CHECK: %[[add:.+]] = "tosa.add"(%[[reO]], %arg2) + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } // ----- + +// CHECK-LABEL: @depthwise_conv2d_as_mul_padded +func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> { + // CHECK: %[[pad:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0], [0, 0]]> : tensor<5x2xi64>} + // CHECK: %[[zero:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} + // CHECK: %[[reIn:.+]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} + // CHECK: %[[padded:.+]] = "tosa.pad"(%[[reIn]], %[[pad]], %[[zero]]) : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor) -> tensor<4x12x12x2x1xf32> + // CHECK: %[[mul:.+]] = "tosa.mul"(%3, %arg1) {shift = 0 : i32} + // CHECK: %[[reOut:.+]] = "tosa.reshape"(%[[mul]]) {new_shape = [4, 12, 12, 6]} + // CHECK: %[[add:.+]] = "tosa.add"(%[[reOut]], %arg2) + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32> + return %0 : tensor<4x12x12x6xf32> +} \ No newline at end of file