diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -360,6 +360,28 @@ } }; +//===----------------------------------------------------------------------===// +// Min/Max Patterns +//===----------------------------------------------------------------------===// + +template +struct MinMaxPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to signed arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == Kind; + } + + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits; + } +}; +using MaxSIPattern = MinMaxPattern; +using MaxUIPattern = MinMaxPattern; +using MinSIPattern = MinMaxPattern; +using MinUIPattern = MinMaxPattern; + //===----------------------------------------------------------------------===// // *IToFPOp Patterns //===----------------------------------------------------------------------===// @@ -548,7 +570,8 @@ patterns.getContext(), options, PatternBenefit(2)); patterns.add( + DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern, + MinUIPattern, SIToFPPattern, UIToFPPattern>( patterns.getContext(), options); } diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -495,6 +495,134 @@ return %f : f16 } +//===----------------------------------------------------------------------===// +// arith.maxsi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @maxsi_extsi_i8 +// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) +// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[LHS]], %[[RHS]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MAX]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @maxsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.maxsi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.maxsi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @maxsi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[MAX]] : i32 +func.func @maxsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.maxsi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.maxui +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @maxui_extui_i8 +// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) +// CHECK-NEXT: %[[MAX:.+]] = arith.maxui %[[LHS]], %[[RHS]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MAX]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @maxui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.maxui %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.maxsi` ops with zero-extended +// arguments. +// +// CHECK-LABEL: func.func @maxui_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[MAX:.+]] = arith.maxui %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[MAX]] : i32 +func.func @maxui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.maxui %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.minsi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @minsi_extsi_i8 +// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) +// CHECK-NEXT: %[[min:.+]] = arith.minsi %[[LHS]], %[[RHS]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[min]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @minsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.minsi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.minsi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @minsi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[min:.+]] = arith.minsi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[min]] : i32 +func.func @minsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.minsi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.minui +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @minui_extui_i8 +// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) +// CHECK-NEXT: %[[min:.+]] = arith.minui %[[LHS]], %[[RHS]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[min]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @minui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.minui %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.minsi` ops with zero-extended +// arguments. +// +// CHECK-LABEL: func.func @minui_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[min:.+]] = arith.minui %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[min]] : i32 +func.func @minui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.minui %a, %b : i32 + return %r : i32 +} + //===----------------------------------------------------------------------===// // Commute Extension over Vector Ops //===----------------------------------------------------------------------===//