[mlir][tosa] Add TOSA f64 type support for const op
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
This patch fix torch.prim.NumToTensor.Scalar float64 support https://github.com/llvm/torch-mlir/pull/1802
This fix should have been added in
https://reviews.llvm.org/rGa2dcd994a7f8cc33640f58105276b78acf3483e5
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | ||
---|---|---|
1824 | This naming and comment around the def is confusing now. Could you update? |
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | ||
---|---|---|
1824 | How about now? I name all the tensor support f64 with Tosa_Tensor_64. |
This is definitely an improvement, I'm okay with this type, but it reads as only a 64-bit tensor. Perhaps Jacques is better at naming than I am, but I'd go something closer to Tosa_Tensor_Plus_F64, indicating that this is an extension of Tosa_Tensor with F64 (as opposed to I64).
Quick question @eric-k256 @AmosLewis,
Will it also fold it to be non-F64 ? i.e. tosa consumers are able to not to witness f64 when tosa is in the canonical form.
I think this should be explicit pass as it may be interprocedural. Consumers post the "verify for profile" (or runtime, forgot what it is called) should not be getting these. Either by way of it errors out during that verify pass or converted to f32 (or some such pass, which would be opt in and may change precision).
For me canonical would be a function of profile.
(which effectively means no not canonical, but one has sections with clear expectations and post "deployment" profile verification there should be none of these for consumers)
Thanks, sounds like a no for now and agree with interprocedural implications.
I was curious because https://reviews.llvm.org/D142599 sounded like F64 wasn't in any profile
and also there are folders for splat based consts (which might help a bit) that also have the same said implications.
You're right that F64 isn't in any of the TOSA profiles today, which is why I was discouraging it's use in the overall definition of Tosa_Tensor. It's something we look at adding, but it adds significant requirements to any profiles it goes into. It would be good to get a sense of what networks need F64 vs ending up with that type as default. Many systems that support f64 do it at a performance cost against f32, and some systems don't implement it at all. Ideally the tooling would have a way to guide developers to minimize their use of f64 to where it has a significant improvement on results justifying the extra computation cost.
I concur.
Hence, the question: should we also look to fold it as much as possible ?
I think the cases we have seen are just F64 constants being created (maybe due to default python behaviour?) and then immediately casted to lower bitwidths.
Exactly what this patch is doing in other words -- so its a step in the right direction.
I think splats are folded; so we are good there - it might cover most bases.
I was just curious whether extending the folding of tosa.cast to tosa.const operands to limit the appearance of f64 where possible -- was a good idea or not, while we are here.
Especially if we can catch them before outlining -- i.e. scenarios where we wouldn't hit the inter-procedural case.
This naming and comment around the def is confusing now. Could you update?