diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -53,6 +53,35 @@ return VectorType::get(newShape, type.getElementType()); } +// Returns a constant of integer of vector type filled with (repeated) `value`. +static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + const APInt &value) { + Attribute attr; + if (auto intTy = type.dyn_cast()) { + attr = rewriter.getIntegerAttr(type, value); + } else { + auto vecTy = type.cast(); + attr = SplatElementsAttr::get(vecTy, value); + } + + return rewriter.create(loc, attr); +} + +// Returns a constant of integer of vector type filled with (repeated) `value`. +static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + int64_t value) { + unsigned elementBitWidth = 0; + if (auto intTy = type.dyn_cast()) + elementBitWidth = intTy.getWidth(); + else + elementBitWidth = type.cast().getElementTypeBitWidth(); + + return createScalarOrSplatConstant(rewriter, loc, type, + APInt(elementBitWidth, value)); +} + // Extracts the `input` vector slice with elements at the last dimension offset // by `lastOffset`. Returns a value of vector type with the last dimension // reduced to x1 or fully scalarized, e.g.: @@ -154,8 +183,7 @@ assert(resultShape.back() == static_cast(resultComponents.size()) && "Wrong number of result components"); - Value resultVec = - rewriter.create(loc, rewriter.getZeroAttr(resultType)); + Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0); for (auto [i, component] : llvm::enumerate(resultComponents)) resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i); @@ -232,9 +260,6 @@ matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); auto newTy = getTypeConverter() ->convertType(op.getType()) .dyn_cast_or_null(); @@ -243,8 +268,10 @@ Type newElemTy = reduceInnermostDim(newTy); - auto [lhsElem0, lhsElem1] = extractLastDimHalves(rewriter, loc, lhs); - auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, rhs); + auto [lhsElem0, lhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + auto [rhsElem0, rhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getRhs()); auto lowSum = rewriter.create(loc, lhsElem0, rhsElem0); Value carryVal = @@ -260,6 +287,100 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertMulI +//===----------------------------------------------------------------------===// + +struct ConvertMulI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto newTy = getTypeConverter() + ->convertType(op.getType()) + .dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure(loc, "expected scalar or vector type"); + + Type newElemTy = reduceInnermostDim(newTy); + unsigned newBitWidth = newTy.getElementTypeBitWidth(); + unsigned digitBitWidth = newBitWidth / 2; + + auto [lhsElem0, lhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + auto [rhsElem0, rhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getRhs()); + + // Emulate multiplication by splitting each input element of type i2N into 4 + // digits of type iN and bit width i(N/2). This is so that the intermediate + // multiplications and additions do not overflow. We extract these i(N/2) + // digits from iN vector elements by masking (low digit) and shifting right + // (high digit). + // + // The multiplication algorithm used is the standard (long) multiplication. + // Multiplying two i2N integers produces (at most) a i4N result, but because + // the calculation of top i2N is not necessary, we omit it. + // In total, this implementations performs 10 intermediate multiplications + // and 16 additions. The number of multiplications could be decreased by + // switching to a more efficient algorithm like Karatsuba. This would, + // however, require being able to perform (intermediate) wide additions and + // subtractions, so it is not clear that such implementation would be more + // efficient. + + APInt lowMaskVal(newBitWidth, 1); + lowMaskVal = lowMaskVal.shl(digitBitWidth) - 1; + Value lowMask = + createScalarOrSplatConstant(rewriter, loc, newElemTy, lowMaskVal); + auto getLowDigit = [lowMask, newElemTy, loc, &rewriter](Value v) { + return rewriter.create(loc, newElemTy, v, lowMask); + }; + + Value shiftVal = + createScalarOrSplatConstant(rewriter, loc, newElemTy, digitBitWidth); + auto getHighDigit = [shiftVal, loc, &rewriter](Value v) { + return rewriter.create(loc, v, shiftVal); + }; + + Value zeroDigit = createScalarOrSplatConstant(rewriter, loc, newElemTy, 0); + std::array resultDigits = {zeroDigit, zeroDigit, zeroDigit, + zeroDigit}; + std::array lhsDigits = { + getLowDigit(lhsElem0), getHighDigit(lhsElem0), getLowDigit(lhsElem1), + getHighDigit(lhsElem1)}; + std::array rhsDigits = { + getLowDigit(rhsElem0), getHighDigit(rhsElem0), getLowDigit(rhsElem1), + getHighDigit(rhsElem1)}; + + for (unsigned i = 0, e = lhsDigits.size(); i != e; ++i) { + for (unsigned j = 0; i + j != e; ++j) { + Value mul = + rewriter.create(loc, lhsDigits[i], rhsDigits[j]); + Value current = + rewriter.createOrFold(loc, resultDigits[i + j], mul); + resultDigits[i + j] = getLowDigit(current); + if (i + j + 1 != e) { + Value carry = rewriter.createOrFold( + loc, resultDigits[i + j + 1], getHighDigit(current)); + resultDigits[i + j + 1] = carry; + } + } + } + + auto combineDigits = [shiftVal, loc, &rewriter](Value low, Value high) { + Value highBits = rewriter.create(loc, high, shiftVal); + return rewriter.create(loc, low, highBits); + }; + Value resultElem0 = combineDigits(resultDigits[0], resultDigits[1]); + Value resultElem1 = combineDigits(resultDigits[2], resultDigits[3]); + Value resultVec = + constructResultVector(rewriter, loc, newTy, {resultElem0, resultElem1}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertExtSI //===----------------------------------------------------------------------===// @@ -287,8 +408,8 @@ Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn()); Value extended = rewriter.createOrFold( loc, newResultComponentTy, newOperand); - Value operandZeroCst = rewriter.create( - loc, rewriter.getZeroAttr(newResultComponentTy)); + Value operandZeroCst = + createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0); Value signBit = rewriter.create( loc, arith::CmpIPredicate::slt, extended, operandZeroCst); Value signValue = @@ -327,8 +448,7 @@ Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn()); Value extended = rewriter.createOrFold( loc, newResultComponentTy, newOperand); - Value zeroCst = rewriter.create( - op->getLoc(), rewriter.getZeroAttr(newTy)); + Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0); Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0); rewriter.replaceOp(op, newRes); return success(); @@ -371,7 +491,7 @@ using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase; void runOnOperation() override { - if (!llvm::isPowerOf2_32(widestIntSupported)) { + if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { signalPassFailure(); return; } @@ -408,7 +528,8 @@ unsigned widestIntSupportedByTarget) : maxIntWidth(widestIntSupportedByTarget) { assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) && - "Only power-of-two integers are supported"); + "Only power-of-two integers with are supported"); + assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow"); // Scalar case. addConversion([this](IntegerType ty) -> Optional { @@ -473,7 +594,7 @@ // Misc ops. ConvertConstant, // Binary ops. - ConvertAddI, + ConvertAddI, ConvertMulI, // Extension and truncation ops. ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int-very-wide.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int-very-wide.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int-very-wide.mlir @@ -0,0 +1,18 @@ +// Check that emulation of wery wide types (>64 bits) works as expected. + +// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=512" %s | FileCheck %s + +// CHECK-LABEL: func.func @muli_scalar +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi512>, [[ARG1:%.+]]: vector<2xi512>) -> vector<2xi512> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi512> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi512> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi512> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi512> +// +// Check that the mask for the low 256-bits was generated correctly. The exact expected value is 2^256 - 1. +// CHECK-NEXT: {{.+}} = arith.constant 115792089237316195423570985008687907853269984665640564039457584007913129639935 : i512 +// CHECK: return {{%.+}} : vector<2xi512> +func.func @muli_scalar(%a : i1024, %b : i1024) -> i1024 { + %m = arith.muli %a, %b : i1024 + return %m : i1024 +} diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -205,3 +205,57 @@ %b = arith.trunci %a : vector<3xi64> to vector<3xi16> return %b : vector<3xi16> } + +// CHECK-LABEL: func.func @muli_scalar +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> +// +// CHECK-DAG: [[MASK:%.+]] = arith.constant 65535 : i32 +// CHECK-DAG: [[C16:%.+]] = arith.constant 16 : i32 +// +// CHECK: [[LOWLOW0:%.+]] = arith.andi [[LOW0]], [[MASK]] : i32 +// CHECK-NEXT: [[HIGHLOW0:%.+]] = arith.shrui [[LOW0]], [[C16]] : i32 +// CHECK-NEXT: [[LOWHIGH0:%.+]] = arith.andi [[HIGH0]], [[MASK]] : i32 +// CHECK-NEXT: [[HIGHHIGH0:%.+]] = arith.shrui [[HIGH0]], [[C16]] : i32 +// CHECK-NEXT: [[LOWLOW1:%.+]] = arith.andi [[LOW1]], [[MASK]] : i32 +// CHECK-NEXT: [[HIGHLOW1:%.+]] = arith.shrui [[LOW1]], [[C16]] : i32 +// CHECK-NEXT: [[LOWHIGH1:%.+]] = arith.andi [[HIGH1]], [[MASK]] : i32 +// CHECK-NEXT: [[HIGHHIGH1:%.+]] = arith.shrui [[HIGH1]], [[C16]] : i32 +// +// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWLOW1]] : i32 +// CHECK-DAG {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHLOW1]] : i32 +// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWHIGH1]] : i32 +// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHHIGH1]] : i32 +// +// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWLOW1]] : i32 +// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[HIGHLOW1]] : i32 +// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWHIGH1]] : i32 +// +// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[LOWLOW1]] : i32 +// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[HIGHLOW1]] : i32 +// +// CHECK-DAG: {{%.+}} = arith.muli [[HIGHHIGH0]], [[LOWLOW1]] : i32 +// +// CHECK: [[RESHIGH0:%.+]] = arith.shli {{%.+}}, [[C16]] : i32 +// CHECK-NEXT: [[RES0:%.+]] = arith.ori {{%.+}}, [[RESHIGH0]] : i32 +// CHECK-NEXT: [[RESHIGH1:%.+]] = arith.shli {{%.+}}, [[C16]] : i32 +// CHECK-NEXT: [[RES1:%.+]] = arith.ori {{%.+}}, [[RESHIGH1]] : i32 +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @muli_scalar(%a : i64, %b : i64) -> i64 { + %m = arith.muli %a, %b : i64 + return %m : i64 +} + +// CHECK-LABEL: func.func @muli_vector +// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK: return {{%.+}} : vector<3x2xi32> +func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %m = arith.muli %a, %b : vector<3xi64> + return %m : vector<3xi64> +}