diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -47,6 +47,10 @@ /// Create a pass to legalize Arith ops. std::unique_ptr createArithExpandOpsPass(); +/// Create a pass to legalize Arith ops with specified configuration. +std::unique_ptr +createArithExpandOpsPass(const ArithExpandOpsOptions &options); + /// Create a pass to replace signed ops with unsigned ones where they are proven /// equivalent. std::unique_ptr createArithUnsignedWhenEquivalentPass(); diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -304,6 +304,11 @@ struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsBase { + ArithExpandOpsPass() = default; + ArithExpandOpsPass(const arith::ArithExpandOpsOptions& options) { + this->includeBf16 = options.includeBf16; + } + void runOnOperation() override { RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); @@ -371,3 +376,8 @@ std::unique_ptr mlir::arith::createArithExpandOpsPass() { return std::make_unique(); } + +std::unique_ptr mlir::arith::createArithExpandOpsPass( + const ArithExpandOpsOptions& options) { + return std::make_unique(options); +}