diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -68,6 +68,45 @@ } return {}; } + +/// Performs constant folding `calculate` with element-wise behavior on the one +/// attribute in `operands` and returns the result if possible. +template > +Attribute constFoldUnaryOp(ArrayRef operands, + const CalculationT &calculate) { + assert(operands.size() == 1 && "unary op takes one operand"); + if (!operands[0]) + return {}; + + if (operands[0].isa()) { + auto s = operands[0].cast(); + + return AttrElementT::get(s.getType(), + calculate(s.getValue())); + } else if (operands[0].isa()) { + // The operand is a splat so we can avoid expanding the value out and + // just fold based on the splat value. + auto s = operands[0].cast(); + + auto elementResult = calculate(s.getSplatValue()); + return DenseElementsAttr::get(s.getType(), elementResult); + } else if (operands[0].isa()) { + // Operands are ElementsAttr-derived; perform an element-wise fold by + // expanding the values. + auto s = operands[0].cast(); + + auto sIt = s.getValues().begin(); + SmallVector elementResults; + elementResults.reserve(s.getNumElements()); + for (size_t i = 0, e = s.getNumElements(); i < e; ++i, ++sIt) + elementResults.push_back(calculate(*sIt)); + return DenseElementsAttr::get(s.getType(), elementResults); + } + return {}; +} } // namespace mlir #endif // MLIR_DIALECT_COMMONFOLDERS_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -804,6 +804,8 @@ def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">; def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">; def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">; +def LLVM_CttzOp : LLVM_OneResultIntrOp<"cttz", [], [0], []>, + Arguments<(ins LLVM_Type:$src, LLVM_Type:$isZeroUndef)>; def LLVM_MemcpyOp : LLVM_ZeroResultIntrOp<"memcpy", [0, 1, 2]>, Arguments<(ins LLVM_Type:$dst, LLVM_Type:$src, diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -88,6 +88,13 @@ [DeclareOpInterfaceMethods])>, Arguments<(ins FloatLike:$operand)>; +class IntUnaryOp traits = []> : + UnaryOpSameOperandAndResultType])>, + Results<(outs IntegerLike:$out)>, + Arguments<(ins IntegerLike:$operand)>; + // Base class for standard arithmetic operations. Requires operands and // results to be of the same type, but does not constrain them to specific // types. Individual classes will have `lhs` and `rhs` accessor to operands. @@ -1955,6 +1962,73 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// BitReverseOp +//===----------------------------------------------------------------------===// + +def BitReverseOp : IntUnaryOp<"bitreverse"> { + let summary = "reverse bits"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `bitreverse` ssa-use `:` type + ``` + + The `bitreverse` operation takes one operands and returns one result. + This type may be an integer scalar type, a vector whose element type is + integer, or a tensor of integers. It has no standard attributes. + + Example: + + ```mlir + // Scalar bit reverse. + %a = bitreverse %b : i64 + + // SIMD vector element-wise bit reverse. + %f = bitreverse %g : vector<4xi32> + + // Tensor element-wise bit reverse. + %x = bitreverse %y : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// CtPopOp +//===----------------------------------------------------------------------===// + +def CtPopOp : IntUnaryOp<"ctpop"> { + let summary = "population count"; + let description = [{ + Syntax: + + + ``` + operation ::= ssa-id `=` `ctpop` ssa-use `:` type + ``` + + The `ctpop` operation takes one operands and returns one result. + This type may be an integer scalar type, a vector whose element type is + integer, or a tensor of integers. It has no standard attributes. + + Example: + + ```mlir + // Scalar population count. + %a = ctpop %b : i64 + + // SIMD vector element-wise population count. + %f = ctpop %g : vector<4xi32> + + // Tensor element-wise population count. + %x = ctpop %y : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -710,6 +710,14 @@ TensorOf<[I1]>.predicate]>, "bool-like">; +// Type constraint for integer-like types: integers, indices, +// vectors of integers, tensors of integers. +def IntegerLike : TypeConstraint.predicate, + TensorOf<[AnyInteger]>.predicate]>, + "integer-like">; + // Type constraint for signless-integer-like types: signless integers, indices, // vectors of signless integers, tensors of signless integers. def SignlessIntegerLike : TypeConstraint; using NegFOpLowering = VectorConvertToLLVMPattern; using OrOpLowering = VectorConvertToLLVMPattern; +using BitReverseOpLowering = VectorConvertToLLVMPattern; +using CtPopOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = OneToOneConvertToLLVMPattern; using ShiftLeftOpLowering = @@ -3175,6 +3177,8 @@ MulIOpLowering, NegFOpLowering, OrOpLowering, + BitReverseOpLowering, + CtPopOpLowering, PrefetchOpLowering, ReOpLowering, RemFOpLowering, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1938,6 +1938,24 @@ [](APInt a, APInt b) { return a | b; }); } +//===----------------------------------------------------------------------===// +// BitReverseOp +//===----------------------------------------------------------------------===// + +OpFoldResult BitReverseOp::fold(ArrayRef operands) { + return constFoldUnaryOp(operands, + [](APInt a) { return a.reverseBits(); }); +} + +//===----------------------------------------------------------------------===// +// CtPopOp +//===----------------------------------------------------------------------===// + +OpFoldResult CtPopOp::fold(ArrayRef operands) { + return constFoldUnaryOp(operands, + [](APInt a) { return APInt(a.getBitWidth(), a.countPopulation()); }); +} + //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -87,3 +87,13 @@ // expected-error@+1 {{must be LLVM dialect type}} return %1 : i32 } + +// ----- + +func @unary_intrinsics(%arg: i32) -> i32 { + // CHECK: %0 = "llvm.intr.bitreverse"(%arg0) : (!llvm.i32) -> !llvm.i32 + %0 = bitreverse %arg: i32 + // CHECK: %1 = "llvm.intr.ctpop"(%0) : (!llvm.i32) -> !llvm.i32 + %1 = ctpop %0: i32 + std.return %1 : i32 +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -800,6 +800,13 @@ %13 = llvm.lshr %arg2, %arg2 : !llvm<"<4 x i64>"> // CHECK-NEXT: %17 = ashr <4 x i64> %2, %2 %14 = llvm.ashr %arg2, %arg2 : !llvm<"<4 x i64>"> +// CHECK-NEXT: %18 = call <4 x i64> @llvm.bitreverse.v4i64(<4 x i64> %2) + %15 = "llvm.intr.bitreverse"(%arg2) : (!llvm<"<4 x i64>">) -> (!llvm<"<4 x i64>">) +// CHECK-NEXT: %19 = call <4 x i64> @llvm.ctpop.v4i64(<4 x i64> %2) + %16 = "llvm.intr.ctpop"(%arg2) : (!llvm<"<4 x i64>">) -> (!llvm<"<4 x i64>">) +// CHECK-NEXT: %20 = call <4 x i64> @llvm.cttz.v4i64(<4 x i64> %2, i1 false) + %false = llvm.mlir.constant(0 : i1) : !llvm.i1 + %17 = "llvm.intr.cttz"(%arg2, %false) : (!llvm<"<4 x i64>">, !llvm.i1) -> (!llvm<"<4 x i64>">) // CHECK-NEXT: ret <4 x float> %4 llvm.return %1 : !llvm<"<4 x float>"> }