diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -237,6 +237,10 @@ /// this, taking into account `BinaryOp` semantics. virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0; + /// Customization point for patterns that should only apply with + /// zero/sign-extension ops as arguments. + virtual bool isSupported(ExtensionOp) const { return true; } + LogicalResult matchAndRewrite(BinaryOp op, PatternRewriter &rewriter) const final { Type origTy = op.getType(); @@ -247,7 +251,7 @@ // For the optimization to apply, we expect the lhs to be an extension op, // and for the rhs to either be the same extension op or a constant. FailureOr ext = ExtensionOp::from(op.getLhs().getDefiningOp()); - if (failed(ext)) + if (failed(ext) || !isSupported(*ext)) return failure(); FailureOr lhsBitsRequired = @@ -291,6 +295,23 @@ } }; +//===----------------------------------------------------------------------===// +// SubIOp Pattern +//===----------------------------------------------------------------------===// + +struct SubIPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to signed arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Sign; + } + + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits + 1; + } +}; + //===----------------------------------------------------------------------===// // MulIOp Pattern //===----------------------------------------------------------------------===// @@ -303,6 +324,42 @@ } }; +//===----------------------------------------------------------------------===// +// DivSIOp Pattern +//===----------------------------------------------------------------------===// + +struct DivSIPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to signed arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Sign; + } + + // Unlike multiplication, division requires only one more result bit. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits + 1; + } +}; + +//===----------------------------------------------------------------------===// +// DivUIOp Pattern +//===----------------------------------------------------------------------===// + +struct DivUIPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to signed arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Zero; + } + + // Unlike multiplication, division requires only one more result bit. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits + 1; + } +}; + //===----------------------------------------------------------------------===// // *IToFPOp Patterns //===----------------------------------------------------------------------===// @@ -490,7 +547,8 @@ ExtensionOverExtractStridedSlice, ExtensionOverInsert>( patterns.getContext(), options, PatternBenefit(2)); - patterns.add( + patterns.add( patterns.getContext(), options); } diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -101,6 +101,75 @@ return %r : vector<3xi32> } +//===----------------------------------------------------------------------===// +// arith.subi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @subi_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @subi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.subi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @subi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because of mixed extensions. +// +// CHECK-LABEL: func.func @subi_mixed_ext_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[ADD]] : i32 +func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// arith.subi produces one more bit of result than the operand bitwidth. +// +// CHECK-LABEL: func.func @subi_extsi_i24 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @subi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + //===----------------------------------------------------------------------===// // arith.muli //===----------------------------------------------------------------------===// @@ -183,6 +252,114 @@ return %r : vector<3xi32> } +//===----------------------------------------------------------------------===// +// arith.divsi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @divsi_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.divsi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @divsi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @divsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +// arith.divsi produces one more bit of result than the operand bitwidth. +// +// CHECK-LABEL: func.func @divsi_extsi_i24 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[ADD:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divsi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.divui +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @divui_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[SUB]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.divui %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.divui` ops with zero-extended +// arguments. +// +// CHECK-LABEL: func.func @divui_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @divui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.divui %a, %b : i32 + return %r : i32 +} + +// arith.divui produces one more bit of result than the operand bitwidth. +// +// CHECK-LABEL: func.func @divui_extui_i24 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[ADD:.+]] = arith.divui %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divui_extui_i24(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extui %lhs : i16 to i32 + %b = arith.extui %rhs : i16 to i32 + %r = arith.divui %a, %b : i32 + return %r : i32 +} + //===----------------------------------------------------------------------===// // arith.*itofp //===----------------------------------------------------------------------===//