diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -280,6 +280,69 @@ }]; } +//===----------------------------------------------------------------------===// +// ShlOp +//===----------------------------------------------------------------------===// + +def Index_ShlOp : IndexBinaryOp<"shl"> { + let summary = "index shift left"; + let description = [{ + The `index.shl` operation shifts an index value to the left by a variable + amount. The low order bits are filled with zeroes. The RHS operand is always + treated as unsigned. If the RHS operand is equal to or greater than the + index bitwidth, the operation is undefined. + + Example: + + ```mlir + // c = a << b + %c = index.shl %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// ShrSOp +//===----------------------------------------------------------------------===// + +def Index_ShrSOp : IndexBinaryOp<"shrs"> { + let summary = "signed index shift right"; + let description = [{ + The `index.shrs` operation shifts an index value to the right by a variable + amount. The LHS operand is treated as signed. The high order bits are filled + with copies of the most significant bit. If the RHS operand is equal to or + greater than the index bitwidth, the operation is undefined. + + Example: + + ```mlir + // c = a >> b + %c = index.shrs %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// ShrUOp +//===----------------------------------------------------------------------===// + +def Index_ShrUOp : IndexBinaryOp<"shru"> { + let summary = "unsigned index shift right"; + let description = [{ + The `index.shru` operation shifts an index value to the right by a variable + amount. The LHS operand is treated as unsigned. The high order bits are + filled with zeroes. If the RHS operand is equal to or greater than the index + bitwidth, the operation is undefined. + + Example: + + ```mlir + // c = a >> b + %c = index.shru %a, %b + ``` + }]; +} + //===----------------------------------------------------------------------===// // CastSOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp --- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp +++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp @@ -268,6 +268,11 @@ mlir::OneToOneConvertToLLVMPattern; using ConvertIndexMaxU = mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexShrS = + mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexShrU = + mlir::OneToOneConvertToLLVMPattern; using ConvertIndexBoolConstant = mlir::OneToOneConvertToLLVMPattern; @@ -290,6 +295,9 @@ ConvertIndexRemU, ConvertIndexMaxS, ConvertIndexMaxU, + ConvertIndexShl, + ConvertIndexShrS, + ConvertIndexShrU, ConvertIndexCeilDivS, ConvertIndexCeilDivU, ConvertIndexFloorDivS, diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -62,17 +62,19 @@ /// the integer result, which in turn must satisfy the above property. static OpFoldResult foldBinaryOpUnchecked( ArrayRef operands, - function_ref calculate) { + function_ref(const APInt &, const APInt &)> calculate) { assert(operands.size() == 2 && "binary operation expected 2 operands"); auto lhs = dyn_cast_if_present(operands[0]); auto rhs = dyn_cast_if_present(operands[1]); if (!lhs || !rhs) return {}; - APInt result = calculate(lhs.getValue(), rhs.getValue()); - assert(result.trunc(32) == + Optional result = calculate(lhs.getValue(), rhs.getValue()); + if (!result) + return {}; + assert(result->trunc(32) == calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32))); - return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(result)); + return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(*result)); } /// Fold an index operation only if the truncated 64-bit result matches the @@ -284,6 +286,50 @@ }); } +//===----------------------------------------------------------------------===// +// ShlOp +//===----------------------------------------------------------------------===// + +OpFoldResult ShlOp::fold(ArrayRef operands) { + return foldBinaryOpUnchecked( + operands, [](const APInt &lhs, const APInt &rhs) -> Optional { + // We cannot fold if the RHS is greater than or equal to 32 because + // this would be UB in 32-bit systems but not on 64-bit systems. RHS is + // already treated as unsigned. + if (rhs.uge(32)) + return {}; + return lhs << rhs; + }); +} + +//===----------------------------------------------------------------------===// +// ShrSOp +//===----------------------------------------------------------------------===// + +OpFoldResult ShrSOp::fold(ArrayRef operands) { + return foldBinaryOpChecked( + operands, [](const APInt &lhs, const APInt &rhs) -> Optional { + // Don't fold if RHS is greater than or equal to 32. + if (rhs.uge(32)) + return {}; + return lhs.ashr(rhs); + }); +} + +//===----------------------------------------------------------------------===// +// ShrUOp +//===----------------------------------------------------------------------===// + +OpFoldResult ShrUOp::fold(ArrayRef operands) { + return foldBinaryOpChecked( + operands, [](const APInt &lhs, const APInt &rhs) -> Optional { + // Don't fold if RHS is greater than or equal to 32. + if (rhs.uge(32)) + return {}; + return lhs.lshr(rhs); + }); +} + //===----------------------------------------------------------------------===// // CastSOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir --- a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir +++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir @@ -22,8 +22,14 @@ %7 = index.maxs %a, %b // CHECK: llvm.intr.umax %8 = index.maxu %a, %b + // CHECK: llvm.shl + %9 = index.shl %a, %b + // CHECK: llvm.ashr + %10 = index.shrs %a, %b + // CHECK: llvm.lshr + %11 = index.shru %a, %b // CHECK: llvm.mlir.constant(true - %9 = index.bool.constant true + %12 = index.bool.constant true return } diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -279,6 +279,111 @@ return %0 : index } +// CHECK-LABEL: @shl +func.func @shl() -> index { + %lhs = index.constant 128 + %rhs = index.constant 2 + // CHECK: %[[A:.*]] = index.constant 512 + %0 = index.shl %lhs, %rhs + // CHECK: return %[[A]] + return %0 : index +} + +// CHECK-LABEL: @shl_32 +func.func @shl_32() -> index { + %lhs = index.constant 1 + %rhs = index.constant 32 + // CHECK: index.shl + %0 = index.shl %lhs, %rhs + return %0 : index +} + +// CHECK-LABEL: @shl_edge +func.func @shl_edge() -> index { + %lhs = index.constant 4000000000 + %rhs = index.constant 31 + // CHECK: %[[A:.*]] = index.constant 858{{[0-9]+}} + %0 = index.shl %lhs, %rhs + // CHECK: return %[[A]] + return %0 : index +} + +// CHECK-LABEL: @shrs +func.func @shrs() -> index { + %lhs = index.constant 128 + %rhs = index.constant 2 + // CHECK: %[[A:.*]] = index.constant 32 + %0 = index.shrs %lhs, %rhs + // CHECK: return %[[A]] + return %0 : index +} + +// CHECK-LABEL: @shrs_32 +func.func @shrs_32() -> index { + %lhs = index.constant 4000000000000 + %rhs = index.constant 32 + // CHECK: index.shrs + %0 = index.shrs %lhs, %rhs + return %0 : index +} + +// CHECK-LABEL: @shrs_nofold +func.func @shrs_nofold() -> index { + %lhs = index.constant 0x100000000 + %rhs = index.constant 1 + // CHECK: index.shrs + %0 = index.shrs %lhs, %rhs + return %0 : index +} + +// CHECK-LABEL: @shrs_edge +func.func @shrs_edge() -> index { + %lhs = index.constant 0x10000000000 + %rhs = index.constant 3 + // CHECK: %[[A:.*]] = index.constant 137{{[0-9]+}} + %0 = index.shrs %lhs, %rhs + // CHECK: return %[[A]] + return %0 : index +} + +// CHECK-LABEL: @shru +func.func @shru() -> index { + %lhs = index.constant 128 + %rhs = index.constant 2 + // CHECK: %[[A:.*]] = index.constant 32 + %0 = index.shru %lhs, %rhs + // CHECK: return %[[A]] + return %0 : index +} + +// CHECK-LABEL: @shru_32 +func.func @shru_32() -> index { + %lhs = index.constant 4000000000000 + %rhs = index.constant 32 + // CHECK: index.shru + %0 = index.shru %lhs, %rhs + return %0 : index +} + +// CHECK-LABEL: @shru_nofold +func.func @shru_nofold() -> index { + %lhs = index.constant 0x100000000 + %rhs = index.constant 1 + // CHECK: index.shru + %0 = index.shru %lhs, %rhs + return %0 : index +} + +// CHECK-LABEL: @shru_edge +func.func @shru_edge() -> index { + %lhs = index.constant 0x10000000000 + %rhs = index.constant 3 + // CHECK: %[[A:.*]] = index.constant 137{{[0-9]+}} + %0 = index.shru %lhs, %rhs + // CHECK: return %[[A]] + return %0 : index +} + // CHECK-LABEL: @cmp func.func @cmp() -> (i1, i1, i1, i1) { %a = index.constant 0 diff --git a/mlir/test/Dialect/Index/index-ops.mlir b/mlir/test/Dialect/Index/index-ops.mlir --- a/mlir/test/Dialect/Index/index-ops.mlir +++ b/mlir/test/Dialect/Index/index-ops.mlir @@ -27,6 +27,12 @@ %10 = index.maxs %a, %b // CHECK-NEXT: index.maxu %[[A]], %[[B]] %11 = index.maxu %a, %b + // CHECK-NEXT: index.shl %[[A]], %[[B]] + %12 = index.shl %a, %b + // CHECK-NEXT: index.shrs %[[A]], %[[B]] + %13 = index.shrs %a, %b + // CHECK-NEXT: index.shru %[[A]], %[[B]] + %14 = index.shru %a, %b return }