diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -23,6 +23,7 @@ #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" @@ -687,6 +688,63 @@ OpFoldResult CastOp::fold(ArrayRef operands) { if (getInput().getType() == getType()) return getInput(); + + auto operand = operands[0].dyn_cast_or_null(); + if (!operand) + return {}; + + auto inTy = getInput().getType().cast(); + auto outTy = getType().cast(); + auto inETy = inTy.getElementType(); + auto outETy = outTy.getElementType(); + + if (operand.isSplat()) { + if (inETy.isa() && outETy.isa()) { + bool overflow; + auto splatVal = operand.getSplatValue(); + auto &semantics = outETy.cast().getFloatSemantics(); + splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven, + &overflow); + return SplatElementsAttr::get(outTy, splatVal); + } + + if (inETy.isa() && outETy.isa()) { + auto unsign = inETy.cast().isUnsignedInteger(); + APFloat splatVal(outETy.cast().getFloatSemantics()); + splatVal.convertFromAPInt(operand.getSplatValue(), !unsign, + llvm::RoundingMode::NearestTiesToEven); + return SplatElementsAttr::get(outTy, splatVal); + } + + if (inETy.isa() && outETy.isa()) { + auto unsign = outETy.cast().isUnsignedInteger(); + auto intVal = + APSInt(outETy.cast().getIntOrFloatBitWidth(), unsign); + auto floatVal = operand.getSplatValue(); + bool exact; + floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact); + return SplatElementsAttr::get(outTy, intVal); + } + + if (inETy.isa() && outETy.isa()) { + auto unsignIn = inETy.cast().isUnsignedInteger(); + bool trunc = + inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); + auto intVal = operand.getSplatValue(); + auto bitwidth = outETy.getIntOrFloatBitWidth(); + + if (trunc) { + intVal = intVal.trunc(bitwidth); + } else if (unsignIn) { + intVal = intVal.zext(bitwidth); + } else { + intVal = intVal.sext(bitwidth); + } + + return SplatElementsAttr::get(outTy, intVal); + } + } + return {}; } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -427,3 +427,58 @@ // CHECK: return %[[SLICE]] return %slice : tensor<1x1xi32> } + +// ----- + +// CHECK: func.func @cast_float_to_float +func.func @cast_float_to_float() -> tensor { + %splat = "tosa.const"() {value = dense<42.0> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.200000e+01> : tensor} : () -> tensor + %cast = "tosa.cast"(%splat) : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_float +func.func @cast_int_to_float() -> tensor { + %splat = "tosa.const"() {value = dense<4> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.000000e+00> : tensor} : () -> tensor + %cast = "tosa.cast"(%splat) : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_float_to_int +func.func @cast_float_to_int() -> tensor { + %splat = "tosa.const"() {value = dense<-4.0> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-4> : tensor} : () -> tensor + %cast = "tosa.cast"(%splat) : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_int_trunc +func.func @cast_int_to_int_trunc() -> tensor { + %splat = "tosa.const"() {value = dense<-1> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor} : () -> tensor + %cast = "tosa.cast"(%splat) : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_int_sign +func.func @cast_int_to_int_sign() -> tensor { + %splat = "tosa.const"() {value = dense<-1> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor} : () -> tensor + %cast = "tosa.cast"(%splat) : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +}