diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -661,7 +661,7 @@ /// The signature of the callback used to determine if an operation is /// dynamically legal on the target. - using DynamicLegalityCallbackFn = std::function; + using DynamicLegalityCallbackFn = std::function(Operation *)>; ConversionTarget(MLIRContext &ctx) : ctx(ctx) {} virtual ~ConversionTarget() = default; @@ -827,10 +827,10 @@ /// The set of information that configures the legalization of an operation. struct LegalizationInfo { /// The legality action this operation was given. - LegalizationAction action; + LegalizationAction action = LegalizationAction::Illegal; /// If some legal instances of this operation may also be recursively legal. - bool isRecursivelyLegal; + bool isRecursivelyLegal = false; /// The legality callback if this operation is dynamically legal. DynamicLegalityCallbackFn legalityFn; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2681,7 +2681,7 @@ /// Register a legality action for the given operation. void ConversionTarget::setOpAction(OperationName op, LegalizationAction action) { - legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr}; + legalOperations[op].action = action; } /// Register a legality action for the given dialects. @@ -2710,8 +2710,11 @@ // Returns true if this operation instance is known to be legal. auto isOpLegal = [&] { // Handle dynamic legality either with the provided legality function. - if (info->action == LegalizationAction::Dynamic) - return info->legalityFn(op); + if (info->action == LegalizationAction::Dynamic) { + Optional result = info->legalityFn(op); + if (result) + return *result; + } // Otherwise, the operation is only legal if it was marked 'Legal'. return info->action == LegalizationAction::Legal; @@ -2723,14 +2726,32 @@ LegalOpDetails legalityDetails; if (info->isRecursivelyLegal) { auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); - if (legalityFnIt != opRecursiveLegalityFns.end()) - legalityDetails.isRecursivelyLegal = legalityFnIt->second(op); - else + if (legalityFnIt != opRecursiveLegalityFns.end()) { + legalityDetails.isRecursivelyLegal = + legalityFnIt->second(op).getValueOr(true); + } else { legalityDetails.isRecursivelyLegal = true; + } } return legalityDetails; } +static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks( + ConversionTarget::DynamicLegalityCallbackFn oldCallback, + ConversionTarget::DynamicLegalityCallbackFn newCallback) { + if (!oldCallback) + return newCallback; + + auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)]( + Operation *op) -> Optional { + if (Optional result = newCl(op)) + return *result; + + return oldCl(op); + }; + return chain; +} + /// Set the dynamic legality callback for the given operation. void ConversionTarget::setLegalityCallback( OperationName name, const DynamicLegalityCallbackFn &callback) { @@ -2739,7 +2760,8 @@ assert(infoIt != legalOperations.end() && infoIt->second.action == LegalizationAction::Dynamic && "expected operation to already be marked as dynamically legal"); - infoIt->second.legalityFn = callback; + infoIt->second.legalityFn = + composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback); } /// Set the recursive legality callback for the given operation and mark the @@ -2752,7 +2774,8 @@ "expected operation to already be marked as legal"); infoIt->second.isRecursivelyLegal = true; if (callback) - opRecursiveLegalityFns[name] = callback; + opRecursiveLegalityFns[name] = composeLegalityCallbacks( + std::move(opRecursiveLegalityFns[name]), callback); else opRecursiveLegalityFns.erase(name); } @@ -2762,14 +2785,15 @@ ArrayRef dialects, const DynamicLegalityCallbackFn &callback) { assert(callback && "expected valid legality callback"); for (StringRef dialect : dialects) - dialectLegalityFns[dialect] = callback; + dialectLegalityFns[dialect] = composeLegalityCallbacks( + std::move(dialectLegalityFns[dialect]), callback); } /// Set the dynamic legality callback for the unknown ops. void ConversionTarget::setLegalityCallback( const DynamicLegalityCallbackFn &callback) { assert(callback && "expected valid legality callback"); - unknownLegalityFn = callback; + unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback); } /// Get the legalization information for the given operation. diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -12,3 +12,4 @@ add_subdirectory(Pass) add_subdirectory(Rewrite) add_subdirectory(TableGen) +add_subdirectory(Transforms) diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_unittest(MLIRTransformsTests + DialectConversion.cpp +) +target_link_libraries(MLIRTransformsTests + PRIVATE + MLIRTransforms) diff --git a/mlir/unittests/Transforms/DialectConversion.cpp b/mlir/unittests/Transforms/DialectConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Transforms/DialectConversion.cpp @@ -0,0 +1,90 @@ +//===- DialectConversion.cpp - Dialect conversion unit tests --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/DialectConversion.h" +#include "gtest/gtest.h" + +using namespace mlir; + +static Operation *createOp(MLIRContext *context) { + context->allowUnregisteredDialects(); + return Operation::create(UnknownLoc::get(context), + OperationName("foo.bar", context), llvm::None, + llvm::None, llvm::None, llvm::None, 0); +} + +namespace { +struct DummyOp { + static StringRef getOperationName() { return "foo.bar"; } +}; + +TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) { + MLIRContext context; + ConversionTarget target(context); + + int index = 0; + int callbackCalled1 = 0; + target.addDynamicallyLegalOp([&](Operation *) { + callbackCalled1 = ++index; + return true; + }); + + int callbackCalled2 = 0; + target.addDynamicallyLegalOp([&](Operation *) -> Optional { + callbackCalled2 = ++index; + return llvm::None; + }); + + auto *op = createOp(&context); + EXPECT_TRUE(target.isLegal(op)); + EXPECT_EQ(2, callbackCalled1); + EXPECT_EQ(1, callbackCalled2); + op->destroy(); +} + +TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) { + MLIRContext context; + ConversionTarget target(context); + + int index = 0; + int callbackCalled = 0; + target.addDynamicallyLegalOp([&](Operation *) -> Optional { + callbackCalled = ++index; + return llvm::None; + }); + + auto *op = createOp(&context); + EXPECT_FALSE(target.isLegal(op)); + EXPECT_EQ(1, callbackCalled); + op->destroy(); +} + +TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) { + MLIRContext context; + ConversionTarget target(context); + + int index = 0; + int callbackCalled1 = 0; + target.markUnknownOpDynamicallyLegal([&](Operation *) { + callbackCalled1 = ++index; + return true; + }); + + int callbackCalled2 = 0; + target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional { + callbackCalled2 = ++index; + return llvm::None; + }); + + auto *op = createOp(&context); + EXPECT_TRUE(target.isLegal(op)); + EXPECT_EQ(2, callbackCalled1); + EXPECT_EQ(1, callbackCalled2); + op->destroy(); +} +} // namespace