diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -155,6 +155,10 @@ /// Create a LinalgStrategyRemoveMarkersPass. std::unique_ptr> createLinalgStrategyRemoveMarkersPass(); +/// Create a LinalgQuantizedMatmulToMatmulPass. +std::unique_ptr> +createLinalgQuantizedMatmulToMatmulPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -392,4 +392,11 @@ ]; } +def LinalgQuantizedMatmulToMatmulPass + : Pass<"linalg-quantized-matmul-to-matmul", "FuncOp"> { + let summary = "lower quantized_matmul to matmul"; + let constructor = "mlir::createLinalgQuantizedMatmulToMatmulPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1363,6 +1363,9 @@ TiledLoopOp loopOp, int64_t idx, TiledLoopOp &result); +/// Patterns to lower quantized_matmul to matmul. +void populateQuantizedMatmulToMatmulPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ NamedOpConversions.cpp PadOpInterchange.cpp Promotion.cpp + QuantizedMatmulToMatmul.cpp Tiling.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/QuantizedMatmulToMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/QuantizedMatmulToMatmul.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/QuantizedMatmulToMatmul.cpp @@ -0,0 +1,179 @@ +//===- QuantizedMatmulToMatmul.cpp - lower quantized_matmul to matmul -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns/pass to remove usage of unit-extent dimensions +// to specify broadcasting in favor of more canonical representation of the +// computation +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +Value additiveReductionLeaving1ParallelDim(PatternRewriter &rewriter, + Location loc, Value matrix, + int parallelDim, Type accElTy) { + MLIRContext *context = rewriter.getContext(); + AffineExpr expr[2]; + bindDims(context, expr[0], expr[1]); + assert(parallelDim == 0 || parallelDim == 1); + AffineExpr parallelExpr = expr[parallelDim]; + AffineMap mapIdentity = AffineMap::get(2, 0, expr, context); + AffineMap mapToParallelDim = AffineMap::get(2, 0, parallelExpr, context); + RankedTensorType matrixType = matrix.getType().cast(); + int64_t dstStaticSize = matrixType.getShape()[parallelDim]; + SmallVector dstDynSizes; + if (dstStaticSize == ShapedType::kDynamicSize) { + dstDynSizes.push_back( + rewriter.create(loc, matrix, parallelDim)); + } + Value initAcc = + rewriter + .create( + loc, dstDynSizes, ArrayRef{dstStaticSize}, accElTy) + .getResult(); + Value zeroInt = + rewriter.create(loc, 0, accElTy).getResult(); + Value zeroAcc = + rewriter.create(loc, zeroInt, initAcc).getResult(0); + ArrayRef indexingMaps{mapIdentity, mapToParallelDim}; + auto iterator = [=](int dim) -> StringRef { + return dim == parallelDim ? "parallel" : "reduction"; + }; + ArrayRef iterators{iterator(0), iterator(1)}; + return rewriter + .create( + loc, zeroAcc.getType(), ValueRange{matrix}, ValueRange{zeroAcc}, + indexingMaps, iterators, + [=](OpBuilder &b, Location loc, ValueRange args) { + Value matrixEl = args[0]; + Value promotedMatrixEl = + b.create(loc, accElTy, matrixEl); + Value accEl = args[1]; + Value sum = b.create(loc, promotedMatrixEl, accEl); + b.create(loc, sum); + }) + .getResult(0); +} + +struct QuantizedMatmulToMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::QuantizedMatmulOp quantizedMatmulOp, + PatternRewriter &rewriter) const override { + Location loc = quantizedMatmulOp.getLoc(); + MLIRContext *context = rewriter.getContext(); + ValueRange inputs = quantizedMatmulOp.inputs(); + assert(inputs.size() == 4); + Value lhs = inputs[0]; + Value rhs = inputs[1]; + Value lhsZp = inputs[2]; + Value rhsZp = inputs[3]; + ValueRange outputs = quantizedMatmulOp.outputs(); + assert(outputs.size() == 1); + Value acc = outputs[0]; + Value matmul = rewriter + .create(loc, ValueRange{lhs, rhs}, + ValueRange{acc}) + .getResult(0); + ShapedType accType = acc.getType().cast(); + auto accDynShape = linalg::getDynOperands(loc, acc, rewriter); + Value initResult = rewriter.create( + loc, accDynShape, accType.getShape(), accType.getElementType()); + AffineExpr m, n; + bindDims(context, m, n); + AffineMap mapToNone = AffineMap::get(2, 0, context); + AffineMap mapToRowDim = AffineMap::get(2, 0, m, context); + AffineMap mapToColumnDim = AffineMap::get(2, 0, n, context); + AffineMap mapIdentity = + AffineMap::get(2, 0, ArrayRef{m, n}, context); + ArrayRef indexingMaps = { + mapIdentity, mapToRowDim, mapToColumnDim, mapToNone, + mapToNone, mapToNone, mapIdentity}; + Type accElTy = accType.getElementType(); + Value lhsSums = + additiveReductionLeaving1ParallelDim(rewriter, loc, lhs, 0, accElTy); + Value rhsSums = + additiveReductionLeaving1ParallelDim(rewriter, loc, rhs, 1, accElTy); + Value lhsZpTimesRhsZp = rewriter.create(loc, lhsZp, rhsZp); + Value kSize = rewriter.create( + loc, accElTy, rewriter.create(loc, lhs, 1)); + Value lhsZpTimesRhsZpTimesKSize = + rewriter.create(loc, lhsZpTimesRhsZp, kSize); + rewriter.replaceOpWithNewOp( + quantizedMatmulOp, acc.getType(), + ValueRange{matmul, lhsSums, rhsSums, lhsZp, rhsZp, + lhsZpTimesRhsZpTimesKSize}, + ValueRange{initResult}, indexingMaps, + ArrayRef{"parallel", "parallel"}, + [](OpBuilder &b, Location loc, ValueRange args) { + Value matmulEl = args[0]; + Value lhsSumsEl = args[1]; + Value rhsSumsEl = args[2]; + Value lhsZp = args[3]; + Value rhsZp = args[4]; + Value lhsZpTimesRhsZpTimesKSize = args[5]; + Value lhsSumsElTimesRhsZp = + b.create(loc, lhsSumsEl, rhsZp); + Value rhsSumsElTimesLhsZp = + b.create(loc, rhsSumsEl, lhsZp); + Value result = matmulEl; + result = b.create(loc, result, lhsSumsElTimesRhsZp); + result = b.create(loc, result, rhsSumsElTimesLhsZp); + result = + b.create(loc, result, lhsZpTimesRhsZpTimesKSize); + b.create(loc, result); + }); + return success(); + } +}; +} // namespace + +void mlir::linalg::populateQuantizedMatmulToMatmulPatterns( + RewritePatternSet &patterns) { + auto *context = patterns.getContext(); + patterns.add(context); +} + +namespace { +/// Pass that lowers quantized_matmul to matmul. +struct LinalgQuantizedMatmulToMatmulPass + : public LinalgQuantizedMatmulToMatmulPassBase< + LinalgQuantizedMatmulToMatmulPass> { + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + populateQuantizedMatmulToMatmulPatterns(patterns); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgQuantizedMatmulToMatmulPass() { + return std::make_unique(); +}