Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
I found that there is a pooling_nhwc_sum_poly in LinalgNamedStructuredOps.yaml, but not pooling_nhwc_min/max_poly. I'm not sure if pooling min/max ops can be added to yaml file, so I still add 3D poolings to the tc file.
Let me know if you'd like me to add them to yaml file. I'm happy to learn it. :D
I just landed https://reviews.llvm.org/D105203 which contains max pooling operation but did not yet add support for the min operation. I hope I can come up with a revision implementing the min operator today.
Also note that pooling_nhwc_sum_poly are test operations I added. Feel free to extend modify them so that they fit rest. E.g. you may want another naming scheme and replace poly with something else.
Thanks @gysit, I added them to yaml files instead, PTAL!
I will migrate the 2D pooling ops to yaml files and replace all the uses in MHLO/IREE repo in a later patch.
Great!
Did you also edit llvm-project/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py? The YAML file is machine generated and shouldn't be edited directly. Instead one can edit core_named_ops.py and follow the procedure discussed by the basic usage section of https://mlir.llvm.org/docs/Dialects/Linalg/OpDSL/.
The OpDSL code to generate the 3D pooling operations should look as follows:
@linalg_structured_op def pooling_ndhwc_sum( I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SD, S.SH, S.SW), dilations=AttributeDef(S.DD, S.DH, S.DW)): """Performs 3D sum pooling. Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] += cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @linalg_structured_op def pooling_ndhwc_max( I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SD, S.SH, S.SW), dilations=AttributeDef(S.DD, S.DH, S.DW)): """Performs 3D max pooling. Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op def pooling_ndhwc_min( I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SD, S.SH, S.SW), dilations=AttributeDef(S.DD, S.DH, S.DW)): """Performs 3D max pooling. Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
PS there are already some 2d pooling op examples in core_named_ops.py. I currently added the postfix _poly to avoid a name collisions with the existing ops. I think once the existing ops are removed we can just drop the _poly postfix? I also added some initial integration tests for these ops. I don't think we should test all ops there but I just wanted to make you aware of these tests in case you want to verify things:
end-to-end integration test:
llvm-project/mlir/test/python/integration/dialects/linalg/opsrun.py
tests to verify the code generated from the yaml specification:
llvm-project/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
Thanks! I added the defs to core_named_ops.py and copy-paste the result the yaml file! Also moved the mlir roundtrip tests from name-ops to the poly ops file.