diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -97,6 +97,71 @@ unsigned maxIntWidth; }; +struct ConvertConstant final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override { + Type oldType = op.getType(); + auto newType = getTypeConverter()->convertType(oldType).cast(); + const unsigned newBitWidth = newType.getElementTypeBitWidth(); + Attribute oldValue = op.getValueAttr(); + + if (auto intAttr = oldValue.dyn_cast()) { + auto [low, high] = getHalves(intAttr.getValue(), newBitWidth); + auto newAttr = DenseElementsAttr::get(newType, {low, high}); + rewriter.replaceOpWithNewOp(op, newAttr); + return success(); + } + + if (auto splatAttr = oldValue.dyn_cast()) { + auto [low, high] = + getHalves(splatAttr.getSplatValue(), newBitWidth); + const auto numSplatElems = + static_cast(splatAttr.getNumElements()); + auto values = llvm::to_vector( + llvm::concat(SmallVector(numSplatElems, low), + SmallVector(numSplatElems, high))); + + auto attr = DenseElementsAttr::get(newType, values); + rewriter.replaceOpWithNewOp(op, attr); + return success(); + } + + if (auto elemsAttr = oldValue.dyn_cast()) { + const auto numElems = static_cast(elemsAttr.getNumElements()); + SmallVector lowVals; + lowVals.reserve(numElems); + SmallVector highVals; + highVals.reserve(numElems); + + for (const APInt &origVal : elemsAttr.getValues()) { + auto [low, high] = getHalves(origVal, newBitWidth); + lowVals.push_back(std::move(low)); + highVals.push_back(std::move(high)); + } + auto values = llvm::to_vector( + llvm::concat(std::move(lowVals), std::move(highVals))); + + auto attr = DenseElementsAttr::get(newType, values); + rewriter.replaceOpWithNewOp(op, attr); + return success(); + } + + return rewriter.notifyMatchFailure(op.getLoc(), + "Unhandled constant attribute"); + } + +private: + static std::pair getHalves(const APInt &value, + unsigned newBitWidth) { + APInt low = value.extractBits(newBitWidth, 0); + APInt high = value.extractBits(newBitWidth, newBitWidth); + return {std::move(low), std::move(high)}; + } +}; + struct EmulateWideIntPass final : arith::impl::ArithmeticEmulateWideIntBase { EmulateWideIntPass(unsigned widestIntSupported) { @@ -128,6 +193,8 @@ ConversionTarget target(*ctx); // clang-format off target.addDynamicallyLegalOp< + // arith ops + arith::ConstantOp, // func ops func::FuncOp, func::CallOp, func::ReturnOp >( @@ -156,6 +223,8 @@ void populateWideIntEmulationPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { + patterns.add(typeConverter, patterns.getContext()); + // Populate `func.*` conversion patterns. populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -42,3 +42,27 @@ %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64> return %res : vector<4xi64> } + +// CHECK-LABEL: func @constant_scalar +// CHECK-SAME: () -> vector<2xi32> +// CHECK-NEXT: [[C0:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[C1:%.+]] = arith.constant dense<[0, 1]> : vector<2xi32> +// CHECK-NEXT: [[C2:%.+]] = arith.constant dense<[-7, -1]> : vector<2xi32> +// CHECK-NEXT: return [[C0]] : vector<2xi32> +func.func @constant_scalar() -> i64 { + %c0 = arith.constant 0 : i64 + %c1 = arith.constant 4294967296 : i64 + %c2 = arith.constant -7 : i64 + return %c0 : i64 +} + +// CHECK-LABEL: func @constant_vector +// CHECK-SAME: () -> vector<2x3xi32> +// CHECK-NEXT: [[C0:%.+]] = arith.constant dense<{{\[\[0, 0, 0\], \[1, 1, 1\]\]}}> : vector<2x3xi32> +// CHECK-NEXT: [[C1:%.+]] = arith.constant dense<{{\[\[0, 1, -2\], \[0, 0, -1\]\]}}> : vector<2x3xi32> +// CHECK-NEXT: return [[C0]] : vector<2x3xi32> +func.func @constant_vector() -> vector<3xi64> { + %c0 = arith.constant dense<4294967296> : vector<3xi64> + %c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64> + return %c0 : vector<3xi64> +}