diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -83,6 +83,36 @@ let hasFolder = 1; } +// A QuantizeRegion (region) represents a quantization unit which wraps +// high-precision ops with quantization specifications for all the inputs +// and outputs. Some quantization specifications can be undetermined and +// derived from other ports by the target specification of the kernel. +def quant_QuantizeRegionOp : quant_Op<"region", [ + NoSideEffect, + IsolatedFromAbove, + SingleBlockImplicitTerminator<"ReturnOp">]> { + let summary = [{ + The `region operation wraps high-precision ops as a logical low-precision + quantized kernel. + }]; + + let arguments = (ins Variadic:$inputs, + TypeArrayAttr:$input_specs, + TypeArrayAttr:$output_specs, + StrAttr:$logical_kernel); + let results = (outs Variadic:$outputs); + let regions = (region SizedRegion<1>:$body); + let verifier = [{ return verifyRegionOp(*this); }]; +} + +def quant_ReturnOp : quant_Op<"return", [Terminator]> { + let summary = [{ + The `return` operation terminates a quantize region and returns values. + }]; + + let arguments = (ins Variadic:$results); +} + //===----------------------------------------------------------------------===// // Training integration and instrumentation ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -34,13 +34,63 @@ } OpFoldResult StorageCastOp::fold(ArrayRef operands) { - /// Matches x -> [scast -> scast] -> y, replacing the second scast with the - /// value of x if the casts invert each other. + // Matches x -> [scast -> scast] -> y, replacing the second scast with the + // value of x if the casts invert each other. auto srcScastOp = dyn_cast_or_null(arg().getDefiningOp()); if (!srcScastOp || srcScastOp.arg().getType() != getType()) return OpFoldResult(); return srcScastOp.arg(); } +/// The quantization specification should match the expressed type. +static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { + if (auto typeAttr = quantSpec.dyn_cast()) { + Type spec = typeAttr.getValue(); + if (spec.isa() || spec.isa()) + return false; + + // The spec should be either a quantized type which is compatible to the + // expressed type, or a primitive type which is as same as the + // (element type of) the expressed type. + if (auto quantizedType = spec.dyn_cast()) + return quantizedType.isCompatibleExpressedType(expressed); + + if (auto tensorType = expressed.dyn_cast()) + return spec == tensorType.getElementType(); + + if (auto vectorType = expressed.dyn_cast()) + return spec == vectorType.getElementType(); + } + return false; +} + +static LogicalResult verifyRegionOp(QuantizeRegionOp op) { + // There are specifications for both inputs and outputs. + if (op.getNumOperands() != op.input_specs().size() || + op.getNumResults() != op.output_specs().size()) + return op.emitOpError( + "has unmatched operands/results number and spec attributes number"); + + // Verify that quantization specifications are valid. + for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) { + Type inputType = std::get<0>(input); + Attribute inputSpec = std::get<1>(input); + if (!isValidQuantizationSpec(inputSpec, inputType)) { + return op.emitOpError() << "has incompatible specification " << inputSpec + << " and input type " << inputType; + } + } + + for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) { + Type outputType = std::get<0>(result); + Attribute outputSpec = std::get<1>(result); + if (!isValidQuantizationSpec(outputSpec, outputType)) { + return op.emitOpError() << "has incompatible specification " << outputSpec + << " and output type " << outputType; + } + } + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp --- a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp +++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp @@ -60,7 +60,7 @@ // Op handlers. addOpHandler( std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2)); - addOpHandler( + addOpHandler( std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2)); addOpHandler( std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2)); diff --git a/mlir/test/Dialect/QuantOps/quant_region.mlir b/mlir/test/Dialect/QuantOps/quant_region.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/QuantOps/quant_region.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @source +func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [f32, f32, f32], output_specs = [f32], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: @annotated +func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, f32], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: @quantized +func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, !quant.uniform], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// ----- + +func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform and input type 'tensor<4xf32>'}} + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, !quant.uniform], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// ----- + +func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // @expected-error @+1 {{'quant.region' op has incompatible specification i32 and input type 'tensor<4xf32>'}} + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, i32], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// ----- + +func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // @expected-error @+1 {{'quant.region' op has unmatched operands/results number and spec attributes number}} + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// ----- + +func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // @expected-note @+1 {{required by region isolation constraints}} + %0 = "quant.region"(%arg0, %arg1) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // @expected-error @+1 {{'bar' op using value defined outside the region}} + %14 = "bar"(%13, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} +