diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" @@ -23,7 +24,75 @@ using namespace mlir; +// Returns N bottom and N top bits from `value`, where N = `newBitWidth`. +// Treats `value` as a 2*N bits-wide integer. +// The bottom bits are returned in the first pair element, while the top bits in +// the second one. +static std::pair getHalves(const APInt &value, + unsigned newBitWidth) { + APInt low = value.extractBits(newBitWidth, 0); + APInt high = value.extractBits(newBitWidth, newBitWidth); + return {std::move(low), std::move(high)}; +} + namespace { +//===----------------------------------------------------------------------===// +// ConvertConstant +//===----------------------------------------------------------------------===// +struct ConvertConstant final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override { + Type oldType = op.getType(); + auto newType = getTypeConverter()->convertType(oldType).cast(); + unsigned newBitWidth = newType.getElementTypeBitWidth(); + Attribute oldValue = op.getValueAttr(); + + if (auto intAttr = oldValue.dyn_cast()) { + auto [low, high] = getHalves(intAttr.getValue(), newBitWidth); + auto newAttr = DenseElementsAttr::get(newType, {low, high}); + rewriter.replaceOpWithNewOp(op, newAttr); + return success(); + } + + if (auto splatAttr = oldValue.dyn_cast()) { + auto [low, high] = + getHalves(splatAttr.getSplatValue(), newBitWidth); + int64_t numSplatElems = splatAttr.getNumElements(); + SmallVector values; + values.reserve(numSplatElems * 2); + for (int64_t i = 0; i < numSplatElems; ++i) { + values.push_back(low); + values.push_back(high); + } + + auto attr = DenseElementsAttr::get(newType, values); + rewriter.replaceOpWithNewOp(op, attr); + return success(); + } + + if (auto elemsAttr = oldValue.dyn_cast()) { + int64_t numElems = elemsAttr.getNumElements(); + SmallVector values; + values.reserve(numElems * 2); + for (const APInt &origVal : elemsAttr.getValues()) { + auto [low, high] = getHalves(origVal, newBitWidth); + values.push_back(std::move(low)); + values.push_back(std::move(high)); + } + + auto attr = DenseElementsAttr::get(newType, values); + rewriter.replaceOpWithNewOp(op, attr); + return success(); + } + + return rewriter.notifyMatchFailure(op.getLoc(), + "unhandled constant attribute"); + } +}; + struct EmulateWideIntPass final : arith::impl::ArithmeticEmulateWideIntBase { using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase; @@ -42,7 +111,11 @@ target.addDynamicallyLegalOp([&typeConverter](Operation *op) { return typeConverter.isLegal(cast(op).getFunctionType()); }); - target.addDynamicallyLegalOp( + target.addDynamicallyLegalOp< + // `func.*` ops + func::CallOp, func::ReturnOp, + // `arith.*` ops + arith::ConstantOp>( [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); RewritePatternSet patterns(ctx); @@ -119,6 +192,9 @@ typeConverter); populateCallOpTypeConversionPattern(patterns, typeConverter); populateReturnOpTypeConversionPattern(patterns, typeConverter); + + // Populate `arith.*` conversion patterns. + patterns.add(typeConverter, patterns.getContext()); } } // namespace mlir diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -49,3 +49,29 @@ %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64> return %res : vector<4xi64> } + +// CHECK-LABEL: func @constant_scalar +// CHECK-SAME: () -> vector<2xi32> +// CHECK-NEXT: [[C0:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[C1:%.+]] = arith.constant dense<[0, 1]> : vector<2xi32> +// CHECK-NEXT: [[C2:%.+]] = arith.constant dense<[-7, -1]> : vector<2xi32> +// CHECK-NEXT: return [[C0]] : vector<2xi32> +func.func @constant_scalar() -> i64 { + %c0 = arith.constant 0 : i64 + %c1 = arith.constant 4294967296 : i64 + %c2 = arith.constant -7 : i64 + return %c0 : i64 +} + +// CHECK-LABEL: func @constant_vector +// CHECK-SAME: () -> vector<3x2xi32> +// CHECK-NEXT: [[C0:%.+]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[0, 1], [0, 1], [0, 1]]> : vector<3x2xi32> +// CHECK-NEXT: [[C1:%.+]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[0, 0], [1, 0], [-2, -1]]> : vector<3x2xi32> +// CHECK-NEXT: return [[C0]] : vector<3x2xi32> +func.func @constant_vector() -> vector<3xi64> { + %c0 = arith.constant dense<4294967296> : vector<3xi64> + %c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64> + return %c0 : vector<3xi64> +}