diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -62,6 +62,10 @@ /// Create a pass which do optimizations based on integer range analysis. std::unique_ptr createIntRangeOptimizationsPass(); +/// Add patterns for integer bitwidth narrowing. +void populateArithIntNarrowingPatterns(RewritePatternSet &patterns, + const ArithIntNarrowingOptions &options); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -83,4 +83,17 @@ let dependentDialects = ["vector::VectorDialect"]; } +def ArithIntNarrowing : Pass<"arith-int-narrowing"> { + let summary = "Reduce integer operation bitwidth"; + let description = [{ + Reduce bitwidths of integer types used in arith operations. This pass + prefers the narrowest available integer bitwidths that are guaranteed to + produce the same results. + }]; + let options = [ + ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned", + "Integer bitwidths supported">, + ]; + } + #endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ Bufferize.cpp EmulateWideInt.cpp ExpandOps.cpp + IntNarrowing.cpp IntRangeOptimizations.cpp ReifyValueBounds.cpp UnsignedWhenEquivalent.cpp diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -0,0 +1,175 @@ +//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include +#include + +namespace mlir::arith { +#define GEN_PASS_DEF_ARITHINTNARROWING +#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" +} // namespace mlir::arith + +namespace mlir::arith { +namespace { +//===----------------------------------------------------------------------===// +// Common Helpers +//===----------------------------------------------------------------------===// + +/// The base for integer bitwidth narrowing patterns. +template +struct NarrowingPattern : OpRewritePattern { + NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(ctx, benefit), + supportedBitwidths(options.bitwidthsSupported.begin(), + options.bitwidthsSupported.end()) { + assert(!supportedBitwidths.empty() && "Invalid options"); + assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth"); + llvm::sort(supportedBitwidths); + } + + FailureOr + getNarrowestCompatibleBitwidth(unsigned bitsRequired) const { + for (unsigned candidate : supportedBitwidths) + if (candidate >= bitsRequired) + return candidate; + + return failure(); + } + + /// Returns the narrowest supported type that fits `bitsRequired`. + FailureOr getNarrowType(unsigned bitsRequired, Type origTy) const { + assert(origTy); + FailureOr bestBitwidth = + getNarrowestCompatibleBitwidth(bitsRequired); + if (failed(bestBitwidth)) + return failure(); + + auto elemTy = getElementTypeOrSelf(origTy); + if (!isa(elemTy)) + return failure(); + + auto newElemTy = IntegerType::get(origTy.getContext(), bitsRequired); + if (newElemTy == elemTy) + return failure(); + + if (origTy == elemTy) + return newElemTy; + + if (auto shapedTy = dyn_cast(origTy)) + if (auto elemTy = dyn_cast(shapedTy.getElementType())) + return shapedTy.clone(shapedTy.getShape(), newElemTy); + + return failure(); + } + +private: + // Supported integer bitwidths in the ascending order. + llvm::SmallVector supportedBitwidths; +}; + +/// Returns the integer bitwidth required to represent `type`. +FailureOr calculateBitsRequired(Type type) { + assert(type); + if (auto intTy = dyn_cast(getElementTypeOrSelf(type))) + return intTy.getWidth(); + + return failure(); +} + +enum class ExtensionKind { Sign, Zero }; + +/// Returns the integer bitwidth required to represent `value`. +/// Looks through either sign- or zero-extension as specified by +/// `lookThroughExtension`. +FailureOr calculateBitsRequired(Value value, + ExtensionKind lookThroughExtension) { + if (lookThroughExtension == ExtensionKind::Sign) { + if (auto sext = value.getDefiningOp()) + return calculateBitsRequired(sext.getIn().getType()); + } else if (lookThroughExtension == ExtensionKind::Zero) { + if (auto zext = value.getDefiningOp()) + return calculateBitsRequired(zext.getIn().getType()); + } + + // If nothing else worked, return the type requirements for this element type. + return calculateBitsRequired(value.getType()); +} + +//===----------------------------------------------------------------------===// +// *IToFPOp Patterns +//===----------------------------------------------------------------------===// + +template +struct IToFPPattern final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(IToFPOp op, + PatternRewriter &rewriter) const override { + FailureOr narrowestWidth = + calculateBitsRequired(op.getIn(), Extension); + if (failed(narrowestWidth)) + return failure(); + + FailureOr narrowTy = + this->getNarrowType(*narrowestWidth, op.getIn().getType()); + if (failed(narrowTy)) + return failure(); + + Value newIn = rewriter.createOrFold(op.getLoc(), *narrowTy, + op.getIn()); + rewriter.replaceOpWithNewOp(op, op.getType(), newIn); + return success(); + } +}; +using SIToFPPattern = IToFPPattern; +using UIToFPPattern = IToFPPattern; + +//===----------------------------------------------------------------------===// +// Pass Definitions +//===----------------------------------------------------------------------===// + +struct ArithIntNarrowingPass final + : impl::ArithIntNarrowingBase { + using ArithIntNarrowingBase::ArithIntNarrowingBase; + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + RewritePatternSet patterns(ctx); + populateArithIntNarrowingPatterns( + patterns, ArithIntNarrowingOptions{bitwidthsSupported}); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Public API +//===----------------------------------------------------------------------===// + +void populateArithIntNarrowingPatterns( + RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { + patterns.add(patterns.getContext(), options); +} + +} // namespace mlir::arith diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \ +// RUN: --verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func.func @sitofp_extsi_i16 +// CHECK-SAME: (%[[ARG:.+]]: i16) +// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : i16 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @sitofp_extsi_i16(%a: i16) -> f16 { + %b = arith.extsi %a : i16 to i32 + %f = arith.sitofp %b : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @sitofp_extsi_vector_i16 +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) +// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : vector<3xi16> to vector<3xf16> +// CHECK-NEXT: return %[[RET]] : vector<3xf16> +func.func @sitofp_extsi_vector_i16(%a: vector<3xi16>) -> vector<3xf16> { + %b = arith.extsi %a : vector<3xi16> to vector<3xi32> + %f = arith.sitofp %b : vector<3xi32> to vector<3xf16> + return %f : vector<3xf16> +} + +// CHECK-LABEL: func.func @sitofp_extsi_tensor_i16 +// CHECK-SAME: (%[[ARG:.+]]: tensor<3x?xi16>) +// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : tensor<3x?xi16> to tensor<3x?xf16> +// CHECK-NEXT: return %[[RET]] : tensor<3x?xf16> +func.func @sitofp_extsi_tensor_i16(%a: tensor<3x?xi16>) -> tensor<3x?xf16> { + %b = arith.extsi %a : tensor<3x?xi16> to tensor<3x?xi32> + %f = arith.sitofp %b : tensor<3x?xi32> to tensor<3x?xf16> + return %f : tensor<3x?xf16> +} + +// Narrowing to i64 is not enabled in pass options. +// +// CHECK-LABEL: func.func @sitofp_extsi_i64 +// CHECK-SAME: (%[[ARG:.+]]: i64) +// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG]] : i64 to i128 +// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXT]] : i128 to f32 +// CHECK-NEXT: return %[[RET]] : f32 +func.func @sitofp_extsi_i64(%a: i64) -> f32 { + %b = arith.extsi %a : i64 to i128 + %f = arith.sitofp %b : i128 to f32 + return %f : f32 +} + +// CHECK-LABEL: func.func @uitofp_extui_i16 +// CHECK-SAME: (%[[ARG:.+]]: i16) +// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[ARG]] : i16 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @uitofp_extui_i16(%a: i16) -> f16 { + %b = arith.extui %a : i16 to i32 + %f = arith.uitofp %b : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @sitofp_extsi_extsi_i8 +// CHECK-SAME: (%[[ARG:.+]]: i8) +// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : i8 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @sitofp_extsi_extsi_i8(%a: i8) -> f16 { + %b = arith.extsi %a : i8 to i16 + %c = arith.extsi %b : i16 to i32 + %f = arith.sitofp %c : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @uitofp_extui_extui_i8 +// CHECK-SAME: (%[[ARG:.+]]: i8) +// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[ARG]] : i8 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @uitofp_extui_extui_i8(%a: i8) -> f16 { + %b = arith.extui %a : i8 to i16 + %c = arith.extui %b : i16 to i32 + %f = arith.uitofp %c : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @uitofp_extsi_extui_i8 +// CHECK-SAME: (%[[ARG:.+]]: i8) +// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG]] : i8 to i16 +// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXT]] : i16 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @uitofp_extsi_extui_i8(%a: i8) -> f16 { + %b = arith.extsi %a : i8 to i16 + %c = arith.extui %b : i16 to i32 + %f = arith.uitofp %c : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @uitofp_trunci_extui_i8 +// CHECK-SAME: (%[[ARG:.+]]: i16) +// CHECK-NEXT: %[[TR:.+]] = arith.trunci %[[ARG]] : i16 to i8 +// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[TR]] : i8 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @uitofp_trunci_extui_i8(%a: i16) -> f16 { + %b = arith.trunci %a : i16 to i8 + %c = arith.extui %b : i8 to i32 + %f = arith.uitofp %c : i32 to f16 + return %f : f16 +} + +// This should not be folded because arith.extui changes the signed +// range of the number. For example: +// extsi -1 : i16 to i32 ==> -1 +// extui -1 : i16 to i32 ==> U16_MAX +// +/// CHECK-LABEL: func.func @sitofp_extui_i16 +// CHECK-SAME: (%[[ARG:.+]]: i16) +// CHECK-NEXT: %[[EXT:.+]] = arith.extui %[[ARG]] : i16 to i32 +// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXT]] : i32 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @sitofp_extui_i16(%a: i16) -> f16 { + %b = arith.extui %a : i16 to i32 + %f = arith.sitofp %b : i32 to f16 + return %f : f16 +} + +// This should not be folded because arith.extsi changes the unsigned +// range of the number. For example: +// extsi -1 : i16 to i32 ==> U32_MAX +// extui -1 : i16 to i32 ==> U16_MAX +// +// CHECK-LABEL: func.func @uitofp_extsi_i16 +// CHECK-SAME: (%[[ARG:.+]]: i16) +// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG]] : i16 to i32 +// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXT]] : i32 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @uitofp_extsi_i16(%a: i16) -> f16 { + %b = arith.extsi %a : i16 to i32 + %f = arith.uitofp %b : i32 to f16 + return %f : f16 +}