diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1086,6 +1086,16 @@ let summary = "floating point division remainder operation"; } +def RsqrtOp : FloatUnaryOp<"rsqrt"> { + let summary = "reciprocal of sqrt (1 / sqrt of the specified value)"; + let description = [{ + The `rsqrt` operation computes the reciprocal of the square root. It takes + one operand and returns one result of the same type. This type may be a + float scalar type, a vector whose element type is float, or a tensor of + floats. It has no standard attributes. + }]; +} + def SignedRemIOp : IntArithmeticOp<"remi_signed"> { let summary = "signed integer division remainder operation"; let hasFolder = 1; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1662,6 +1662,31 @@ bool useAlloca; }; +// A `rsqrt` is converted into `1 / sqrt`. +struct RsqrtOpLowering : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + OperandAdaptor transformed(operands); + auto operandType = + transformed.operand().getType().dyn_cast_or_null(); + + if (!operandType) + return matchFailure(); + + auto loc = op->getLoc(); + auto one = rewriter.create( + loc, typeConverter.convertType(operandType), + rewriter.getFloatAttr(operandType, 1.0)); + auto sqrt = rewriter.create(loc, transformed.operand()); + rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); + return matchSuccess(); + } +}; + // A `tanh` is converted into a call to the `tanh` function. struct TanhOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; @@ -2806,6 +2831,7 @@ PrefetchOpLowering, RemFOpLowering, ReturnOpLowering, + RsqrtOpLowering, SIToFPLowering, SelectOpLowering, ShiftLeftOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -18,6 +18,42 @@ // ----- +// CHECK-LABEL: func @rsqrt( +// CHECK-SAME: !llvm.float +func @rsqrt(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : !llvm.float) : !llvm.float + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float + %0 = rsqrt %arg0 : f32 + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_double( +// CHECK-SAME: !llvm.double +func @rsqrt_double(%arg0 : f64) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : !llvm.double) : !llvm.double + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm.double) -> !llvm.double + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.double + %0 = rsqrt %arg0 : f64 + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_vector( +// CHECK-SAME: !llvm<"<4 x float>"> +func @rsqrt_vector(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : !llvm<"<4 x float>">) : !llvm<"<4 x float>"> + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm<"<4 x float>"> + %0 = rsqrt %arg0 : vector<4xf32> + std.return +} + +// ----- + // This should not crash. The first operation cannot be converted, so the // second should not match. This attempts to convert `return` to `llvm.return` // and complains about non-LLVM types. diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -512,6 +512,9 @@ // CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16> %144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16> + // CHECK: %{{[0-9]+}} = rsqrt %arg1 : f32 + %145 = rsqrt %f : f32 + return }