diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -230,8 +230,12 @@ /// purpose of type conversions. class LLVMOpLowering : public ConversionPattern { public: + /// Returns the default benefit value of all LLVM op lowering patterns. + static PatternBenefit getDefaultBenefit() { return PatternBenefit(1); } + LLVMOpLowering(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &lowering, PatternBenefit benefit = 1); + LLVMTypeConverter &lowering, + PatternBenefit benefit = getDefaultBenefit()); protected: // Back-reference to the lowering class, used to call type and function diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -41,6 +41,9 @@ // corresponding pattern isImpossibleToMatch() then this aborts. unsigned short getBenefit() const; + /// Increases the current benefit by one (if possible). + PatternBenefit increase() const; + bool operator==(const PatternBenefit &rhs) const { return representation == rhs.representation; } diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -29,9 +29,10 @@ struct OpToFuncCallLowering : public LLVMOpLowering { public: explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, - StringRef f64Func) + StringRef f64Func, + PatternBenefit benefit = getDefaultBenefit()) : LLVMOpLowering(SourceOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), + lowering_.getDialect()->getContext(), lowering_, benefit), f32Func(f32Func), f64Func(f64Func) {} PatternMatchResult diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -746,8 +746,9 @@ "__nv_cos"); patterns.insert>(converter, "__nv_expf", "__nv_exp"); - patterns.insert>(converter, "__nv_tanhf", - "__nv_tanh"); + patterns.insert>( + converter, "__nv_tanhf", "__nv_tanh", + LLVMOpLowering::getDefaultBenefit().increase()); } std::unique_ptr> diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -58,8 +58,9 @@ "__ocml_cos_f64"); patterns.insert>(converter, "__ocml_exp_f32", "__ocml_exp_f64"); - patterns.insert>(converter, "__ocml_tanh_f32", - "__ocml_tanh_f64"); + patterns.insert>( + converter, "__ocml_tanh_f32", "__ocml_tanh_f64", + LLVMOpLowering::getDefaultBenefit().increase()); ConversionTarget target(getContext()); target.addLegalDialect(); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -24,6 +24,12 @@ return representation; } +PatternBenefit PatternBenefit::increase() const { + if (representation + 1 < ImpossibleToMatchSentinel) + return PatternBenefit(representation + 1); + return *this; +} + //===----------------------------------------------------------------------===// // Pattern implementation //===----------------------------------------------------------------------===//