diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -81,7 +81,6 @@ let assemblyFormat = "attr-dict"; } -// TODO: evolve lowering_strategy to proper enums. def ApplyLowerContractionPatternsOp : Op]> { @@ -143,7 +142,6 @@ let assemblyFormat = "attr-dict"; } -// TODO: evolve lowering_strategy to proper enums. def ApplyLowerMultiReductionPatternsOp : Op]> { @@ -232,7 +230,6 @@ }]; } -// TODO: evolve lowering_strategy to proper enums. def ApplyLowerTransposePatternsOp : Op]> { @@ -259,7 +256,6 @@ }]; } -// TODO: evolve split_transfer_strategy to proper enums. def ApplySplitTransferFullPartialPatternsOp : Op]> { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -192,6 +192,24 @@ DIALECT_NAME transform EXTENSION_NAME structured_transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/VectorTransformOps.td + SOURCES + dialects/transform/vector.py + DIALECT_NAME transform + EXTENSION_NAME vector_transform) + +set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td") +mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings) +add_public_tablegen_target(MLIRVectorTransformPyEnumGen) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.vector_transform.enum_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform + SOURCES "dialects/_vector_transform_enum_gen.py" ) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/VectorTransformOps.td b/mlir/python/mlir/dialects/VectorTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/VectorTransformOps.td @@ -0,0 +1,19 @@ +//===-- VectorTransformOps.td ------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the vector transform ops. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_VECTORTRANSFORMOPS +#define PYTHON_BINDINGS_VECTORTRANSFORMOPS + +include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.td" + +#endif // PYTHON_BINDINGS_VECTORTRANSFORMOPS diff --git a/mlir/python/mlir/dialects/transform/vector.py b/mlir/python/mlir/dialects/transform/vector.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/vector.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._vector_transform_enum_gen import * +from .._vector_transform_ops_gen import * diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -0,0 +1,153 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import vector + + +def run_apply_patterns(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + apply = transform.ApplyPatternsOp(sequence.bodyTarget) + with InsertionPoint(apply.patterns): + f() + transform.YieldOp() + print("\nTEST:", f.__name__) + print(module) + return f + + +@run_apply_patterns +def non_configurable_patterns(): + # CHECK-LABEL: TEST: non_configurable_patterns + # CHECK: apply_patterns + # CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim + vector.ApplyCastAwayVectorLeadingOneDimPatternsOp() + # CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns + vector.ApplyRankReducingSubviewPatternsOp() + # CHECK: transform.apply_patterns.vector.transfer_permutation_patterns + vector.ApplyTransferPermutationPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_broadcast + vector.ApplyLowerBroadcastPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_masks + vector.ApplyLowerMasksPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_masked_transfers + vector.ApplyLowerMaskedTransfersPatternsOp() + # CHECK: transform.apply_patterns.vector.materialize_masks + vector.ApplyMaterializeMasksPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_outerproduct + vector.ApplyLowerOuterProductPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_gather + vector.ApplyLowerGatherPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_scan + vector.ApplyLowerScanPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_shape_cast + vector.ApplyLowerShapeCastPatternsOp() + + +@run_apply_patterns +def configurable_patterns(): + # CHECK-LABEL: TEST: configurable_patterns + # CHECK: apply_patterns + # CHECK: transform.apply_patterns.vector.lower_transfer + # CHECK-SAME: max_transfer_rank = 4 + vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4) + # CHECK: transform.apply_patterns.vector.transfer_to_scf + # CHECK-SAME: max_transfer_rank = 3 + # CHECK-SAME: full_unroll = true + vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True) + + +@run_apply_patterns +def enum_configurable_patterns(): + # CHECK: transform.apply_patterns.vector.lower_contraction + vector.ApplyLowerContractionPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_contraction + # CHECK-SAME: lowering_strategy = matmulintrinsics + vector.ApplyLowerContractionPatternsOp( + lowering_strategy=vector.VectorContractLowering.MATMUL + ) + # CHECK: transform.apply_patterns.vector.lower_contraction + # CHECK-SAME: lowering_strategy = parallelarith + vector.ApplyLowerContractionPatternsOp( + lowering_strategy=vector.VectorContractLowering.PARALLEL_ARITH + ) + + # CHECK: transform.apply_patterns.vector.lower_multi_reduction + vector.ApplyLowerMultiReductionPatternsOp() + # CHECK: transform.apply_patterns.vector.lower_multi_reduction + # This is the default mode, not printed. + vector.ApplyLowerMultiReductionPatternsOp( + lowering_strategy=vector.VectorMultiReductionLowering.INNER_PARALLEL + ) + # CHECK: transform.apply_patterns.vector.lower_multi_reduction + # CHECK-SAME: lowering_strategy = innerreduction + vector.ApplyLowerMultiReductionPatternsOp( + lowering_strategy=vector.VectorMultiReductionLowering.INNER_REDUCTION + ) + + # CHECK: transform.apply_patterns.vector.lower_transpose + # CHECK-SAME: lowering_strategy = eltwise + # CHECK-SAME: avx2_lowering_strategy = false + vector.ApplyLowerTransposePatternsOp() + # CHECK: transform.apply_patterns.vector.lower_transpose + # CHECK-SAME: lowering_strategy = eltwise + # CHECK-SAME: avx2_lowering_strategy = false + vector.ApplyLowerTransposePatternsOp( + lowering_strategy=vector.VectorTransposeLowering.ELT_WISE + ) + # CHECK: transform.apply_patterns.vector.lower_transpose + # CHECK-SAME: lowering_strategy = flat_transpose + # CHECK-SAME: avx2_lowering_strategy = false + vector.ApplyLowerTransposePatternsOp( + lowering_strategy=vector.VectorTransposeLowering.FLAT + ) + # CHECK: transform.apply_patterns.vector.lower_transpose + # CHECK-SAME: lowering_strategy = shuffle_1d + # CHECK-SAME: avx2_lowering_strategy = false + vector.ApplyLowerTransposePatternsOp( + lowering_strategy=vector.VectorTransposeLowering.SHUFFLE1_D + ) + # CHECK: transform.apply_patterns.vector.lower_transpose + # CHECK-SAME: lowering_strategy = shuffle_16x16 + # CHECK-SAME: avx2_lowering_strategy = false + vector.ApplyLowerTransposePatternsOp( + lowering_strategy=vector.VectorTransposeLowering.SHUFFLE16X16 + ) + # CHECK: transform.apply_patterns.vector.lower_transpose + # CHECK-SAME: lowering_strategy = flat_transpose + # CHECK-SAME: avx2_lowering_strategy = true + vector.ApplyLowerTransposePatternsOp( + lowering_strategy=vector.VectorTransposeLowering.FLAT, + avx2_lowering_strategy=True, + ) + + # CHECK: transform.apply_patterns.vector.split_transfer_full_partial + vector.ApplySplitTransferFullPartialPatternsOp() + # CHECK: transform.apply_patterns.vector.split_transfer_full_partial + # CHECK-SAME: split_transfer_strategy = none + vector.ApplySplitTransferFullPartialPatternsOp( + split_transfer_strategy=vector.VectorTransferSplit.NONE + ) + # CHECK: transform.apply_patterns.vector.split_transfer_full_partial + # CHECK-SAME: split_transfer_strategy = "vector-transfer" + vector.ApplySplitTransferFullPartialPatternsOp( + split_transfer_strategy=vector.VectorTransferSplit.VECTOR_TRANSFER + ) + # CHECK: transform.apply_patterns.vector.split_transfer_full_partial + # This is the default mode, not printed. + vector.ApplySplitTransferFullPartialPatternsOp( + split_transfer_strategy=vector.VectorTransferSplit.LINALG_COPY + ) + # CHECK: transform.apply_patterns.vector.split_transfer_full_partial + # CHECK-SAME: split_transfer_strategy = "force-in-bounds" + vector.ApplySplitTransferFullPartialPatternsOp( + split_transfer_strategy=vector.VectorTransferSplit.FORCE_IN_BOUNDS + ) diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -902,6 +902,45 @@ ], ) +gentbl_filegroup( + name = "VectorTransformEnumPyGen", + tbl_outs = [ + ( + ["-gen-python-enum-bindings"], + "mlir/dialects/_vector_transform_enum_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/VectorTransformOps.td", + deps = [ + "//mlir:OpBaseTdFiles", + "//mlir:TransformDialectTdFiles", + "//mlir:VectorTransformOpsTdFiles", + ], +) + +gentbl_filegroup( + name = "VectorTransformOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + "-dialect-extension=vector_transform", + ], + "mlir/dialects/_vector_transform_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/VectorTransformOps.td", + deps = [ + "//mlir:OpBaseTdFiles", + "//mlir:TransformDialectTdFiles", + "//mlir:VectorTransformOpsTdFiles", + ], +) + + filegroup( name = "TransformOpsPyFiles", srcs = [ @@ -919,6 +958,8 @@ ":StructuredTransformOpsPyGen", ":TransformEnumPyGen", ":TransformOpsPyGen", + ":VectorTransformEnumPyGen", + ":VectorTransformOpsPyGen", ], )