diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -23,6 +23,7 @@ #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" @@ -440,6 +441,63 @@ OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { if (getInput().getType() == getType()) return getInput(); + + auto operand = operands[0].dyn_cast_or_null<ElementsAttr>(); + if (!operand) + return {}; + + auto inTy = getInput().getType().cast<ShapedType>(); + auto outTy = getType().cast<ShapedType>(); + auto inETy = inTy.getElementType(); + auto outETy = outTy.getElementType(); + + if (operand.isSplat()) { + if (inETy.isa<FloatType>() && outETy.isa<FloatType>()) { + bool overflow; + auto splatVal = operand.getSplatValue<APFloat>(); + auto &semantics = outETy.cast<FloatType>().getFloatSemantics(); + splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven, + &overflow); + return SplatElementsAttr::get(outTy, splatVal); + } + + if (inETy.isa<IntegerType>() && outETy.isa<FloatType>()) { + auto unsign = inETy.cast<IntegerType>().isUnsignedInteger(); + APFloat splatVal(outETy.cast<FloatType>().getFloatSemantics()); + splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign, + llvm::RoundingMode::NearestTiesToEven); + return SplatElementsAttr::get(outTy, splatVal); + } + + if (inETy.isa<FloatType>() && outETy.isa<IntegerType>()) { + auto unsign = outETy.cast<IntegerType>().isUnsignedInteger(); + auto intVal = + APSInt(outETy.cast<IntegerType>().getIntOrFloatBitWidth(), unsign); + auto floatVal = operand.getSplatValue<APFloat>(); + bool exact; + floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact); + return SplatElementsAttr::get(outTy, intVal); + } + + if (inETy.isa<IntegerType>() && outETy.isa<IntegerType>()) { + auto unsignIn = inETy.cast<IntegerType>().isUnsignedInteger(); + bool trunc = + inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); + auto intVal = operand.getSplatValue<APInt>(); + auto bitwidth = outETy.getIntOrFloatBitWidth(); + + if (trunc) { + intVal = intVal.trunc(bitwidth); + } else if (unsignIn) { + intVal = intVal.zext(bitwidth); + } else { + intVal = intVal.sext(bitwidth); + } + + return SplatElementsAttr::get(outTy, intVal); + } + } + return {}; } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -97,3 +97,58 @@ %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> } + +// ----- + +// CHECK: func.func @cast_float_to_float +func.func @cast_float_to_float() -> tensor<f16> { + %splat = "tosa.const"() {value = dense<42.0> : tensor<f32>} : () -> tensor<f32> + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.200000e+01> : tensor<f16>} : () -> tensor<f16> + %cast = "tosa.cast"(%splat) : (tensor<f32>) -> tensor<f16> + // CHECK: return %[[SPLAT]] + return %cast : tensor<f16> +} + +// ----- + +// CHECK: func.func @cast_int_to_float +func.func @cast_int_to_float() -> tensor<f16> { + %splat = "tosa.const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.000000e+00> : tensor<f16>} : () -> tensor<f16> + %cast = "tosa.cast"(%splat) : (tensor<i32>) -> tensor<f16> + // CHECK: return %[[SPLAT]] + return %cast : tensor<f16> +} + +// ----- + +// CHECK: func.func @cast_float_to_int +func.func @cast_float_to_int() -> tensor<i16> { + %splat = "tosa.const"() {value = dense<-4.0> : tensor<f32>} : () -> tensor<f32> + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-4> : tensor<i16>} : () -> tensor<i16> + %cast = "tosa.cast"(%splat) : (tensor<f32>) -> tensor<i16> + // CHECK: return %[[SPLAT]] + return %cast : tensor<i16> +} + +// ----- + +// CHECK: func.func @cast_int_to_int_trunc +func.func @cast_int_to_int_trunc() -> tensor<i16> { + %splat = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32> + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16> + %cast = "tosa.cast"(%splat) : (tensor<i32>) -> tensor<i16> + // CHECK: return %[[SPLAT]] + return %cast : tensor<i16> +} + +// ----- + +// CHECK: func.func @cast_int_to_int_sign +func.func @cast_int_to_int_sign() -> tensor<i32> { + %splat = "tosa.const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16> + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32> + %cast = "tosa.cast"(%splat) : (tensor<i16>) -> tensor<i32> + // CHECK: return %[[SPLAT]] + return %cast : tensor<i32> +} \ No newline at end of file