diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -97,6 +97,16 @@ unsigned maxIntWidth; }; +Type peelOutermostDim(ShapedType integerLike) { + if (auto ty = integerLike.dyn_cast()) { + if (ty.getShape().size() == 1) + return ty.getElementType(); + return VectorType::get(ty.getShape().drop_front(), ty.getElementType()); + } + + return nullptr; +} + struct ConvertConstant final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -162,6 +172,48 @@ } }; +struct ConvertAddI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + auto newTy = getTypeConverter() + ->convertType(op.getResult().getType()) + .dyn_cast_or_null(); + Type newElemTy = peelOutermostDim(newTy); + if (!newElemTy) + return rewriter.notifyMatchFailure(loc, "Expected scalar or vector type"); + + const unsigned newBitWidth = newTy.getElementTypeBitWidth(); + + Value lhsElem0 = rewriter.create(loc, lhs, 0); + Value lhsElem1 = rewriter.create(loc, lhs, 1); + + Value rhsElem0 = rewriter.create(loc, rhs, 0); + Value rhsElem1 = rewriter.create(loc, rhs, 1); + + auto lowSum = rewriter.create(loc, lhsElem0, rhsElem0); + Value carryVal = + rewriter.create(loc, newElemTy, lowSum.getCarry()); + + Value high0 = rewriter.create(loc, carryVal, lhsElem1); + Value high = rewriter.create(loc, high0, rhsElem1); + + Attribute zeroAttr = DenseElementsAttr::get(newTy, newBitWidth); + Value zeroVec = rewriter.create(loc, zeroAttr); + Value vecLow = + rewriter.create(loc, lowSum.getSum(), zeroVec, 0); + Value vecLowHigh = rewriter.create(loc, high, vecLow, 1); + rewriter.replaceOp(op, vecLowHigh); + return success(); + } +}; + struct EmulateWideIntPass final : arith::impl::ArithmeticEmulateWideIntBase { EmulateWideIntPass(unsigned widestIntSupported) { @@ -195,6 +247,7 @@ target.addDynamicallyLegalOp< // arith ops arith::ConstantOp, + arith::AddIOp, // func ops func::FuncOp, func::CallOp, func::ReturnOp >( @@ -223,7 +276,12 @@ void populateWideIntEmulationPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); + // clang-format off + patterns.add< + ConvertConstant, + ConvertAddI + >(typeConverter, patterns.getContext()); + // clang-format on // Populate `func.*` conversion patterns. populateFunctionOpInterfaceTypeConversionPattern(patterns, diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -66,3 +66,39 @@ %c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64> return %c0 : vector<3xi64> } + +// CHECK-LABEL: func @addi_scalar_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> +// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1 +// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : i1 to i32 +// CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : i32 +// CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : i32 +// CHECK: [[INS0:%.+]] = vector.insert [[SUM_L]], {{%.+}} [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SUM_H1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @addi_scalar_a_b(%a : i64, %b : i64) -> i64 { + %x = arith.addi %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @addi_vector_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2x4xi32>, [[ARG1:%.+]]: vector<2x4xi32>) -> vector<2x4xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2x4xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2x4xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2x4xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2x4xi32> +// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : vector<4xi32>, vector<4xi1> +// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : vector<4xi1> to vector<4xi32> +// CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : vector<4xi32> +// CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : vector<4xi32> +// CHECK: [[INS0:%.+]] = vector.insert [[SUM_L]], {{%.+}} [0] : vector<4xi32> into vector<2x4xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SUM_H1]], [[INS0]] [1] : vector<4xi32> into vector<2x4xi32> +// CHECK-NEXT: return [[INS1]] : vector<2x4xi32> +func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> { + %x = arith.addi %a, %b : vector<4xi64> + return %x : vector<4xi64> +}