diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -24,6 +24,10 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Common Helper Functions +//===----------------------------------------------------------------------===// + // Returns N bottom and N top bits from `value`, where N = `newBitWidth`. // Treats `value` as a 2*N bits-wide integer. // The bottom bits are returned in the first pair element, while the top bits in @@ -35,6 +39,96 @@ return {std::move(low), std::move(high)}; } +// Returns the type with the last (innermost) dimention reduced to x1. +// Scalarizes 1D vector inputs to match how we extract/insert vector values, +// e.g.: +// - vector<3x2xi16> --> vector<3x1xi16> +// - vector<2xi16> --> i16 +static Type reduceInnermostDim(VectorType type) { + if (type.getShape().size() == 1) + return type.getElementType(); + + auto newShape = to_vector(type.getShape()); + newShape.back() = 1; + return VectorType::get(newShape, type.getElementType()); +} + +// Extracts the `input` vector slice with elements at the last dimension offset +// by `lastOffset`. Returns a value of vector type with the last dimension +// reduced to x1 or fully scalarized, e.g.: +// - vector<3x2xi16> --> vector<3x1xi16> +// - vector<2xi16> --> i16 +static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, + Location loc, Value input, + int64_t lastOffset) { + llvm::ArrayRef shape = input.getType().cast().getShape(); + assert(lastOffset < shape.back() && "Offset out of bounds"); + + // Scalarize the result in case of 1D vectors. + if (shape.size() == 1) + return rewriter.create(loc, input, lastOffset); + + SmallVector offsets(shape.size(), 0); + offsets.back() = lastOffset; + auto sizes = llvm::to_vector(shape); + sizes.back() = 1; + SmallVector strides(shape.size(), 1); + + return rewriter.create(loc, input, offsets, + sizes, strides); +} + +// Extracts two vector slices from the `input` whose type is `vector<...x2T>`, +// with the first element at offset 0 and the second element at offset 1. +static std::pair +extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, + Value input) { + return {extractLastDimSlice(rewriter, loc, input, 0), + extractLastDimSlice(rewriter, loc, input, 1)}; +} + +// Inserts the `source` vector slice into the `dest` vector at offset +// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is a +// 1D vector. +static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, + Location loc, Value source, Value dest, + int64_t lastOffset) { + llvm::ArrayRef shape = dest.getType().cast().getShape(); + assert(lastOffset < shape.back() && "Offset out of bounds"); + + // Handle scalar source. + if (source.getType().isa()) + return rewriter.create(loc, source, dest, lastOffset); + + SmallVector offsets(shape.size(), 0); + offsets.back() = lastOffset; + SmallVector strides(shape.size(), 1); + return rewriter.create(loc, source, dest, + offsets, strides); +} + +// Constructs a new vector of type `resultType` by creating a series of +// insertions of `resultComponents`, each at the next offset of the last vector +// dimension. +// When all `resultComponents` are scalars, the result type is `vector`; +// when `resultComponents` are `vector<...x1xT>`s, the result type is +// `vector<...xNxT>`, where `N` is the number of `resultComponenets`. +static Value constructResultVector(ConversionPatternRewriter &rewriter, + Location loc, VectorType resultType, + ValueRange resultComponents) { + llvm::ArrayRef resultShape = resultType.getShape(); + assert(!resultShape.empty() && "Result expected to have dimentions"); + assert(resultShape.back() == static_cast(resultComponents.size()) && + "Wrong number of result components"); + + Value resultVec = + rewriter.create(loc, rewriter.getZeroAttr(resultType)); + for (auto [i, component] : llvm::enumerate(resultComponents)) + resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i); + + return resultVec; +} + namespace { //===----------------------------------------------------------------------===// // ConvertConstant @@ -94,6 +188,45 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertAddI +//===----------------------------------------------------------------------===// + +struct ConvertAddI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + auto newTy = getTypeConverter() + ->convertType(op.getType()) + .dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure(loc, "expected scalar or vector type"); + + Type newElemTy = reduceInnermostDim(newTy); + + auto [lhsElem0, lhsElem1] = extractLastDimHalves(rewriter, loc, lhs); + auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, rhs); + + auto lowSum = rewriter.create(loc, lhsElem0, rhsElem0); + Value carryVal = + rewriter.create(loc, newElemTy, lowSum.getCarry()); + + Value high0 = rewriter.create(loc, carryVal, lhsElem1); + Value high = rewriter.create(loc, high0, rhsElem1); + + Value resultVec = + constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -116,12 +249,12 @@ target.addDynamicallyLegalOp([&typeConverter](Operation *op) { return typeConverter.isLegal(cast(op).getFunctionType()); }); - target.addDynamicallyLegalOp< - // `func.*` ops - func::CallOp, func::ReturnOp, - // `arith.*` ops - arith::ConstantOp>( - [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + auto opLegalCallback = [&typeConverter](Operation *op) { + return typeConverter.isLegal(op); + }; + target.addDynamicallyLegalOp(opLegalCallback); + target.addDynamicallyLegalDialect(opLegalCallback); RewritePatternSet patterns(ctx); arith::populateWideIntEmulationPatterns(typeConverter, patterns); @@ -201,5 +334,6 @@ populateReturnOpTypeConversionPattern(patterns, typeConverter); // Populate `arith.*` conversion patterns. - patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -75,3 +75,39 @@ %c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64> return %c0 : vector<3xi64> } + +// CHECK-LABEL: func @addi_scalar_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> +// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1 +// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : i1 to i32 +// CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : i32 +// CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : i32 +// CHECK: [[INS0:%.+]] = vector.insert [[SUM_L]], {{%.+}} [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SUM_H1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @addi_scalar_a_b(%a : i64, %b : i64) -> i64 { + %x = arith.addi %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @addi_vector_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> +// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : vector<4x1xi32>, vector<4x1xi1> +// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : vector<4x1xi1> to vector<4x1xi32> +// CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : vector<4x1xi32> +// CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : vector<4x1xi32> +// CHECK: [[INS0:%.+]] = vector.insert_strided_slice [[SUM_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert_strided_slice [[SUM_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32> +// CHECK-NEXT: return [[INS1]] : vector<4x2xi32> +func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> { + %x = arith.addi %a, %b : vector<4xi64> + return %x : vector<4xi64> +}