diff --git a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithDialect + MLIRComplexDialect MLIRDialect MLIRIR ) diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; @@ -122,7 +123,57 @@ if (toFloatType.getWidth() < fromFloatType.getWidth()) return b.create(loc, toFloatType, operand); } + } else if (auto toComplexType = dyn_cast(toType)) { + if (auto fromComplexType = dyn_cast(operand.getType())) { + if (isa(toComplexType.getElementType()) && + isa(fromComplexType.getElementType())) { + Value real = b.create(loc, operand); + Value imag = b.create(loc, operand); + if (toComplexType.getElementType().getIntOrFloatBitWidth() < + fromComplexType.getElementType().getIntOrFloatBitWidth()) { + real = b.create(loc, toComplexType.getElementType(), + real); + imag = b.create(loc, toComplexType.getElementType(), + imag); + } else { + real = b.create(loc, toComplexType.getElementType(), + real); + imag = b.create(loc, toComplexType.getElementType(), + imag); + } + return b.create(loc, toComplexType, real, imag); + } + } + + if (auto fromFpType = dyn_cast(operand.getType())) { + FloatType toFloatTy = toComplexType.getElementType().cast(); + int32_t toBitwidth = toFloatTy.getIntOrFloatBitWidth(); + Value from = operand; + if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { + from = b.create(loc, toFloatTy, from); + } + if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { + from = b.create(loc, toFloatTy, from); + } + Value zero = b.create( + loc, mlir::APFloat(toFloatTy.getFloatSemantics(), 0), toFloatTy); + return b.create(loc, toComplexType, from, zero); + } + + if (auto fromIntType = dyn_cast(operand.getType())) { + FloatType toFloatTy = toComplexType.getElementType().cast(); + Value from = operand; + if (isUnsignedCast) { + from = b.create(loc, toFloatTy, from); + } else { + from = b.create(loc, toFloatTy, from); + } + Value zero = b.create( + loc, mlir::APFloat(toFloatTy.getFloatSemantics(), 0), toFloatTy); + return b.create(loc, toComplexType, from, zero); + } } + emitWarning(loc) << "could not cast operand of type " << operand.getType() << " to " << toType; return operand;