diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -34,17 +34,6 @@ } // namespace tosa } // namespace mlir -//===----------------------------------------------------------------------===// -// Utility Functions -//===----------------------------------------------------------------------===// -namespace mlir { -namespace tosa { -/// Appends the canonicalization patterns for all the TOSA ops to the `patterns` -void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx, - RewritePatternSet &patterns); -} // namespace tosa -} // namespace mlir - #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc" diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,7 +1,8 @@ add_mlir_dialect_library(MLIRTosaDialect + IR/TosaOps.cpp + IR/TosaCanonicalizations.cpp Utils/ConversionUtils.cpp Utils/QuantUtils.cpp - IR/TosaOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -0,0 +1,544 @@ +//===- TosaCanonicalizations.cpp - Canonicalization patterns and folders ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// TOSA canonicalization patterns and folders. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// Operator Canonicalizers. +//===----------------------------------------------------------------------===// + +struct ConcatOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConcatOp op, + PatternRewriter &rewriter) const override { + if (op.input1().size() != 1) + return failure(); + if (op.input1().front().getType() != op.getType()) { + rewriter + .replaceOpWithNewOp(op, op.getType(), + op.input1().front()) + .getResult(); + return success(); + } + + rewriter.replaceOp(op, op.input1().front()); + return success(); + } +}; + +void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct ReshapeReshapeOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReshapeOp op, + PatternRewriter &rewriter) const override { + Value input = op.input1(); + Operation *definingOp = input.getDefiningOp(); + if (!definingOp) + return failure(); + + if (tosa::ReshapeOp reshapeOp = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp( + op, op.getType(), reshapeOp.input1(), op.new_shape()); + return success(); + } + + return failure(); + } +}; + +struct ReshapeConstOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReshapeOp op, + PatternRewriter &rewriter) const override { + Value input = op.input1(); + ArrayAttr newShape = op.new_shape(); + + // Check if input is constant + DenseElementsAttr inputAttr; + if (!matchPattern(input, m_Constant(&inputAttr))) + return failure(); + + // Check if has >1 consumer and is not splat + if (!input.hasOneUse() && !inputAttr.isSplat()) + return failure(); + + // Grab the new shape + SmallVector newShapeValues = llvm::to_vector<6>( + llvm::map_range(newShape.getValue(), [](const Attribute &val) { + return val.cast().getValue().getSExtValue(); + })); + + // Build new const op with correct output shape + ShapedType inputShape = input.getType().cast(); + DenseElementsAttr outputAttr = + inputAttr.reshape(inputShape.clone(newShapeValues)); + rewriter.replaceOpWithNewOp(op, outputAttr.getType(), + outputAttr); + return success(); + } +}; + +void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); + results.add(context); +} + +LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { + auto notOp = op.pred().getDefiningOp(); + if (!notOp) + return failure(); + rewriter.updateRootInPlace(op, [&]() { + op.getOperation()->setOperands( + {notOp.input1(), op.on_false(), op.on_true()}); + }); + return success(); +} + +struct NoOpOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + auto perm = op.perms(); + + DenseIntElementsAttr permAttr; + if (!matchPattern(perm, m_Constant(&permAttr))) { + return failure(); + } + + SmallVector permValues = llvm::to_vector<6>( + llvm::map_range(permAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + for (int i = 0, s = permValues.size(); i < s; i++) { + if (i != permValues[i]) { + return failure(); + } + } + + rewriter.replaceOp(op, op.input1()); + return success(); + } +}; + +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct AddZeroOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::AddOp op, + PatternRewriter &rewriter) const override { + auto input1 = op.input1(); + auto input2 = op.input2(); + + DenseElementsAttr input1Attr; + if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && + input2.getType() == op.getType()) { + if (input1Attr.getType().getElementType().isa() && + input1Attr.getSplatValue().isZero()) { + rewriter.replaceOp(op, op.input2()); + return success(); + } + } + + DenseElementsAttr input2Attr; + if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && + input1.getType() == op.getType()) { + if (input2Attr.getType().getElementType().isa() && + input2Attr.getSplatValue().isZero()) { + rewriter.replaceOp(op, op.input1()); + return success(); + } + } + + return failure(); + } +}; + +void AddOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct MulOneOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MulOp op, + PatternRewriter &rewriter) const override { + auto input1 = op.input1(); + auto input2 = op.input2(); + + DenseElementsAttr input1Attr; + if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && + input2.getType() == op.getType()) { + if (input1Attr.getType().getElementType().isa() && + input1Attr.getSplatValue().isExactlyValue(1)) { + rewriter.replaceOp(op, op.input2()); + return success(); + } + + if (input1Attr.getType().getElementType().isa() && + matchPattern(input1, m_One())) { + rewriter.replaceOp(op, op.input2()); + return success(); + } + } + + DenseElementsAttr input2Attr; + if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && + input1.getType() == op.getType()) { + if (input2Attr.getType().getElementType().isa() && + input2Attr.getSplatValue().isExactlyValue(1)) { + rewriter.replaceOp(op, op.input1()); + return success(); + } + + if (input2Attr.getType().getElementType().isa() && + matchPattern(input2, m_One())) { + rewriter.replaceOp(op, op.input1()); + return success(); + } + } + + return failure(); + } +}; + +void MulOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct MaterializePadValue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::PadOp op, + PatternRewriter &rewriter) const override { + if (op.pad_const()) + return failure(); + + auto input = op.input1(); + auto padding = op.padding(); + + ShapedType inputTy = input.getType().cast(); + Type elementTy = inputTy.getElementType(); + + Attribute constantAttr; + if (elementTy.isa()) { + constantAttr = rewriter.getFloatAttr(elementTy, 0.0); + } else if (elementTy.isa() && !op.quantization_info()) { + constantAttr = rewriter.getIntegerAttr(elementTy, 0); + } else if (elementTy.isa() && op.quantization_info()) { + auto value = op.quantization_info()->getInputZp(); + constantAttr = rewriter.getIntegerAttr(elementTy, value); + } + + if (!constantAttr) { + return rewriter.notifyMatchFailure( + op, + "tosa.pad to linalg lowering encountered an unknown element type"); + } + + auto denseAttr = DenseElementsAttr::get( + RankedTensorType::get({}, elementTy), constantAttr); + auto constantVal = rewriter.create( + op.getLoc(), denseAttr.getType(), denseAttr); + + rewriter.replaceOpWithNewOp( + op, op.getType(), ValueRange{input, padding, constantVal}, + op->getAttrs()); + return success(); + } +}; + +void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct MaxPool2dIsNoOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value output = op.output(); + ShapedType inputType = input.getType().cast(); + ShapedType outputType = output.getType().cast(); + + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { + return failure(); + } + + // If the output and input shapes are 1x1, then this is a no op. + ArrayRef outputShape = outputType.getShape(); + if (outputShape[1] != 1 || outputShape[2] != 1) { + return failure(); + } + + ArrayRef inputShape = inputType.getShape(); + if (inputShape[1] != 1 || inputShape[2] != 1) { + return failure(); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + +void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct ClampIsNoOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ClampOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + auto inputType = op.input().getType().template dyn_cast(); + auto inputElementType = inputType.getElementType(); + + if (!inputType.hasStaticShape()) { + return failure(); + } + + if (inputElementType.isF32()) { + auto minClamp = op.min_fp(); + auto maxClamp = op.max_fp(); + bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) && + minClamp.isNegative(); + bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) && + !maxClamp.isNegative(); + + if (isMin && isMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + if (inputElementType.isUnsignedInteger()) { + int64_t minClamp = op.min_int(); + int64_t maxClamp = op.max_int(); + + int64_t intMin = + APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) + .getZExtValue(); + int64_t intMax = + APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) + .getZExtValue(); + + if (minClamp <= intMin && maxClamp >= intMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + if (inputElementType.isa()) { + int64_t minClamp = op.min_int(); + int64_t maxClamp = op.max_int(); + + int64_t intMin = + APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) + .getSExtValue(); + int64_t intMax = + APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) + .getSExtValue(); + + if (minClamp <= intMin && maxClamp >= intMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + return failure(); + } +}; + +struct ClampClampOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ClampOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + + Operation *definingOp = input.getDefiningOp(); + if (!definingOp) + return failure(); + + if (tosa::ClampOp clampOp = dyn_cast(definingOp)) { + auto minFp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat(); + auto maxFp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat(); + + auto minInt = std::max(op.min_int(), clampOp.min_int()); + auto maxInt = std::min(op.max_int(), clampOp.max_int()); + + rewriter.replaceOpWithNewOp( + op, op.getType(), clampOp.input(), rewriter.getI64IntegerAttr(minInt), + rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), + rewriter.getF32FloatAttr(maxFp)); + return success(); + } + + return failure(); + } +}; + +void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); + results.add(context); +} + +//===----------------------------------------------------------------------===// +// Operator Folders. +//===----------------------------------------------------------------------===// + +OpFoldResult CastOp::fold(ArrayRef operands) { + if (input().getType() == getType()) + return input(); + return {}; +} + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + return valueAttr(); +} + +#define REDUCE_FOLDER(OP) \ + OpFoldResult OP::fold(ArrayRef operands) { \ + ShapedType inputTy = input().getType().cast(); \ + if (!inputTy.hasRank()) \ + return {}; \ + if (inputTy.getDimSize(axis()) == 1) \ + return input(); \ + return {}; \ + } + +REDUCE_FOLDER(ReduceAllOp) +REDUCE_FOLDER(ReduceAnyOp) +REDUCE_FOLDER(ReduceMaxOp) +REDUCE_FOLDER(ReduceMinOp) +REDUCE_FOLDER(ReduceProdOp) +REDUCE_FOLDER(ReduceSumOp) +#undef REDUCE_FOLDER + +OpFoldResult ReshapeOp::fold(ArrayRef operands) { + auto inputTy = input1().getType().dyn_cast(); + auto outputTy = getType().dyn_cast(); + + if (!inputTy || !outputTy || inputTy != outputTy) + return {}; + return input1(); +} + +OpFoldResult PadOp::fold(ArrayRef operands) { + // If the pad is all zeros we can fold this operation away. + if (operands[1]) { + auto densePad = operands[1].cast(); + if (densePad.isSplat() && densePad.getSplatValue().isZero()) { + return input1(); + } + } + + return {}; +} + +OpFoldResult SliceOp::fold(ArrayRef operands) { + auto inputTy = input().getType().dyn_cast(); + auto outputTy = getType().dyn_cast(); + + if (!inputTy || !outputTy || inputTy != outputTy) + return {}; + if (inputTy.hasStaticShape()) + return input(); + + return {}; +} + +OpFoldResult tosa::SelectOp::fold(ArrayRef operands) { + if (on_true() == on_false()) + return on_true(); + + auto predicate = operands[0].dyn_cast_or_null(); + if (!predicate) + return {}; + + if (!predicate.isSplat()) + return {}; + return predicate.getSplatValue().getBoolValue() ? on_true() + : on_false(); +} + +OpFoldResult TileOp::fold(ArrayRef operands) { + bool allOnes = true; + for (Attribute val : multiples().getValue()) { + allOnes = allOnes && val.cast().getValue().getSExtValue() == 1; + } + + if (allOnes && input1().getType() == getType()) + return input1(); + return {}; +} + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + if (!operands[1]) + return {}; + + // Transposing splat values just means reshaping. + if (auto input = operands[0].dyn_cast_or_null()) { + if (input.isSplat()) + return input.reshape(getType().cast()); + } + + auto perms = llvm::to_vector<6>(llvm::map_range( + operands[1].cast().getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + if (llvm::equal(llvm::seq(0, perms.size()), perms) && + input1().getType() == getType()) + return input1(); + return {}; +} + diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -21,9 +21,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" @@ -96,533 +94,6 @@ return nullptr; } -//===----------------------------------------------------------------------===// -// Operator Canonicalizers. -//===----------------------------------------------------------------------===// - -template -void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) { - (void)std::initializer_list{ - 0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...}; -} - -void mlir::tosa::populateTosaOpsCanonicalizationPatterns( - MLIRContext *ctx, RewritePatternSet &patterns) { - addOpsCanonicalizations< -#define GET_OP_LIST -#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" - >(ctx, patterns); -} - -struct ConcatOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ConcatOp op, - PatternRewriter &rewriter) const override { - if (op.input1().size() != 1) - return failure(); - if (op.input1().front().getType() != op.getType()) { - rewriter - .replaceOpWithNewOp(op, op.getType(), - op.input1().front()) - .getResult(); - return success(); - } - - rewriter.replaceOp(op, op.input1().front()); - return success(); - } -}; - -void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct ReshapeReshapeOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ReshapeOp op, - PatternRewriter &rewriter) const override { - Value input = op.input1(); - Operation *definingOp = input.getDefiningOp(); - if (!definingOp) - return failure(); - - if (tosa::ReshapeOp reshapeOp = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp( - op, op.getType(), reshapeOp.input1(), op.new_shape()); - return success(); - } - - return failure(); - } -}; - -struct ReshapeConstOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ReshapeOp op, - PatternRewriter &rewriter) const override { - Value input = op.input1(); - ArrayAttr newShape = op.new_shape(); - - // Check if input is constant - DenseElementsAttr inputAttr; - if (!matchPattern(input, m_Constant(&inputAttr))) - return failure(); - - // Check if has >1 consumer and is not splat - if (!input.hasOneUse() && !inputAttr.isSplat()) - return failure(); - - // Grab the new shape - SmallVector newShapeValues = llvm::to_vector<6>( - llvm::map_range(newShape.getValue(), [](const Attribute &val) { - return val.cast().getValue().getSExtValue(); - })); - - // Build new const op with correct output shape - ShapedType inputShape = input.getType().cast(); - DenseElementsAttr outputAttr = - inputAttr.reshape(inputShape.clone(newShapeValues)); - rewriter.replaceOpWithNewOp(op, outputAttr.getType(), - outputAttr); - return success(); - } -}; - -void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); - results.add(context); -} - -LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { - auto notOp = op.pred().getDefiningOp(); - if (!notOp) - return failure(); - rewriter.updateRootInPlace(op, [&]() { - op.getOperation()->setOperands( - {notOp.input1(), op.on_false(), op.on_true()}); - }); - return success(); -} - -struct NoOpOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::TransposeOp op, - PatternRewriter &rewriter) const override { - auto perm = op.perms(); - - DenseIntElementsAttr permAttr; - if (!matchPattern(perm, m_Constant(&permAttr))) { - return failure(); - } - - SmallVector permValues = llvm::to_vector<6>( - llvm::map_range(permAttr.getValues(), - [](const APInt &val) { return val.getSExtValue(); })); - - for (int i = 0, s = permValues.size(); i < s; i++) { - if (i != permValues[i]) { - return failure(); - } - } - - rewriter.replaceOp(op, op.input1()); - return success(); - } -}; - -void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct AddZeroOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::AddOp op, - PatternRewriter &rewriter) const override { - auto input1 = op.input1(); - auto input2 = op.input2(); - - DenseElementsAttr input1Attr; - if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && - input2.getType() == op.getType()) { - if (input1Attr.getType().getElementType().isa() && - input1Attr.getSplatValue().isZero()) { - rewriter.replaceOp(op, op.input2()); - return success(); - } - } - - DenseElementsAttr input2Attr; - if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && - input1.getType() == op.getType()) { - if (input2Attr.getType().getElementType().isa() && - input2Attr.getSplatValue().isZero()) { - rewriter.replaceOp(op, op.input1()); - return success(); - } - } - - return failure(); - } -}; - -void AddOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct MulOneOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::MulOp op, - PatternRewriter &rewriter) const override { - auto input1 = op.input1(); - auto input2 = op.input2(); - - DenseElementsAttr input1Attr; - if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && - input2.getType() == op.getType()) { - if (input1Attr.getType().getElementType().isa() && - input1Attr.getSplatValue().isExactlyValue(1)) { - rewriter.replaceOp(op, op.input2()); - return success(); - } - - if (input1Attr.getType().getElementType().isa() && - matchPattern(input1, m_One())) { - rewriter.replaceOp(op, op.input2()); - return success(); - } - } - - DenseElementsAttr input2Attr; - if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && - input1.getType() == op.getType()) { - if (input2Attr.getType().getElementType().isa() && - input2Attr.getSplatValue().isExactlyValue(1)) { - rewriter.replaceOp(op, op.input1()); - return success(); - } - - if (input2Attr.getType().getElementType().isa() && - matchPattern(input2, m_One())) { - rewriter.replaceOp(op, op.input1()); - return success(); - } - } - - return failure(); - } -}; - -void MulOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct MaterializePadValue : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::PadOp op, - PatternRewriter &rewriter) const override { - if (op.pad_const()) - return failure(); - - auto input = op.input1(); - auto padding = op.padding(); - - ShapedType inputTy = input.getType().cast(); - Type elementTy = inputTy.getElementType(); - - Attribute constantAttr; - if (elementTy.isa()) { - constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (elementTy.isa() && !op.quantization_info()) { - constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (elementTy.isa() && op.quantization_info()) { - auto value = op.quantization_info()->getInputZp(); - constantAttr = rewriter.getIntegerAttr(elementTy, value); - } - - if (!constantAttr) { - return rewriter.notifyMatchFailure( - op, - "tosa.pad to linalg lowering encountered an unknown element type"); - } - - auto denseAttr = DenseElementsAttr::get( - RankedTensorType::get({}, elementTy), constantAttr); - auto constantVal = rewriter.create( - op.getLoc(), denseAttr.getType(), denseAttr); - - rewriter.replaceOpWithNewOp( - op, op.getType(), ValueRange{input, padding, constantVal}, - op->getAttrs()); - return success(); - } -}; - -void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct MaxPool2dIsNoOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value output = op.output(); - ShapedType inputType = input.getType().cast(); - ShapedType outputType = output.getType().cast(); - - if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { - return failure(); - } - - // If the output and input shapes are 1x1, then this is a no op. - ArrayRef outputShape = outputType.getShape(); - if (outputShape[1] != 1 || outputShape[2] != 1) { - return failure(); - } - - ArrayRef inputShape = inputType.getShape(); - if (inputShape[1] != 1 || inputShape[2] != 1) { - return failure(); - } - - rewriter.replaceOp(op, input); - return success(); - } -}; - -void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct ClampIsNoOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ClampOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - auto inputType = op.input().getType().template dyn_cast(); - auto inputElementType = inputType.getElementType(); - - if (!inputType.hasStaticShape()) { - return failure(); - } - - if (inputElementType.isF32()) { - auto minClamp = op.min_fp(); - auto maxClamp = op.max_fp(); - bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) && - minClamp.isNegative(); - bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) && - !maxClamp.isNegative(); - - if (isMin && isMax) { - rewriter.replaceOp(op, input); - return success(); - } - return failure(); - } - - if (inputElementType.isUnsignedInteger()) { - int64_t minClamp = op.min_int(); - int64_t maxClamp = op.max_int(); - - int64_t intMin = - APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) - .getZExtValue(); - int64_t intMax = - APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) - .getZExtValue(); - - if (minClamp <= intMin && maxClamp >= intMax) { - rewriter.replaceOp(op, input); - return success(); - } - return failure(); - } - - if (inputElementType.isa()) { - int64_t minClamp = op.min_int(); - int64_t maxClamp = op.max_int(); - - int64_t intMin = - APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) - .getSExtValue(); - int64_t intMax = - APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) - .getSExtValue(); - - if (minClamp <= intMin && maxClamp >= intMax) { - rewriter.replaceOp(op, input); - return success(); - } - return failure(); - } - - return failure(); - } -}; - -struct ClampClampOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ClampOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - - Operation *definingOp = input.getDefiningOp(); - if (!definingOp) - return failure(); - - if (tosa::ClampOp clampOp = dyn_cast(definingOp)) { - auto minFp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat(); - auto maxFp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat(); - - auto minInt = std::max(op.min_int(), clampOp.min_int()); - auto maxInt = std::min(op.max_int(), clampOp.max_int()); - - rewriter.replaceOpWithNewOp( - op, op.getType(), clampOp.input(), rewriter.getI64IntegerAttr(minInt), - rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), - rewriter.getF32FloatAttr(maxFp)); - return success(); - } - - return failure(); - } -}; - -void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); - results.add(context); -} - -//===----------------------------------------------------------------------===// -// Operator Folders. -//===----------------------------------------------------------------------===// - -OpFoldResult CastOp::fold(ArrayRef operands) { - if (input().getType() == getType()) - return input(); - return {}; -} - -OpFoldResult ConstOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); - return valueAttr(); -} - -#define REDUCE_FOLDER(OP) \ - OpFoldResult OP::fold(ArrayRef operands) { \ - ShapedType inputTy = input().getType().cast(); \ - if (!inputTy.hasRank()) \ - return {}; \ - if (inputTy.getDimSize(axis()) == 1) \ - return input(); \ - return {}; \ - } - -REDUCE_FOLDER(ReduceAllOp) -REDUCE_FOLDER(ReduceAnyOp) -REDUCE_FOLDER(ReduceMaxOp) -REDUCE_FOLDER(ReduceMinOp) -REDUCE_FOLDER(ReduceProdOp) -REDUCE_FOLDER(ReduceSumOp) -#undef REDUCE_FOLDER - -OpFoldResult ReshapeOp::fold(ArrayRef operands) { - auto inputTy = input1().getType().dyn_cast(); - auto outputTy = getType().dyn_cast(); - - if (!inputTy || !outputTy || inputTy != outputTy) - return {}; - return input1(); -} - -OpFoldResult PadOp::fold(ArrayRef operands) { - // If the pad is all zeros we can fold this operation away. - if (operands[1]) { - auto densePad = operands[1].cast(); - if (densePad.isSplat() && densePad.getSplatValue().isZero()) { - return input1(); - } - } - - return {}; -} - -OpFoldResult SliceOp::fold(ArrayRef operands) { - auto inputTy = input().getType().dyn_cast(); - auto outputTy = getType().dyn_cast(); - - if (!inputTy || !outputTy || inputTy != outputTy) - return {}; - if (inputTy.hasStaticShape()) - return input(); - - return {}; -} - -OpFoldResult tosa::SelectOp::fold(ArrayRef operands) { - if (on_true() == on_false()) - return on_true(); - - auto predicate = operands[0].dyn_cast_or_null(); - if (!predicate) - return {}; - - if (!predicate.isSplat()) - return {}; - return predicate.getSplatValue().getBoolValue() ? on_true() - : on_false(); -} - -OpFoldResult TileOp::fold(ArrayRef operands) { - bool allOnes = true; - for (Attribute val : multiples().getValue()) { - allOnes = allOnes && val.cast().getValue().getSExtValue() == 1; - } - - if (allOnes && input1().getType() == getType()) - return input1(); - return {}; -} - -OpFoldResult TransposeOp::fold(ArrayRef operands) { - if (!operands[1]) - return {}; - - // Transposing splat values just means reshaping. - if (auto input = operands[0].dyn_cast_or_null()) { - if (input.isSplat()) - return input.reshape(getType().cast()); - } - - auto perms = llvm::to_vector<6>(llvm::map_range( - operands[1].cast().getValues(), - [](const APInt &val) { return val.getSExtValue(); })); - - if (llvm::equal(llvm::seq(0, perms.size()), perms) && - input1().getType() == getType()) - return input1(); - return {}; -} - //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -21,6 +21,20 @@ namespace { +template +void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) { + (void)std::initializer_list{ + 0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...}; +} + +void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx, + RewritePatternSet &patterns) { + addOpsCanonicalizations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(ctx, patterns); +} + struct TosaLayerwiseConstantFoldPass : public TosaLayerwiseConstantFoldPassBase { void runOnOperation() override { @@ -29,7 +43,7 @@ auto func = getOperation(); mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); - mlir::tosa::populateTosaOpsCanonicalizationPatterns(ctx, patterns); + populateTosaOpsCanonicalizationPatterns(ctx, patterns); if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) signalPassFailure();