diff --git a/mlir/docs/Dialects/Standard.md b/mlir/docs/Dialects/Standard.md --- a/mlir/docs/Dialects/Standard.md +++ b/mlir/docs/Dialects/Standard.md @@ -587,6 +587,25 @@ scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. +### 'sqrt' operation + +Syntax: + +``` +operation ::= ssa-id `=` `sqrt` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar square root value. +%a = sqrt %b : f64 +// SIMD vector element-wise square root value. +%f = sqrt %g : vector<4xf32> +// Tensor element-wise square root value. +%x = sqrt %y : tensor<4x?xf32> +``` + ### 'tanh' operation Syntax: diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -716,6 +716,7 @@ def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">; def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">; def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">; +def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">; def LLVM_LogOp : LLVM_Op<"intr.log", [NoSideEffect]>, Arguments<(ins LLVM_Type:$in)>, diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1402,6 +1402,16 @@ let hasCanonicalizer = 1; } +def SqrtOp : FloatUnaryOp<"sqrt"> { + let summary = "sqrt of the specified value"; + let description = [{ + The `sqrt` operation computes the square root. It takes one operand and + returns one result of the same type. This type may be a float scalar type, a + vector whose element type is float, or a tensor of floats. It has no standard + attributes. + }]; +} + def TanhOp : FloatUnaryOp<"tanh"> { let summary = "hyperbolic tangent of the specified value"; let description = [{ diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -807,6 +807,9 @@ : public BinaryOpLLVMOpLowering { using Super::Super; }; +struct SqrtOpLowering : public UnaryOpLLVMOpLowering { + using Super::Super; +}; struct UnsignedDivIOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; @@ -2108,6 +2111,7 @@ SignedShiftRightOpLowering, SplatOpLowering, SplatNdOpLowering, + SqrtOpLowering, SubFOpLowering, SubIOpLowering, TanhOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -398,8 +398,8 @@ } // CHECK-LABEL: @ops -func @ops(f32, f32, i32, i32) -> (f32, i32) { -^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32): +func @ops(f32, f32, i32, i32, f64) -> (f32, i32) { +^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64): // CHECK-NEXT: %0 = llvm.fsub %arg0, %arg1 : !llvm.float %0 = subf %arg0, %arg1: f32 // CHECK-NEXT: %1 = llvm.sub %arg2, %arg3 : !llvm.i32 @@ -440,7 +440,10 @@ %19 = shift_right_signed %arg2, %arg3 : i32 // CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32 %20 = shift_right_unsigned %arg2, %arg3 : i32 - +// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float + %21 = std.sqrt %arg0 : f32 +// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double + %22 = std.sqrt %arg4 : f64 return %0, %4 : f32, i32 } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -494,6 +494,18 @@ // CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32> %138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32> + // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32 + %139 = "std.sqrt"(%f) : (f32) -> f32 + + // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32 + %140 = sqrt %f : f32 + + // CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32> + %141 = sqrt %vcf32 : vector<4xf32> + + // CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32> + %142 = sqrt %t : tensor<4x4x?xf32> + return } diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -59,6 +59,15 @@ llvm.return } +// CHECK-LABEL: @sqrt_test +llvm.func @sqrt_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) { + // CHECK: call float @llvm.sqrt.f32 + "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float + // CHECK: call <8 x float> @llvm.sqrt.v8f32 + "llvm.intr.sqrt"(%arg1) : (!llvm<"<8 x float>">) -> !llvm<"<8 x float>"> + llvm.return +} + // CHECK-LABEL: @ceil_test llvm.func @ceil_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) { // CHECK: call float @llvm.ceil.f32 @@ -100,6 +109,8 @@ // CHECK: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0 // CHECK: declare float @llvm.fabs.f32(float) // CHECK: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0 +// CHECK: declare float @llvm.sqrt.f32(float) +// CHECK: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0 // CHECK: declare float @llvm.ceil.f32(float) // CHECK: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0 // CHECK: declare float @llvm.cos.f32(float)