diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc" diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -22,6 +22,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -17,7 +17,6 @@ #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -99,7 +99,7 @@ //===----------------------------------------------------------------------===// def UnrealizedConversionCastOp : Builtin_Op<"unrealized_conversion_cast", [ - DeclareOpInterfaceMethods, Pure + Pure ]> { let summary = "An unrealized conversion from one set of types to another"; let description = [{ @@ -141,6 +141,7 @@ ($inputs^ `:` type($inputs))? `to` type($outputs) attr-dict }]; let hasFolder = 1; + let hasVerifier = 1; } #endif // BUILTIN_OPS diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -2104,21 +2104,6 @@ friend InterfaceBase; }; -//===----------------------------------------------------------------------===// -// CastOpInterface utilities -//===----------------------------------------------------------------------===// - -// These functions are out-of-line implementations of the methods in -// CastOpInterface, which avoids them being template instantiated/duplicated. -namespace impl { -/// Attempt to fold the given cast operation. -LogicalResult foldCastInterfaceOp(Operation *op, - ArrayRef attrOperands, - SmallVectorImpl &foldResults); -/// Attempt to verify the given cast operation. -LogicalResult verifyCastInterfaceOp( - Operation *op, function_ref areCastCompatible); -} // namespace impl } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -82,6 +82,7 @@ #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" namespace mlir { @@ -145,6 +146,7 @@ arith::registerValueBoundsOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); + builtin::registerCastOpInterfaceExternalModels(registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); linalg::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.h b/mlir/include/mlir/Interfaces/CastInterfaces.h --- a/mlir/include/mlir/Interfaces/CastInterfaces.h +++ b/mlir/include/mlir/Interfaces/CastInterfaces.h @@ -16,6 +16,24 @@ #include "mlir/IR/OpDefinition.h" +namespace mlir { +class DialectRegistry; + +namespace impl { +/// Attempt to fold the given cast operation. +LogicalResult foldCastInterfaceOp(Operation *op, + ArrayRef attrOperands, + SmallVectorImpl &foldResults); + +/// Attempt to verify the given cast operation. +LogicalResult verifyCastInterfaceOp(Operation *op); +} // namespace impl + +namespace builtin { +void registerCastOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace builtin +} // namespace mlir + /// Include the generated interface declarations. #include "mlir/Interfaces/CastInterfaces.h.inc" diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.td b/mlir/include/mlir/Interfaces/CastInterfaces.td --- a/mlir/include/mlir/Interfaces/CastInterfaces.td +++ b/mlir/include/mlir/Interfaces/CastInterfaces.td @@ -44,7 +44,7 @@ } }]; let verify = [{ - return impl::verifyCastInterfaceOp($_op, ConcreteOp::areCastCompatible); + return impl::verifyCastInterfaceOp($_op); }]; } diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -22,6 +22,7 @@ MLIRArithOpsInterfacesIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRDialect MLIRInferIntRangeCommon MLIRInferIntRangeInterface diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -12,6 +12,7 @@ MLIRTransformInterfacesIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRIR MLIRParser MLIRRewrite diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -217,10 +217,12 @@ return success(); } -bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs, - TypeRange outputs) { - // `UnrealizedConversionCastOp` is agnostic of the input/output types. - return true; +LogicalResult UnrealizedConversionCastOp::verify() { + // TODO: The verifier of external models is not called. This op verifier can + // be removed when that is fixed. + if (getNumResults() == 0) + return emitOpError() << "expected at least one result for cast operation"; + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1275,52 +1275,6 @@ op->hasTrait() && op->hasTrait(); } -//===----------------------------------------------------------------------===// -// CastOpInterface -//===----------------------------------------------------------------------===// - -/// Attempt to fold the given cast operation. -LogicalResult -impl::foldCastInterfaceOp(Operation *op, ArrayRef attrOperands, - SmallVectorImpl &foldResults) { - OperandRange operands = op->getOperands(); - if (operands.empty()) - return failure(); - ResultRange results = op->getResults(); - - // Check for the case where the input and output types match 1-1. - if (operands.getTypes() == results.getTypes()) { - foldResults.append(operands.begin(), operands.end()); - return success(); - } - - return failure(); -} - -/// Attempt to verify the given cast operation. -LogicalResult impl::verifyCastInterfaceOp( - Operation *op, function_ref areCastCompatible) { - auto resultTypes = op->getResultTypes(); - if (resultTypes.empty()) - return op->emitOpError() - << "expected at least one result for cast operation"; - - auto operandTypes = op->getOperandTypes(); - if (!areCastCompatible(operandTypes, resultTypes)) { - InFlightDiagnostic diag = op->emitOpError("operand type"); - if (operandTypes.empty()) - diag << "s []"; - else if (llvm::size(operandTypes) == 1) - diag << " " << *operandTypes.begin(); - else - diag << "s " << operandTypes; - return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") - << resultTypes << " are cast incompatible"; - } - - return success(); -} - //===----------------------------------------------------------------------===// // Misc. utils //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/CastInterfaces.cpp b/mlir/lib/Interfaces/CastInterfaces.cpp --- a/mlir/lib/Interfaces/CastInterfaces.cpp +++ b/mlir/lib/Interfaces/CastInterfaces.cpp @@ -8,8 +8,83 @@ #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" + using namespace mlir; +//===----------------------------------------------------------------------===// +// Helper functions for CastOpInterface +//===----------------------------------------------------------------------===// + +/// Attempt to fold the given cast operation. +LogicalResult +impl::foldCastInterfaceOp(Operation *op, ArrayRef attrOperands, + SmallVectorImpl &foldResults) { + OperandRange operands = op->getOperands(); + if (operands.empty()) + return failure(); + ResultRange results = op->getResults(); + + // Check for the case where the input and output types match 1-1. + if (operands.getTypes() == results.getTypes()) { + foldResults.append(operands.begin(), operands.end()); + return success(); + } + + return failure(); +} + +/// Attempt to verify the given cast operation. +LogicalResult impl::verifyCastInterfaceOp(Operation *op) { + auto resultTypes = op->getResultTypes(); + if (resultTypes.empty()) + return op->emitOpError() + << "expected at least one result for cast operation"; + + auto operandTypes = op->getOperandTypes(); + if (!cast(op).areCastCompatible(operandTypes, resultTypes)) { + InFlightDiagnostic diag = op->emitOpError("operand type"); + if (operandTypes.empty()) + diag << "s []"; + else if (llvm::size(operandTypes) == 1) + diag << " " << *operandTypes.begin(); + else + diag << "s " << operandTypes; + return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") + << resultTypes << " are cast incompatible"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// External model for BuiltinDialect ops +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace { +// This interface cannot be implemented directly on the op because the IR build +// unit cannot depend on the Interfaces build unit. +struct UnrealizedConversionCastOpInterface + : CastOpInterface::ExternalModel { + static bool areCastCompatible(TypeRange inputs, TypeRange outputs) { + // `UnrealizedConversionCastOp` is agnostic of the input/output types. + return true; + } +}; +} // namespace +} // namespace mlir + +void mlir::builtin::registerCastOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { + UnrealizedConversionCastOp::attachInterface< + UnrealizedConversionCastOpInterface>(*ctx); + }); +} + //===----------------------------------------------------------------------===// // Table-generated class definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Tools/PDLL/CMakeLists.txt b/mlir/test/lib/Tools/PDLL/CMakeLists.txt --- a/mlir/test/lib/Tools/PDLL/CMakeLists.txt +++ b/mlir/test/lib/Tools/PDLL/CMakeLists.txt @@ -20,6 +20,7 @@ MLIRTestPDLLPatternsIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRIR MLIRPass MLIRSupport diff --git a/mlir/test/lib/Tools/PDLL/TestPDLL.cpp b/mlir/test/lib/Tools/PDLL/TestPDLL.cpp --- a/mlir/test/lib/Tools/PDLL/TestPDLL.cpp +++ b/mlir/test/lib/Tools/PDLL/TestPDLL.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -53,6 +53,7 @@ ${test_libs} MLIRAffineAnalysis MLIRAnalysis + MLIRCastInterfaces MLIRDialect MLIROptLib MLIRParser diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -304,7 +304,6 @@ ]) + [ "include/mlir/Bytecode/BytecodeImplementation.h", "include/mlir/Interfaces/CallInterfaces.h", - "include/mlir/Interfaces/CastInterfaces.h", "include/mlir/Interfaces/DataLayoutInterfaces.h", "include/mlir/Interfaces/FoldInterfaces.h", "include/mlir/Interfaces/SideEffectInterfaces.h", @@ -320,7 +319,6 @@ ":BuiltinTypeInterfacesIncGen", ":BuiltinTypesIncGen", ":CallOpInterfacesIncGen", - ":CastOpInterfacesIncGen", ":DataLayoutInterfacesIncGen", ":FunctionInterfacesIncGen", ":InferTypeOpInterfaceIncGen", @@ -2865,6 +2863,7 @@ includes = ["include"], deps = [ ":BytecodeOpInterface", + ":CastInterfaces", ":EmitCAttributesIncGen", ":EmitCOpsIncGen", ":IR", @@ -3454,6 +3453,7 @@ deps = [ ":ArithDialect", ":BytecodeOpInterface", + ":CastInterfaces", ":ControlFlowInterfaces", ":Dialect", ":FuncDialect", @@ -3647,7 +3647,7 @@ ":ArithDialect", ":BytecodeOpInterface", ":CallOpInterfaces", - ":CastOpInterfaces", + ":CastInterfaces", ":CommonFolders", ":ControlFlowDialect", ":ControlFlowInterfaces", @@ -5867,7 +5867,7 @@ ":ArithDialect", ":ArithUtils", ":BytecodeOpInterface", - ":CastOpInterfaces", + ":CastInterfaces", ":ComplexDialect", ":ControlFlowInterfaces", ":DestinationStyleOpInterface", @@ -6874,7 +6874,7 @@ ) gentbl_cc_library( - name = "CastOpInterfacesIncGen", + name = "CastInterfacesIncGen", strip_include_prefix = "include", tbl_outs = [ ( @@ -6892,12 +6892,12 @@ ) cc_library( - name = "CastOpInterfaces", + name = "CastInterfaces", srcs = ["lib/Interfaces/CastInterfaces.cpp"], hdrs = ["include/mlir/Interfaces/CastInterfaces.h"], includes = ["include"], deps = [ - ":CastOpInterfacesIncGen", + ":CastInterfacesIncGen", ":IR", "//llvm:Support", ], @@ -7538,6 +7538,7 @@ ":BufferizationDialect", ":BufferizationTransformOps", ":BufferizationTransforms", + ":CastInterfaces", ":ComplexDialect", ":ComplexToLLVM", ":ComplexToLibm", @@ -8698,6 +8699,7 @@ includes = ["include"], deps = [ ":BytecodeOpInterface", + ":CastInterfaces", ":IR", ":IndexEnumsIncGen", ":IndexOpsIncGen", @@ -10024,6 +10026,7 @@ ":Analysis", ":BytecodeOpInterface", ":CallOpInterfaces", + ":CastInterfaces", ":ControlFlowInterfaces", ":IR", ":Rewrite", @@ -10501,6 +10504,7 @@ ":ArithOpsIncGen", ":ArithOpsInterfacesIncGen", ":BytecodeOpInterface", + ":CastInterfaces", ":CommonFolders", ":IR", ":InferIntRangeCommon", @@ -10795,6 +10799,7 @@ ":ArithDialect", ":ArithUtils", ":BytecodeOpInterface", + ":CastInterfaces", ":ComplexDialect", ":ControlFlowInterfaces", ":CopyOpInterface", diff --git a/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch4/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch4/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch4/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch4/BUILD.bazel @@ -96,7 +96,7 @@ ":ToyOpsIncGen", "//llvm:Support", "//mlir:Analysis", - "//mlir:CastOpInterfaces", + "//mlir:CastInterfaces", "//mlir:IR", "//mlir:Parser", "//mlir:Pass", diff --git a/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch5/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch5/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch5/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch5/BUILD.bazel @@ -101,7 +101,7 @@ "//mlir:AllPassesAndDialects", "//mlir:Analysis", "//mlir:ArithDialect", - "//mlir:CastOpInterfaces", + "//mlir:CastInterfaces", "//mlir:FuncDialect", "//mlir:IR", "//mlir:MemRefDialect", diff --git a/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch6/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch6/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch6/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch6/BUILD.bazel @@ -106,7 +106,7 @@ "//mlir:ArithDialect", "//mlir:ArithToLLVM", "//mlir:BuiltinToLLVMIRTranslation", - "//mlir:CastOpInterfaces", + "//mlir:CastInterfaces", "//mlir:ControlFlowToLLVM", "//mlir:ExecutionEngine", "//mlir:ExecutionEngineUtils", diff --git a/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch7/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch7/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch7/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/examples/toy/Ch7/BUILD.bazel @@ -106,7 +106,7 @@ "//mlir:ArithDialect", "//mlir:ArithToLLVM", "//mlir:BuiltinToLLVMIRTranslation", - "//mlir:CastOpInterfaces", + "//mlir:CastInterfaces", "//mlir:ControlFlowToLLVM", "//mlir:ExecutionEngine", "//mlir:ExecutionEngineUtils", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -962,6 +962,7 @@ deps = [ ":TestDialect", ":TestPDLLPatternsIncGen", + "//mlir:CastInterfaces", "//mlir:IR", "//mlir:PDLDialect", "//mlir:PDLInterpDialect",