diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -155,6 +155,10 @@ /// Create a LinalgStrategyRemoveMarkersPass. std::unique_ptr> createLinalgStrategyRemoveMarkersPass(); +/// Create a LinalgQuantizedMatmulToMatmulPass. +std::unique_ptr> +createLinalgQuantizedMatmulToMatmulPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -392,4 +392,11 @@ ]; } +def LinalgQuantizedMatmulToMatmulPass + : Pass<"linalg-quantized-matmul-to-matmul", "FuncOp"> { + let summary = "lower quantized_matmul to matmul"; + let constructor = "mlir::createLinalgQuantizedMatmulToMatmulPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1378,6 +1378,9 @@ TiledLoopOp loopOp, int64_t idx, TiledLoopOp &result); +/// Patterns to lower quantized_matmul to matmul. +void populateQuantizedMatmulToMatmulPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ NamedOpConversions.cpp PadOpInterchange.cpp Promotion.cpp + QuantizedMatmulToMatmul.cpp Tiling.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/QuantizedMatmulToMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/QuantizedMatmulToMatmul.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/QuantizedMatmulToMatmul.cpp @@ -0,0 +1,246 @@ +//===- QuantizedMatmulToMatmul.cpp - lower quantized_matmul to matmul -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file rewrites any linalg.quantized_matmul into a linalg.matmul plus +// other ops as needed to implement the effect of the zero-points. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +// Returns the add-reduction of the input 2D tensor `matrix` along one of the +// two dimensions. The `parallelDim` argument specifies which of the two +// dimensions (0 or 1) is the parallel (i.e. not reduction) dimension. +// The input `matrix`'s element type is assumed to be signless integer. +// The result's element type is `accElTy`. The input elements are sign-extended +// to `accElTy` before being added. +Value additiveReductionLeaving1ParallelDim(PatternRewriter &rewriter, + Location loc, Value matrix, + int parallelDim, Type accElTy) { + RankedTensorType matrixType = matrix.getType().cast(); + assert(matrixType.getRank() == 2); + assert(parallelDim == 0 || parallelDim == 1); + // Create the accumulator. + int64_t dstStaticSize = matrixType.getShape()[parallelDim]; + SmallVector dstDynSizes; + if (dstStaticSize == ShapedType::kDynamicSize) { + dstDynSizes.push_back( + rewriter.create(loc, matrix, parallelDim)); + } + Value initAcc = + rewriter + .create( + loc, dstDynSizes, ArrayRef{dstStaticSize}, accElTy) + .getResult(); + // Zero-fill the accumulator. + Value zeroInt = + rewriter.create(loc, 0, accElTy).getResult(); + Value zeroAcc = + rewriter.create(loc, zeroInt, initAcc).getResult(0); + // Create the indexing maps for the generic. + MLIRContext *context = rewriter.getContext(); + AffineExpr expr[2]; + bindDims(context, expr[0], expr[1]); + AffineExpr parallelExpr = expr[parallelDim]; + AffineMap mapIdentity = AffineMap::get(2, 0, expr, context); + AffineMap mapToParallelDim = AffineMap::get(2, 0, parallelExpr, context); + SmallVector indexingMaps{mapIdentity, mapToParallelDim}; + // Create the iterators for the generic. + auto iterator = [=](int dim) -> StringRef { + return dim == parallelDim ? "parallel" : "reduction"; + }; + SmallVector iterators{iterator(0), iterator(1)}; + // Create the generic. + return rewriter + .create( + loc, zeroAcc.getType(), ValueRange{matrix}, ValueRange{zeroAcc}, + indexingMaps, iterators, + [=](OpBuilder &b, Location loc, ValueRange args) { + Value matrixEl = args[0]; + // Sign-extend the input matrix elem to accElTy before adding. + Value promotedMatrixEl = + b.create(loc, accElTy, matrixEl); + Value accEl = args[1]; + Value sum = b.create(loc, promotedMatrixEl, accEl); + b.create(loc, sum); + }) + .getResult(0); +} + +bool isConstantZero(Value val) { + auto constIntOp = val.getDefiningOp(); + return constIntOp && constIntOp.value() == 0; +} + +// Pattern lowering quantized_matmul to matmul. +// Always succeeds. +// +// This is implementing the math explained in Section 2.3 of +// https://arxiv.org/abs/1712.05877. +struct QuantizedMatmulToMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::QuantizedMatmulOp quantizedMatmulOp, + PatternRewriter &rewriter) const override { + Location loc = quantizedMatmulOp.getLoc(); + ValueRange inputs = quantizedMatmulOp.inputs(); + assert(inputs.size() == 4); + Value lhs = inputs[0]; + Value rhs = inputs[1]; + Value lhsZp = inputs[2]; + Value rhsZp = inputs[3]; + ValueRange outputs = quantizedMatmulOp.outputs(); + // Compute the matmul part. + Value acc = outputs[0]; + Value matmul = rewriter + .create(loc, ValueRange{lhs, rhs}, + ValueRange{acc}) + .getResult(0); + bool lhsZpIsConstantZero = isConstantZero(lhsZp); + bool rhsZpIsConstantZero = isConstantZero(rhsZp); + if (lhsZpIsConstantZero && rhsZpIsConstantZero) { + // Easy case: both zero points are constant zeros, so the quantized_matmul + // was just a matmul all along. + rewriter.replaceOp(quantizedMatmulOp, matmul); + return success(); + } + // Create the result. No need to zero-fill it as we will overwrite it. + ShapedType accType = acc.getType().cast(); + auto accDynShape = linalg::getDynOperands(loc, acc, rewriter); + Value initResult = rewriter.create( + loc, accDynShape, accType.getShape(), accType.getElementType()); + // Create the indexing maps for the generic. + MLIRContext *context = rewriter.getContext(); + AffineExpr m, n; + bindDims(context, m, n); + AffineMap mapToNone = AffineMap::get(2, 0, context); + AffineMap mapToRowDim = AffineMap::get(2, 0, m, context); + AffineMap mapToColumnDim = AffineMap::get(2, 0, n, context); + AffineMap mapIdentity = + AffineMap::get(2, 0, ArrayRef{m, n}, context); + SmallVector indexingMaps; + SmallVector ins; + auto addInput = [&](Value val, AffineMap map) -> int { + ins.push_back(val); + indexingMaps.push_back(map); + return ins.size() - 1; + }; + int indexOfMatmulInput = addInput(matmul, mapIdentity); + int indexOfLhsSumsInput = 0; + int indexOfLhsZpInput = 0; + int indexOfRhsSumsInput = 0; + int indexOfRhsZpInput = 0; + int indexOfLhsZpTimesRhsZpTimesKSizeInput = 0; + Type accElTy = accType.getElementType(); + if (!rhsZpIsConstantZero) { + Value lhsSums = + additiveReductionLeaving1ParallelDim(rewriter, loc, lhs, 0, accElTy); + indexOfLhsSumsInput = addInput(lhsSums, mapToRowDim); + indexOfRhsZpInput = addInput(rhsZp, mapToNone); + } + if (!lhsZpIsConstantZero) { + Value rhsSums = + additiveReductionLeaving1ParallelDim(rewriter, loc, rhs, 1, accElTy); + indexOfRhsSumsInput = addInput(rhsSums, mapToColumnDim); + indexOfLhsZpInput = addInput(lhsZp, mapToNone); + } + if (!lhsZpIsConstantZero && !rhsZpIsConstantZero) { + Value lhsZpTimesRhsZp = rewriter.create(loc, lhsZp, rhsZp); + Value kSize = rewriter.create( + loc, accElTy, rewriter.create(loc, lhs, 1)); + Value lhsZpTimesRhsZpTimesKSize = + rewriter.create(loc, lhsZpTimesRhsZp, kSize); + indexOfLhsZpTimesRhsZpTimesKSizeInput = + addInput(lhsZpTimesRhsZpTimesKSize, mapToNone); + } + // Add the indexing map for the initResult 'output' even though it's unused. + indexingMaps.push_back(mapIdentity); + // Create the generic putting all the terms together. + SmallVector iterators{"parallel", "parallel"}; + rewriter.replaceOpWithNewOp( + quantizedMatmulOp, acc.getType(), ins, ValueRange{initResult}, + indexingMaps, iterators, + [=](OpBuilder &b, Location loc, ValueRange args) { + Value matmulEl = args[indexOfMatmulInput]; + Value lhsSumsEl = args[indexOfLhsSumsInput]; + Value rhsSumsEl = args[indexOfRhsSumsInput]; + Value lhsZp = args[indexOfLhsZpInput]; + Value rhsZp = args[indexOfRhsZpInput]; + Value lhsZpTimesRhsZpTimesKSize = + args[indexOfLhsZpTimesRhsZpTimesKSizeInput]; + Value result = matmulEl; + // If the rhs zero-point is not a constant zero, we need to add it + // times the sums along rows of lhs. + if (!rhsZpIsConstantZero) { + Value lhsSumsElTimesRhsZp = + b.create(loc, lhsSumsEl, rhsZp); + result = b.create(loc, result, lhsSumsElTimesRhsZp); + } + // If the lhs zero-point is not a constant zero, we need to add it + // times the sums along columns of rhs. + if (!lhsZpIsConstantZero) { + Value rhsSumsElTimesLhsZp = + b.create(loc, rhsSumsEl, lhsZp); + result = b.create(loc, result, rhsSumsElTimesLhsZp); + } + // Add the final correction term, if neither zero-point is cst zero. + if (!lhsZpIsConstantZero && !rhsZpIsConstantZero) { + result = + b.create(loc, result, lhsZpTimesRhsZpTimesKSize); + } + b.create(loc, result); + }); + return success(); + } +}; +} // namespace + +void mlir::linalg::populateQuantizedMatmulToMatmulPatterns( + RewritePatternSet &patterns) { + auto *context = patterns.getContext(); + patterns.add(context); +} + +namespace { +/// Pass that lowers quantized_matmul to matmul. +struct LinalgQuantizedMatmulToMatmulPass + : public LinalgQuantizedMatmulToMatmulPassBase< + LinalgQuantizedMatmulToMatmulPass> { + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + populateQuantizedMatmulToMatmulPatterns(patterns); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgQuantizedMatmulToMatmulPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/quantized-matmul-to-matmul.mlir b/mlir/test/Dialect/Linalg/quantized-matmul-to-matmul.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/quantized-matmul-to-matmul.mlir @@ -0,0 +1,132 @@ +// RUN: mlir-opt -linalg-quantized-matmul-to-matmul -split-input-file %s | FileCheck %s + +// Tests -linalg-quantized-matmul-to-matmul, converting linalg.quantized_matmul +// ops to linalg.matmul ops plus additional arithmetic to account for any +// nonzero zero-point. + +func @quantized_matmul_both_zp_0_dynamic(%lhs : tensor, %rhs : tensor, %acc : tensor) -> tensor { + %lhs_zp = arith.constant 0 : i32 + %rhs_zp = arith.constant 0 : i32 + %1 = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor, tensor, i32, i32) outs(%acc : tensor) -> tensor + return %1 : tensor +} +// CHECK-LABEL: func @quantized_matmul_both_zp_0_dynamic +// CHECK-SAME: %[[LHS:.+]]: tensor, %[[RHS:.+]]: tensor +// CHECK-SAME: %[[ACC:.+]]: tensor +// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor, tensor) outs(%[[ACC]] : tensor) +// CHECK: return %[[MATMUL]] +// ----- + +func @quantized_matmul_lhs_zp_0_dynamic(%lhs : tensor, %rhs : tensor, %rhs_zp : i32, %acc : tensor) -> tensor { + %lhs_zp = arith.constant 0 : i32 + %1 = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor, tensor, i32, i32) outs(%acc : tensor) -> tensor + return %1 : tensor +} +// CHECK-LABEL: func @quantized_matmul_lhs_zp_0_dynamic +// CHECK-SAME: %[[LHS:.+]]: tensor, %[[RHS:.+]]: tensor +// CHECK-SAME: %[[RHS_ZP:.+]]: i32 +// CHECK-SAME: %[[ACC:.+]]: tensor +// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 +// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor, tensor) outs(%[[ACC]] : tensor) +// CHECK: %[[INIT_RESULT:.+]] = linalg.init_tensor +// CHECK: %[[INIT_LHS_SUMS_ACC:.+]] = linalg.init_tensor +// CHECK: %[[ZERO_LHS_SUMS_ACC:.+]] = linalg.fill(%[[C0_I32]], %[[INIT_LHS_SUMS_ACC]]) +// CHECK: %[[LHS_SUMS:.+]] = linalg.generic +// CHECK-SAME: "parallel", "reduction" +// CHECK-SAME: ins(%[[LHS]] : tensor) +// CHECK-SAME: outs(%[[ZERO_LHS_SUMS_ACC]] : tensor) +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: "parallel", "parallel" +// CHECK-SAME: ins(%[[MATMUL]], %[[LHS_SUMS]], %[[RHS_ZP]] : tensor, tensor, i32) +// CHECK: return %[[RESULT]] +// ----- + +func @quantized_matmul_rhs_zp_0_dynamic(%lhs : tensor, %rhs : tensor, %lhs_zp : i32, %acc : tensor) -> tensor { + %rhs_zp = arith.constant 0 : i32 + %1 = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor, tensor, i32, i32) outs(%acc : tensor) -> tensor + return %1 : tensor +} +// CHECK-LABEL: func @quantized_matmul_rhs_zp_0_dynamic +// CHECK-SAME: %[[LHS:.+]]: tensor, %[[RHS:.+]]: tensor +// CHECK-SAME: %[[LHS_ZP:.+]]: i32 +// CHECK-SAME: %[[ACC:.+]]: tensor +// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 +// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor, tensor) outs(%[[ACC]] : tensor) +// CHECK: %[[INIT_RESULT:.+]] = linalg.init_tensor +// CHECK: %[[INIT_RHS_SUMS_ACC:.+]] = linalg.init_tensor +// CHECK: %[[ZERO_RHS_SUMS_ACC:.+]] = linalg.fill(%[[C0_I32]], %[[INIT_RHS_SUMS_ACC]]) +// CHECK: %[[RHS_SUMS:.+]] = linalg.generic +// CHECK-SAME: "reduction", "parallel" +// CHECK-SAME: ins(%[[RHS]] : tensor) +// CHECK-SAME: outs(%[[ZERO_RHS_SUMS_ACC]] : tensor) +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: "parallel", "parallel" +// CHECK-SAME: ins(%[[MATMUL]], %[[RHS_SUMS]], %[[LHS_ZP]] : tensor, tensor, i32) +// CHECK: return %[[RESULT]] +// ----- + +func @quantized_matmul_neither_zp_0_dynamic(%lhs : tensor, %rhs : tensor, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor) -> tensor { + %1 = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor, tensor, i32, i32) outs(%acc : tensor) -> tensor + return %1 : tensor +} +// CHECK-LABEL: func @quantized_matmul_neither_zp_0_dynamic +// CHECK-SAME: %[[LHS:.+]]: tensor, %[[RHS:.+]]: tensor +// CHECK-SAME: %[[LHS_ZP:.+]]: i32, %[[RHS_ZP:.+]]: i32 +// CHECK-SAME: %[[ACC:.+]]: tensor +// CHECK-DAG: %[[C1_INDEX:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32 +// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor, tensor) outs(%[[ACC]] : tensor) +// CHECK: %[[INIT_RESULT:.+]] = linalg.init_tensor +// CHECK: %[[INIT_LHS_SUMS_ACC:.+]] = linalg.init_tensor +// CHECK: %[[ZERO_LHS_SUMS_ACC:.+]] = linalg.fill(%[[C0_I32]], %[[INIT_LHS_SUMS_ACC]]) +// CHECK: %[[LHS_SUMS:.+]] = linalg.generic +// CHECK-SAME: "parallel", "reduction" +// CHECK-SAME: ins(%[[LHS]] : tensor) +// CHECK-SAME: outs(%[[ZERO_LHS_SUMS_ACC]] : tensor) +// CHECK: %[[INIT_RHS_SUMS_ACC:.+]] = linalg.init_tensor +// CHECK: %[[ZERO_RHS_SUMS_ACC:.+]] = linalg.fill(%[[C0_I32]], %[[INIT_RHS_SUMS_ACC]]) +// CHECK: %[[RHS_SUMS:.+]] = linalg.generic +// CHECK-SAME: "reduction", "parallel" +// CHECK-SAME: ins(%[[RHS]] : tensor) +// CHECK-SAME: outs(%[[ZERO_RHS_SUMS_ACC]] : tensor) +// CHECK: %[[LHS_ZP_TIMES_RHS_ZP:.+]] = arith.muli %[[LHS_ZP]], %[[RHS_ZP]] +// CHECK: %[[K_SIZE:.+]] = tensor.dim %[[LHS]], %[[C1_INDEX]] +// CHECK: %[[K_SIZE_I32:.+]] = arith.index_cast %[[K_SIZE]] : index to i32 +// CHECK: %[[PRODUCT_TERM:.+]] = arith.muli %[[LHS_ZP_TIMES_RHS_ZP]], %[[K_SIZE_I32]] +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: "parallel", "parallel" +// CHECK-SAME: ins(%[[MATMUL]], %[[LHS_SUMS]], %[[RHS_ZP]], %[[RHS_SUMS]], %[[LHS_ZP]], %[[PRODUCT_TERM]] : tensor, tensor, i32, tensor, i32, i32) +// CHECK: return %[[RESULT]] +// ----- + +func @quantized_matmul_neither_zp_0_3x4x5(%lhs : tensor<3x4xi8>, %rhs : tensor<4x5xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<3x5xi32>) -> tensor<3x5xi32> { + %1 = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) outs(%acc : tensor<3x5xi32>) -> tensor<3x5xi32> + return %1 : tensor<3x5xi32> +} +// CHECK-LABEL: func @quantized_matmul_neither_zp_0_3x4x5 +// CHECK-SAME: %[[LHS:.+]]: tensor<3x4xi8>, %[[RHS:.+]]: tensor<4x5xi8> +// CHECK-SAME: %[[LHS_ZP:.+]]: i32, %[[RHS_ZP:.+]]: i32 +// CHECK-SAME: %[[ACC:.+]]: tensor<3x5xi32> +// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C4_I32:.+]] = arith.constant 4 : i32 +// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<3x4xi8>, tensor<4x5xi8>) outs(%[[ACC]] : tensor<3x5xi32>) +// CHECK: %[[INIT_RESULT:.+]] = linalg.init_tensor +// CHECK: %[[INIT_LHS_SUMS_ACC:.+]] = linalg.init_tensor +// CHECK: %[[ZERO_LHS_SUMS_ACC:.+]] = linalg.fill(%[[C0_I32]], %[[INIT_LHS_SUMS_ACC]]) +// CHECK: %[[LHS_SUMS:.+]] = linalg.generic +// CHECK-SAME: "parallel", "reduction" +// CHECK-SAME: ins(%[[LHS]] : tensor<3x4xi8>) +// CHECK-SAME: outs(%[[ZERO_LHS_SUMS_ACC]] : tensor<3xi32>) +// CHECK: %[[INIT_RHS_SUMS_ACC:.+]] = linalg.init_tensor +// CHECK: %[[ZERO_RHS_SUMS_ACC:.+]] = linalg.fill(%[[C0_I32]], %[[INIT_RHS_SUMS_ACC]]) +// CHECK: %[[RHS_SUMS:.+]] = linalg.generic +// CHECK-SAME: "reduction", "parallel" +// CHECK-SAME: ins(%[[RHS]] : tensor<4x5xi8>) +// CHECK-SAME: outs(%[[ZERO_RHS_SUMS_ACC]] : tensor<5xi32>) +// CHECK: %[[LHS_ZP_TIMES_RHS_ZP:.+]] = arith.muli %[[LHS_ZP]], %[[RHS_ZP]] +// CHECK: %[[PRODUCT_TERM:.+]] = arith.muli %[[LHS_ZP_TIMES_RHS_ZP]], %[[C4_I32]] +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: "parallel", "parallel" +// CHECK-SAME: ins(%[[MATMUL]], %[[LHS_SUMS]], %[[RHS_ZP]], %[[RHS_SUMS]], %[[LHS_ZP]], %[[PRODUCT_TERM]] : tensor<3x5xi32>, tensor<3xi32>, i32, tensor<5xi32>, i32, i32) +// CHECK: return %[[RESULT]] +// -----