diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1705,14 +1705,16 @@ float to signed 16 float int16 signed 8 to float int8 float signed 16 to float int16 float + float 32 to float 64 float32 float64 + float 64 to float 32 float64 float32 }]; let arguments = (ins - Tosa_Tensor:$input + Tosa_Tensor_Cast:$input ); let results = (outs - Tosa_Tensor:$output + Tosa_Tensor_Cast:$output ); let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -95,6 +95,9 @@ //===----------------------------------------------------------------------===// def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], "number">; +// Add F64 type support for just for tosa::CastOp +def Tosa_AnyNumber_Cast : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64], + "number_cast">; //===----------------------------------------------------------------------===// // Tensor types @@ -105,6 +108,7 @@ // Either ranked or unranked tensor of TOSA supported element types. def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor_Cast : TensorOf<[Tosa_AnyNumber_Cast]>; // Must be ranked but no further constraints def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -58,6 +58,9 @@ getElementTypeOrSelf(operand).isa()) { return signalPassFailure(); } + if (getElementTypeOrSelf(operand).isF64()) { + return signalPassFailure(); + } } }); }