diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1710,11 +1710,11 @@ }]; let arguments = (ins - Tosa_Tensor_Cast:$input + Tosa_Tensor_F64:$input ); let results = (outs - Tosa_Tensor_Cast:$output + Tosa_Tensor_F64:$output ); let hasFolder = 1; @@ -1785,7 +1785,7 @@ ); let results = (outs - Tosa_Tensor:$output + Tosa_Tensor_F64:$output ); let hasFolder = 1; } diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -95,8 +95,8 @@ //===----------------------------------------------------------------------===// def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], "number">; -// Add F64 type support for just for tosa::CastOp -def Tosa_AnyNumber_Cast : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64], +// Add F64 type support for just for tosa::CastOp and tosa::Const +def Tosa_AnyNumber_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64], "number_cast">; //===----------------------------------------------------------------------===// @@ -108,7 +108,7 @@ // Either ranked or unranked tensor of TOSA supported element types. def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; -def Tosa_Tensor_Cast : TensorOf<[Tosa_AnyNumber_Cast]>; +def Tosa_Tensor_F64 : TensorOf<[Tosa_AnyNumber_F64]>; // Must be ranked but no further constraints def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;