diff --git a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithDialect + MLIRComplexDialect MLIRDialect MLIRIR ) diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; @@ -84,45 +86,118 @@ return b.create(loc, targetIntegerType, value); } -Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, - Type toType, bool isUnsignedCast) { - if (operand.getType() == toType) - return operand; - if (auto toIntType = dyn_cast(toType)) { - // If operand is floating point, cast directly to the int type. - if (isa(operand.getType())) { - if (isUnsignedCast) - return b.create(loc, toType, operand); - return b.create(loc, toType, operand); +static Value convertToTargetInt(ImplicitLocOpBuilder &b, Value operand, + IntegerType toType, bool isUnsigned) { + // If operand is floating point, cast directly to the int type. + if (isa(operand.getType())) { + if (isUnsigned) + return b.create(toType, operand); + return b.create(toType, operand); + } + // Cast index operands directly to the int type. + if (operand.getType().isIndex()) + return b.create(toType, operand); + if (auto fromIntType = dyn_cast(operand.getType())) { + // Either extend or truncate. + if (toType.getWidth() > fromIntType.getWidth()) { + if (isUnsigned) + return b.create(toType, operand); + return b.create(toType, operand); } - // Cast index operands directly to the int type. - if (operand.getType().isIndex()) - return b.create(loc, toType, operand); - if (auto fromIntType = dyn_cast(operand.getType())) { - // Either extend or truncate. - if (toIntType.getWidth() > fromIntType.getWidth()) { - if (isUnsignedCast) - return b.create(loc, toType, operand); - return b.create(loc, toType, operand); + if (toType.getWidth() < fromIntType.getWidth()) + return b.create(toType, operand); + } + + return {}; +} + +static Value convertToTargetFp(ImplicitLocOpBuilder &b, Value operand, + FloatType toType, bool isUnsigned) { + // If operand is integer, cast directly to the float type. + // Note that it is unclear how to cast from BF16<->FP16. + if (isa(operand.getType())) { + if (isUnsigned) + return b.create(toType, operand); + return b.create(toType, operand); + } + if (auto fromFpTy = dyn_cast(operand.getType())) { + if (toType.getWidth() > fromFpTy.getWidth()) + return b.create(toType, operand); + if (toType.getWidth() < fromFpTy.getWidth()) + return b.create(toType, operand); + } + + return {}; +} + +static Value convertToTargetComplex(ImplicitLocOpBuilder &b, Value operand, + ComplexType targetType, bool isUnsigned) { + if (auto fromComplexType = dyn_cast(operand.getType())) { + if (isa(targetType.getElementType()) && + isa(fromComplexType.getElementType())) { + Value real = b.create(operand); + Value imag = b.create(operand); + Type targetETy = targetType.getElementType(); + if (targetType.getElementType().getIntOrFloatBitWidth() < + fromComplexType.getElementType().getIntOrFloatBitWidth()) { + real = b.create(targetETy, real); + imag = b.create(targetETy, imag); + } else { + real = b.create(targetETy, real); + imag = b.create(targetETy, imag); } - if (toIntType.getWidth() < fromIntType.getWidth()) - return b.create(loc, toType, operand); + return b.create(targetType, real, imag); } - } else if (auto toFloatType = dyn_cast(toType)) { - // If operand is integer, cast directly to the float type. - // Note that it is unclear how to cast from BF16<->FP16. - if (isa(operand.getType())) { - if (isUnsignedCast) - return b.create(loc, toFloatType, operand); - return b.create(loc, toFloatType, operand); + } + + if (auto fromFpType = dyn_cast(operand.getType())) { + FloatType toFpTy = cast(targetType.getElementType()); + auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); + Value from = operand; + if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { + from = b.create(toFpTy, from); + } + if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { + from = b.create(toFpTy, from); } - if (auto fromFloatType = dyn_cast(operand.getType())) { - if (toFloatType.getWidth() > fromFloatType.getWidth()) - return b.create(loc, toFloatType, operand); - if (toFloatType.getWidth() < fromFloatType.getWidth()) - return b.create(loc, toFloatType, operand); + Value zero = b.create( + mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); + return b.create(targetType, from, zero); + } + + if (auto fromIntType = dyn_cast(operand.getType())) { + FloatType toFpTy = cast(targetType.getElementType()); + Value from = operand; + if (isUnsigned) { + from = b.create(toFpTy, from); + } else { + from = b.create(toFpTy, from); } + Value zero = b.create( + mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); + return b.create(targetType, from, zero); + } + + return {}; +} + +Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, + Type toType, bool isUnsignedCast) { + if (operand.getType() == toType) + return operand; + ImplicitLocOpBuilder ib(loc, b); + Value result; + if (auto intTy = dyn_cast(toType)) { + result = convertToTargetInt(ib, operand, intTy, isUnsignedCast); + } else if (auto floatTy = dyn_cast(toType)) { + result = convertToTargetFp(ib, operand, floatTy, isUnsignedCast); + } else if (auto complexTy = dyn_cast(toType)) { + result = convertToTargetComplex(ib, operand, complexTy, isUnsignedCast); } + + if (result) + return result; + emitWarning(loc) << "could not cast operand of type " << operand.getType() << " to " << toType; return operand;