diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -339,6 +339,7 @@ template class OperationPass : public Pass { protected: OperationPass(TypeID passID) : Pass(passID, OpT::getOperationName()) {} + OperationPass(const OperationPass &) = default; /// Support isa/dyn_cast functionality. static bool classof(const Pass *pass) { @@ -371,6 +372,7 @@ template <> class OperationPass : public Pass { protected: OperationPass(TypeID passID) : Pass(passID) {} + OperationPass(const OperationPass &) = default; }; /// A model for providing function pass specific utilities. @@ -409,6 +411,7 @@ protected: PassWrapper() : BaseT(TypeID::get()) {} + PassWrapper(const PassWrapper &) = default; /// Returns the derived pass name. StringRef getName() const override { return llvm::getTypeName(); } diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -48,7 +48,7 @@ using Base = {0}Base; {0}Base() : {1}(::mlir::TypeID::get()) {{} - {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get()) {{} + {0}Base(const {0}Base &other) : {1}(other) {{} /// Returns the command-line argument attached to this pass. static constexpr ::llvm::StringLiteral getArgumentName() { diff --git a/mlir/unittests/TableGen/CMakeLists.txt b/mlir/unittests/TableGen/CMakeLists.txt --- a/mlir/unittests/TableGen/CMakeLists.txt +++ b/mlir/unittests/TableGen/CMakeLists.txt @@ -8,15 +8,21 @@ mlir_tablegen(StructAttrGenTest.cpp.inc -gen-struct-attr-defs) add_public_tablegen_target(MLIRTableGenStructAttrIncGen) +set(LLVM_TARGET_DEFINITIONS passes.td) +mlir_tablegen(PassGenTest.h.inc -gen-pass-decls -name TableGenTest) +add_public_tablegen_target(MLIRTableGenTestPassIncGen) + add_mlir_unittest(MLIRTableGenTests EnumsGenTest.cpp StructsGenTest.cpp FormatTest.cpp OpBuildGen.cpp + PassGenTest.cpp ) add_dependencies(MLIRTableGenTests MLIRTableGenEnumsIncGen) add_dependencies(MLIRTableGenTests MLIRTableGenStructAttrIncGen) +add_dependencies(MLIRTableGenTests MLIRTableGenTestPassIncGen) add_dependencies(MLIRTableGenTests MLIRTestDialect) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../test/lib/Dialect/Test) diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/TableGen/PassGenTest.cpp @@ -0,0 +1,48 @@ +//===- PassGenTest.cpp - TableGen PassGen Tests ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" + +#include "gmock/gmock.h" + +std::unique_ptr createTestPass(int v = 0); + +#define GEN_PASS_REGISTRATION +#include "PassGenTest.h.inc" + +#define GEN_PASS_CLASSES +#include "PassGenTest.h.inc" + +struct TestPass : public TestPassBase { + explicit TestPass(int v) : extraVal(v) {} + + void runOnOperation() override {} + + std::unique_ptr clone() const { + return TestPassBase::clone(); + } + + int extraVal; +}; + +std::unique_ptr createTestPass(int v) { + return std::make_unique(v); +} + +TEST(PassGenTest, PassClone) { + mlir::MLIRContext context; + + const auto unwrap = [](const std::unique_ptr &pass) { + return static_cast(pass.get()); + }; + + const auto origPass = createTestPass(10); + const auto clonePass = unwrap(origPass)->clone(); + + EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal); +} diff --git a/mlir/unittests/TableGen/passes.td b/mlir/unittests/TableGen/passes.td new file mode 100644 --- /dev/null +++ b/mlir/unittests/TableGen/passes.td @@ -0,0 +1,19 @@ +//===-- passes.td - PassGen test definition file -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +include "mlir/Pass/PassBase.td" +include "mlir/Pass/PassBase.td" +include "mlir/Rewrite/PassUtil.td" + +def TestPass : Pass<"test"> { + let summary = "Test pass"; + + let constructor = "::createTestPass()"; + + let options = RewritePassUtils.options; +}