diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp --- a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp +++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp @@ -106,6 +106,15 @@ ConversionPatternRewriter &rewriter) const override; }; +struct AddUICarryOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -186,6 +195,45 @@ rewriter); } +//===----------------------------------------------------------------------===// +// AddUICarryOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult AddUICarryOpLowering::matchAndRewrite( + arith::AddUICarryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type operandType = adaptor.getLhs().getType(); + Type sumResultType = op.getSum().getType(); + Type carryResultType = op.getCarry().getType(); + + if (!LLVM::isCompatibleType(operandType)) + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + Type newCarryType = typeConverter->convertType(carryResultType); + Type structType = + LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newCarryType}); + Value addOverflow = rewriter.create( + loc, structType, adaptor.getLhs(), adaptor.getRhs()); + Value sumExtracted = + rewriter.create(loc, addOverflow, 0); + Value carryExtracted = + rewriter.create(loc, addOverflow, 1); + rewriter.replaceOp(op, {sumExtracted, carryExtracted}); + return success(); + } + + if (!sumResultType.isa()) + return rewriter.notifyMatchFailure(loc, "expected vector result types"); + + return rewriter.notifyMatchFailure(loc, + "ND vector types are not supported yet"); +} + //===----------------------------------------------------------------------===// // CmpIOpLowering //===----------------------------------------------------------------------===// @@ -300,6 +348,7 @@ AddFOpLowering, AddIOpLowering, AndIOpLowering, + AddUICarryOpLowering, BitcastOpLowering, ConstantOpLowering, CmpFOpLowering, diff --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir @@ -338,6 +338,30 @@ // ----- +// CHECK-LABEL: @addui_carry_scalar +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i1) +func.func @addui_carry_scalar(%arg0: i32, %arg1: i32) -> (i32, i1) { + // CHECK-NEXT: [[RES:%.+]] = "llvm.intr.uadd.with.overflow"([[ARG0]], [[ARG1]]) : (i32, i32) -> !llvm.struct<(i32, i1)> + // CHECK-NEXT: [[SUM:%.+]] = llvm.extractvalue [[RES]][0] : !llvm.struct<(i32, i1)> + // CHECK-NEXT: [[CARRY:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(i32, i1)> + %sum, %carry = arith.addui_carry %arg0, %arg1 : i32, i1 + // CHECK-NEXT: return [[SUM]], [[CARRY]] : i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @addui_carry_vector1d +// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi16>, [[ARG1:%.+]]: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>) +func.func @addui_carry_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>) { + // CHECK-NEXT: [[RES:%.+]] = "llvm.intr.uadd.with.overflow"([[ARG0]], [[ARG1]]) : (vector<3xi16>, vector<3xi16>) -> !llvm.struct<(vector<3xi16>, vector<3xi1>)> + // CHECK-NEXT: [[SUM:%.+]] = llvm.extractvalue [[RES]][0] : !llvm.struct<(vector<3xi16>, vector<3xi1>)> + // CHECK-NEXT: [[CARRY:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(vector<3xi16>, vector<3xi1>)> + %sum, %carry = arith.addui_carry %arg0, %arg1 : vector<3xi16>, vector<3xi1> + // CHECK-NEXT: return [[SUM]], [[CARRY]] : vector<3xi16>, vector<3xi1> + return %sum, %carry : vector<3xi16>, vector<3xi1> +} + +// ----- + // CHECK-LABEL: func @cmpf_2dvector( func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) { // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast