diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -30,6 +30,8 @@ RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/FoldUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/FoldUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Utils/FoldUtils.h @@ -0,0 +1,41 @@ +//===- FoldUtils.h - Helper Functions for Folds -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper functions useful for various different TOSA constant folds. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H +#define MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H + +#include +#include + +namespace mlir { +namespace tosa { + +/// Rounding mode to be used on floating point operations that require rounding. +static constexpr llvm::RoundingMode tosaRoundingMode = + llvm::APFloat::rmNearestTiesToEven; + +/// Apply the given transformation \p toApply to every element of the tensor to +/// be transformed \p toTransform. +/// +/// Elements of \p toTransform are extracted as \p SrcValueType. +/// +/// \returns A tensor with the same size as \p toTransform, containing +/// \p TargetValueType values of type \p TargetType. +template +DenseElementsAttr applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + TargetType targetType); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -2,6 +2,7 @@ IR/TosaOps.cpp IR/TosaCanonicalizations.cpp Utils/ConversionUtils.cpp + Utils/FoldUtils.cpp Utils/QuantUtils.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -2,6 +2,8 @@ TosaDecomposeTransposeConv.cpp TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp + TosaFoldCommon.cpp + TosaFoldConstantReciprocal.cpp TosaFoldConstantTranspose.cpp TosaInferShapes.cpp TosaLayerwiseConstantFoldPass.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -0,0 +1,49 @@ +//===- TosaFoldCommon.h - Helper Functions for Folds ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper functions useful for various different TOSA constant folds. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H + +#include +#include +#include + +namespace mlir { +namespace tosa { + +/// Function that checks if \p toCheck is a dense TOSA constant float tensor. +LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter); + +/// Function that checks if \p toCheck is a dense TOSA constant tensor. +LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter); + +/// Function that checks if the type contained in \p toCheck is float. +LogicalResult notifyIfNotFloat(TypedValue toCheck, TosaOp location, + PatternRewriter &rewriter); + +/// Heuristic to decide when to replace a unary operation on a constant with the +/// folded value. +/// Folding operations on constants can lead to an increased memory usage +/// whenever the input cannot be replaced but a new constant is inserted. Hence, +/// this will currently only suggest folding when the memory impact is +/// negligible. +/// Takes the \p unaryOp and the constant input \p values. +/// \returns Whether folding should be applied. +bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -0,0 +1,83 @@ +//===- TosaFoldCommon.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper functions useful for various different TOSA constant folds. +// +//===----------------------------------------------------------------------===// + +#include "TosaFoldCommon.h" +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +LogicalResult +mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter) { + auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter); + if (failed(floatCheck)) { + return floatCheck; + } + return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter); +} + +LogicalResult +mlir::tosa::notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter) { + // Check whether the tensor is constant and dense + // TODO We currently ensure the tensor is dense by using the correct type for + // the bind_value, however we do not actually need this value. It would be + // nicer to only have a check here. + DenseElementsAttr tmp; + if (!matchPattern(toCheck, m_Constant(&tmp))) { + return rewriter.notifyMatchFailure(location, + "Non-const or non-dense input tensor"); + } + + // Make sure it actually is a TOSA constant (the match allows for other + // constants as well) + if (isa(toCheck.getDefiningOp())) { + return success(); + } + + return rewriter.notifyMatchFailure(location, + "The reciprocal can only be folded if " + "it operates on a TOSA constant"); +} + +LogicalResult mlir::tosa::notifyIfNotFloat(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter) { + if (isa(toCheck.getType().getElementType())) { + return success(); + } + return rewriter.notifyMatchFailure(location, + "Unexpected input tensor type: the " + "TOSA spec only allows floats"); +} + +bool mlir::tosa::constantUnaryOpShouldBeFolded(TosaOp unaryOp, + DenseElementsAttr values) { + assert(unaryOp->getNumOperands() == 1); + auto inputOp = unaryOp->getOperand(0); + + // If the input is a splat, we don't care for the number of users + if (isa(values)) { + return true; + } + + // If this is the only use of the tensor it should be replaced as no + // additional memory is required + return inputOp.hasOneUse(); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp @@ -0,0 +1,80 @@ +//===- TosaFoldConstantReciprocal.cpp -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA reciprocal operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "TosaFoldCommon.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/FoldUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/FloatingPointMode.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantReciprocal : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) { + auto recipAttr = FloatAttr::get(floatTy, 1.0); + APFloat recip = recipAttr.getValue(); + recip.divide(floatVal, tosaRoundingMode); + + return recip; + } + + LogicalResult matchAndRewrite(ReciprocalOp recip, + PatternRewriter &rewriter) const override { + auto inputTensor = recip.getInput1(); + + // Check that we can apply folding + auto preCondCheck = + notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter); + if (failed(preCondCheck)) { + return preCondCheck; + } + + // Extract the tensor values + DenseElementsAttr inputValues; + matchPattern(inputTensor, m_Constant(&inputValues)); + + // Check whether this should be folded. + if (!constantUnaryOpShouldBeFolded(recip, inputValues)) { + return rewriter.notifyMatchFailure( + recip, "Currently, reciprocals will only be folded if the input " + "tensor has a single user"); + } + + // Create a new tensor with the updated values + auto newTensor = applyElementWise( + inputValues, &computeReciprocal, + cast(inputValues.getElementType())); + + // Replace the use of the reciprocal with the transformed tensor + rewriter.replaceOpWithNewOp(recip, newTensor.getType(), newTensor); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantReciprocalPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -50,6 +50,7 @@ RewritePatternSet patterns(ctx); auto func = getOperation(); + mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); populateTosaOpsCanonicalizationPatterns(ctx, patterns); diff --git a/mlir/lib/Dialect/Tosa/Utils/FoldUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/FoldUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Utils/FoldUtils.cpp @@ -0,0 +1,48 @@ +//===- FoldUtils.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper functions useful for various different TOSA constant folds. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/FoldUtils.h" + +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +template +DenseElementsAttr mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + TargetType targetType) { + SmallVector transformedValues; + // We already know the amount of values we will insert, reserve space for + // all of them to avoid dynamic resizing + transformedValues.reserve(toTransform.getNumElements()); + for (auto val : toTransform.getValues()) { + auto transformedVal = toApply(val, targetType); + transformedValues.push_back(transformedVal); + } + + // Make sure that the output tensor has the expected output type + auto inShape = toTransform.getType(); + auto outTy = inShape.cloneWith({}, targetType); + + return DenseElementsAttr::get(outTy, transformedValues); +} + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + FloatType targetType); diff --git a/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s + +// CHECK-LABEL: @reciprocal_fold_single_valued +func.func @reciprocal_fold_single_valued() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_fold_splat +func.func @reciprocal_fold_splat() -> tensor<12x7xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<12x7xf32> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.0> : tensor<12x7xf32>} : () -> tensor<12x7xf32> + %1 = "tosa.reciprocal"(%0) : (tensor<12x7xf32>) -> tensor<12x7xf32> + return %1 : tensor<12x7xf32> +} + +// CHECK-LABEL: @reciprocal_div_zero +func.func @reciprocal_div_zero() -> tensor { + // 0x7F800000 is the value for +infinity + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_neg_zero +func.func @reciprocal_div_neg_zero() -> tensor { + // 0xFF800000 is the value for -infinity + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0xFF800000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<-0.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_nan +func.func @reciprocal_div_nan() -> tensor { + // 0x7FC00000 is the value for NAN + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7FC00000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0x7FC00000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_infinity +func.func @reciprocal_div_infinity() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0x7F800000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_neg_infinity +func.func @reciprocal_div_neg_infinity() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<-0.{{0*}}e+00> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0xFF800000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_underflow +func.func @reciprocal_div_underflow() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-0.{{0*}}e+00, 0.{{0*}}e+00 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[-6.0e+15, 6.0e+15]> : tensor<2xf16>} : () -> tensor<2xf16> + %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +} + +// CHECK-LABEL: @reciprocal_div_overflow +func.func @reciprocal_div_overflow() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[0.0000001, -0.0000001]> : tensor<2xf16>} : () -> tensor<2xf16> + %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +} + +// CHECK-LABEL: @reciprocal_no_fold +// The folding optimization works only intra-procedurally, so we won't be able +// to fold anything here +func.func @reciprocal_no_fold(%arg0: tensor) -> tensor { + // CHECK: tosa.reciprocal + // CHECK-NEXT: return + %0 = "tosa.reciprocal"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @reciprocal_fold +func.func @reciprocal_fold() -> tensor<4x6xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-SAME{LITERAL}: [[5.68828249, 11.4416485, 1.6880486, 0.680272102, -0.875350117, 0.342313349], + // CHECK-SAME{LITERAL}: [-4.81231928, 0.698080301, 0.65432179, -82.6446304, -4.33651352, -0.747551739], + // CHECK-SAME{LITERAL}: [-12.4378109, 13.140605, 1.89501607, 0.885582745, 4.08830738, 1.4396776], + // CHECK-SAME{LITERAL}: [2.02880907, -1.53280187, 0.552730501, 7.15819644, 0.64495325, -0.973709881]] + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() { value = dense<[ + [ 0.1758, 0.0874, 0.5924, 1.4700, -1.1424, 2.9213], + [-0.2078, 1.4325, 1.5283, -0.0121, -0.2306, -1.3377], + [-0.0804, 0.0761, 0.5277, 1.1292, 0.2446, 0.6946], + [ 0.4929, -0.6524, 1.8092, 0.1397, 1.5505, -1.0270]]> + : tensor<4x6xf32> + } : () -> tensor<4x6xf32> + %1 = "tosa.reciprocal"(%0) : (tensor<4x6xf32>) -> tensor<4x6xf32> + return %1 : tensor<4x6xf32> +} + +// CHECK-LABEL: @reciprocal_of_const_sparse +// Sparse tensors are currently not supported +func.func @reciprocal_of_const_sparse() -> tensor<32xbf16> { + // CHECK: tosa.const + // CHECK: tosa.reciprocal + %0 = "tosa.const"() { value = sparse< + [[0], [3], [11], [17], [20], [23], [25], [30], [31]], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]> + : tensor<32xbf16> } : () -> tensor<32xbf16> + %1 = "tosa.reciprocal"(%0) : (tensor<32xbf16>) -> tensor<32xbf16> + return %1 : tensor<32xbf16> +}