diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -95,6 +95,24 @@ const std::string f64Func; }; +namespace gpu { +/// Returns a predicate to be used with addDynamicallyLegalOp. The predicate +/// returns false for calls to the provided intrinsics and true otherwise. +inline std::function +filterIllegalLLVMIntrinsics(ArrayRef intrinsics, MLIRContext *ctx) { + SmallVector illegalIds(intrinsics.begin(), intrinsics.end()); + return [illegalIds](Operation *op) -> bool { + LLVM::CallOp callOp = dyn_cast(op); + if (!callOp || !callOp.callee()) + return true; + StringRef callee = callOp.callee().getValue(); + return !llvm::any_of(illegalIds, [callee](StringRef intrinsic) { + return callee.equals(intrinsic); + }); + }; +} +} // namespace gpu + } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -714,6 +714,8 @@ target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); + target.addDynamicallyLegalOp( + gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext())); // TODO(csigg): Remove once we support replacing non-root ops. target.addLegalOp(); if (failed(applyPartialConversion(m, target, patterns, &converter))) diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -65,8 +65,9 @@ target.addLegalDialect(); target.addIllegalOp(); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + target.addDynamicallyLegalOp( + gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext())); + target.addIllegalOp(); if (failed(applyPartialConversion(m, target, patterns, &converter))) signalPassFailure(); }