diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1401,6 +1401,12 @@ let arguments = (ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c); } +class LLVM_CountZerosIntrinsicOp traits = []> : + LLVM_OneResultIntrOp { + let arguments = (ins LLVM_Type:$in, I<1>:$zero_undefined); +} + def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">; def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">; def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">; @@ -1421,6 +1427,8 @@ def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">; def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">; def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">; +def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrinsicOp<"ctlz">; +def LLVM_CountTrailingZerosOp : LLVM_CountZerosIntrinsicOp<"cttz">; def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">; def LLVM_MaxNumOp : LLVM_BinarySameArgsIntrinsicOp<"maxnum">; def LLVM_MinNumOp : LLVM_BinarySameArgsIntrinsicOp<"minnum">; diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -297,6 +297,54 @@ }]; } +//===----------------------------------------------------------------------===// +// CountLeadingZerosOp +//===----------------------------------------------------------------------===// + +def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> { + let summary = "counts the leading zeros an integer value"; + let description = [{ + The `ctlz` operation computes the number of leading zeros of an integer value. + + Example: + + ```mlir + // Scalar ctlz function value. + %a = math.ctlz %b : i32 + + // SIMD vector element-wise ctlz function value. + %f = math.ctlz %g : vector<4xi16> + + // Tensor element-wise ctlz function value. + %x = math.ctlz %y : tensor<4x?xi8> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// CountTrailingZerosOp +//===----------------------------------------------------------------------===// + +def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> { + let summary = "counts the trailing zeros an integer value"; + let description = [{ + The `cttz` operation computes the number of trailing zeros of an integer value. + + Example: + + ```mlir + // Scalar cttz function value. + %a = math.cttz %b : i32 + + // SIMD vector element-wise cttz function value. + %f = math.cttz %g : vector<4xi16> + + // Tensor element-wise cttz function value. + %x = math.cttz %y : tensor<4x?xi8> + ``` + }]; +} + //===----------------------------------------------------------------------===// // CtPopOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -38,6 +38,54 @@ using SinOpLowering = VectorConvertToLLVMPattern; using SqrtOpLowering = VectorConvertToLLVMPattern; +// A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`. +template +struct CountOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Super = CountOpLowering; + + LogicalResult + matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operandType = adaptor.getOperand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return failure(); + + auto loc = op.getLoc(); + auto resultType = op.getResult().getType(); + auto boolType = rewriter.getIntegerType(1); + auto boolZero = rewriter.getIntegerAttr(boolType, 0); + + if (!operandType.template isa()) { + LLVM::ConstantOp zero = + rewriter.create(loc, boolType, boolZero); + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand(), + zero); + return success(); + } + + auto vectorType = resultType.template dyn_cast(); + if (!vectorType) + return failure(); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + LLVM::ConstantOp zero = + rewriter.create(loc, boolType, boolZero); + return rewriter.replaceOpWithNewOp(op, llvm1DVectorTy, + operands[0], zero); + }, + rewriter); + } +}; + +using CountLeadingZerosOpLowering = + CountOpLowering; +using CountTrailingZerosOpLowering = + CountOpLowering; + // A `expm1` is converted into `exp - 1`. struct ExpM1OpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -222,6 +270,8 @@ CeilOpLowering, CopySignOpLowering, CosOpLowering, + CountLeadingZerosOpLowering, + CountTrailingZerosOpLowering, CtPopFOpLowering, ExpOpLowering, Exp2OpLowering, diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -74,6 +74,39 @@ // ----- +// CHECK-LABEL: func @ctlz( +// CHECK-SAME: i32 +func @ctlz(%arg0 : i32) { + // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1 + // CHECK: "llvm.intr.ctlz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32 + %0 = math.ctlz %arg0 : i32 + std.return +} + +// ----- + +// CHECK-LABEL: func @cttz( +// CHECK-SAME: i32 +func @cttz(%arg0 : i32) { + // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1 + // CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32 + %0 = math.cttz %arg0 : i32 + std.return +} + +// ----- + +// CHECK-LABEL: func @cttz_vec( +// CHECK-SAME: i32 +func @cttz_vec(%arg0 : vector<4xi32>) { + // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1 + // CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (vector<4xi32>, i1) -> vector<4xi32> + %0 = math.cttz %arg0 : vector<4xi32> + std.return +} + +// ----- + // CHECK-LABEL: func @ctpop( // CHECK-SAME: i32 func @ctpop(%arg0 : i32) { diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -135,6 +135,26 @@ llvm.return } +// CHECK-LABEL: @ctlz_test +llvm.func @ctlz_test(%arg0: i32, %arg1: vector<8xi32>) { + %i1 = llvm.mlir.constant(false) : i1 + // CHECK: call i32 @llvm.ctlz.i32 + "llvm.intr.ctlz"(%arg0, %i1) : (i32, i1) -> i32 + // CHECK: call <8 x i32> @llvm.ctlz.v8i32 + "llvm.intr.ctlz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32> + llvm.return +} + +// CHECK-LABEL: @cttz_test +llvm.func @cttz_test(%arg0: i32, %arg1: vector<8xi32>) { + %i1 = llvm.mlir.constant(false) : i1 + // CHECK: call i32 @llvm.cttz.i32 + "llvm.intr.cttz"(%arg0, %i1) : (i32, i1) -> i32 + // CHECK: call <8 x i32> @llvm.cttz.v8i32 + "llvm.intr.cttz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32> + llvm.return +} + // CHECK-LABEL: @ctpop_test llvm.func @ctpop_test(%arg0: i32, %arg1: vector<8xi32>) { // CHECK: call i32 @llvm.ctpop.i32