diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -8,11 +8,13 @@ #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" #include "llvm/Support/MathExtras.h" #include @@ -32,6 +34,66 @@ return converter.isLegal(op); } +std::pair getHalves(const APInt &value, unsigned newBitWidth) { + APInt low = value.extractBits(newBitWidth, 0); + APInt high = value.extractBits(newBitWidth, newBitWidth); + return {std::move(low), std::move(high)}; +} + +struct ConvertConstant final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override { + Type oldType = op.getType(); + auto newType = getTypeConverter()->convertType(oldType).cast(); + unsigned newBitWidth = newType.getElementTypeBitWidth(); + Attribute oldValue = op.getValueAttr(); + + if (auto intAttr = oldValue.dyn_cast()) { + auto [low, high] = getHalves(intAttr.getValue(), newBitWidth); + auto newAttr = DenseElementsAttr::get(newType, {low, high}); + rewriter.replaceOpWithNewOp(op, newAttr); + return success(); + } + + if (auto splatAttr = oldValue.dyn_cast()) { + auto [low, high] = + getHalves(splatAttr.getSplatValue(), newBitWidth); + int64_t numSplatElems = splatAttr.getNumElements(); + SmallVector values; + values.reserve(numSplatElems * 2); + for (int64_t i = 0; i < numSplatElems; ++i) { + values.push_back(low); + values.push_back(high); + } + + auto attr = DenseElementsAttr::get(newType, values); + rewriter.replaceOpWithNewOp(op, attr); + return success(); + } + + if (auto elemsAttr = oldValue.dyn_cast()) { + int64_t numElems = elemsAttr.getNumElements(); + SmallVector values; + values.reserve(numElems * 2); + for (const APInt &origVal : elemsAttr.getValues()) { + auto [low, high] = getHalves(origVal, newBitWidth); + values.push_back(std::move(low)); + values.push_back(std::move(high)); + } + + auto attr = DenseElementsAttr::get(newType, values); + rewriter.replaceOpWithNewOp(op, attr); + return success(); + } + + return rewriter.notifyMatchFailure(op.getLoc(), + "unhandled constant attribute"); + } +}; + struct EmulateWideIntPass final : arith::impl::ArithmeticEmulateWideIntBase { using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase; @@ -55,7 +117,8 @@ typeConverter.addTargetMaterialization(addUnrealizedCast); ConversionTarget target(*ctx); - target.addDynamicallyLegalOp( + target.addDynamicallyLegalOp( [&typeConverter](Operation *op) { return isOpLegal(op, typeConverter); }); @@ -129,6 +192,8 @@ void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) { + patterns.add(typeConverter, patterns.getContext()); + // Populate `func.*` conversion patterns. populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -62,3 +62,29 @@ %i = arith.fptosi %f : f64 to i64 return %i : i64 } + +// CHECK-LABEL: func @constant_scalar +// CHECK-SAME: () -> vector<2xi32> +// CHECK-NEXT: [[C0:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[C1:%.+]] = arith.constant dense<[0, 1]> : vector<2xi32> +// CHECK-NEXT: [[C2:%.+]] = arith.constant dense<[-7, -1]> : vector<2xi32> +// CHECK-NEXT: return [[C0]] : vector<2xi32> +func.func @constant_scalar() -> i64 { + %c0 = arith.constant 0 : i64 + %c1 = arith.constant 4294967296 : i64 + %c2 = arith.constant -7 : i64 + return %c0 : i64 +} + +// CHECK-LABEL: func @constant_vector +// CHECK-SAME: () -> vector<3x2xi32> +// CHECK-NEXT: [[C0:%.+]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[0, 1], [0, 1], [0, 1]]> : vector<3x2xi32> +// CHECK-NEXT: [[C1:%.+]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[0, 0], [1, 0], [-2, -1]]> : vector<3x2xi32> +// CHECK-NEXT: return [[C0]] : vector<3x2xi32> +func.func @constant_vector() -> vector<3xi64> { + %c0 = arith.constant dense<4294967296> : vector<3xi64> + %c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64> + return %c0 : vector<3xi64> +}