diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -153,48 +153,6 @@ } }; -template -struct MaxMinFOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const final { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - Location loc = op.getLoc(); - // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). - static_assert(pred == arith::CmpFPredicate::UGT || - pred == arith::CmpFPredicate::ULT, - "pred must be either UGT or ULT"); - Value cmp = rewriter.create(loc, pred, lhs, rhs); - Value select = rewriter.create(loc, cmp, lhs, rhs); - - // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. - Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, - rhs, rhs); - rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); - return success(); - } -}; - -template -struct MaxMinIOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const final { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - Location loc = op.getLoc(); - Value cmp = rewriter.create(loc, pred, lhs, rhs); - rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); - return success(); - } -}; - struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsBase { void runOnOperation() override { @@ -208,13 +166,7 @@ target.addIllegalOp< arith::CeilDivSIOp, arith::CeilDivUIOp, - arith::FloorDivSIOp, - arith::MaxFOp, - arith::MaxSIOp, - arith::MaxUIOp, - arith::MinFOp, - arith::MinSIOp, - arith::MinUIOp + arith::FloorDivSIOp >(); // clang-format on if (failed(applyPartialConversion(getOperation(), target, @@ -234,16 +186,6 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); - // clang-format off - patterns.add< - MaxMinFOpConverter, - MaxMinFOpConverter, - MaxMinIOpConverter, - MaxMinIOpConverter, - MaxMinIOpConverter, - MaxMinIOpConverter - >(patterns.getContext()); - // clang-format on } std::unique_ptr mlir::arith::createArithExpandOpsPass() { diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -145,88 +145,3 @@ // CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index // CHECK: [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : index } - -// ----- - -// CHECK-LABEL: func @maxf -func.func @maxf(%a: f32, %b: f32) -> f32 { - %result = arith.maxf %a, %b : f32 - return %result : f32 -} -// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 -// CHECK-NEXT: return %[[RESULT]] : f32 - -// ----- - -// CHECK-LABEL: func @maxf_vector -func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> { - %result = arith.maxf %a, %b : vector<4xf16> - return %result : vector<4xf16> -} -// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] -// CHECK-NEXT: return %[[RESULT]] : vector<4xf16> - -// ----- - -// CHECK-LABEL: func @minf -func.func @minf(%a: f32, %b: f32) -> f32 { - %result = arith.minf %a, %b : f32 - return %result : f32 -} -// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 -// CHECK-NEXT: return %[[RESULT]] : f32 - - -// ----- - -// CHECK-LABEL: func @maxsi -func.func @maxsi(%a: i32, %b: i32) -> i32 { - %result = arith.maxsi %a, %b : i32 - return %result : i32 -} -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32 - -// ----- - -// CHECK-LABEL: func @minsi -func.func @minsi(%a: i32, %b: i32) -> i32 { - %result = arith.minsi %a, %b : i32 - return %result : i32 -} -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32 - - -// ----- - -// CHECK-LABEL: func @maxui -func.func @maxui(%a: i32, %b: i32) -> i32 { - %result = arith.maxui %a, %b : i32 - return %result : i32 -} -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32 - - -// ----- - -// CHECK-LABEL: func @minui -func.func @minui(%a: i32, %b: i32) -> i32 { - %result = arith.minui %a, %b : i32 - return %result : i32 -} -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32