diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -28,10 +28,10 @@ // Common Helper Functions //===----------------------------------------------------------------------===// -// Returns N bottom and N top bits from `value`, where N = `newBitWidth`. -// Treats `value` as a 2*N bits-wide integer. -// The bottom bits are returned in the first pair element, while the top bits in -// the second one. +/// Returns N bottom and N top bits from `value`, where N = `newBitWidth`. +/// Treats `value` as a 2*N bits-wide integer. +/// The bottom bits are returned in the first pair element, while the top bits +/// in the second one. static std::pair getHalves(const APInt &value, unsigned newBitWidth) { APInt low = value.extractBits(newBitWidth, 0); @@ -39,11 +39,11 @@ return {std::move(low), std::move(high)}; } -// Returns the type with the last (innermost) dimention reduced to x1. -// Scalarizes 1D vector inputs to match how we extract/insert vector values, -// e.g.: -// - vector<3x2xi16> --> vector<3x1xi16> -// - vector<2xi16> --> i16 +/// Returns the type with the last (innermost) dimention reduced to x1. +/// Scalarizes 1D vector inputs to match how we extract/insert vector values, +/// e.g.: +/// - vector<3x2xi16> --> vector<3x1xi16> +/// - vector<2xi16> --> i16 static Type reduceInnermostDim(VectorType type) { if (type.getShape().size() == 1) return type.getElementType(); @@ -53,7 +53,7 @@ return VectorType::get(newShape, type.getElementType()); } -// Returns a constant of integer of vector type filled with (repeated) `value`. +/// Returns a constant of integer of vector type filled with (repeated) `value`. static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, const APInt &value) { @@ -68,7 +68,7 @@ return rewriter.create(loc, attr); } -// Returns a constant of integer of vector type filled with (repeated) `value`. +/// Returns a constant of integer of vector type filled with (repeated) `value`. static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, int64_t value) { @@ -82,11 +82,11 @@ APInt(elementBitWidth, value)); } -// Extracts the `input` vector slice with elements at the last dimension offset -// by `lastOffset`. Returns a value of vector type with the last dimension -// reduced to x1 or fully scalarized, e.g.: -// - vector<3x2xi16> --> vector<3x1xi16> -// - vector<2xi16> --> i16 +/// Extracts the `input` vector slice with elements at the last dimension offset +/// by `lastOffset`. Returns a value of vector type with the last dimension +/// reduced to x1 or fully scalarized, e.g.: +/// - vector<3x2xi16> --> vector<3x1xi16> +/// - vector<2xi16> --> i16 static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset) { @@ -107,8 +107,8 @@ sizes, strides); } -// Extracts two vector slices from the `input` whose type is `vector<...x2T>`, -// with the first element at offset 0 and the second element at offset 1. +/// Extracts two vector slices from the `input` whose type is `vector<...x2T>`, +/// with the first element at offset 0 and the second element at offset 1. static std::pair extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input) { @@ -133,8 +133,8 @@ return rewriter.create(loc, newVecTy, input); } -// Performs a vector shape cast to append an x1 dimension. If the -// `input` is a scalar, this is a noop. +/// Performs a vector shape cast to append an x1 dimension. If the +/// `input` is a scalar, this is a noop. static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input) { auto vecTy = input.getType().dyn_cast(); @@ -148,9 +148,9 @@ return rewriter.create(loc, newTy, input); } -// Inserts the `source` vector slice into the `dest` vector at offset -// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is a -// 1D vector. +/// Inserts the `source` vector slice into the `dest` vector at offset +/// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is +/// a 1D vector. static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset) { @@ -168,12 +168,12 @@ offsets, strides); } -// Constructs a new vector of type `resultType` by creating a series of -// insertions of `resultComponents`, each at the next offset of the last vector -// dimension. -// When all `resultComponents` are scalars, the result type is `vector`; -// when `resultComponents` are `vector<...x1xT>`s, the result type is -// `vector<...xNxT>`, where `N` is the number of `resultComponenets`. +/// Constructs a new vector of type `resultType` by creating a series of +/// insertions of `resultComponents`, each at the next offset of the last vector +/// dimension. +/// When all `resultComponents` are scalars, the result type is `vector`; +/// when `resultComponents` are `vector<...x1xT>`s, the result type is +/// `vector<...xNxT>`, where `N` is the number of `resultComponenets`. static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents) { @@ -451,6 +451,90 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertShRUI +//===----------------------------------------------------------------------===// + +struct ConvertShRUI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + Type oldTy = op.getType(); + auto newTy = getTypeConverter()->convertType(oldTy).cast(); + Type newOperandTy = reduceInnermostDim(newTy); + unsigned newBitWidth = newTy.getElementTypeBitWidth(); + + auto [lhsElem0, lhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0); + + // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and + // high halves of the results separately: + // 1. low := a or b or c, where: + // a) Bits from LHS.low, shifted by the RHS. + // b) Bits from LHS.high, shifted left. These matter when + // RHS < newBitWidth, e.g.: + // [hhhh][0000] shrui 3 --> [000h][hhh0] + // ^ + // | + // [hhhh] shli (4 - 1) + // c) Bits from LHS.high, shifted right. These come into play when + // RHS > newBitWidth, e.g.: + // [hhhh][0000] shrui 7 --> [0000][000h] + // ^ + // | + // [hhhh] shrui (7 - 4) + // + // 2. high := LHS.high shrui RHS + // + // Because shifts by values >= newBitWidth are undefined, we ignore the high + // half of RHS, and introduce 'bounds checks' to account for + // RHS.low > newBitWidth. + // + // TODO: Explore possible optimizations. + Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0); + Value elemBitWidth = + createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); + + Value illegalElemShift = rewriter.create( + loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); + + Value shiftedElem0 = + rewriter.create(loc, lhsElem0, rhsElem0); + Value resElem0Low = rewriter.create(loc, illegalElemShift, + zeroCst, shiftedElem0); + Value shiftedElem1 = + rewriter.create(loc, lhsElem1, rhsElem0); + Value resElem1 = rewriter.create(loc, illegalElemShift, + zeroCst, shiftedElem1); + + Value cappedShiftAmount = rewriter.create( + loc, illegalElemShift, elemBitWidth, rhsElem0); + Value leftShiftAmount = + rewriter.create(loc, elemBitWidth, cappedShiftAmount); + Value shiftedLeft = + rewriter.create(loc, lhsElem1, leftShiftAmount); + Value overshotShiftAmount = + rewriter.create(loc, rhsElem0, elemBitWidth); + Value shiftedRight = + rewriter.create(loc, lhsElem1, overshotShiftAmount); + + Value resElem0High = rewriter.create( + loc, illegalElemShift, shiftedRight, shiftedLeft); + Value resElem0 = + rewriter.create(loc, resElem0Low, resElem0High); + + Value resultVec = + constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertTruncI //===----------------------------------------------------------------------===// @@ -607,7 +691,7 @@ // Misc ops. ConvertConstant, ConvertVectorPrint, // Binary ops. - ConvertAddI, ConvertMulI, + ConvertAddI, ConvertMulI, ConvertShRUI, // Extension and truncation ops. ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -259,3 +259,57 @@ %m = arith.muli %a, %b : vector<3xi64> return %m : vector<3xi64> } + +// CHECK-LABEL: func.func @shrui_scalar +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[CST0:%.+]] = arith.constant 0 : i32 +// CHECK-NEXT: [[CST32:%.+]] = arith.constant 32 : i32 +// CHECK-DAG: [[OOB:%.+]] = arith.cmpi uge, [[LOW1]], [[CST32]] : i32 +// CHECK-DAG: [[SHLOW0:%.+]] = arith.shrui [[LOW0]], [[LOW1]] : i32 +// CHECK-NEXT: [[RES0LOW:%.+]] = arith.select [[OOB]], [[CST0]], [[SHLOW0]] : i32 +// CHECK-NEXT: [[SHRHIGH0:%.+]] = arith.shrui [[HIGH0]], [[LOW1]] : i32 +// CHECK-NEXT: [[RESLOW1:%.+]] = arith.select [[OOB]], [[CST0]], [[SHRHIGH0]] : i32 +// CHECK-NEXT: [[SHAMT:%.+]] = arith.select [[OOB]], [[CST32]], [[LOW1]] : i32 +// CHECK-NEXT: [[LSHAMT:%.+]] = arith.subi [[CST32]], [[SHAMT]] : i32 +// CHECK-NEXT: [[SHLHIGH0:%.+]] = arith.shli [[HIGH0]], [[LSHAMT]] : i32 +// CHECK-NEXT: [[RSHAMT:%.+]] = arith.subi [[LOW1]], [[CST32]] : i32 +// CHECK-NEXT: [[SHRHIGH0:%.+]] = arith.shrui [[HIGH0]], [[RSHAMT]] : i32 +// CHECK-NEXT: [[RES0HIGH:%.+]] = arith.select [[OOB]], [[SHRHIGH0]], [[SHLHIGH0]] : i32 +// CHECK-NEXT: [[RES0:%.+]] = arith.ori [[RES0LOW]], [[RES0HIGH]] : i32 +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RESLOW1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @shrui_scalar(%a : i64, %b : i64) -> i64 { + %c = arith.shrui %a, %b : i64 + return %c : i64 +} + +// CHECK-LABEL: func.func @shrui_scalar_cst_2 +// CHECK-SAME: ({{%.+}}: vector<2xi32>) -> vector<2xi32> +// CHECK: return {{%.+}} : vector<2xi32> +func.func @shrui_scalar_cst_2(%a : i64) -> i64 { + %b = arith.constant 2 : i64 + %c = arith.shrui %a, %b : i64 + return %c : i64 +} + +// CHECK-LABEL: func.func @shrui_scalar_cst_36 +// CHECK-SAME: ({{%.+}}: vector<2xi32>) -> vector<2xi32> +// CHECK: return {{%.+}} : vector<2xi32> +func.func @shrui_scalar_cst_36(%a : i64) -> i64 { + %b = arith.constant 36 : i64 + %c = arith.shrui %a, %b : i64 + return %c : i64 +} + +// CHECK-LABEL: func.func @shrui_vector +// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK: return {{%.+}} : vector<3x2xi32> +func.func @shrui_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %m = arith.shrui %a, %b : vector<3xi64> + return %m : vector<3xi64> +}