diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -847,7 +847,8 @@ StringRef className = def->getValueAsString("cppClassName"); StringRef cppNamespace = def->getValueAsString("cppNamespace"); std::string codeBlock = - llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) + llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));", + cppNamespace, className) .str(); if (def->isSubClassOf("OpInterface")) { @@ -892,8 +893,9 @@ // Format the condition template. tblgen::FmtContext fmtContext; fmtContext.withSelf("self"); - std::string codeBlock = - tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); + std::string codeBlock = tblgen::tgfmt( + "return ::mlir::success(" + constraint.getConditionTemplate() + ");", + &fmtContext); return createODSNativePDLLConstraintDecl( constraint.getUniqueDefName(), codeBlock, loc, type); diff --git a/mlir/test/lib/Tools/PDLL/TestPDLL.pdll b/mlir/test/lib/Tools/PDLL/TestPDLL.pdll --- a/mlir/test/lib/Tools/PDLL/TestPDLL.pdll +++ b/mlir/test/lib/Tools/PDLL/TestPDLL.pdll @@ -7,6 +7,10 @@ //===----------------------------------------------------------------------===// #include "TestOps.td" +#include "mlir/Interfaces/CastInterfaces.td" /// A simple pattern that matches and replaces an operation. Pattern TestSimplePattern => replace op with op; + +// Test the import of interfaces. +Pattern TestInterface => replace _: CastOpInterface with op; diff --git a/mlir/test/mlir-pdll/Integration/test-pdll.mlir b/mlir/test/mlir-pdll/Integration/test-pdll.mlir --- a/mlir/test/mlir-pdll/Integration/test-pdll.mlir +++ b/mlir/test/mlir-pdll/Integration/test-pdll.mlir @@ -6,3 +6,12 @@ "test.simple"() : () -> () return } + +// CHECK-LABEL: func @testImportedInterface +func @testImportedInterface() { + // CHECK: test.non_cast + // CHECK: test.success + "test.non_cast"() : () -> () + "builtin.unrealized_conversion_cast"() : () -> (i1) + return +} diff --git a/mlir/test/mlir-pdll/Parser/include_td.pdll b/mlir/test/mlir-pdll/Parser/include_td.pdll --- a/mlir/test/mlir-pdll/Parser/include_td.pdll +++ b/mlir/test/mlir-pdll/Parser/include_td.pdll @@ -32,20 +32,20 @@ // CHECK-NEXT: CppClass: ::mlir::IntegerType // CHECK-NEXT: } -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints` // CHECK: `-AttrConstraintDecl -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints` // CHECK: `-OpConstraintDecl // CHECK: `-OpNameDecl -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints`