diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -17,6 +17,7 @@ #define MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" @@ -74,19 +75,18 @@ /// Helper struct to build simple arithmetic quantities with minimal type /// inference support. -struct ArithBuilder { - ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {} +struct ArithBuilder : DialectBuilder { + ArithBuilder(OpBuilder &b, Location loc) : DialectBuilder(b, loc) {} + ArithBuilder(ImplicitLocOpBuilder &lb) : DialectBuilder(lb) {} + ArithBuilder(DialectBuilder &db) : DialectBuilder(db) {} Value _and(Value lhs, Value rhs); Value add(Value lhs, Value rhs); + Value sub(Value lhs, Value rhs); Value mul(Value lhs, Value rhs); Value select(Value cmp, Value lhs, Value rhs); Value sgt(Value lhs, Value rhs); Value slt(Value lhs, Value rhs); - -private: - OpBuilder &b; - Location loc; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/ImplicitLocOpBuilder.h b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h --- a/mlir/include/mlir/IR/ImplicitLocOpBuilder.h +++ b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h @@ -109,6 +109,19 @@ Location curLoc; }; +struct DialectBuilder { + DialectBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {} + DialectBuilder(ImplicitLocOpBuilder &lb) : b(lb), loc(lb.getLoc()) {} + DialectBuilder(DialectBuilder &db) : b(db.b), loc(db.loc) {} + + OpBuilder &getBuilder() { return b; } + Location getLoc() { return loc; } + +protected: + OpBuilder &b; + Location loc; +}; + } // namespace mlir #endif // MLIR_IR_IMPLICITLOCOPBUILDER_H diff --git a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp --- a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp +++ b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp @@ -57,6 +57,11 @@ return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } +Value ArithBuilder::sub(Value lhs, Value rhs) { + if (lhs.getType().isa()) + return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs); +} Value ArithBuilder::mul(Value lhs, Value rhs) { if (lhs.getType().isa()) return b.create(loc, lhs, rhs);