diff --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h @@ -99,6 +99,7 @@ Value _and(Value lhs, Value rhs); Value add(Value lhs, Value rhs); + Value sub(Value lhs, Value rhs); Value mul(Value lhs, Value rhs); Value select(Value cmp, Value lhs, Value rhs); Value sgt(Value lhs, Value rhs); diff --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp @@ -93,24 +93,29 @@ return b.create(loc, lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { - if (lhs.getType().isa()) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs); + if (lhs.getType().isa()) + return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs); +} +Value ArithBuilder::sub(Value lhs, Value rhs) { + if (lhs.getType().isa()) + return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs); } Value ArithBuilder::mul(Value lhs, Value rhs) { - if (lhs.getType().isa()) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs); + if (lhs.getType().isa()) + return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs); } Value ArithBuilder::sgt(Value lhs, Value rhs) { - if (lhs.getType().isa()) - return b.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); - return b.create(loc, arith::CmpFPredicate::OGT, lhs, rhs); + if (lhs.getType().isa()) + return b.create(loc, arith::CmpFPredicate::OGT, lhs, rhs); + return b.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { - if (lhs.getType().isa()) - return b.create(loc, arith::CmpIPredicate::slt, lhs, rhs); - return b.create(loc, arith::CmpFPredicate::OLT, lhs, rhs); + if (lhs.getType().isa()) + return b.create(loc, arith::CmpFPredicate::OLT, lhs, rhs); + return b.create(loc, arith::CmpIPredicate::slt, lhs, rhs); } Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { return b.create(loc, cmp, lhs, rhs);