diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -159,14 +159,65 @@ } // tosa::NegateOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + + if (isa(op) && elementTy.isa() && + !cast(op).quantization_info()) { auto constant = - rewriter.create(loc, IntegerAttr::get(elementTy, -1)); - return rewriter.create(loc, resultTypes, args[0], constant); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + return rewriter.create(loc, resultTypes, constant, args[0]); } - if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + if (isa(op) && elementTy.isa() && + cast(op).quantization_info()) { + auto quantizationInfo = cast(op).quantization_info(); + int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); + int64_t inZp = + quantizationInfo.getValue().input_zp().getValue().getSExtValue(); + int64_t outZp = + quantizationInfo.getValue().output_zp().getValue().getSExtValue(); + + // Compute the maximum value that can occur in the intermediate buffer. + int64_t zpAdd = inZp + outZp; + int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + + std::abs(zpAdd) + 1; + + // Convert that maximum value into the maximum bitwidth needed to represent + // it. We assume 48-bit numbers may be supported further in the pipeline. + int intermediateBitWidth = 64; + if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { + intermediateBitWidth = 16; + } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { + intermediateBitWidth = 32; + } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { + intermediateBitWidth = 48; + } + + Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); + Value zpAddValue = rewriter.create( + loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + + // The negation can be applied by doing: + // outputValue = inZp + outZp - inputValue + auto ext = rewriter.create(loc, intermediateType, args[0]); + auto sub = rewriter.create(loc, zpAddValue, ext); + + // Clamp to the negation range. + auto min = rewriter.create( + loc, rewriter.getIntegerAttr( + intermediateType, + APInt::getSignedMinValue(inputBitWidth).getSExtValue())); + auto max = rewriter.create( + loc, rewriter.getIntegerAttr( + intermediateType, + APInt::getSignedMaxValue(inputBitWidth).getSExtValue())); + auto clamp = clampHelper(loc, sub, min, max, + CmpIPredicate::slt, rewriter); + + // Truncate to the final value. + return rewriter.create(loc, elementTy, clamp); + } // tosa::BitwiseAndOp if (isa(op) && elementTy.isa()) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -258,7 +258,8 @@ %3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic - // CHECK: muli + // CHECK: [[ZERO:%.+]] = constant 0 + // CHECK: subi [[ZERO]], %arg1 %4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic @@ -363,6 +364,35 @@ // ----- +// CHECK-LABEL: @test_negate_quantized +func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { + // CHECK: linalg.generic + // CHECK: [[ZERO:%.+]] = constant 0 + // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i16 + // CHECK: [[SUB:%.+]] = subi [[ZERO]], [[EXT]] + // CHECK: [[MIN:%.+]] = constant -128 + // CHECK: [[MAX:%.+]] = constant 127 + // CHECK: [[PRED1:%.+]] = cmpi slt, [[SUB]], [[MIN]] + // CHECK: [[LBOUND:%.+]] = select [[PRED1]], [[MIN]], [[SUB]] + // CHECK: [[PRED2:%.+]] = cmpi slt, [[MAX]], [[SUB]] + // CHECK: [[UBOUND:%.+]] = select [[PRED2]], [[MAX]], [[LBOUND]] + // CHECK: [[TRUNC:%.+]] = trunci [[UBOUND]] + // CHECK: linalg.yield [[TRUNC]] + %0 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 0 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8> + + // CHECK: linalg.generic + // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i16 + %1 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 32639 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8> + + // CHECK: linalg.generic + // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i32 + %2 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 32640 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8> + + return +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @test_reshape_downrank func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {