diff --git a/mlir/test/python/dialects/ext_test_helper.py b/mlir/test/python/dialects/ext_test_helper.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/ext_test_helper.py @@ -0,0 +1,23 @@ +import inspect + + +def print_gen_init_arg_names(cls: type): + """ + Prints the argument names of the __init__ function of the generated base + class of a mix-in extension. + + This is useful for ensuring that mix-in classes remain in sync with their + underlying op. If operands are added or existing ones are renamed in that + op, the same change needs to occur in the mix-in class. However, in many + cases, these changes do not otherwise break the tests of the mix-ins, such + that the mix-in classes become out of sync. This function allows to print + the argument names of the generated base class, which can then be checked + for in a `CHECK` statement. + + If you are modifying a tablegen definition and that modification breaks a + test that uses this function, please update the corresponding mix-in class + in the `*_ext.py` of the dialect. + """ + mro = cls.__mro__[2] + assert mro.__module__.endswith("_gen") + print(list(inspect.signature(mro.__init__).parameters.keys())) diff --git a/mlir/test/python/dialects/transform_gpu_ext.py b/mlir/test/python/dialects/transform_gpu_ext.py --- a/mlir/test/python/dialects/transform_gpu_ext.py +++ b/mlir/test/python/dialects/transform_gpu_ext.py @@ -4,8 +4,11 @@ from mlir.dialects import transform from mlir.dialects.transform import gpu +from ext_test_helper import print_gen_init_arg_names + def run(f): + print("\nTEST:", f.__name__) with Context(), Location.unknown(): module = Module.create() with InsertionPoint(module.body): @@ -17,7 +20,6 @@ with InsertionPoint(sequence.body): f(sequence.bodyTarget) transform.YieldOp() - print("\nTEST:", f.__name__) print(module) return f @@ -77,3 +79,10 @@ # CHECK-SAME: block_dims = [4, 2] # CHECK-SAME: sync_after_distribute = false # CHECK-SAME: warp_size = 64 + + +@run +def testMapNestedForallToThreadsGenArgs(_): + print_gen_init_arg_names(gpu.MapNestedForallToThreads) + # CHECK-LABEL: TEST: testMapNestedForallToThreadsGenArgs + # CHECK: ['self', 'result', 'target', 'block_dims', 'sync_after_distribute', 'warp_size', 'loc', 'ip']