diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -863,6 +863,14 @@ LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target, const FrozenRewritePatternList &patterns, DenseSet &convertedOps); + +/// Report whether the given operation, and all nested operations, are legal as +/// specified by the give ConversionTarget. Returns failure and emits error +/// diagnostics if any operations are not legal as well as a summary of the +/// illegal operations. +LogicalResult verifyAllOperationsAreLegal(Operation *op, + ConversionTarget &target); + } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -13,9 +13,11 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/FunctionSupport.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" @@ -2784,3 +2786,37 @@ return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, convertedOps); } + +static void emitLegalizationErrors(Location loc, + const DenseSet &illegalOps) { + // Print op errors for each of the illegal ops that still remain. + llvm::MapVector opNameCounts; + for (Operation *illegalOp : illegalOps) { + StringRef opName = illegalOp->getName().getStringRef(); + opNameCounts[opName]++; + illegalOp->emitOpError() << ": illegal op still exists"; + } + + std::vector errorMessages; + errorMessages.reserve(opNameCounts.size()); + for (const auto &opInfo : opNameCounts) { + errorMessages.push_back( + llvm::formatv("\t{0} (count: {1})", opInfo.first, opInfo.second)); + } + emitError(loc) << "The following illegal operations still remain: \n" + << llvm::join(errorMessages, "\n") << "\n"; +} + +LogicalResult mlir::verifyAllOperationsAreLegal(Operation *op, + ConversionTarget &target) { + DenseSet illegalOps; + op->walk([&](Operation *op) { + if (!target.isLegal(op)) { + illegalOps.insert(op); + } + }); + if (illegalOps.empty()) + return success(); + emitLegalizationErrors(op->getLoc(), illegalOps); + return failure(); +} diff --git a/mlir/test/Transforms/test-verify-fully-converted.mlir b/mlir/test/Transforms/test-verify-fully-converted.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-verify-fully-converted.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -split-input-file -test-verify-fully-converted -verify-diagnostics -allow-unregistered-dialect + +// ----- + +// expected-error@+1 {{The following illegal operations still remain}} +module { + func @f() -> () { + // expected-error@+1 {{'test.some_op' op : illegal op still exists}} + %0 = "test.some_op"() : () -> tensor + // expected-error@+1 {{'test.other_op' op : illegal op still exists}} + %1 = "test.other_op"() : () -> tensor + return + } +} + + diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -32,6 +32,7 @@ TestSCFUtils.cpp TestSparsification.cpp TestVectorTransforms.cpp + TestVerifyFullyConverted.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestVerifyFullyConverted.cpp b/mlir/test/lib/Transforms/TestVerifyFullyConverted.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestVerifyFullyConverted.cpp @@ -0,0 +1,38 @@ +//===- TestVerifyFullyConverted.cpp - Test VectorTransfers lowering ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +class TestVerifyFullyConvertedPass + : public PassWrapper> { +public: + void runOnOperation() override { + ConversionTarget target(getContext()); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + target.addIllegalDialect(); + if (failed(verifyAllOperationsAreLegal(getOperation(), target))) { + return signalPassFailure(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerVerifyFullyConverted() { + PassRegistration( + "test-verify-fully-converted", + "Test verification that no illegal ops remain"); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -56,6 +56,7 @@ namespace test { void registerConvertCallOpPass(); void registerInliner(); +void registerVerifyFullyConverted(); void registerMemRefBoundCheck(); void registerPatternsTestPass(); void registerSimpleParametricTilingPass(); @@ -125,6 +126,7 @@ test::registerConvertCallOpPass(); test::registerInliner(); + test::registerVerifyFullyConverted(); test::registerMemRefBoundCheck(); test::registerPatternsTestPass(); test::registerSimpleParametricTilingPass();