diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -95,24 +95,6 @@ const std::string f64Func; }; -namespace gpu { -/// Returns a predicate to be used with addDynamicallyLegalOp. The predicate -/// returns false for calls to the provided intrinsics and true otherwise. -inline std::function -filterIllegalLLVMIntrinsics(ArrayRef intrinsics, MLIRContext *ctx) { - SmallVector illegalIds(intrinsics.begin(), intrinsics.end()); - return [illegalIds](Operation *op) -> bool { - LLVM::CallOp callOp = dyn_cast(op); - if (!callOp || !callOp.callee()) - return true; - StringRef callee = callOp.callee().getValue(); - return !llvm::any_of(illegalIds, [callee](StringRef intrinsic) { - return callee.equals(intrinsic); - }); - }; -} -} // namespace gpu - } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -279,8 +279,6 @@ LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>(); target.addIllegalOp(); target.addLegalDialect(); - target.addDynamicallyLegalOp( - gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext())); // TODO(csigg): Remove once we support replacing non-root ops. target.addLegalOp(); if (failed(applyPartialConversion(m, target, patterns, &converter))) diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -71,8 +71,6 @@ target.addLegalDialect(); target.addIllegalOp(); - target.addDynamicallyLegalOp( - gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext())); target.addIllegalOp(); if (failed(applyPartialConversion(m, target, patterns, &converter))) signalPassFailure(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1737,56 +1737,6 @@ } }; -// A `tanh` is converted into a call to the `tanh` function. -struct TanhOpLowering : public LLVMLegalizationPattern { - using LLVMLegalizationPattern::LLVMLegalizationPattern; - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - - using LLVMFuncOpT = LLVM::LLVMFuncOp; - using LLVMTypeT = LLVM::LLVMType; - - OperandAdaptor transformed(operands); - LLVMTypeT operandType = - transformed.operand().getType().dyn_cast(); - - if (!operandType) - return failure(); - - std::string functionName; - if (operandType.isFloatTy()) - functionName = "tanhf"; - else if (operandType.isDoubleTy()) - functionName = "tanh"; - else - return failure(); - - // Get a reference to the tanh function, inserting it if necessary. - Operation *tanhFunc = - SymbolTable::lookupNearestSymbolFrom(op, functionName); - - LLVMFuncOpT tanhLLVMFunc; - if (tanhFunc) { - tanhLLVMFunc = cast(tanhFunc); - } else { - PatternRewriter::InsertionGuard insertGuard(rewriter); - auto module = op->getParentOfType(); - rewriter.setInsertionPointToStart(module.getBody()); - tanhLLVMFunc = rewriter.create( - module.getLoc(), functionName, - LLVMTypeT::getFunctionTy(operandType, operandType, - /*isVarArg=*/false)); - } - - rewriter.replaceOpWithNewOp( - op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc), - transformed.operand()); - return success(); - } -}; - struct MemRefCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; @@ -2833,7 +2783,6 @@ SqrtOpLowering, SubFOpLowering, SubIOpLowering, - TanhOpLowering, TruncateIOpLowering, UnsignedDivIOpLowering, UnsignedRemIOpLowering, @@ -3022,6 +2971,7 @@ : ConversionTarget(ctx) { this->addLegalDialect(); this->addIllegalOp(); + this->addIllegalOp(); } std::unique_ptr> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -407,43 +407,39 @@ // CHECK-NEXT: %2 = llvm.icmp "slt" %arg2, %1 : !llvm.i32 %2 = cmpi "slt", %arg2, %1 : i32 // CHECK-NEXT: %3 = llvm.sdiv %arg2, %arg3 : !llvm.i32 - %4 = divi_signed %arg2, %arg3 : i32 + %3 = divi_signed %arg2, %arg3 : i32 // CHECK-NEXT: %4 = llvm.udiv %arg2, %arg3 : !llvm.i32 - %5 = divi_unsigned %arg2, %arg3 : i32 + %4 = divi_unsigned %arg2, %arg3 : i32 // CHECK-NEXT: %5 = llvm.srem %arg2, %arg3 : !llvm.i32 - %6 = remi_signed %arg2, %arg3 : i32 + %5 = remi_signed %arg2, %arg3 : i32 // CHECK-NEXT: %6 = llvm.urem %arg2, %arg3 : !llvm.i32 - %7 = remi_unsigned %arg2, %arg3 : i32 + %6 = remi_unsigned %arg2, %arg3 : i32 // CHECK-NEXT: %7 = llvm.select %2, %arg2, %arg3 : !llvm.i1, !llvm.i32 - %8 = select %2, %arg2, %arg3 : i32 + %7 = select %2, %arg2, %arg3 : i32 // CHECK-NEXT: %8 = llvm.fdiv %arg0, %arg1 : !llvm.float - %9 = divf %arg0, %arg1 : f32 + %8 = divf %arg0, %arg1 : f32 // CHECK-NEXT: %9 = llvm.frem %arg0, %arg1 : !llvm.float - %10 = remf %arg0, %arg1 : f32 + %9 = remf %arg0, %arg1 : f32 // CHECK-NEXT: %10 = llvm.and %arg2, %arg3 : !llvm.i32 - %11 = and %arg2, %arg3 : i32 + %10 = and %arg2, %arg3 : i32 // CHECK-NEXT: %11 = llvm.or %arg2, %arg3 : !llvm.i32 - %12 = or %arg2, %arg3 : i32 + %11 = or %arg2, %arg3 : i32 // CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32 - %13 = xor %arg2, %arg3 : i32 + %12 = xor %arg2, %arg3 : i32 // CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float - %14 = std.exp %arg0 : f32 -// CHECK-NEXT: %14 = llvm.call @tanhf(%arg0) : (!llvm.float) -> !llvm.float - %15 = std.tanh %arg0 : f32 -// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double - %16 = constant 7.9e-01 : f64 -// CHECK-NEXT: %16 = llvm.call @tanh(%15) : (!llvm.double) -> !llvm.double - %17 = std.tanh %16 : f64 -// CHECK-NEXT: %17 = llvm.shl %arg2, %arg3 : !llvm.i32 - %18 = shift_left %arg2, %arg3 : i32 -// CHECK-NEXT: %18 = llvm.ashr %arg2, %arg3 : !llvm.i32 - %19 = shift_right_signed %arg2, %arg3 : i32 -// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32 - %20 = shift_right_unsigned %arg2, %arg3 : i32 + %13 = std.exp %arg0 : f32 +// CHECK-NEXT: %14 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double + %14 = constant 7.9e-01 : f64 +// CHECK-NEXT: %15 = llvm.shl %arg2, %arg3 : !llvm.i32 + %15 = shift_left %arg2, %arg3 : i32 +// CHECK-NEXT: %16 = llvm.ashr %arg2, %arg3 : !llvm.i32 + %16 = shift_right_signed %arg2, %arg3 : i32 +// CHECK-NEXT: %17 = llvm.lshr %arg2, %arg3 : !llvm.i32 + %17 = shift_right_unsigned %arg2, %arg3 : i32 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float - %21 = std.sqrt %arg0 : f32 + %18 = std.sqrt %arg0 : f32 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double - %22 = std.sqrt %arg4 : f64 + %19 = std.sqrt %arg4 : f64 return %0, %4 : f32, i32 } @@ -853,22 +849,6 @@ // ----- -module { - func @check_tanh_func_added_only_once_to_symbol_table(%f: f32, %lf: f64) -> () { - %f0 = std.tanh %f : f32 - %f1 = std.tanh %f0 : f32 - %lf0 = std.tanh %lf : f64 - %lf1 = std.tanh %lf0 : f64 - return - } -// CHECK: module { -// CHECK: llvm.func @tanh(!llvm.double) -> !llvm.double -// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float -// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table -} - -// ----- - // CHECK-LABEL: func @atomic_rmw func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) { atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32