diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -419,78 +419,26 @@ return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); - Type newElemTy = reduceInnermostDim(newTy); - unsigned newBitWidth = newTy.getElementTypeBitWidth(); - unsigned digitBitWidth = newBitWidth / 2; - auto [lhsElem0, lhsElem1] = extractLastDimHalves(rewriter, loc, adaptor.getLhs()); auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, adaptor.getRhs()); - // Emulate multiplication by splitting each input element of type i2N into 4 - // digits of type iN and bit width i(N/2). This is so that the intermediate - // multiplications and additions do not overflow. We extract these i(N/2) - // digits from iN vector elements by masking (low digit) and shifting right - // (high digit). - // // The multiplication algorithm used is the standard (long) multiplication. - // Multiplying two i2N integers produces (at most) a i4N result, but because - // the calculation of top i2N is not necessary, we omit it. - // In total, this implementations performs 10 intermediate multiplications - // and 16 additions. The number of multiplications could be decreased by - // switching to a more efficient algorithm like Karatsuba. This would, - // however, require being able to perform (intermediate) wide additions and - // subtractions, so it is not clear that such implementation would be more - // efficient. - - APInt lowMaskVal(newBitWidth, 1); - lowMaskVal = lowMaskVal.shl(digitBitWidth) - 1; - Value lowMask = - createScalarOrSplatConstant(rewriter, loc, newElemTy, lowMaskVal); - auto getLowDigit = [lowMask, newElemTy, loc, &rewriter](Value v) { - return rewriter.create(loc, newElemTy, v, lowMask); - }; + // Multiplying two i2N integers produces (at most) an i4N result, but + // because the calculation of top i2N is not necessary, we omit it. + auto mulLowLow = + rewriter.create(loc, lhsElem0, rhsElem0); + Value mulLowHi = rewriter.create(loc, lhsElem0, rhsElem1); + Value mulHiLow = rewriter.create(loc, lhsElem1, rhsElem0); + + Value resLow = mulLowLow.getLow(); + Value resHi = + rewriter.create(loc, mulLowLow.getHigh(), mulLowHi); + resHi = rewriter.create(loc, resHi, mulHiLow); - Value shiftVal = - createScalarOrSplatConstant(rewriter, loc, newElemTy, digitBitWidth); - auto getHighDigit = [shiftVal, loc, &rewriter](Value v) { - return rewriter.create(loc, v, shiftVal); - }; - - Value zeroDigit = createScalarOrSplatConstant(rewriter, loc, newElemTy, 0); - std::array resultDigits = {zeroDigit, zeroDigit, zeroDigit, - zeroDigit}; - std::array lhsDigits = { - getLowDigit(lhsElem0), getHighDigit(lhsElem0), getLowDigit(lhsElem1), - getHighDigit(lhsElem1)}; - std::array rhsDigits = { - getLowDigit(rhsElem0), getHighDigit(rhsElem0), getLowDigit(rhsElem1), - getHighDigit(rhsElem1)}; - - for (unsigned i = 0, e = lhsDigits.size(); i != e; ++i) { - for (unsigned j = 0; i + j != e; ++j) { - Value mul = - rewriter.create(loc, lhsDigits[i], rhsDigits[j]); - Value current = - rewriter.createOrFold(loc, resultDigits[i + j], mul); - resultDigits[i + j] = getLowDigit(current); - if (i + j + 1 != e) { - Value carry = rewriter.createOrFold( - loc, resultDigits[i + j + 1], getHighDigit(current)); - resultDigits[i + j + 1] = carry; - } - } - } - - auto combineDigits = [shiftVal, loc, &rewriter](Value low, Value high) { - Value highBits = rewriter.create(loc, high, shiftVal); - return rewriter.create(loc, low, highBits); - }; - Value resultElem0 = combineDigits(resultDigits[0], resultDigits[1]); - Value resultElem1 = combineDigits(resultDigits[2], resultDigits[3]); Value resultVec = - constructResultVector(rewriter, loc, newTy, {resultElem0, resultElem1}); + constructResultVector(rewriter, loc, newTy, {resLow, resHi}); rewriter.replaceOp(op, resultVec); return success(); } diff --git a/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir b/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir @@ -9,8 +9,12 @@ // CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi512> // CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi512> // -// Check that the mask for the low 256-bits was generated correctly. The exact expected value is 2^256 - 1. -// CHECK-NEXT: {{.+}} = arith.constant 115792089237316195423570985008687907853269984665640564039457584007913129639935 : i512 +// CHECK-DAG: arith.mului_extended +// CHECK-DAG: arith.muli +// CHECK-DAG: arith.muli +// CHECK-NEXT: arith.addi +// CHECK-NEXT: arith.addi +// // CHECK: return {{%.+}} : vector<2xi512> func.func @muli_scalar(%a : i1024, %b : i1024) -> i1024 { %m = arith.muli %a, %b : i1024 diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -661,44 +661,20 @@ // CHECK-LABEL: func.func @muli_scalar // CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> -// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> -// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> -// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> -// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> // -// CHECK-DAG: [[MASK:%.+]] = arith.constant 65535 : i32 -// CHECK-DAG: [[C16:%.+]] = arith.constant 16 : i32 +// CHECK-DAG: [[RESLOW:%.+]], [[HI0:%.+]] = arith.mului_extended [[LOW0]], [[LOW1]] : i32 +// CHECK-DAG: [[HI1:%.+]] = arith.muli [[LOW0]], [[HIGH1]] : i32 +// CHECK-DAG: [[HI2:%.+]] = arith.muli [[HIGH0]], [[LOW1]] : i32 +// CHECK-NEXT: [[RESHI1:%.+]] = arith.addi [[HI0]], [[HI1]] : i32 +// CHECK-NEXT: [[RESHI2:%.+]] = arith.addi [[RESHI1]], [[HI2]] : i32 // -// CHECK: [[LOWLOW0:%.+]] = arith.andi [[LOW0]], [[MASK]] : i32 -// CHECK-NEXT: [[HIGHLOW0:%.+]] = arith.shrui [[LOW0]], [[C16]] : i32 -// CHECK-NEXT: [[LOWHIGH0:%.+]] = arith.andi [[HIGH0]], [[MASK]] : i32 -// CHECK-NEXT: [[HIGHHIGH0:%.+]] = arith.shrui [[HIGH0]], [[C16]] : i32 -// CHECK-NEXT: [[LOWLOW1:%.+]] = arith.andi [[LOW1]], [[MASK]] : i32 -// CHECK-NEXT: [[HIGHLOW1:%.+]] = arith.shrui [[LOW1]], [[C16]] : i32 -// CHECK-NEXT: [[LOWHIGH1:%.+]] = arith.andi [[HIGH1]], [[MASK]] : i32 -// CHECK-NEXT: [[HIGHHIGH1:%.+]] = arith.shrui [[HIGH1]], [[C16]] : i32 -// -// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWLOW1]] : i32 -// CHECK-DAG {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHLOW1]] : i32 -// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWHIGH1]] : i32 -// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHHIGH1]] : i32 -// -// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWLOW1]] : i32 -// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[HIGHLOW1]] : i32 -// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWHIGH1]] : i32 -// -// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[LOWLOW1]] : i32 -// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[HIGHLOW1]] : i32 -// -// CHECK-DAG: {{%.+}} = arith.muli [[HIGHHIGH0]], [[LOWLOW1]] : i32 -// -// CHECK: [[RESHIGH0:%.+]] = arith.shli {{%.+}}, [[C16]] : i32 -// CHECK-NEXT: [[RES0:%.+]] = arith.ori {{%.+}}, [[RESHIGH0]] : i32 -// CHECK-NEXT: [[RESHIGH1:%.+]] = arith.shli {{%.+}}, [[C16]] : i32 -// CHECK-NEXT: [[RES1:%.+]] = arith.ori {{%.+}}, [[RESHIGH1]] : i32 -// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> -// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32> -// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RESLOW]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RESHI2]], [[INS0]] [1] : i32 into vector<2xi32> // CHECK-NEXT: return [[INS1]] : vector<2xi32> func.func @muli_scalar(%a : i64, %b : i64) -> i64 { %m = arith.muli %a, %b : i64 @@ -707,6 +683,11 @@ // CHECK-LABEL: func.func @muli_vector // CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK-DAG: arith.mului_extended +// CHECK-DAG: arith.muli +// CHECK-DAG: arith.muli +// CHECK-NEXT: arith.addi +// CHECK-NEXT: arith.addi // CHECK: return {{%.+}} : vector<3x2xi32> func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { %m = arith.muli %a, %b : vector<3xi64>