diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -13,8 +13,11 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "arith-to-spirv-pattern" @@ -192,6 +195,15 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts arith.addi_carry to spv.IAddCarry. +class AddICarryOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts arith.select to spv.Select. class SelectOpPattern final : public OpConversionPattern { public: @@ -833,6 +845,34 @@ return success(); } +//===----------------------------------------------------------------------===// +// AddICarryOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type dstElemTy = adaptor.getLhs().getType(); + auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy}); + + Location loc = op->getLoc(); + Value result = rewriter.create( + loc, resultTy, adaptor.getLhs(), adaptor.getRhs()); + + Value sumResult = rewriter.create( + loc, result, llvm::makeArrayRef(0)); + Value carryValue = rewriter.create( + loc, result, llvm::makeArrayRef(1)); + + // Convert the carry value to boolean. + Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); + Value carryResult = + rewriter.create(loc, carryValue, one); + + rewriter.replaceOp(op, {sumResult, carryResult}); + return success(); +} + //===----------------------------------------------------------------------===// // SelectOpPattern //===----------------------------------------------------------------------===// @@ -887,7 +927,7 @@ TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, - SelectOpPattern, + AddICarryOpPattern, SelectOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -72,6 +72,33 @@ return } +// Check integer add-with-carry conversions. +// CHECK-LABEL: @int32_scalar_addi_carry +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) { + // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : i32 + // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1 + %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @int32_vector_addi_carry +// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>) +func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { + // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32> + // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32> + // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + // Check float unary operation conversions. // CHECK-LABEL: @float32_unary_scalar func.func @float32_unary_scalar(%arg0: f32) {