diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -634,7 +634,7 @@ // MaxFOp //===----------------------------------------------------------------------===// -def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf", [Commutative]> { +def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf"> { let summary = "floating-point maximum operation"; let description = [{ Syntax: @@ -678,7 +678,7 @@ // MinFOp //===----------------------------------------------------------------------===// -def Arith_MinFOp : Arith_FloatBinaryOp<"minf", [Commutative]> { +def Arith_MinFOp : Arith_FloatBinaryOp<"minf"> { let summary = "floating-point minimum operation"; let description = [{ Syntax: diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -612,6 +612,12 @@ if (getLhs() == getRhs()) return getRhs(); + // maxf(c,x) -> maxf(x,c) + if (operands.front() && !operands.back()) { + std::swap(getOperation()->getOpOperand(0), getOperation()->getOpOperand(1)); + return getResult(); + } + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); @@ -682,6 +688,12 @@ if (getLhs() == getRhs()) return getRhs(); + // minf(c,x) -> minf(x,c) + if (operands.front() && !operands.back()) { + std::swap(getOperation()->getOpOperand(0), getOperation()->getOpOperand(1)); + return getResult(); + } + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });