diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -216,6 +216,93 @@ return calculateBitsRequired(value.getType()); } +/// Base pattern for arith binary ops. +/// Example: +/// ``` +/// %lhs = arith.extsi %a : i8 to i32 +/// %rhs = arith.extsi %b : i8 to i32 +/// %r = arith.addi %lhs, %rhs : i32 +/// ==> +/// %lhs = arith.extsi %a : i8 to i16 +/// %rhs = arith.extsi %b : i8 to i16 +/// %add = arith.addi %lhs, %rhs : i16 +/// %r = arith.extsi %add : i16 to i32 +/// ``` +template +struct BinaryOpNarrowingPattern : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + /// Returns the number of bits required to represent the full result, assuming + /// that both operands are `operandBits`-wide. Derived classes must implement + /// this, taking into account `BinaryOp` semantics. + virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0; + + LogicalResult matchAndRewrite(BinaryOp op, + PatternRewriter &rewriter) const final { + Type origTy = op.getType(); + FailureOr resultBits = calculateBitsRequired(origTy); + if (failed(resultBits)) + return failure(); + + // For the optimization to apply, we expect the lhs to be an extension op, + // and for the rhs to either be the same extension op or a constant. + FailureOr ext = ExtensionOp::from(op.getLhs().getDefiningOp()); + if (failed(ext)) + return failure(); + + FailureOr lhsBitsRequired = + calculateBitsRequired(ext->getIn(), ext->getKind()); + if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits) + return failure(); + + FailureOr rhsBitsRequired = + calculateBitsRequired(op.getRhs(), ext->getKind()); + if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits) + return failure(); + + // Negotiate a common bit requirements for both lhs and rhs, accounting for + // the result requiring more bits than the operands. + unsigned commonBitsRequired = + getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired)); + FailureOr narrowTy = this->getNarrowType(commonBitsRequired, origTy); + if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits) + return failure(); + + Location loc = op.getLoc(); + Value newLhs = + rewriter.createOrFold(loc, *narrowTy, op.getLhs()); + Value newRhs = + rewriter.createOrFold(loc, *narrowTy, op.getRhs()); + Value newAdd = rewriter.create(loc, newLhs, newRhs); + ext->recreateAndReplace(rewriter, op, newAdd); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// AddIOp Pattern +//===----------------------------------------------------------------------===// + +struct AddIPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits + 1; + } +}; + +//===----------------------------------------------------------------------===// +// MulIOp Pattern +//===----------------------------------------------------------------------===// + +struct MulIPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + unsigned getResultBitsProduced(unsigned operandBits) const override { + return 2 * operandBits; + } +}; + //===----------------------------------------------------------------------===// // *IToFPOp Patterns //===----------------------------------------------------------------------===// @@ -538,7 +625,8 @@ ExtensionOverTranspose, ExtensionOverFlatTranspose>( patterns.getContext(), options, PatternBenefit(2)); - patterns.add(patterns.getContext(), options); + patterns.add( + patterns.getContext(), options); } } // namespace mlir::arith diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -1,6 +1,188 @@ -// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \ +// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,24,32" \ // RUN: --verify-diagnostics %s | FileCheck %s +//===----------------------------------------------------------------------===// +// arith.addi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @addi_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.addi %a, %b : i32 + return %r : i32 +} + +// CHECK-LABEL: func.func @addi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.addi %a, %b : i32 + return %r : i32 +} + +// arith.addi produces one more bit of result than the operand bitwidth. +// +// CHECK-LABEL: func.func @addi_extsi_i24 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @addi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.addi %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because of mixed extensions. +// +// CHECK-LABEL: func.func @addi_mixed_ext_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[ADD]] : i32 +func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.addi %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because we cannot reduce the bitwidth +// below i16, given the pass options set. +// +// CHECK-LABEL: func.func @addi_extsi_i16 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i16 +// CHECK-NEXT: return %[[ADD]] : i16 +func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 { + %a = arith.extsi %lhs : i8 to i16 + %b = arith.extsi %rhs : i8 to i16 + %r = arith.addi %a, %b : i16 + return %r : i16 +} + +// CHECK-LABEL: func.func @addi_extsi_3xi8_cst +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16> +// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32> +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16> +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[CST]] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> { + %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32> + %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32> + %r = arith.addi %a, %cst : vector<3xi32> + return %r : vector<3xi32> +} + +//===----------------------------------------------------------------------===// +// arith.muli +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @muli_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @muli_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.muli %a, %b : i32 + return %r : i32 +} + +// CHECK-LABEL: func.func @muli_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.muli %a, %b : i32 + return %r : i32 +} + +// We do not expect this case to be optimized because given n-bit operands, +// arith.muli produces 2n bits of result. +// +// CHECK-LABEL: func.func @muli_extsi_i32 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[LHS:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[RHS:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[RET:.+]] = arith.muli %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.muli %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because of mixed extensions. +// +// CHECK-LABEL: func.func @muli_mixed_ext_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[MUL]] : i32 +func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.muli %a, %b : i32 + return %r : i32 +} + +// CHECK-LABEL: func.func @muli_extsi_3xi8_cst +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16> +// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32> +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16> +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[CST]] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> { + %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32> + %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32> + %r = arith.muli %a, %cst : vector<3xi32> + return %r : vector<3xi32> +} + //===----------------------------------------------------------------------===// // arith.*itofp //===----------------------------------------------------------------------===//