diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -203,24 +203,26 @@ } -def Arith_AddICarryOp : Arith_Op<"addi_carry", [Commutative, +def Arith_AddUICarryOp : Arith_Op<"addui_carry", [Commutative, AllTypesMatch<["lhs", "rhs", "sum"]>]> { - let summary = "integer addition operation returning both the sum and carry"; + let summary = "unsigned integer addition operation returning sum and carry"; let description = [{ - The `addi_carry` operation takes two operands and returns two results: the - sum (same type as both operands), and the carry (boolean-like). + The `addui_carry` operation takes two operands and returns two results: the + sum (same type as both operands), and the carry (boolean-like). The carry + value `1` indicates unsigned addition overflow, while indicates `0` no + overflow. Example: ```mlir // Scalar addition. - %sum, %carry = arith.addi_carry %b, %c : i64, i1 + %sum, %carry = arith.addui_carry %b, %c : i64, i1 // Vector element-wise addition. - %b:2 = arith.addi_carry %g, %h : vector<4xi32>, vector<4xi1> + %b:2 = arith.addui_carry %g, %h : vector<4xi32>, vector<4xi1> // Tensor element-wise addition. - %c:2 = arith.addi_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1> + %c:2 = arith.addui_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1> ``` }]; diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -195,12 +195,13 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts arith.addi_carry to spv.IAddCarry. -class AddICarryOpPattern final : public OpConversionPattern { +/// Converts arith.addui_carry to spv.IAddCarry. +class AddICarryOpPattern final + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor, + matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -850,7 +851,7 @@ //===----------------------------------------------------------------------===// LogicalResult -AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor, +AddICarryOpPattern::matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type dstElemTy = adaptor.getLhs().getType(); auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy}); @@ -866,8 +867,7 @@ // Convert the carry value to boolean. Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); - Value carryResult = - rewriter.create(loc, carryValue, one); + Value carryResult = rewriter.create(loc, carryValue, one); rewriter.replaceOp(op, {sumResult, carryResult}); return success(); diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -218,10 +218,10 @@ } //===----------------------------------------------------------------------===// -// AddICarryOp +// AddUICarryOp //===----------------------------------------------------------------------===// -Optional> arith::AddICarryOp::getShapeForUnroll() { +Optional> arith::AddUICarryOp::getShapeForUnroll() { if (auto vt = getType(0).dyn_cast()) return llvm::to_vector<4>(vt.getShape()); return None; @@ -233,10 +233,11 @@ return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1); } -LogicalResult arith::AddICarryOp::fold(ArrayRef operands, - SmallVectorImpl &results) { +LogicalResult +arith::AddUICarryOp::fold(ArrayRef operands, + SmallVectorImpl &results) { auto carryTy = getCarry().getType(); - // addi_carry(x, 0) -> x, false + // addui_carry(x, 0) -> x, false if (matchPattern(getRhs(), m_Zero())) { auto carryZero = APInt::getZero(1); Builder builder(getContext()); @@ -247,7 +248,7 @@ return success(); } - // addi_carry(constant_a, constant_b) -> constant_sum, constant_carry + // addui_carry(constant_a, constant_b) -> constant_sum, constant_carry // Let the `constFoldBinaryOp` utility attempt to fold the sum of both // operands. If that succeeds, calculate the carry boolean based on the sum // and the first (constant) operand, `lhs`. Note that we cannot simply call diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -73,29 +73,29 @@ } // Check integer add-with-carry conversions. -// CHECK-LABEL: @int32_scalar_addi_carry +// CHECK-LABEL: @int32_scalar_addui_carry // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) -func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) { +func.func @int32_scalar_addui_carry(%lhs: i32, %rhs: i32) -> (i32, i1) { // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)> // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)> // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)> // CHECK-DAG: %[[ONE:.+]] = spv.Constant 1 : i32 // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : i32 // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1 - %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1 + %sum, %carry = arith.addui_carry %lhs, %rhs: i32, i1 return %sum, %carry : i32, i1 } -// CHECK-LABEL: @int32_vector_addi_carry +// CHECK-LABEL: @int32_vector_addui_carry // CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>) -func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { +func.func @int32_vector_addui_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)> // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)> // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)> // CHECK-DAG: %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32> // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32> // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1> - %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addui_carry %lhs, %rhs: vector<4xi32>, vector<4xi1> return %sum, %carry : vector<4xi32>, vector<4xi1> } diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -549,7 +549,7 @@ // CHECK-NEXT: return %arg0, %[[false]] func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) { %zero = arith.constant 0 : i32 - %sum, %carry = arith.addi_carry %arg0, %zero: i32, i1 + %sum, %carry = arith.addui_carry %arg0, %zero: i32, i1 return %sum, %carry : i32, i1 } @@ -558,7 +558,7 @@ // CHECK-NEXT: return %arg0, %[[false]] func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { %zero = arith.constant dense<0> : vector<4xi32> - %sum, %carry = arith.addi_carry %arg0, %zero: vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addui_carry %arg0, %zero: vector<4xi32>, vector<4xi1> return %sum, %carry : vector<4xi32>, vector<4xi1> } @@ -567,7 +567,7 @@ // CHECK-NEXT: return %arg0, %[[false]] func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) { %zero = arith.constant 0 : i32 - %sum, %carry = arith.addi_carry %zero, %arg0: i32, i1 + %sum, %carry = arith.addui_carry %zero, %arg0: i32, i1 return %sum, %carry : i32, i1 } @@ -578,7 +578,7 @@ func.func @addiCarryConstants() -> (i32, i1) { %c13 = arith.constant 13 : i32 %c37 = arith.constant 37 : i32 - %sum, %carry = arith.addi_carry %c13, %c37: i32, i1 + %sum, %carry = arith.addui_carry %c13, %c37: i32, i1 return %sum, %carry : i32, i1 } @@ -589,7 +589,7 @@ func.func @addiCarryConstantsOverflow1() -> (i32, i1) { %max = arith.constant 4294967295 : i32 %c1 = arith.constant 1 : i32 - %sum, %carry = arith.addi_carry %max, %c1: i32, i1 + %sum, %carry = arith.addui_carry %max, %c1: i32, i1 return %sum, %carry : i32, i1 } @@ -599,7 +599,7 @@ // CHECK-NEXT: return %[[c_2]], %[[true]] func.func @addiCarryConstantsOverflow2() -> (i32, i1) { %max = arith.constant 4294967295 : i32 - %sum, %carry = arith.addi_carry %max, %max: i32, i1 + %sum, %carry = arith.addui_carry %max, %max: i32, i1 return %sum, %carry : i32, i1 } @@ -610,7 +610,7 @@ func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) { %v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32> %v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32> - %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addui_carry %v1, %v2 : vector<4xi32>, vector<4xi1> return %sum, %carry : vector<4xi32>, vector<4xi1> } @@ -621,7 +621,7 @@ func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) { %v1 = arith.constant dense<1> : vector<4xi32> %v2 = arith.constant dense<2> : vector<4xi32> - %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addui_carry %v1, %v2 : vector<4xi32>, vector<4xi1> return %sum, %carry : vector<4xi32>, vector<4xi1> } diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir --- a/mlir/test/Dialect/Arithmetic/invalid.mlir +++ b/mlir/test/Dialect/Arithmetic/invalid.mlir @@ -111,32 +111,32 @@ // ----- func.func @func_with_ops(%a: f32) { - // expected-error@+1 {{'arith.addi_carry' op operand #0 must be signless-integer-like}} - %r:2 = arith.addi_carry %a, %a : f32, i32 + // expected-error@+1 {{'arith.addui_carry' op operand #0 must be signless-integer-like}} + %r:2 = arith.addui_carry %a, %a : f32, i32 return } // ----- func.func @func_with_ops(%a: i32) { - // expected-error@+1 {{'arith.addi_carry' op result #1 must be bool-like}} - %r:2 = arith.addi_carry %a, %a : i32, i32 + // expected-error@+1 {{'arith.addui_carry' op result #1 must be bool-like}} + %r:2 = arith.addui_carry %a, %a : i32, i32 return } // ----- func.func @func_with_ops(%a: vector<8xi32>) { - // expected-error@+1 {{'arith.addi_carry' op if an operand is non-scalar, then all results must be non-scalar}} - %r:2 = arith.addi_carry %a, %a : vector<8xi32>, i1 + // expected-error@+1 {{'arith.addui_carry' op if an operand is non-scalar, then all results must be non-scalar}} + %r:2 = arith.addui_carry %a, %a : vector<8xi32>, i1 return } // ----- func.func @func_with_ops(%a: vector<8xi32>) { - // expected-error@+1 {{'arith.addi_carry' op all non-scalar operands/results must have the same shape and base type}} - %r:2 = arith.addi_carry %a, %a : vector<8xi32>, tensor<8xi1> + // expected-error@+1 {{'arith.addui_carry' op all non-scalar operands/results must have the same shape and base type}} + %r:2 = arith.addui_carry %a, %a : vector<8xi32>, tensor<8xi1> return } diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir --- a/mlir/test/Dialect/Arithmetic/ops.mlir +++ b/mlir/test/Dialect/Arithmetic/ops.mlir @@ -25,27 +25,27 @@ return %0 : vector<[8]xi64> } -// CHECK-LABEL: test_addi_carry -func.func @test_addi_carry(%arg0 : i64, %arg1 : i64) -> i64 { - %sum, %carry = arith.addi_carry %arg0, %arg1 : i64, i1 +// CHECK-LABEL: test_addui_carry +func.func @test_addui_carry(%arg0 : i64, %arg1 : i64) -> i64 { + %sum, %carry = arith.addui_carry %arg0, %arg1 : i64, i1 return %sum : i64 } -// CHECK-LABEL: test_addi_carry_tensor -func.func @test_addi_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { - %sum, %carry = arith.addi_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1> +// CHECK-LABEL: test_addui_carry_tensor +func.func @test_addui_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { + %sum, %carry = arith.addui_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1> return %sum : tensor<8x8xi64> } -// CHECK-LABEL: test_addi_carry_vector -func.func @test_addi_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> { - %0:2 = arith.addi_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1> +// CHECK-LABEL: test_addui_carry_vector +func.func @test_addui_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> { + %0:2 = arith.addui_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1> return %0#0 : vector<8xi64> } -// CHECK-LABEL: test_addi_carry_scalable_vector -func.func @test_addi_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { - %0:2 = arith.addi_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1> +// CHECK-LABEL: test_addui_carry_scalable_vector +func.func @test_addui_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0:2 = arith.addui_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1> return %0#0 : vector<[8]xi64> }