diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1122,6 +1122,16 @@ let summary = "floating point division remainder operation"; } +def RsqrtOp : FloatUnaryOp<"rsqrt"> { + let summary = "reciprocal of sqrt (1 / sqrt of the specified value)"; + let description = [{ + The `rsqrt` operation computes the reciprocal of the square root. It takes + one operand and returns one result of the same type. This type may be a + float scalar type, a vector whose element type is float, or a tensor of + floats. It has no standard attributes. + }]; +} + def SignedRemIOp : IntArithmeticOp<"remi_signed"> { let summary = "signed integer division remainder operation"; let hasFolder = 1; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -16,10 +16,12 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" @@ -1662,6 +1664,74 @@ bool useAlloca; }; +// A `rsqrt` is converted into `1 / sqrt`. +struct RsqrtOpLowering : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor transformed(operands); + auto operandType = + transformed.operand().getType().dyn_cast_or_null(); + + if (!operandType) + return matchFailure(); + + auto loc = op->getLoc(); + auto resultType = *op->result_type_begin(); + auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + + if (!operandType.isArrayTy()) { + LLVM::ConstantOp one; + if (operandType.isVectorTy()) { + one = rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), floatOne)); + } else { + one = rewriter.create(loc, operandType, floatOne); + } + auto sqrt = rewriter.create(loc, transformed.operand()); + rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); + return matchSuccess(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return this->matchFailure(); + + auto vectorTypeInfo = + extractNDVectorTypeInfo(vectorType, this->typeConverter); + auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; + if (!llvmVectorTy || operandType != vectorTypeInfo.llvmArrayTy) + return this->matchFailure(); + + Value desc = rewriter.create(loc, operandType); + nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { + // For this unrolled `position` corresponding to the `linearIndex`^th + // element, extract operand vectors + auto extractedOperand = rewriter.create( + loc, llvmVectorTy, operands[0], position); + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get( + {llvmVectorTy.getUnderlyingType()->getVectorNumElements()}, + floatType), + floatOne); + auto one = + rewriter.create(loc, llvmVectorTy, splatAttr); + auto sqrt = + rewriter.create(loc, llvmVectorTy, extractedOperand); + auto div = rewriter.create(loc, llvmVectorTy, one, sqrt); + desc = rewriter.create(loc, operandType, desc, div, + position); + }); + rewriter.replaceOp(op, desc); + + return matchSuccess(); + } +}; + // A `tanh` is converted into a call to the `tanh` function. struct TanhOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; @@ -2806,6 +2876,7 @@ PrefetchOpLowering, RemFOpLowering, ReturnOpLowering, + RsqrtOpLowering, SIToFPLowering, SelectOpLowering, ShiftLeftOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -18,6 +18,56 @@ // ----- +// CHECK-LABEL: func @rsqrt( +// CHECK-SAME: !llvm.float +func @rsqrt(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float + %0 = rsqrt %arg0 : f32 + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_double( +// CHECK-SAME: !llvm.double +func @rsqrt_double(%arg0 : f64) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm.double) -> !llvm.double + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.double + %0 = rsqrt %arg0 : f64 + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_vector( +// CHECK-SAME: !llvm<"<4 x float>"> +func @rsqrt_vector(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>"> + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm<"<4 x float>"> + %0 = rsqrt %arg0 : vector<4xf32> + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_multidim_vector( +// CHECK-SAME: !llvm<"[4 x <3 x float>]"> +func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) { + // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %arg0[0] : !llvm<"[4 x <3 x float>]"> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : !llvm<"<3 x float>"> + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%[[EXTRACT]]) : (!llvm<"<3 x float>">) -> !llvm<"<3 x float>"> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm<"<3 x float>"> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %0[0] : !llvm<"[4 x <3 x float>]"> + %0 = rsqrt %arg0 : vector<4x3xf32> + std.return +} + +// ----- + // This should not crash. The first operation cannot be converted, so the // second should not match. This attempts to convert `return` to `llvm.return` // and complains about non-LLVM types. diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -512,6 +512,9 @@ // CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16> %144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16> + // CHECK: %{{[0-9]+}} = rsqrt %arg1 : f32 + %145 = rsqrt %f : f32 + return }