diff --git a/mlir/include/mlir/Transforms/VerifyFullyConverted.h b/mlir/include/mlir/Transforms/VerifyFullyConverted.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/VerifyFullyConverted.h @@ -0,0 +1,67 @@ +//===- VerifyFullyConverted.h - Conversion verification pass ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines a base class for a generic pass that verifies all +// operations conform to the specification of a ConversionTarget. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_VERIFY_FULLY_CONVERTED_H +#define MLIR_TRANSFORMS_VERIFY_FULLY_CONVERTED_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +template +class VerifyFullyConvertedPass + : public PassWrapper> { +public: + void runOnOperation() override { + DenseSet illegalOps; + this->getOperation()->walk([&](Operation *op) { + if (!this->getConversionTarget().isLegal(op)) { + illegalOps.insert(op); + } + }); + if (!illegalOps.empty()) { + emitLegalizationErrors(illegalOps); + return this->signalPassFailure(); + } + } + +protected: + virtual ConversionTarget getConversionTarget() = 0; + +private: + void emitLegalizationErrors(const DenseSet &illegalOps) { + // Print op errors for each of the illegal ops that still remain. + llvm::MapVector opNameCounts; + for (Operation *illegalOp : illegalOps) { + StringRef opName = illegalOp->getName().getStringRef(); + opNameCounts[opName]++; + illegalOp->emitOpError() << ": illegal op still exists"; + } + + std::vector errorMessages; + errorMessages.reserve(opNameCounts.size()); + for (const auto &opInfo : opNameCounts) { + errorMessages.push_back( + llvm::formatv("\t{0} (count: {1})", opInfo.first, opInfo.second)); + } + emitError(this->getOperation()->getLoc()) + << "The following illegal operations still remain: \n" + << llvm::join(errorMessages, "\n") << "\n"; + } +}; + +#endif // MLIR_TRANSFORMS_VERIFY_FULLY_CONVERTED_H diff --git a/mlir/test/Transforms/test-verify-fully-converted.mlir b/mlir/test/Transforms/test-verify-fully-converted.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-verify-fully-converted.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -split-input-file -test-verify-fully-converted -verify-diagnostics -allow-unregistered-dialect + +// ----- + +// expected-error@+1 {{The following illegal operations still remain}} +module { + func @f() -> () { + // expected-error@+1 {{'test.some_op' op : illegal op still exists}} + %0 = "test.some_op"() : () -> tensor + // expected-error@+1 {{'test.other_op' op : illegal op still exists}} + %1 = "test.other_op"() : () -> tensor + return + } +} + + diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -32,6 +32,7 @@ TestSCFUtils.cpp TestSparsification.cpp TestVectorTransforms.cpp + TestVerifyFullyConverted.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestVerifyFullyConverted.cpp b/mlir/test/lib/Transforms/TestVerifyFullyConverted.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestVerifyFullyConverted.cpp @@ -0,0 +1,29 @@ +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/VerifyFullyConverted.h" + +using namespace mlir; + +namespace { +class TestVerifyFullyConvertedPass + : public VerifyFullyConvertedPass { +protected: + ConversionTarget getConversionTarget() override { + ConversionTarget target(getContext()); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + target.addIllegalDialect(); + return target; + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerVerifyFullyConverted() { + PassRegistration( + "test-verify-fully-converted", + "Test verification that no illegal ops remain"); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -56,6 +56,7 @@ namespace test { void registerConvertCallOpPass(); void registerInliner(); +void registerVerifyFullyConverted(); void registerMemRefBoundCheck(); void registerPatternsTestPass(); void registerSimpleParametricTilingPass(); @@ -125,6 +126,7 @@ test::registerConvertCallOpPass(); test::registerInliner(); + test::registerVerifyFullyConverted(); test::registerMemRefBoundCheck(); test::registerPatternsTestPass(); test::registerSimpleParametricTilingPass();