diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -793,9 +793,19 @@ /// If the given operation instance is legal on this target, a structure /// containing legality information is returned. If the operation is not - /// legal, None is returned. + /// legal, None is returned. Also returns None is operation legality wasn't + /// registered by user or dynamic legality callbacks returned None. + /// + /// Note: Legality is actually a 4-state: Legal(recursive=true), + /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated + /// either as Legal or Illegal depending on context. Optional isLegal(Operation *op) const; + /// Returns true is operation instance is illegal on this target. Returns + /// false if operation is legal, operation legality wasn't registered by user + /// or dynamic legality callbacks returned None. + bool isIllegal(Operation *op) const; + private: /// Set the dynamic legality callback for the given operation. void setLegalityCallback(OperationName name, diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1816,14 +1816,7 @@ } bool OperationLegalizer::isIllegal(Operation *op) const { - // Check if the target explicitly marked this operation as illegal. - if (auto info = target.getOpAction(op->getName())) { - if (*info == LegalizationAction::Dynamic) - return !target.isLegal(op); - return *info == LegalizationAction::Illegal; - } - - return false; + return target.isIllegal(op); } LogicalResult @@ -3137,6 +3130,22 @@ return legalityDetails; } +bool ConversionTarget::isIllegal(Operation *op) const { + Optional info = getOpInfo(op->getName()); + if (!info) + return false; + + if (info->action == LegalizationAction::Dynamic) { + Optional result = info->legalityFn(op); + if (!result) + return false; + + return !(*result); + } + + return info->action == LegalizationAction::Illegal; +} + static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks( ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback) { diff --git a/mlir/unittests/Transforms/DialectConversion.cpp b/mlir/unittests/Transforms/DialectConversion.cpp --- a/mlir/unittests/Transforms/DialectConversion.cpp +++ b/mlir/unittests/Transforms/DialectConversion.cpp @@ -44,6 +44,9 @@ EXPECT_TRUE(target.isLegal(op)); EXPECT_EQ(2, callbackCalled1); EXPECT_EQ(1, callbackCalled2); + EXPECT_FALSE(target.isIllegal(op)); + EXPECT_EQ(4, callbackCalled1); + EXPECT_EQ(3, callbackCalled2); op->destroy(); } @@ -61,6 +64,8 @@ auto *op = createOp(&context); EXPECT_FALSE(target.isLegal(op)); EXPECT_EQ(1, callbackCalled); + EXPECT_FALSE(target.isIllegal(op)); + EXPECT_EQ(2, callbackCalled); op->destroy(); } @@ -85,6 +90,43 @@ EXPECT_TRUE(target.isLegal(op)); EXPECT_EQ(2, callbackCalled1); EXPECT_EQ(1, callbackCalled2); + EXPECT_FALSE(target.isIllegal(op)); + EXPECT_EQ(4, callbackCalled1); + EXPECT_EQ(3, callbackCalled2); + op->destroy(); +} + +TEST(DialectConversionTest, DynamicallyLegalReturnNone) { + MLIRContext context; + ConversionTarget target(context); + + target.addDynamicallyLegalOp( + [&](Operation *) -> Optional { return llvm::None; }); + + auto *op = createOp(&context); + EXPECT_FALSE(target.isLegal(op)); + EXPECT_FALSE(target.isIllegal(op)); + + EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {}))); + EXPECT_TRUE(failed(applyFullConversion(op, target, {}))); + + op->destroy(); +} + +TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) { + MLIRContext context; + ConversionTarget target(context); + + target.markUnknownOpDynamicallyLegal( + [&](Operation *) -> Optional { return llvm::None; }); + + auto *op = createOp(&context); + EXPECT_FALSE(target.isLegal(op)); + EXPECT_FALSE(target.isIllegal(op)); + + EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {}))); + EXPECT_TRUE(failed(applyFullConversion(op, target, {}))); + op->destroy(); } } // namespace