diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/DialectExtension.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/DialectExtension.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/DialectExtension.h @@ -0,0 +1,15 @@ +//===- DialectExtension.h - Linalg transform dialect extension --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h @@ -0,0 +1,48 @@ +//===- LinalgMatchOps.h - Linalg transform matcher ops ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H +#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" + +namespace mlir { +namespace transform { + +namespace detail { +LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, + Value structuredOpHandle); +} // namespace detail + +template +class StructuredOpPredicateOpTrait + : public OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert( + OpTy::template hasTrait(), + "StructuredOpPredicateOpTrait requires SingleOpMatcherOpTrait"); + + return detail::verifyStructuredOpPredicateOpTrait( + op, cast(op).getOperandHandle()); + } +}; + +} // namespace transform +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Linalg Matcher Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc" + +#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -11,7 +11,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" -#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" @@ -55,30 +54,7 @@ ArrayRef mixedTileSizes, std::optional mapping, SmallVector &tileOps, SmallVector &tiledOps); -namespace detail { -LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, - Value structuredOpHandle); -} // namespace detail - -template -class StructuredOpPredicateOpTrait - : public OpTrait::TraitBase { -public: - static LogicalResult verifyTrait(Operation *op) { - static_assert( - OpTy::template hasTrait(), - "StructuredOpPredicateOpTrait requires SingleOpMatcherOpTrait"); - - return detail::verifyStructuredOpPredicateOpTrait( - op, cast(op).getOperandHandle()); - } -}; - } // namespace transform - -namespace linalg { -void registerTransformDialectExtension(DialectRegistry ®istry); -} // namespace linalg } // namespace mlir //===----------------------------------------------------------------------===// @@ -90,7 +66,4 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc" -#define GET_OP_CLASSES -#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc" - #endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -42,7 +42,7 @@ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRLinalgTransformOps + DialectExtension.cpp LinalgMatchOps.cpp LinalgTransformOps.cpp diff --git a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp @@ -0,0 +1,59 @@ +//===- DialectExtension.cpp - Linalg transform dialect extension ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +using namespace mlir; + +namespace { +/// Registers new ops and declares PDL as dependent dialect since the +/// additional ops are using PDL types for operands and results. +class LinalgTransformDialectExtension + : public transform::TransformDialectExtension< + LinalgTransformDialectExtension> { +public: + using Base::Base; + + void init() { + declareDependentDialect(); + declareDependentDialect(); + + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" + >(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" + >(); + } +}; +} // namespace + +void mlir::linalg::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/FunctionImplementation.h" diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/AsmParser/AsmParser.h" + #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -3231,46 +3232,7 @@ return DiagnosedSilenceableFailure::success(); } -//===----------------------------------------------------------------------===// -// Transform op registration -//===----------------------------------------------------------------------===// - -namespace { -/// Registers new ops and declares PDL as dependent dialect since the -/// additional ops are using PDL types for operands and results. -class LinalgTransformDialectExtension - : public transform::TransformDialectExtension< - LinalgTransformDialectExtension> { -public: - using Base::Base; - - void init() { - declareDependentDialect(); - declareDependentDialect(); - declareGeneratedDialect(); - declareGeneratedDialect(); - declareGeneratedDialect(); - declareGeneratedDialect(); - declareGeneratedDialect(); - - registerTransformOps< -#define GET_OP_LIST -#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" - >(); - registerTransformOps< -#define GET_OP_LIST -#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" - >(); - } -}; -} // namespace - #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" - -void mlir::linalg::registerTransformDialectExtension( - DialectRegistry ®istry) { - registry.addExtensions(); -} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8700,9 +8700,9 @@ srcs = glob([ "lib/Dialect/Linalg/TransformOps/*.cpp", ]), - hdrs = [ - "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h", - ], + hdrs = glob([ + "include/mlir/Dialect/Linalg/TransformOps/*.h", + ]), includes = ["include"], deps = [ ":AffineDialect",