diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -440,6 +441,63 @@
 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
   if (getInput().getType() == getType())
     return getInput();
+
+  auto operand = operands[0].dyn_cast_or_null<ElementsAttr>();
+  if (!operand)
+    return {};
+
+  auto inTy = getInput().getType().cast<ShapedType>();
+  auto outTy = getType().cast<ShapedType>();
+  auto inETy = inTy.getElementType();
+  auto outETy = outTy.getElementType();
+
+  if (operand.isSplat()) {
+    if (inETy.isa<FloatType>() && outETy.isa<FloatType>()) {
+      bool overflow;
+      auto splatVal = operand.getSplatValue<APFloat>();
+      auto &semantics = outETy.cast<FloatType>().getFloatSemantics();
+      splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
+                       &overflow);
+      return SplatElementsAttr::get(outTy, splatVal);
+    }
+
+    if (inETy.isa<IntegerType>() && outETy.isa<FloatType>()) {
+      auto unsign = inETy.cast<IntegerType>().isUnsignedInteger();
+      APFloat splatVal(outETy.cast<FloatType>().getFloatSemantics());
+      splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
+                                llvm::RoundingMode::NearestTiesToEven);
+      return SplatElementsAttr::get(outTy, splatVal);
+    }
+
+    if (inETy.isa<FloatType>() && outETy.isa<IntegerType>()) {
+      auto unsign = outETy.cast<IntegerType>().isUnsignedInteger();
+      auto intVal =
+          APSInt(outETy.cast<IntegerType>().getIntOrFloatBitWidth(), unsign);
+      auto floatVal = operand.getSplatValue<APFloat>();
+      bool exact;
+      floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
+      return SplatElementsAttr::get(outTy, intVal);
+    }
+
+    if (inETy.isa<IntegerType>() && outETy.isa<IntegerType>()) {
+      auto unsignIn = inETy.cast<IntegerType>().isUnsignedInteger();
+      bool trunc =
+          inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
+      auto intVal = operand.getSplatValue<APInt>();
+      auto bitwidth = outETy.getIntOrFloatBitWidth();
+
+      if (trunc) {
+        intVal = intVal.trunc(bitwidth);
+      } else if (unsignIn) {
+        intVal = intVal.zext(bitwidth);
+      } else {
+        intVal = intVal.sext(bitwidth);
+      }
+
+      return SplatElementsAttr::get(outTy, intVal);
+    }
+  }
+
   return {};
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -97,3 +97,58 @@
   %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
   return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
 }
+
+// -----
+
+// CHECK: func.func @cast_float_to_float
+func.func @cast_float_to_float() -> tensor<f16> {
+  %splat = "tosa.const"() {value = dense<42.0> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.200000e+01> : tensor<f16>} : () -> tensor<f16>
+  %cast = "tosa.cast"(%splat) : (tensor<f32>) -> tensor<f16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<f16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_float
+func.func @cast_int_to_float() -> tensor<f16> {
+  %splat = "tosa.const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.000000e+00> : tensor<f16>} : () -> tensor<f16>
+  %cast = "tosa.cast"(%splat) : (tensor<i32>) -> tensor<f16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<f16>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_int
+func.func @cast_float_to_int() -> tensor<i16> {
+  %splat = "tosa.const"() {value = dense<-4.0> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-4> : tensor<i16>} : () -> tensor<i16>
+  %cast = "tosa.cast"(%splat) : (tensor<f32>) -> tensor<i16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_int_trunc
+func.func @cast_int_to_int_trunc() -> tensor<i16> {
+  %splat = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16>
+  %cast = "tosa.cast"(%splat) : (tensor<i32>) -> tensor<i16>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_int_sign
+func.func @cast_int_to_int_sign() -> tensor<i32> {
+  %splat = "tosa.const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  %cast = "tosa.cast"(%splat) : (tensor<i16>) -> tensor<i32>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i32>
+}
\ No newline at end of file