diff --git a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Interfaces) +add_subdirectory(TransformOps) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS VectorTransformOps.td) +mlir_tablegen(VectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(VectorTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRVectorTransformOpsIncGen) + +add_mlir_doc(VectorTransformOps VectorTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h @@ -0,0 +1,37 @@ +//===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H +#define MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace vector { +class VectorOp; +} // namespace vector +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace vector { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace vector +} // namespace mlir + +#endif // MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -0,0 +1,19 @@ +//===- VectorTransformOps.td - Vector transform ops --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef VECTOR_TRANSFORM_OPS +#define VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformEffects.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +#endif // VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -66,6 +66,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Dialect.h" @@ -115,12 +116,13 @@ // clang-format on // Register all dialect extensions. + affine::registerTransformDialectExtension(registry); bufferization::registerTransformDialectExtension(registry); + gpu::registerTransformDialectExtension(registry); linalg::registerTransformDialectExtension(registry); memref::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); - gpu::registerTransformDialectExtension(registry); - affine::registerTransformDialectExtension(registry); + vector::registerTransformDialectExtension(registry); // Register all external models. arith::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(IR) add_subdirectory(Interfaces) add_subdirectory(Transforms) +add_subdirectory(TransformOps) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRVectorTransformOps + VectorTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/TransformOps + + DEPENDS + MLIRVectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRVectorDialect + MLIRVectorTransforms + MLIRParser + MLIRPDLDialect + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRVectorDialect + ) diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -0,0 +1,51 @@ +//===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" + +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::transform; + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +/// Registers new ops and declares PDL as dependent dialect since the additional +/// ops are using PDL types for operands and results. +class VectorTransformDialectExtension + : public transform::TransformDialectExtension< + VectorTransformDialectExtension> { +public: + VectorTransformDialectExtension() { + declareDependentDialect(); + declareDependentDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" + +void mlir::vector::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3349,6 +3349,32 @@ ], ) +cc_library( + name = "VectorTransformOps", + srcs = [ + "lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h", + ], + includes = ["include"], + deps = [ + ":AffineDialect", + ":ArithDialect", + ":AsmParser", + ":IR", + ":VectorDialect", + ":VectorTransformOpsIncGen", + ":VectorTransforms", + ":PDLDialect", + ":Parser", + ":SideEffectInterfaces", + ":TransformDialect", + ":TransformUtils", + "//llvm:Support", + ], +) + gentbl_cc_library( name = "VectorPassIncGen", strip_include_prefix = "include", @@ -6798,6 +6824,7 @@ ":VectorToLLVM", ":VectorToSCF", ":VectorToSPIRV", + ":VectorTransformOps", ":VectorTransforms", ":X86VectorDialect", ":X86VectorTransforms", @@ -8373,6 +8400,19 @@ ], ) +td_library( + name = "VectorTransformOpsTdFiles", + srcs = [ + "include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td", + ], + includes = ["include"], + deps = [ + ":PDLDialectTdFiles", + ":SCFTdFiles", + ":TransformDialectTdFiles", + ], +) + gentbl_cc_library( name = "MaskableOpInterfaceIncGen", strip_include_prefix = "include", @@ -8461,6 +8501,26 @@ deps = [":VectorOpsTdFiles"], ) +gentbl_cc_library( + name = "VectorTransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td", + deps = [ + ":VectorTransformOpsTdFiles", + ], +) + cc_library( name = "MaskableOpInterface", srcs = ["lib/Dialect/Vector/Interfaces/MaskableOpInterface.cpp"],