diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1746,11 +1746,11 @@ }]; let arguments = (ins - Tosa_Tensor_Cast:$input + Tosa_Tensor_Plus_F64:$input ); let results = (outs - Tosa_Tensor_Cast:$output + Tosa_Tensor_Plus_F64:$output ); let hasFolder = 1; @@ -1821,7 +1821,7 @@ ); let results = (outs - Tosa_Tensor:$output + Tosa_Tensor_Plus_F64:$output ); let hasFolder = 1; } diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -95,9 +95,9 @@ //===----------------------------------------------------------------------===// def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], "number">; -// Add F64 type support for just for tosa::CastOp -def Tosa_AnyNumber_Cast : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64], - "number_cast">; +// Add F64 type support just for tosa::CastOp and tosa::ConstOp +def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64], + "number_plus_f64">; //===----------------------------------------------------------------------===// // Tensor types @@ -108,7 +108,7 @@ // Either ranked or unranked tensor of TOSA supported element types. def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; -def Tosa_Tensor_Cast : TensorOf<[Tosa_AnyNumber_Cast]>; +def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>; // Must be ranked but no further constraints def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>; diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -7,6 +7,13 @@ return %0 : tensor<4xi32> } +// CHECK-LABEL: func @test_const_i64 +func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { + // CHECK: "tosa.const" + %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + return %0 : tensor<4xi64> +} + // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor) { // CHECK: "tosa.equal"