diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py @@ -20,8 +20,8 @@ result_type: Type, target: Union[Operation, OpView, Value], *, - grid_dims: Optional[Sequence[int]] = None, - generate_gpu_launch: Optional[bool] = None, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, loc=None, ip=None ): @@ -32,8 +32,8 @@ self, target: Union[Operation, OpView, Value], *, - grid_dims: Optional[Sequence[int]] = None, - generate_gpu_launch: Optional[bool] = None, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, loc=None, ip=None ): @@ -44,8 +44,8 @@ result_type_or_target: Union[Operation, OpView, Type, Value], target_or_none: Optional[Union[Operation, OpView, Value]] = None, *, - grid_dims: Optional[Sequence[int]] = None, - generate_gpu_launch: Optional[bool] = None, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, loc=None, ip=None ): @@ -56,9 +56,6 @@ result_type = transform.AnyOpType.get() target = result_type_or_target - if grid_dims is not None and not isinstance(grid_dims, ArrayAttr): - grid_dims = DenseI64ArrayAttr.get(grid_dims) - super().__init__( result_type, target, @@ -67,3 +64,61 @@ loc=loc, ip=ip, ) + + +class MapNestedForallToThreads: + """Specialization for MapNestedForallToThreads class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Value, Type], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + block_dims: Optional[Union[Sequence[int], Attribute]] = None, + warp_size: Optional[Union[Sequence[int], Attribute]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = result_type_or_target.type + target = result_type_or_target + super().__init__( + result_type, + target, + block_dims=block_dims, + warp_size=warp_size, + sync_after_distribute=sync_after_distribute, + loc=loc, + ip=ip, + ) diff --git a/mlir/test/python/dialects/transform_gpu_ext.py b/mlir/test/python/dialects/transform_gpu_ext.py --- a/mlir/test/python/dialects/transform_gpu_ext.py +++ b/mlir/test/python/dialects/transform_gpu_ext.py @@ -9,20 +9,22 @@ with Context(), Location.unknown(): module = Module.create() with InsertionPoint(module.body): - print("\nTEST:", f.__name__) - f() + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + f(sequence.bodyTarget) + transform.YieldOp() + print("\nTEST:", f.__name__) print(module) return f @run -def testMapForallToBlocksCompact(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - gpu.MapForallToBlocks(sequence.bodyTarget) - transform.YieldOp() +def testMapForallToBlocksCompact(target): + gpu.MapForallToBlocks(target) # CHECK-LABEL: TEST: testMapForallToBlocksCompact # CHECK: = transform.gpu.map_forall_to_blocks # CHECK-NOT: grid_dims @@ -31,29 +33,47 @@ @run -def testMapForallToBlocksTyped(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - gpu.MapForallToBlocks( - transform.OperationType.get("test.dummy"), sequence.bodyTarget - ) - transform.YieldOp() +def testMapForallToBlocksTyped(target): + gpu.MapForallToBlocks(transform.OperationType.get("test.dummy"), target) # CHECK-LABEL: TEST: testMapForallToBlocksTyped # CHECK: = transform.gpu.map_forall_to_blocks # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy"> @run -def testMapForallToBlocksGridDims(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - gpu.MapForallToBlocks(sequence.bodyTarget, grid_dims=[4, 2]) - transform.YieldOp() +def testMapForallToBlocksGridDims(target): + gpu.MapForallToBlocks(target, grid_dims=[4, 2]) # CHECK-LABEL: TEST: testMapForallToBlocksGridDims # CHECK: = transform.gpu.map_forall_to_blocks # CHECK-SAME: grid_dims = [4, 2] # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +def testMapNestedForallToThreadsCompact(target): + gpu.MapNestedForallToThreads(target) + # CHECK-LABEL: TEST: testMapNestedForallToThreadsCompact + # CHECK: transform.gpu.map_nested_forall_to_threads + # CHECK-SAME: block_dims = [] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +def testMapNestedForallToThreadsTyped(target): + gpu.MapNestedForallToThreads(transform.OperationType.get("test.dummy"), target) + # CHECK-LABEL: TEST: testMapNestedForallToThreadsTyped + # CHECK: transform.gpu.map_nested_forall_to_threads + # CHECK-SAME: block_dims = [] + # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy"> + + +@run +def testMapNestedForallToThreadsAttributes(target): + gpu.MapNestedForallToThreads( + target, block_dims=[4, 2], warp_size=64, sync_after_distribute=False + ) + # CHECK-LABEL: TEST: testMapNestedForallToThreadsAttributes + # CHECK: transform.gpu.map_nested_forall_to_threads + # CHECK-SAME: block_dims = [4, 2] + # CHECK-SAME: sync_after_distribute = false + # CHECK-SAME: warp_size = 64