diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -16,6 +16,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc" diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -21,6 +21,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -17,7 +17,6 @@ #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -99,7 +99,7 @@ //===----------------------------------------------------------------------===// def UnrealizedConversionCastOp : Builtin_Op<"unrealized_conversion_cast", [ - DeclareOpInterfaceMethods, Pure + Pure ]> { let summary = "An unrealized conversion from one set of types to another"; let description = [{ @@ -141,6 +141,7 @@ ($inputs^ `:` type($inputs))? `to` type($outputs) attr-dict }]; let hasFolder = 1; + let hasVerifier = 1; } #endif // BUILTIN_OPS diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1930,21 +1930,6 @@ friend InterfaceBase; }; -//===----------------------------------------------------------------------===// -// CastOpInterface utilities -//===----------------------------------------------------------------------===// - -// These functions are out-of-line implementations of the methods in -// CastOpInterface, which avoids them being template instantiated/duplicated. -namespace impl { -/// Attempt to fold the given cast operation. -LogicalResult foldCastInterfaceOp(Operation *op, - ArrayRef attrOperands, - SmallVectorImpl &foldResults); -/// Attempt to verify the given cast operation. -LogicalResult verifyCastInterfaceOp( - Operation *op, function_ref areCastCompatible); -} // namespace impl } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -72,6 +72,7 @@ #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" namespace mlir { @@ -141,6 +142,7 @@ tensor::registerInferTypeOpInterfaceExternalModels(registry); tensor::registerTilingInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); + registerCastOpInterfaceExternalModels(registry); } /// Append all the MLIR dialects to the registry contained in the given context. diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.h b/mlir/include/mlir/Interfaces/CastInterfaces.h --- a/mlir/include/mlir/Interfaces/CastInterfaces.h +++ b/mlir/include/mlir/Interfaces/CastInterfaces.h @@ -16,6 +16,21 @@ #include "mlir/IR/OpDefinition.h" +namespace mlir { +class DialectRegistry; + +namespace impl { +/// Attempt to fold the given cast operation. +LogicalResult foldCastInterfaceOp(Operation *op, + ArrayRef attrOperands, + SmallVectorImpl &foldResults); +/// Attempt to verify the given cast operation. +LogicalResult verifyCastInterfaceOp(Operation *op); +} // namespace impl + +void registerCastOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace mlir + /// Include the generated interface declarations. #include "mlir/Interfaces/CastInterfaces.h.inc" diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.td b/mlir/include/mlir/Interfaces/CastInterfaces.td --- a/mlir/include/mlir/Interfaces/CastInterfaces.td +++ b/mlir/include/mlir/Interfaces/CastInterfaces.td @@ -44,7 +44,7 @@ } }]; let verify = [{ - return impl::verifyCastInterfaceOp($_op, ConcreteOp::areCastCompatible); + return impl::verifyCastInterfaceOp($_op); }]; } diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -15,6 +15,7 @@ MLIRArithOpsInterfacesIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRDialect MLIRInferIntRangeCommon MLIRInferIntRangeInterface diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -9,6 +9,7 @@ MLIRTransformInterfacesIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRIR MLIRParser MLIRPDLDialect diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -217,10 +217,12 @@ return success(); } -bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs, - TypeRange outputs) { - // `UnrealizedConversionCastOp` is agnostic of the input/output types. - return true; +LogicalResult UnrealizedConversionCastOp::verify() { + // TODO: The verifier of external models is not called. This op verifier can + // be removed when that is fixed. + if (getNumResults() == 0) + return emitOpError() << "expected at least one result for cast operation"; + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1149,52 +1149,6 @@ op->hasTrait() && op->hasTrait(); } -//===----------------------------------------------------------------------===// -// CastOpInterface -//===----------------------------------------------------------------------===// - -/// Attempt to fold the given cast operation. -LogicalResult -impl::foldCastInterfaceOp(Operation *op, ArrayRef attrOperands, - SmallVectorImpl &foldResults) { - OperandRange operands = op->getOperands(); - if (operands.empty()) - return failure(); - ResultRange results = op->getResults(); - - // Check for the case where the input and output types match 1-1. - if (operands.getTypes() == results.getTypes()) { - foldResults.append(operands.begin(), operands.end()); - return success(); - } - - return failure(); -} - -/// Attempt to verify the given cast operation. -LogicalResult impl::verifyCastInterfaceOp( - Operation *op, function_ref areCastCompatible) { - auto resultTypes = op->getResultTypes(); - if (resultTypes.empty()) - return op->emitOpError() - << "expected at least one result for cast operation"; - - auto operandTypes = op->getOperandTypes(); - if (!areCastCompatible(operandTypes, resultTypes)) { - InFlightDiagnostic diag = op->emitOpError("operand type"); - if (operandTypes.empty()) - diag << "s []"; - else if (llvm::size(operandTypes) == 1) - diag << " " << *operandTypes.begin(); - else - diag << "s " << operandTypes; - return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") - << resultTypes << " are cast incompatible"; - } - - return success(); -} - //===----------------------------------------------------------------------===// // Misc. utils //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/CastInterfaces.cpp b/mlir/lib/Interfaces/CastInterfaces.cpp --- a/mlir/lib/Interfaces/CastInterfaces.cpp +++ b/mlir/lib/Interfaces/CastInterfaces.cpp @@ -8,8 +8,82 @@ #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" + using namespace mlir; +//===----------------------------------------------------------------------===// +// Helper functions for CastOpInterface +//===----------------------------------------------------------------------===// + +/// Attempt to fold the given cast operation. +LogicalResult +impl::foldCastInterfaceOp(Operation *op, ArrayRef attrOperands, + SmallVectorImpl &foldResults) { + OperandRange operands = op->getOperands(); + if (operands.empty()) + return failure(); + ResultRange results = op->getResults(); + + // Check for the case where the input and output types match 1-1. + if (operands.getTypes() == results.getTypes()) { + foldResults.append(operands.begin(), operands.end()); + return success(); + } + + return failure(); +} + +/// Attempt to verify the given cast operation. +LogicalResult impl::verifyCastInterfaceOp(Operation *op) { + auto resultTypes = op->getResultTypes(); + if (resultTypes.empty()) + return op->emitOpError() + << "expected at least one result for cast operation"; + + auto operandTypes = op->getOperandTypes(); + if (!cast(op).areCastCompatible(operandTypes, resultTypes)) { + InFlightDiagnostic diag = op->emitOpError("operand type"); + if (operandTypes.empty()) + diag << "s []"; + else if (llvm::size(operandTypes) == 1) + diag << " " << *operandTypes.begin(); + else + diag << "s " << operandTypes; + return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") + << resultTypes << " are cast incompatible"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// External model for BuiltinDialect ops +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace { +// This interface cannot be implemented directly on the op because the IR build +// unit cannot depend on the Interfaces build unit. +struct UnrealizedConversionCastOpInterface + : CastOpInterface::ExternalModel { + static bool areCastCompatible(TypeRange inputs, TypeRange outputs) { + // `UnrealizedConversionCastOp` is agnostic of the input/output types. + return true; + } +}; +} // namespace +} // namespace mlir + +void mlir::registerCastOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { + UnrealizedConversionCastOp::attachInterface< + UnrealizedConversionCastOpInterface>(*ctx); + }); +} + //===----------------------------------------------------------------------===// // Table-generated class definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Tools/PDLL/CMakeLists.txt b/mlir/test/lib/Tools/PDLL/CMakeLists.txt --- a/mlir/test/lib/Tools/PDLL/CMakeLists.txt +++ b/mlir/test/lib/Tools/PDLL/CMakeLists.txt @@ -20,6 +20,7 @@ MLIRTestPDLLPatternsIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRIR MLIRPass MLIRSupport diff --git a/mlir/test/lib/Tools/PDLL/TestPDLL.cpp b/mlir/test/lib/Tools/PDLL/TestPDLL.cpp --- a/mlir/test/lib/Tools/PDLL/TestPDLL.cpp +++ b/mlir/test/lib/Tools/PDLL/TestPDLL.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -52,6 +52,7 @@ ${test_libs} MLIRAffineAnalysis MLIRAnalysis + MLIRCastInterfaces MLIRDialect MLIROptLib MLIRParser diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -268,7 +268,6 @@ ]) + [ "include/mlir/Bytecode/BytecodeImplementation.h", "include/mlir/Interfaces/CallInterfaces.h", - "include/mlir/Interfaces/CastInterfaces.h", "include/mlir/Interfaces/SideEffectInterfaces.h", "include/mlir/Interfaces/DataLayoutInterfaces.h", "include/mlir/Interfaces/FoldInterfaces.h", @@ -283,7 +282,6 @@ ":BuiltinTypeInterfacesIncGen", ":BuiltinTypesIncGen", ":CallOpInterfacesIncGen", - ":CastOpInterfacesIncGen", ":DataLayoutInterfacesIncGen", ":FunctionInterfacesIncGen", ":InferTypeOpInterfaceIncGen", @@ -2565,6 +2563,7 @@ ]), includes = ["include"], deps = [ + ":CastOpInterfaces", ":EmitCAttributesIncGen", ":EmitCOpsIncGen", ":IR", @@ -3135,6 +3134,7 @@ includes = ["include"], deps = [ ":ArithDialect", + ":CastOpInterfaces", ":ControlFlowInterfaces", ":Dialect", ":FuncDialect", @@ -7064,6 +7064,7 @@ ":BufferizationDialect", ":BufferizationTransformOps", ":BufferizationTransforms", + ":CastOpInterfaces", ":ComplexDialect", ":ComplexToLLVM", ":ComplexToLibm", @@ -8151,6 +8152,7 @@ hdrs = glob(["include/mlir/Dialect/Index/IR/*.h"]), includes = ["include"], deps = [ + ":CastOpInterfaces", ":IR", ":IndexEnumsIncGen", ":IndexOpsIncGen", @@ -9350,6 +9352,7 @@ deps = [ ":Analysis", ":CallOpInterfaces", + ":CastOpInterfaces", ":ControlFlowInterfaces", ":IR", ":PDLDialect", @@ -9753,6 +9756,7 @@ ":ArithCanonicalizationIncGen", ":ArithOpsIncGen", ":ArithOpsInterfacesIncGen", + ":CastOpInterfaces", ":CommonFolders", ":IR", ":InferIntRangeCommon", @@ -10040,6 +10044,7 @@ deps = [ ":ArithDialect", ":ArithUtils", + ":CastOpInterfaces", ":ControlFlowInterfaces", ":CopyOpInterface", ":DialectUtils", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -931,6 +931,7 @@ deps = [ ":TestDialect", ":TestPDLLPatternsIncGen", + "//mlir:CastOpInterfaces", "//mlir:IR", "//mlir:PDLDialect", "//mlir:PDLInterpDialect",