diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -202,6 +202,41 @@ let hasCanonicalizer = 1; } + +def Arith_AddICarryOp : Arith_Op<"addi_carry", [Commutative, + AllTypesMatch<["lhs", "rhs", "sum"]>]> { + let summary = "integer addition operation returning both the sum and carry"; + let description = [{ + The `addi_carry` operation takes two operands and returns two results: the + sum (same type as both operands), and the carry (boolean-like). + + Example: + + ```mlir + // Scalar addition. + %sum, %carry = arith.addi_carry %b, %c : i64, i1 + + // Vector element-wise addition. + %b:2 = arith.addi_carry %g, %h : vector<4xi32>, vector<4xi1> + + // Tensor element-wise addition. + %c:2 = arith.addi_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1> + ``` + }]; + + let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs); + let results = (outs SignlessIntegerLike:$sum, BoolLike:$carry); + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($sum) `,` type($carry) + }]; + + let hasFolder = 1; + + let extraClassDeclaration = [{ + ::llvm::Optional<::llvm::SmallVector> getShapeForUnroll(); + }]; +} + //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -15,9 +16,9 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/SmallString.h" using namespace mlir; using namespace mlir::arith; @@ -216,6 +217,81 @@ context); } +//===----------------------------------------------------------------------===// +// AddICarryOp +//===----------------------------------------------------------------------===// + +Optional> arith::AddICarryOp::getShapeForUnroll() { + if (auto vt = getType(0).dyn_cast()) + return llvm::to_vector<4>(vt.getShape()); + return None; +} + +// Returns the carry bit, assuming that `sum` is the result of addition of +// `operand` and another number. +static APInt calculateCarry(const APInt &sum, const APInt &operand) { + return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1); +} + +LogicalResult arith::AddICarryOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + auto carryTy = getCarry().getType(); + // addi_carry(x, 0) -> x, false + if (matchPattern(getRhs(), m_Zero())) { + auto carryZero = APInt::getZero(1); + Builder builder(getContext()); + auto falseValue = builder.getZeroAttr(carryTy); + + results.push_back(getLhs()); + results.push_back(falseValue); + return success(); + } + + // addi_carry(constant_a, constant_b) -> constant_sum, constant_carry + // Let the `constFoldBinaryOp` utility attempt to fold the sum of both + // operands. If that succeeds, calculate the carry boolean based on the sum + // and the first (constant) operand, `lhs`. Note that we cannot simply call + // `constFoldBinaryOp` again to calculate the carry (bit) because the + // constructed attribute is of the same element type as both operands. + if (Attribute sumAttr = constFoldBinaryOp( + operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) { + Attribute carryAttr; + if (auto lhs = operands[0].dyn_cast()) { + // Both arguments are scalars, calculate the scalar carry value. + auto sum = sumAttr.cast(); + carryAttr = IntegerAttr::get( + carryTy, calculateCarry(sum.getValue(), lhs.getValue())); + } else if (auto lhs = operands[0].dyn_cast()) { + // Both arguments are splats, calculate the splat carry value. + auto sum = sumAttr.cast(); + APInt carry = calculateCarry(sum.getSplatValue(), + lhs.getSplatValue()); + carryAttr = SplatElementsAttr::get(carryTy, carry); + } else if (auto lhs = operands[0].dyn_cast()) { + // Othwerwise calculate element-wise carry values. + auto sum = sumAttr.cast(); + const auto numElems = static_cast(sum.getNumElements()); + SmallVector carryValues; + carryValues.reserve(numElems); + + auto sumIt = sum.value_begin(); + auto lhsIt = lhs.value_begin(); + for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt) + carryValues.push_back(calculateCarry(*sumIt, *lhsIt)); + + carryAttr = DenseElementsAttr::get(carryTy, carryValues); + } else { + return failure(); + } + + results.push_back(sumAttr); + results.push_back(carryAttr); + return success(); + } + + return failure(); +} + //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -544,6 +544,87 @@ return %add : index } +// CHECK-LABEL: @addiCarryZeroRhs +// CHECK-NEXT: %[[false:.+]] = arith.constant false +// CHECK-NEXT: return %arg0, %[[false]] +func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) { + %zero = arith.constant 0 : i32 + %sum, %carry = arith.addi_carry %arg0, %zero: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryZeroRhsSplat +// CHECK-NEXT: %[[false:.+]] = arith.constant dense : vector<4xi1> +// CHECK-NEXT: return %arg0, %[[false]] +func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { + %zero = arith.constant dense<0> : vector<4xi32> + %sum, %carry = arith.addi_carry %arg0, %zero: vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + +// CHECK-LABEL: @addiCarryZeroLhs +// CHECK-NEXT: %[[false:.+]] = arith.constant false +// CHECK-NEXT: return %arg0, %[[false]] +func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) { + %zero = arith.constant 0 : i32 + %sum, %carry = arith.addi_carry %zero, %arg0: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryConstants +// CHECK-DAG: %[[false:.+]] = arith.constant false +// CHECK-DAG: %[[c50:.+]] = arith.constant 50 : i32 +// CHECK-NEXT: return %[[c50]], %[[false]] +func.func @addiCarryConstants() -> (i32, i1) { + %c13 = arith.constant 13 : i32 + %c37 = arith.constant 37 : i32 + %sum, %carry = arith.addi_carry %c13, %c37: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryConstantsOverflow1 +// CHECK-DAG: %[[true:.+]] = arith.constant true +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[c0]], %[[true]] +func.func @addiCarryConstantsOverflow1() -> (i32, i1) { + %max = arith.constant 4294967295 : i32 + %c1 = arith.constant 1 : i32 + %sum, %carry = arith.addi_carry %max, %c1: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryConstantsOverflow2 +// CHECK-DAG: %[[true:.+]] = arith.constant true +// CHECK-DAG: %[[c_2:.+]] = arith.constant -2 : i32 +// CHECK-NEXT: return %[[c_2]], %[[true]] +func.func @addiCarryConstantsOverflow2() -> (i32, i1) { + %max = arith.constant 4294967295 : i32 + %sum, %carry = arith.addi_carry %max, %max: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryConstantsOverflowVector +// CHECK-DAG: %[[sum:.+]] = arith.constant dense<[1, 6, 2, 14]> : vector<4xi32> +// CHECK-DAG: %[[carry:.+]] = arith.constant dense<[false, false, true, false]> : vector<4xi1> +// CHECK-NEXT: return %[[sum]], %[[carry]] +func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) { + %v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32> + %v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32> + %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + +// CHECK-LABEL: @addiCarryConstantsSplatVector +// CHECK-DAG: %[[sum:.+]] = arith.constant dense<3> : vector<4xi32> +// CHECK-DAG: %[[carry:.+]] = arith.constant dense : vector<4xi1> +// CHECK-NEXT: return %[[sum]], %[[carry]] +func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) { + %v1 = arith.constant dense<1> : vector<4xi32> + %v2 = arith.constant dense<2> : vector<4xi32> + %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + // CHECK-LABEL: @notCmpEQ // CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8 // CHECK: return %[[cres]] diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir --- a/mlir/test/Dialect/Arithmetic/invalid.mlir +++ b/mlir/test/Dialect/Arithmetic/invalid.mlir @@ -110,6 +110,38 @@ // ----- +func.func @func_with_ops(%a: f32) { + // expected-error@+1 {{'arith.addi_carry' op operand #0 must be signless-integer-like}} + %r:2 = arith.addi_carry %a, %a : f32, i32 + return +} + +// ----- + +func.func @func_with_ops(%a: i32) { + // expected-error@+1 {{'arith.addi_carry' op result #1 must be bool-like}} + %r:2 = arith.addi_carry %a, %a : i32, i32 + return +} + +// ----- + +func.func @func_with_ops(%a: vector<8xi32>) { + // expected-error@+1 {{'arith.addi_carry' op if an operand is non-scalar, then all results must be non-scalar}} + %r:2 = arith.addi_carry %a, %a : vector<8xi32>, i1 + return +} + +// ----- + +func.func @func_with_ops(%a: vector<8xi32>) { + // expected-error@+1 {{'arith.addi_carry' op all non-scalar operands/results must have the same shape and base type}} + %r:2 = arith.addi_carry %a, %a : vector<8xi32>, tensor<8xi1> + return +} + +// ----- + func.func @func_with_ops(i32) { ^bb0(%a : i32): %sf = arith.addf %a, %a : i32 // expected-error {{'arith.addf' op operand #0 must be floating-point-like}} diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir --- a/mlir/test/Dialect/Arithmetic/ops.mlir +++ b/mlir/test/Dialect/Arithmetic/ops.mlir @@ -25,6 +25,30 @@ return %0 : vector<[8]xi64> } +// CHECK-LABEL: test_addi_carry +func.func @test_addi_carry(%arg0 : i64, %arg1 : i64) -> i64 { + %sum, %carry = arith.addi_carry %arg0, %arg1 : i64, i1 + return %sum : i64 +} + +// CHECK-LABEL: test_addi_carry_tensor +func.func @test_addi_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { + %sum, %carry = arith.addi_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1> + return %sum : tensor<8x8xi64> +} + +// CHECK-LABEL: test_addi_carry_vector +func.func @test_addi_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> { + %0:2 = arith.addi_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1> + return %0#0 : vector<8xi64> +} + +// CHECK-LABEL: test_addi_carry_scalable_vector +func.func @test_addi_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0:2 = arith.addi_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1> + return %0#0 : vector<[8]xi64> +} + // CHECK-LABEL: test_subi func.func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.subi %arg0, %arg1 : i64