diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -17,6 +17,7 @@ #define MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" @@ -74,19 +75,18 @@ /// Helper struct to build simple arithmetic quantities with minimal type /// inference support. -struct ArithBuilder { - ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {} +class ArithBuilder : public ImplicitLocOpBuilder { +public: + ArithBuilder(OpBuilder &b, Location loc) : ImplicitLocOpBuilder(loc, b) {} + ArithBuilder(ImplicitLocOpBuilder &lb) : ImplicitLocOpBuilder(lb) {} Value _and(Value lhs, Value rhs); Value add(Value lhs, Value rhs); + Value sub(Value lhs, Value rhs); Value mul(Value lhs, Value rhs); Value select(Value cmp, Value lhs, Value rhs); Value sgt(Value lhs, Value rhs); Value slt(Value lhs, Value rhs); - -private: - OpBuilder &b; - Location loc; }; } // end namespace mlir diff --git a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp --- a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp +++ b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp @@ -50,28 +50,33 @@ } Value ArithBuilder::_and(Value lhs, Value rhs) { - return b.create(loc, lhs, rhs); + return create(lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { if (lhs.getType().isa()) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs); + return create(lhs, rhs); + return create(lhs, rhs); +} +Value ArithBuilder::sub(Value lhs, Value rhs) { + if (lhs.getType().isa()) + return create(lhs, rhs); + return create(lhs, rhs); } Value ArithBuilder::mul(Value lhs, Value rhs) { if (lhs.getType().isa()) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs); + return create(lhs, rhs); + return create(lhs, rhs); } Value ArithBuilder::sgt(Value lhs, Value rhs) { if (lhs.getType().isa()) - return b.create(loc, CmpIPredicate::sgt, lhs, rhs); - return b.create(loc, CmpFPredicate::OGT, lhs, rhs); + return create(CmpIPredicate::sgt, lhs, rhs); + return create(CmpFPredicate::OGT, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { if (lhs.getType().isa()) - return b.create(loc, CmpIPredicate::slt, lhs, rhs); - return b.create(loc, CmpFPredicate::OLT, lhs, rhs); + return create(CmpIPredicate::slt, lhs, rhs); + return create(CmpFPredicate::OLT, lhs, rhs); } Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { - return b.create(loc, cmp, lhs, rhs); + return create(cmp, lhs, rhs); }