diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -237,6 +237,10 @@ /// this, taking into account `BinaryOp` semantics. virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0; + /// Customization point for patterns that should only apply with + /// zero/sign-extension ops as arguments. + virtual bool isSupported(ExtensionOp) const { return true; } + LogicalResult matchAndRewrite(BinaryOp op, PatternRewriter &rewriter) const final { Type origTy = op.getType(); @@ -247,7 +251,7 @@ // For the optimization to apply, we expect the lhs to be an extension op, // and for the rhs to either be the same extension op or a constant. FailureOr ext = ExtensionOp::from(op.getLhs().getDefiningOp()); - if (failed(ext)) + if (failed(ext) || !isSupported(*ext)) return failure(); FailureOr lhsBitsRequired = @@ -286,6 +290,27 @@ struct AddIPattern final : BinaryOpNarrowingPattern { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + // Addition may require one extra bit for the result. + // Example: `UINT8_MAX + 1 == 255 + 1 == 256`. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits + 1; + } +}; + +//===----------------------------------------------------------------------===// +// SubIOp Pattern +//===----------------------------------------------------------------------===// + +struct SubIPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to signed arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Sign; + } + + // Subtraction may require one extra bit for the result. + // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`. unsigned getResultBitsProduced(unsigned operandBits) const override { return operandBits + 1; } @@ -298,11 +323,50 @@ struct MulIPattern final : BinaryOpNarrowingPattern { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + // Multiplication may require up double the operand bits. + // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`. unsigned getResultBitsProduced(unsigned operandBits) const override { return 2 * operandBits; } }; +//===----------------------------------------------------------------------===// +// DivSIOp Pattern +//===----------------------------------------------------------------------===// + +struct DivSIPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to signed arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Sign; + } + + // Unlike multiplication, signed division requires only one more result bit. + // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits + 1; + } +}; + +//===----------------------------------------------------------------------===// +// DivUIOp Pattern +//===----------------------------------------------------------------------===// + +struct DivUIPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to unsigned arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Zero; + } + + // Unsigned division does not require any extra result bits. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits; + } +}; + //===----------------------------------------------------------------------===// // *IToFPOp Patterns //===----------------------------------------------------------------------===// @@ -625,7 +689,8 @@ ExtensionOverTranspose, ExtensionOverFlatTranspose>( patterns.getContext(), options, PatternBenefit(2)); - patterns.add( + patterns.add( patterns.getContext(), options); } diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -101,6 +101,75 @@ return %r : vector<3xi32> } +//===----------------------------------------------------------------------===// +// arith.subi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @subi_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @subi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.subi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @subi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because of mixed extensions. +// +// CHECK-LABEL: func.func @subi_mixed_ext_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[ADD]] : i32 +func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// arith.subi produces one more bit of result than the operand bitwidth. +// +// CHECK-LABEL: func.func @subi_extsi_i24 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @subi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + //===----------------------------------------------------------------------===// // arith.muli //===----------------------------------------------------------------------===// @@ -183,6 +252,92 @@ return %r : vector<3xi32> } +//===----------------------------------------------------------------------===// +// arith.divsi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @divsi_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.divsi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @divsi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @divsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +// arith.divsi produces one more bit of result than the operand bitwidth. +// +// CHECK-LABEL: func.func @divsi_extsi_i24 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[ADD:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divsi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.divui +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @divui_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[ARG0]], %[[ARG1]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[SUB]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.divui %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.divui` ops with zero-extended +// arguments. +// +// CHECK-LABEL: func.func @divui_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @divui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.divui %a, %b : i32 + return %r : i32 +} + //===----------------------------------------------------------------------===// // arith.*itofp //===----------------------------------------------------------------------===//