diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -3,7 +3,6 @@ add_mlir_interface(TosaInterfaces) set(LLVM_TARGET_DEFINITIONS TosaOps.td) -mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls) -mlir_tablegen(TosaStructs.cpp.inc -gen-struct-attr-defs) -add_public_tablegen_target(MLIRTosaStructsIncGen) - +mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRTosaAttributesIncGen) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -10,13 +10,16 @@ // //===----------------------------------------------------------------------===// - #ifndef TOSA_OP_BASE #define TOSA_OP_BASE +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + //===----------------------------------------------------------------------===// // The TOSA Dialect. //===----------------------------------------------------------------------===// + def Tosa_Dialect : Dialect { let name = "tosa"; @@ -41,6 +44,16 @@ let cppNamespace = "mlir::tosa"; let hasConstantMaterializer = 1; + let useDefaultAttributePrinterParser = 1; +} + +//===----------------------------------------------------------------------===// +// TOSA Attributes. +//===----------------------------------------------------------------------===// + +class Tosa_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; } //===----------------------------------------------------------------------===// @@ -51,7 +64,7 @@ // feed numerical precision parameters to the functional implementation of TOSA // operators. // The functional behavior is defined in the TOSA specification maintained at -// https://developer.mlplatform.org/w/tosa/ . TOSA leverages MLIR's built in +// https://developer.mlplatform.org/w/tosa/. TOSA leverages MLIR's built in // quantization support: https://mlir.llvm.org/docs/Quantization/, and supports // uniform quantization. Depending on datatype, asymmetric and symmetric // quantization are supported. The types themselves are described in @@ -60,12 +73,11 @@ // This quantization attribute expresses numerical behavior of operators where // the operator has a numerical relationship between a single input and output. // For example: tosa.negate. -def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr", - Tosa_Dialect, [ - StructFieldAttr<"input_zp", I32Attr>, - StructFieldAttr<"output_zp", I32Attr> - ]> { +def Tosa_UnaryOpQuantizationAttr + : Tosa_Attr<"UnaryOpQuantization", "unary_quant"> { let summary = "Attribute for UnaryOp quantization information."; + let parameters = (ins "int64_t":$input_zp, "int64_t":$output_zp); + let assemblyFormat = "`<` struct(params) `>`"; } // There is no explicit BinaryOpQuantizationAttr for 2-input/1-output ops. In @@ -79,31 +91,28 @@ // the inputs. // The scaling of their accumulator output is done using an explicit // tosa.rescale operator that scales the accumulator result to output scale. -def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr", - Tosa_Dialect, [ - StructFieldAttr<"input_zp", I32Attr>, - StructFieldAttr<"weight_zp", I32Attr> - ]> { +def Tosa_ConvOpQuantizationAttr + : Tosa_Attr<"ConvOpQuantization", "conv_quant"> { let summary = "Attribute for Conv type op quantization information."; + let parameters = (ins "int64_t":$input_zp, "int64_t":$weight_zp); + let assemblyFormat = "`<` struct(params) `>`"; } -def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr", - Tosa_Dialect, [ - StructFieldAttr<"a_zp", I32Attr>, - StructFieldAttr<"b_zp", I32Attr> - ]> { +def Tosa_MatMulOpQuantizationAttr + : Tosa_Attr< "MatMulOpQuantization", "matmul_quant"> { let summary = "Attribute for MatMulOp quantization information."; + let parameters = (ins "int64_t":$a_zp, "int64_t":$b_zp); + let assemblyFormat = "`<` struct(params) `>`"; } // This attribute holds input zero point correction applied to the padding // zeros to ensure numerical accuracy in the subsequent TOSA operations. // Its functional application is described in the tosa.pad() operator // description in the specification. -def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr", - Tosa_Dialect, [ - StructFieldAttr<"input_zp", I32Attr> - ]> { +def Tosa_PadOpQuantizationAttr : Tosa_Attr<"PadOpQuantization", "pad_quant"> { let summary = "Attribute for PadOp quantization information."; + let parameters = (ins "int64_t":$input_zp); + let assemblyFormat = "`<` struct(params) `>`"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -21,8 +21,8 @@ //===----------------------------------------------------------------------===// // TOSA dialect and structs includes. //===----------------------------------------------------------------------===// + #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc" -#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc" namespace mlir { class PatternRewriter; @@ -45,6 +45,9 @@ } // namespace tosa } // namespace mlir +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -147,10 +147,8 @@ cast(op).quantization_info()) { auto quantizationInfo = cast(op).quantization_info(); int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); - int64_t inZp = - quantizationInfo.getValue().input_zp().getValue().getSExtValue(); - int64_t outZp = - quantizationInfo.getValue().output_zp().getValue().getSExtValue(); + int64_t inZp = quantizationInfo.getValue().getInput_zp(); + int64_t outZp = quantizationInfo.getValue().getOutput_zp(); // Compute the maximum value that can occur in the intermediate buffer. int64_t zpAdd = inZp + outZp; @@ -1844,13 +1842,13 @@ loc, padOp.pad_const(), ValueRange({})); } else { Attribute constantAttr; - if (elementTy.isa()) + if (elementTy.isa()) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - else if (elementTy.isa() && !padOp.quantization_info()) + } else if (elementTy.isa() && !padOp.quantization_info()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); - else if (elementTy.isa() && padOp.quantization_info()) { - auto value = padOp.quantization_info().getValue().input_zp().getValue(); - constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); + } else if (elementTy.isa() && padOp.quantization_info()) { + int64_t value = padOp.quantization_info().getValue().getInput_zp(); + constantAttr = rewriter.getIntegerAttr(elementTy, value); } if (constantAttr) padConstant = rewriter.create(loc, constantAttr); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -202,7 +202,7 @@ if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - auto iZp = quantizationInfo.input_zp().getValue().getSExtValue(); + int64_t iZp = quantizationInfo.getInput_zp(); int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) @@ -274,10 +274,8 @@ if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - auto iZp = rewriter.getI32IntegerAttr( - quantizationInfo.input_zp().getValue().getSExtValue()); - auto kZp = rewriter.getI32IntegerAttr( - quantizationInfo.weight_zp().getValue().getSExtValue()); + auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()); + auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()); auto iZpVal = rewriter.create(loc, iZp); auto kZpVal = rewriter.create(loc, kZp); @@ -368,10 +366,8 @@ if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - iZp = rewriter.getI32IntegerAttr( - quantizationInfo.input_zp().getValue().getSExtValue()); - kZp = rewriter.getI32IntegerAttr( - quantizationInfo.weight_zp().getValue().getSExtValue()); + iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()); + kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()); } auto weightShape = weightTy.getShape(); @@ -382,7 +378,7 @@ if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - auto iZp = quantizationInfo.input_zp().getValue().getSExtValue(); + int64_t iZp = quantizationInfo.getInput_zp(); int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) @@ -546,11 +542,9 @@ auto quantizationInfo = op.quantization_info().getValue(); auto aZp = rewriter.create( - loc, rewriter.getI32IntegerAttr( - quantizationInfo.a_zp().getValue().getSExtValue())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getA_zp())); auto bZp = rewriter.create( - loc, rewriter.getI32IntegerAttr( - quantizationInfo.b_zp().getValue().getSExtValue())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getB_zp())); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor); @@ -658,11 +652,9 @@ auto quantizationInfo = op.quantization_info().getValue(); auto inputZp = rewriter.create( - loc, rewriter.getI32IntegerAttr( - quantizationInfo.input_zp().getValue().getSExtValue())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp())); auto outputZp = rewriter.create( - loc, rewriter.getI32IntegerAttr( - quantizationInfo.weight_zp().getValue().getSExtValue())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp())); Value matmul = rewriter .create( @@ -900,7 +892,8 @@ if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); auto inputZp = rewriter.create( - loc, quantizationInfo.input_zp()); + loc, + b.getIntegerAttr(accETy, quantizationInfo.getInput_zp())); Value offset = rewriter.create(loc, accETy, countI, inputZp); poolVal = @@ -936,7 +929,8 @@ if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); auto outputZp = rewriter.create( - loc, quantizationInfo.output_zp()); + loc, b.getIntegerAttr(scaled.getType(), + quantizationInfo.getOutput_zp())); scaled = rewriter.create(loc, scaled, outputZp) .getResult(); } diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -7,8 +7,8 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa DEPENDS + MLIRTosaAttributesIncGen MLIRTosaOpsIncGen - MLIRTosaStructsIncGen MLIRTosaInterfacesIncGen LINK_LIBS PUBLIC diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -18,12 +18,14 @@ #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::tosa; @@ -33,8 +35,8 @@ //===----------------------------------------------------------------------===// // Tosa dialect structs and interface includes. //===----------------------------------------------------------------------===// + #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" -#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc" namespace { //===----------------------------------------------------------------------===// @@ -78,6 +80,10 @@ #define GET_OP_LIST #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" + >(); addInterfaces(); } @@ -336,13 +342,13 @@ Type elementTy = inputTy.getElementType(); Attribute constantAttr; - if (elementTy.isa()) + if (elementTy.isa()) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - else if (elementTy.isa() && !op.quantization_info()) + } else if (elementTy.isa() && !op.quantization_info()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); - else if (elementTy.isa() && op.quantization_info()) { - auto value = op.quantization_info().getValue().input_zp().getValue(); - constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); + } else if (elementTy.isa() && op.quantization_info()) { + auto value = op.quantization_info().getValue().getInput_zp(); + constantAttr = rewriter.getIntegerAttr(elementTy, value); } if (!constantAttr) { @@ -1925,6 +1931,13 @@ return success(); } +//===----------------------------------------------------------------------===// +// TOSA Attribute Definitions. +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -214,8 +214,7 @@ weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, weightPaddingVal, nullptr, - PadOpQuantizationAttr::get(quantInfo.weight_zp(), - rewriter.getContext())); + rewriter.getAttr(quantInfo.getWeight_zp())); } else { weight = createOpAndInfer(rewriter, loc, @@ -279,8 +278,7 @@ input = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(inputETy), input, inputPaddingVal, nullptr, - PadOpQuantizationAttr::get(quantInfo.input_zp(), - rewriter.getContext())); + rewriter.getAttr(quantInfo.getInput_zp())); } else { input = createOpAndInfer(rewriter, loc, UnrankedTensorType::get(inputETy), diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -137,7 +137,6 @@ "Inputs and weights must be all quantized or all not quantized"); if (inputQType) { - int64_t inputZp = inputQType.getZeroPoint(); int64_t weightZp = 0; @@ -147,11 +146,7 @@ weightZp = weightPerAxisQType.getZeroPoints().front(); } - auto quantAttr = tosa::ConvOpQuantizationAttr::get( - builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), - builder.getContext()); - - return quantAttr; + return builder.getAttr(inputZp, weightZp); } return nullptr; @@ -179,15 +174,8 @@ "Matmul operands must be all quantized or all not quantized"); if (aQType) { - - int64_t aZp = aQType.getZeroPoint(); - int64_t bZp = bQType.getZeroPoint(); - - auto quantAttr = tosa::MatMulOpQuantizationAttr::get( - builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), - builder.getContext()); - - return quantAttr; + return builder.getAttr( + aQType.getZeroPoint(), bQType.getZeroPoint()); } return nullptr; @@ -215,15 +203,8 @@ "Unary inputs/outputs must be all quantized or all not quantized"); if (inputQType) { - - int64_t inputZp = inputQType.getZeroPoint(); - int64_t outputZp = outputQType.getZeroPoint(); - - auto quantAttr = tosa::UnaryOpQuantizationAttr::get( - builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), - builder.getContext()); - - return quantAttr; + return builder.getAttr(inputQType.getZeroPoint(), + outputQType.getZeroPoint()); } return nullptr; @@ -242,13 +223,8 @@ auto inputQType = GET_UQTYPE(inputType); if (inputQType) { - - int64_t inputZp = inputQType.getZeroPoint(); - - auto quantAttr = tosa::PadOpQuantizationAttr::get( - builder.getI32IntegerAttr(inputZp), builder.getContext()); - - return quantAttr; + return builder.getAttr( + inputQType.getZeroPoint()); } return nullptr; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -21,7 +21,7 @@ // CHECK: [[ONE:%.+]] = arith.constant 1 // CHECK: [[TWO:%.+]] = arith.constant 2 // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32> - %0 = "tosa.matmul"(%arg0, %arg1) {quantization_info = {a_zp = 1 : i32, b_zp = 2 : i32}} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) + %0 = "tosa.matmul"(%arg0, %arg1) {quantization_info = #tosa.matmul_quant} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) return %0 : tensor<1x5x6xi32> } @@ -108,7 +108,7 @@ // CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32): // CHECK: [[ADD:%.+]] = arith.addi // CHECK: linalg.yield [[ADD]] : i32 - %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) {quantization_info = {input_zp = 1:i32, weight_zp = 2:i32}} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> (tensor<5x6xi32>) + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) {quantization_info = #tosa.conv_quant} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> (tensor<5x6xi32>) return %0 : tensor<5x6xi32> } @@ -304,7 +304,7 @@ // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]] // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = #tosa.unary_quant, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> return } @@ -333,7 +333,7 @@ // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]] // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi16>) -> tensor<1x32x32x2xi16> + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = #tosa.unary_quant, stride = [4, 4]} : (tensor<1x128x128x2xi16>) -> tensor<1x32x32x2xi16> return } @@ -461,7 +461,7 @@ // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C22]] // CHECK: linalg.conv_2d_nhwc_hwcf_q - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [1, 1, 1, 1], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [1, 1, 1, 1], quantization_info = #tosa.conv_quant, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> return } @@ -557,7 +557,7 @@ // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x12x12x512xi32> - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], quantization_info = #tosa.conv_quant, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> return } @@ -581,7 +581,7 @@ // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x10x10x512xi32> - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [2, 2] } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant, stride = [1, 1], dilation = [2, 2] } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> return } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -511,15 +511,15 @@ // CHECK: [[UBOUND:%.+]] = arith.select [[PRED2]], [[MAX]], [[LBOUND]] // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]] // CHECK: linalg.yield [[TRUNC]] - %0 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 0 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8> + %0 = "tosa.negate"(%arg0) {quantization_info = #tosa.unary_quant} : (tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic // CHECK: [[EXT:%.+]] = arith.extsi %arg1 : i8 to i16 - %1 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 32639 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8> + %1 = "tosa.negate"(%arg0) {quantization_info = #tosa.unary_quant} : (tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic // CHECK: [[EXT:%.+]] = arith.extsi %arg1 : i8 to i32 - %2 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 32640 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8> + %2 = "tosa.negate"(%arg0) {quantization_info = #tosa.unary_quant} : (tensor<1xi8>) -> tensor<1xi8> return } @@ -1257,7 +1257,7 @@ // CHECK: [[CST:%.+]] = arith.constant 42 : i32 // CHECK: tensor.pad // CHECK: tensor.yield [[CST]] - %1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) return %1 : tensor<4x9xi32> } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -172,7 +172,7 @@ // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor} // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> - %1 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = 42:i32} } : (tensor, tensor<2x2xi32>) -> tensor + %1 = "tosa.pad"(%arg0, %arg1) {quantization_info = #tosa.pad_quant} : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor } diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -12,7 +12,7 @@ // CHECK-LABEL: test_build_mult_and_shift func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> { // CHECK: tosa.conv2d - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 2, 2], dilation = [2, 1], stride = [1, 1], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32}} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 2, 2], dilation = [2, 1], stride = [1, 1], quantization_info = #tosa.conv_quant} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> return %0 : tensor<1x32x32x16x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -28,12 +28,12 @@ // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} // CHECK-SAME: -> tensor<3x2xi8> // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) - // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32} + // CHECK-SAME: quantization_info = #tosa.conv_quant // CHECK-SAME: -> tensor<400x3xi32> // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} // CHECK-SAME: -> tensor<4x10x10x3xi32> // CHECK: return %[[VAR3]] - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> return %0 : tensor<4x10x10x3xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -25,7 +25,7 @@ // CHECK-LABEL: @depthwise_conv2d_as_mul_q func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { // CHECK: "tosa.depthwise_conv2d" - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -16,8 +16,8 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) { // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64} // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64} - // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> + // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], quantization_info = #tosa.conv_quant, stride = [1, 1]} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = #tosa.conv_quant, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> return %0 : tensor<2x18x19x5xi32> } @@ -72,7 +72,7 @@ // Manipulate the weight matrix to handle striding. // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>} // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} - // CHECK-DAG: %[[PADW:.+]] = "tosa.pad"(%arg1, %[[PADV]]) {quantization_info = {input_zp = 42 : i32}} + // CHECK-DAG: %[[PADW:.+]] = "tosa.pad"(%arg1, %[[PADV]]) {quantization_info = #tosa.pad_quant} // CHECK-DAG: %[[RESW1:.+]] = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]} // CHECK-DAG: %[[TRANS:.+]] = "tosa.transpose"(%[[RESW1]], %[[TRANSV]]) // CHECK-DAG: %[[RESW2:.+]] = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]} @@ -82,16 +82,16 @@ // Pad out the input matrix to handle the transpose conv. // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>} // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} - // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]]) {quantization_info = {input_zp = -22 : i32}} + // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]]) {quantization_info = #tosa.pad_quant} // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() {value = dense<0> : tensor<30xi32>} - // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]} + // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant, stride = [1, 1]} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]} // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]]) // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]} // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]} // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2) - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = #tosa.conv_quant, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> return %0 : tensor<2x35x47x5xi32> } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7770,6 +7770,14 @@ ["-gen-dialect-defs"], "include/mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc", ), + ( + ["-gen-attrdef-decls"], + "include/mlir/Dialect/Tosa/IR/TosaAttributes.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "include/mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc", + ), ( ["-gen-op-doc"], "g3doc/Dialects/Tosa/TosaOps.md",