diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -13,6 +13,8 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" @@ -221,10 +223,11 @@ RewritePatternSet patterns(&ctx); populateStdExpandOpsPatterns(patterns); + populateExpandTanhPattern(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addDynamicallyLegalOp([](AtomicRMWOp op) { return op.kind() != AtomicRMWKind::maxf && op.kind() != AtomicRMWKind::minf; @@ -234,6 +237,7 @@ }); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -14,6 +14,7 @@ MLIRDialect MLIRIR MLIRPass + MLIRMathTransforms MLIRReduceLib )