Index: mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td =================================================================== --- mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -202,6 +202,41 @@ let hasCanonicalizer = 1; } + +def Arith_AddICarryOp : Arith_Op<"addi_carry", [Commutative, + AllTypesMatch<["lhs", "rhs", "sum"]>]> { + let summary = "integer addition operation returning both the sum and carry"; + let description = [{ + The `addi_carry` operation takes two operands and returns two results: the + sum (same type as both operands), and the carry (boolean-like). + + Example: + + ```mlir + // Scalar addition. + %sum, %carry = arith.addi_carry %b, %c : i64, i1 + + // Vector element-wise addition. + %b:2 = arith.addi_carry %g, %h : vector<4xi32>, vector<4xi1> + + // Tensor element-wise addition. + %c:2 = arith.addi_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1> + ``` + }]; + + let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs); + let results = (outs SignlessIntegerLike:$sum, BoolLike:$carry); + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($sum) `,` type($carry) + }]; + + let hasFolder = 1; + + let extraClassDeclaration = [{ + ::llvm::Optional<::llvm::SmallVector> getShapeForUnroll(); + }]; +} + //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// Index: mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -13,8 +13,11 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "arith-to-spirv-pattern" @@ -192,6 +195,15 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts arith.addi_carry to spv.IAddCarry. +class AddICarryOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts arith.select to spv.Select. class SelectOpPattern final : public OpConversionPattern { public: @@ -833,6 +845,34 @@ return success(); } +//===----------------------------------------------------------------------===// +// AddICarryOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type dstElemTy = adaptor.getLhs().getType(); + auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy}); + + Location loc = op->getLoc(); + Value result = rewriter.create( + loc, resultTy, adaptor.getLhs(), adaptor.getRhs()); + + Value sumResult = rewriter.create( + loc, result, llvm::makeArrayRef(0)); + Value carryValue = rewriter.create( + loc, result, llvm::makeArrayRef(1)); + + // Convert the carry value to boolean. + Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); + Value carryResult = + rewriter.create(loc, carryValue, one); + + rewriter.replaceOp(op, {sumResult, carryResult}); + return success(); +} + //===----------------------------------------------------------------------===// // SelectOpPattern //===----------------------------------------------------------------------===// @@ -887,7 +927,7 @@ TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, - SelectOpPattern, + AddICarryOpPattern, SelectOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, Index: mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp =================================================================== --- mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -15,9 +16,12 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::arith; @@ -216,6 +220,79 @@ context); } +//===----------------------------------------------------------------------===// +// AddICarryOp +//===----------------------------------------------------------------------===// + +Optional> arith::AddICarryOp::getShapeForUnroll() { + if (auto vt = getSum().getType().dyn_cast()) + return llvm::to_vector<4>(vt.getShape()); + return None; +} + +// Returns the carry bit, assuming that `sum` is the result of addition of +// `operand` and another number. +static APInt calculateCarry(const APInt &sum, const APInt &operand) { + return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1); +} + +LogicalResult arith::AddICarryOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + auto carryTy = getCarry().getType(); + // addi_carry(x, 0) -> x, false + if (matchPattern(getRhs(), m_Zero())) { + auto carryZero = APInt::getZero(1); + Attribute falseValue; + if (operands[1].isa()) + falseValue = IntegerAttr::get(carryTy, carryZero); + else if (operands[1].isa()) + falseValue = SplatElementsAttr::get(carryTy, carryZero); + else + return failure(); + + results.push_back(getLhs()); + results.push_back(falseValue); + return success(); + } + + if (Attribute sumAttr = constFoldBinaryOp( + operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) { + Attribute carryAttr; + if (auto lhs = operands[0].dyn_cast()) { + auto sum = sumAttr.cast(); + carryAttr = IntegerAttr::get( + carryTy, calculateCarry(sum.getValue(), lhs.getValue())); + } else if (auto lhs = operands[0].dyn_cast()) { + auto sum = sumAttr.cast(); + APInt carry = calculateCarry(sum.getSplatValue(), + lhs.getSplatValue()); + carryAttr = SplatElementsAttr::get(carryTy, carry); + } else if (auto lhs = operands[0].dyn_cast()) { + auto sum = sumAttr.cast(); + const int64_t numElems = sum.getNumElements(); + assert(lhs.getNumElements() == numElems); + assert(numElems >= 0); + SmallVector carryValues; + carryValues.reserve(static_cast(numElems)); + + auto sumIt = sum.value_begin(); + auto lhsIt = lhs.value_begin(); + for (int64_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt) + carryValues.push_back(calculateCarry(*sumIt, *lhsIt)); + + carryAttr = DenseElementsAttr::get(carryTy, carryValues); + } else { + return failure(); + } + + results.push_back(sumAttr); + results.push_back(carryAttr); + return success(); + } + + return failure(); +} + //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// Index: mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir =================================================================== --- mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -72,6 +72,33 @@ return } +// Check integer add-with-carry conversions. +// CHECK-LABEL: @int32_scalar_addi_carry +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) { + // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : i32 + // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1 + %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @int32_vector_addi_carry +// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>) +func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { + // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32> + // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32> + // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + // Check float unary operation conversions. // CHECK-LABEL: @float32_unary_scalar func.func @float32_unary_scalar(%arg0: f32) { Index: mlir/test/Dialect/Arithmetic/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -544,6 +544,87 @@ return %add : index } +// CHECK-LABEL: @addiCarryZeroRhs +// CHECK-NEXT: %[[false:.+]] = arith.constant false +// CHECK-NEXT: return %arg0, %[[false]] +func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) { + %zero = arith.constant 0 : i32 + %sum, %carry = arith.addi_carry %arg0, %zero: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryZeroRhsSplat +// CHECK-NEXT: %[[false:.+]] = arith.constant dense : vector<4xi1> +// CHECK-NEXT: return %arg0, %[[false]] +func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { + %zero = arith.constant dense<0> : vector<4xi32> + %sum, %carry = arith.addi_carry %arg0, %zero: vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + +// CHECK-LABEL: @addiCarryZeroLhs +// CHECK-NEXT: %[[false:.+]] = arith.constant false +// CHECK-NEXT: return %arg0, %[[false]] +func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) { + %zero = arith.constant 0 : i32 + %sum, %carry = arith.addi_carry %zero, %arg0: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryConstants +// CHECK-DAG: %[[false:.+]] = arith.constant false +// CHECK-DAG: %[[c50:.+]] = arith.constant 50 : i32 +// CHECK-NEXT: return %[[c50]], %[[false]] +func.func @addiCarryConstants() -> (i32, i1) { + %c13 = arith.constant 13 : i32 + %c37 = arith.constant 37 : i32 + %sum, %carry = arith.addi_carry %c13, %c37: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryConstantsOverflow1 +// CHECK-DAG: %[[true:.+]] = arith.constant true +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[c0]], %[[true]] +func.func @addiCarryConstantsOverflow1() -> (i32, i1) { + %max = arith.constant 4294967295 : i32 + %c1 = arith.constant 1 : i32 + %sum, %carry = arith.addi_carry %max, %c1: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryConstantsOverflow2 +// CHECK-DAG: %[[true:.+]] = arith.constant true +// CHECK-DAG: %[[c_2:.+]] = arith.constant -2 : i32 +// CHECK-NEXT: return %[[c_2]], %[[true]] +func.func @addiCarryConstantsOverflow2() -> (i32, i1) { + %max = arith.constant 4294967295 : i32 + %sum, %carry = arith.addi_carry %max, %max: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addiCarryConstantsOverflowVector +// CHECK-DAG: %[[sum:.+]] = arith.constant dense<[1, 6, 2, 14]> : vector<4xi32> +// CHECK-DAG: %[[carry:.+]] = arith.constant dense<[false, false, true, false]> : vector<4xi1> +// CHECK-NEXT: return %[[sum]], %[[carry]] +func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) { + %v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32> + %v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32> + %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + +// CHECK-LABEL: @addiCarryConstantsSplatVector +// CHECK-DAG: %[[sum:.+]] = arith.constant dense<3> : vector<4xi32> +// CHECK-DAG: %[[carry:.+]] = arith.constant dense : vector<4xi1> +// CHECK-NEXT: return %[[sum]], %[[carry]] +func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) { + %v1 = arith.constant dense<1> : vector<4xi32> + %v2 = arith.constant dense<2> : vector<4xi32> + %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + // CHECK-LABEL: @notCmpEQ // CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8 // CHECK: return %[[cres]] Index: mlir/test/Dialect/Arithmetic/invalid.mlir =================================================================== --- mlir/test/Dialect/Arithmetic/invalid.mlir +++ mlir/test/Dialect/Arithmetic/invalid.mlir @@ -110,6 +110,38 @@ // ----- +func.func @func_with_ops(%a: f32) { + // expected-error@+1 {{'arith.addi_carry' op operand #0 must be signless-integer-like}} + %r:2 = arith.addi_carry %a, %a : f32, i32 + return +} + +// ----- + +func.func @func_with_ops(%a: i32) { + // expected-error@+1 {{'arith.addi_carry' op result #1 must be bool-like}} + %r:2 = arith.addi_carry %a, %a : i32, i32 + return +} + +// ----- + +func.func @func_with_ops(%a: vector<8xi32>) { + // expected-error@+1 {{'arith.addi_carry' op if an operand is non-scalar, then all results must be non-scalar}} + %r:2 = arith.addi_carry %a, %a : vector<8xi32>, i1 + return +} + +// ----- + +func.func @func_with_ops(%a: vector<8xi32>) { + // expected-error@+1 {{'arith.addi_carry' op all non-scalar operands/results must have the same shape and base type}} + %r:2 = arith.addi_carry %a, %a : vector<8xi32>, tensor<8xi1> + return +} + +// ----- + func.func @func_with_ops(i32) { ^bb0(%a : i32): %sf = arith.addf %a, %a : i32 // expected-error {{'arith.addf' op operand #0 must be floating-point-like}} Index: mlir/test/Dialect/Arithmetic/ops.mlir =================================================================== --- mlir/test/Dialect/Arithmetic/ops.mlir +++ mlir/test/Dialect/Arithmetic/ops.mlir @@ -25,6 +25,30 @@ return %0 : vector<[8]xi64> } +// CHECK-LABEL: test_addi_carry +func.func @test_addi_carry(%arg0 : i64, %arg1 : i64) -> i64 { + %sum, %carry = arith.addi_carry %arg0, %arg1 : i64, i1 + return %sum : i64 +} + +// CHECK-LABEL: test_addi_carry_tensor +func.func @test_addi_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { + %sum, %carry = arith.addi_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1> + return %sum : tensor<8x8xi64> +} + +// CHECK-LABEL: test_addi_carry_vector +func.func @test_addi_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> { + %0:2 = arith.addi_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1> + return %0#0 : vector<8xi64> +} + +// CHECK-LABEL: test_addi_carry_scalable_vector +func.func @test_addi_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0:2 = arith.addi_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1> + return %0#0 : vector<[8]xi64> +} + // CHECK-LABEL: test_subi func.func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.subi %arg0, %arg1 : i64