diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -182,26 +182,50 @@ casts between two different shapes. ```tablegen -def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape] + > { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type without changing any data elements. The source and destination types - must both be tensor types with the same element type. If both are ranked - then the rank should be the same and static dimensions should match. The - operation is invalid if converting to a mismatching constant dimension. + must both be tensor types with the same element type. If both are ranked, + then shape is required to match. The operation is invalid if converting + to a mismatching constant dimension. }]; let arguments = (ins F64Tensor:$input); let results = (outs F64Tensor:$output); +} +``` + +Note that the definition of this cast operation adds a `CastOpInterface` to the +traits list. This interface provides several utilities for cast-like operation, +such as folding identity casts and verification. We hook into this interface by +providing a definition for the `areCastCompatible` method: - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; +```c++ +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; } + ``` -We can then override the necessary hook on the ToyInlinerInterface to insert -this for us when necessary: +With a proper cast operation, we can now override the necessary hook on the +ToyInlinerInterface to insert it for us when necessary: ```c++ struct ToyInlinerInterface : public DialectInlinerInterface { diff --git a/mlir/examples/toy/Ch4/CMakeLists.txt b/mlir/examples/toy/Ch4/CMakeLists.txt --- a/mlir/examples/toy/Ch4/CMakeLists.txt +++ b/mlir/examples/toy/Ch4/CMakeLists.txt @@ -29,6 +29,7 @@ target_link_libraries(toyc-ch4 PRIVATE MLIRAnalysis + MLIRCastInterfaces MLIRCallInterfaces MLIRIR MLIRParser diff --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h --- a/mlir/examples/toy/Ch4/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -14,6 +14,7 @@ #define TOY_OPS include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "toy/ShapeInferenceInterface.td" @@ -102,25 +103,25 @@ ]; } -def CastOp : Toy_Op<"cast", - [DeclareOpInterfaceMethods, NoSideEffect, - SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape + ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type - without changing any data elements. The source and destination types - must both be tensor types with the same element type. If both are ranked - then the rank should be the same and static dimensions should match. The - operation is invalid if converting to a mismatching constant dimension. + without changing any data elements. The source and destination types must + both be tensor types with the same element type. If both are ranked, then + shape is required to match. The operation is invalid if converting to a + mismatching constant dimension. }]; let arguments = (ins F64Tensor:$input); let results = (outs F64Tensor:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; - - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; } def GenericCallOp : Toy_Op<"generic_call", diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -232,6 +232,21 @@ /// inference interface. void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; +} + //===----------------------------------------------------------------------===// // GenericCallOp diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -23,11 +23,6 @@ #include "ToyCombine.inc" } // end anonymous namespace -/// Fold simple cast operations that return the same type as the input. -OpFoldResult CastOp::fold(ArrayRef operands) { - return mlir::impl::foldCastOp(*this); -} - /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt --- a/mlir/examples/toy/Ch5/CMakeLists.txt +++ b/mlir/examples/toy/Ch5/CMakeLists.txt @@ -33,6 +33,7 @@ ${dialect_libs} MLIRAnalysis MLIRCallInterfaces + MLIRCastInterfaces MLIRIR MLIRParser MLIRPass diff --git a/mlir/examples/toy/Ch5/include/toy/Dialect.h b/mlir/examples/toy/Ch5/include/toy/Dialect.h --- a/mlir/examples/toy/Ch5/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch5/include/toy/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -14,6 +14,7 @@ #define TOY_OPS include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "toy/ShapeInferenceInterface.td" @@ -102,25 +103,25 @@ ]; } -def CastOp : Toy_Op<"cast", - [DeclareOpInterfaceMethods, NoSideEffect, - SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape + ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type - without changing any data elements. The source and destination types - must both be tensor types with the same element type. If both are ranked - then the rank should be the same and static dimensions should match. The - operation is invalid if converting to a mismatching constant dimension. + without changing any data elements. The source and destination types must + both be tensor types with the same element type. If both are ranked, then + shape is required to match. The operation is invalid if converting to a + mismatching constant dimension. }]; let arguments = (ins F64Tensor:$input); let results = (outs F64Tensor:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; - - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; } def GenericCallOp : Toy_Op<"generic_call", diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -232,6 +232,21 @@ /// inference interface. void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; +} + //===----------------------------------------------------------------------===// // GenericCallOp diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -23,11 +23,6 @@ #include "ToyCombine.inc" } // end anonymous namespace -/// Fold simple cast operations that return the same type as the input. -OpFoldResult CastOp::fold(ArrayRef operands) { - return mlir::impl::foldCastOp(*this); -} - /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt --- a/mlir/examples/toy/Ch6/CMakeLists.txt +++ b/mlir/examples/toy/Ch6/CMakeLists.txt @@ -39,6 +39,7 @@ ${conversion_libs} MLIRAnalysis MLIRCallInterfaces + MLIRCastInterfaces MLIRExecutionEngine MLIRIR MLIRLLVMIR diff --git a/mlir/examples/toy/Ch6/include/toy/Dialect.h b/mlir/examples/toy/Ch6/include/toy/Dialect.h --- a/mlir/examples/toy/Ch6/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch6/include/toy/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -14,6 +14,7 @@ #define TOY_OPS include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "toy/ShapeInferenceInterface.td" @@ -102,25 +103,25 @@ ]; } -def CastOp : Toy_Op<"cast", - [DeclareOpInterfaceMethods, NoSideEffect, - SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape + ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type - without changing any data elements. The source and destination types - must both be tensor types with the same element type. If both are ranked - then the rank should be the same and static dimensions should match. The - operation is invalid if converting to a mismatching constant dimension. + without changing any data elements. The source and destination types must + both be tensor types with the same element type. If both are ranked, then + shape is required to match. The operation is invalid if converting to a + mismatching constant dimension. }]; let arguments = (ins F64Tensor:$input); let results = (outs F64Tensor:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; - - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; } def GenericCallOp : Toy_Op<"generic_call", diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -232,6 +232,21 @@ /// inference interface. void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; +} + //===----------------------------------------------------------------------===// // GenericCallOp diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -23,11 +23,6 @@ #include "ToyCombine.inc" } // end anonymous namespace -/// Fold simple cast operations that return the same type as the input. -OpFoldResult CastOp::fold(ArrayRef operands) { - return mlir::impl::foldCastOp(*this); -} - /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt --- a/mlir/examples/toy/Ch7/CMakeLists.txt +++ b/mlir/examples/toy/Ch7/CMakeLists.txt @@ -39,6 +39,7 @@ ${conversion_libs} MLIRAnalysis MLIRCallInterfaces + MLIRCastInterfaces MLIRExecutionEngine MLIRIR MLIRParser diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h --- a/mlir/examples/toy/Ch7/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -14,6 +14,7 @@ #define TOY_OPS include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "toy/ShapeInferenceInterface.td" @@ -115,25 +116,25 @@ ]; } -def CastOp : Toy_Op<"cast", - [DeclareOpInterfaceMethods, NoSideEffect, - SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape + ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type - without changing any data elements. The source and destination types - must both be tensor types with the same element type. If both are ranked - then the rank should be the same and static dimensions should match. The - operation is invalid if converting to a mismatching constant dimension. + without changing any data elements. The source and destination types must + both be tensor types with the same element type. If both are ranked, then + shape is required to match. The operation is invalid if converting to a + mismatching constant dimension. }]; let arguments = (ins F64Tensor:$input); let results = (outs F64Tensor:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; - - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; } def GenericCallOp : Toy_Op<"generic_call", diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -284,6 +284,21 @@ /// inference interface. void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; +} + //===----------------------------------------------------------------------===// // GenericCallOp diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -23,11 +23,6 @@ #include "ToyCombine.inc" } // end anonymous namespace -/// Fold simple cast operations that return the same type as the input. -OpFoldResult CastOp::fold(ArrayRef operands) { - return mlir::impl::foldCastOp(*this); -} - /// Fold constants. OpFoldResult ConstantOp::fold(ArrayRef operands) { return value(); } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -19,6 +19,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -17,6 +17,7 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" @@ -45,9 +46,10 @@ // Base class for standard cast operations. Requires single operand and result, // but does not constrain them to specific types. class CastOp traits = []> : - Std_Op { - + Std_Op + ]> { let results = (outs AnyType); let builders = [ @@ -62,9 +64,9 @@ let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; - let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }]; - let hasFolder = 1; + // Cast operations are fully verified by its traits. + let verifier = ?; } // Base class for arithmetic cast operations. @@ -1643,14 +1645,6 @@ The destination type must to be strictly wider than the source type. Only scalars are currently supported. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -1663,14 +1657,6 @@ Cast from a value interpreted as floating-point to the nearest (rounding towards zero) signed integer value. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -1683,14 +1669,6 @@ Cast from a value interpreted as floating-point to the nearest (rounding towards zero) unsigned integer value. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -1705,14 +1683,6 @@ If the value cannot be exactly represented, it is rounded using the default rounding mode. Only scalars are currently supported. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -1849,12 +1819,6 @@ sign-extended. If casting to a narrower integer, the value is truncated. }]; - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - let hasFolder = 1; } @@ -2045,14 +2009,7 @@ let arguments = (ins AnyRankedOrUnrankedMemRef:$source); let results = (outs AnyRankedOrUnrankedMemRef); - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a memref_cast is always a memref. - Type getType() { return getResult().getType(); } - }]; + let hasFolder = 1; } @@ -2786,14 +2743,6 @@ exactly represented, it is rounded using the default rounding mode. Scalars and vector types are currently supported. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -3628,14 +3577,6 @@ value cannot be exactly represented, it is rounded using the default rounding mode. Scalars and vector types are currently supported. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -13,6 +13,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -10,6 +10,7 @@ #define TENSOR_OPS include "mlir/Dialect/Tensor/IR/TensorBase.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -24,7 +25,9 @@ // CastOp //===----------------------------------------------------------------------===// -def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> { +def Tensor_CastOp : Tensor_Op<"cast", [ + DeclareOpInterfaceMethods, NoSideEffect + ]> { let summary = "tensor cast operation"; let description = [{ Convert a tensor from one type to an equivalent type without changing any @@ -51,19 +54,9 @@ let arguments = (ins AnyTensor:$source); let results = (outs AnyTensor:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; - let verifier = "return impl::verifyCastOp(*this, areCastCompatible);"; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a tensor.cast is always a tensor. - TensorType getType() { return getResult().getType().cast(); } - }]; - let hasFolder = 1; let hasCanonicalizer = 1; + let verifier = ?; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -50,6 +50,34 @@ /// A variant type that holds a single argument for a diagnostic. class DiagnosticArgument { public: + /// Note: The constructors below are only exposed due to problems accessing + /// constructors from type traits, they should not be used directly by users. + // Construct from an Attribute. + explicit DiagnosticArgument(Attribute attr); + // Construct from a floating point number. + explicit DiagnosticArgument(double val) + : kind(DiagnosticArgumentKind::Double), doubleVal(val) {} + explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {} + // Construct from a signed integer. + template + explicit DiagnosticArgument( + T val, typename std::enable_if::value && + std::numeric_limits::is_integer && + sizeof(T) <= sizeof(int64_t)>::type * = 0) + : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {} + // Construct from an unsigned integer. + template + explicit DiagnosticArgument( + T val, typename std::enable_if::value && + std::numeric_limits::is_integer && + sizeof(T) <= sizeof(uint64_t)>::type * = 0) + : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {} + // Construct from a string reference. + explicit DiagnosticArgument(StringRef val) + : kind(DiagnosticArgumentKind::String), stringVal(val) {} + // Construct from a Type. + explicit DiagnosticArgument(Type val); + /// Enum that represents the different kinds of diagnostic arguments /// supported. enum class DiagnosticArgumentKind { @@ -100,37 +128,6 @@ private: friend class Diagnostic; - // Construct from an Attribute. - explicit DiagnosticArgument(Attribute attr); - - // Construct from a floating point number. - explicit DiagnosticArgument(double val) - : kind(DiagnosticArgumentKind::Double), doubleVal(val) {} - explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {} - - // Construct from a signed integer. - template - explicit DiagnosticArgument( - T val, typename std::enable_if::value && - std::numeric_limits::is_integer && - sizeof(T) <= sizeof(int64_t)>::type * = 0) - : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {} - - // Construct from an unsigned integer. - template - explicit DiagnosticArgument( - T val, typename std::enable_if::value && - std::numeric_limits::is_integer && - sizeof(T) <= sizeof(uint64_t)>::type * = 0) - : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {} - - // Construct from a string reference. - explicit DiagnosticArgument(StringRef val) - : kind(DiagnosticArgumentKind::String), stringVal(val) {} - - // Construct from a Type. - explicit DiagnosticArgument(Type val); - /// The kind of this argument. DiagnosticArgumentKind kind; @@ -189,8 +186,10 @@ /// Stream operator for inserting new diagnostic arguments. template - typename std::enable_if::value, - Diagnostic &>::type + typename std::enable_if< + !std::is_convertible::value && + std::is_constructible::value, + Diagnostic &>::type operator<<(Arg &&val) { arguments.push_back(DiagnosticArgument(std::forward(val))); return *this; @@ -220,17 +219,17 @@ } /// Stream in a range. - template Diagnostic &operator<<(iterator_range range) { - return appendRange(range); - } - template Diagnostic &operator<<(ArrayRef range) { + template > + std::enable_if_t::value, + Diagnostic &> + operator<<(T &&range) { return appendRange(range); } /// Append a range to the diagnostic. The default delimiter between elements /// is ','. - template class Container> - Diagnostic &appendRange(const Container &c, const char *delim = ", ") { + template + Diagnostic &appendRange(const T &c, const char *delim = ", ") { llvm::interleave( c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; }); return *this; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1822,18 +1822,27 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p); } // namespace impl -// These functions are out-of-line implementations of the methods in CastOp, -// which avoids them being template instantiated/duplicated. +// These functions are out-of-line implementations of the methods in +// CastOpInterface, which avoids them being template instantiated/duplicated. namespace impl { +/// Attempt to fold the given cast operation. +LogicalResult foldCastInterfaceOp(Operation *op, + ArrayRef attrOperands, + SmallVectorImpl &foldResults); +/// Attempt to verify the given cast operation. +LogicalResult verifyCastInterfaceOp( + Operation *op, function_ref areCastCompatible); + // TODO: Remove the parse/print/build here (new ODS functionality obsoletes the // need for them, but some older ODS code in `std` still depends on them). void buildCastOp(OpBuilder &builder, OperationState &result, Value source, Type destType); ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); void printCastOp(Operation *op, OpAsmPrinter &p); -// TODO: Create a CastOpInterface with a method areCastCompatible. -// Also, consider adding functionality to CastOpInterface to be able to perform -// the ChainedTensorCast canonicalization generically. +// TODO: These methods are deprecated in favor of CastOpInterface. Remove them +// when all uses have been updated. Also, consider adding functionality to +// CastOpInterface to be able to perform the ChainedTensorCast canonicalization +// generically. Value foldCastOp(Operation *op); LogicalResult verifyCastOp(Operation *op, function_ref areCastCompatible); diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_interface(CallInterfaces) +add_mlir_interface(CastInterfaces) add_mlir_interface(ControlFlowInterfaces) add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.h b/mlir/include/mlir/Interfaces/CastInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/CastInterfaces.h @@ -0,0 +1,22 @@ +//===- CastInterfaces.h - Cast Interfaces for MLIR --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the cast interfaces defined in +// `CastInterfaces.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_CASTINTERFACES_H +#define MLIR_INTERFACES_CASTINTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/CastInterfaces.h.inc" + +#endif // MLIR_INTERFACES_CASTINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.td b/mlir/include/mlir/Interfaces/CastInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/CastInterfaces.td @@ -0,0 +1,51 @@ +//===- CastInterfaces.td - Cast Interfaces for ops ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces that can be used to define information +// related to cast-like operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_CASTINTERFACES +#define MLIR_INTERFACES_CASTINTERFACES + +include "mlir/IR/OpBase.td" + +def CastOpInterface : OpInterface<"CastOpInterface"> { + let description = [{ + A cast-like operation is one that converts from a set of input types to a + set of output types. The arity of the inputs may be from 0-N, whereas the + arity of the outputs may be anything from 1-N. Cast-like operations are + trivially removable in cases where they produce an No-op, i.e when the + input types and output types match 1-1. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + Returns true if the given set of input and result types are compatible + to cast using this cast operation. + }], + "bool", "areCastCompatible", + (ins "mlir::TypeRange":$inputs, "mlir::TypeRange":$outputs) + >, + ]; + + let extraTraitClassDeclaration = [{ + /// Attempt to fold the given cast operation. + static LogicalResult foldTrait(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return impl::foldCastInterfaceOp(op, operands, results); + } + }]; + let verify = [{ + return impl::verifyCastInterfaceOp($_op, ConcreteOp::areCastCompatible); + }]; +} + +#endif // MLIR_INTERFACES_CASTINTERFACES diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -12,6 +12,7 @@ MLIRShapeOpsIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRControlFlowInterfaces MLIRDialect MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRCallInterfaces + MLIRCastInterfaces MLIRControlFlowInterfaces MLIREDSC MLIRIR diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -195,7 +195,8 @@ /// Returns 'true' if the vector types are cast compatible, and 'false' /// otherwise. static bool areVectorCastSimpleCompatible( - Type a, Type b, function_ref areElementsCastCompatible) { + Type a, Type b, + function_ref areElementsCastCompatible) { if (auto va = a.dyn_cast()) if (auto vb = b.dyn_cast()) return va.getShape().equals(vb.getShape()) && @@ -1746,7 +1747,10 @@ // FPExtOp //===----------------------------------------------------------------------===// -bool FPExtOp::areCastCompatible(Type a, Type b) { +bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); @@ -1757,7 +1761,10 @@ // FPToSIOp //===----------------------------------------------------------------------===// -bool FPToSIOp::areCastCompatible(Type a, Type b) { +bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isa() && b.isSignlessInteger()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); @@ -1767,7 +1774,10 @@ // FPToUIOp //===----------------------------------------------------------------------===// -bool FPToUIOp::areCastCompatible(Type a, Type b) { +bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isa() && b.isSignlessInteger()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); @@ -1777,7 +1787,10 @@ // FPTruncOp //===----------------------------------------------------------------------===// -bool FPTruncOp::areCastCompatible(Type a, Type b) { +bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); @@ -1889,7 +1902,10 @@ //===----------------------------------------------------------------------===// // Index cast is applicable from index to integer and backwards. -bool IndexCastOp::areCastCompatible(Type a, Type b) { +bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isa() && b.isa()) { auto aShaped = a.cast(); auto bShaped = b.cast(); @@ -1965,7 +1981,10 @@ Value MemRefCastOp::getViewSource() { return source(); } -bool MemRefCastOp::areCastCompatible(Type a, Type b) { +bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); @@ -2036,8 +2055,6 @@ } OpFoldResult MemRefCastOp::fold(ArrayRef operands) { - if (Value folded = impl::foldCastOp(*this)) - return folded; return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } @@ -2633,7 +2650,10 @@ //===----------------------------------------------------------------------===// // sitofp is applicable from integer types to float types. -bool SIToFPOp::areCastCompatible(Type a, Type b) { +bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isSignlessInteger() && b.isa()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); @@ -2715,7 +2735,10 @@ //===----------------------------------------------------------------------===// // uitofp is applicable from integer types to float types. -bool UIToFPOp::areCastCompatible(Type a, Type b) { +bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isSignlessInteger() && b.isa()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -12,6 +12,7 @@ Core LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRIR MLIRSideEffectInterfaces MLIRSupport diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -73,7 +73,10 @@ return true; } -bool CastOp::areCastCompatible(Type a, Type b) { +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); if (!aT || !bT) @@ -85,10 +88,6 @@ return succeeded(verifyCompatibleShape(aT, bT)); } -OpFoldResult CastOp::fold(ArrayRef operands) { - return impl::foldCastOp(*this); -} - /// Compute a TensorType that has the joined shape knowledge of the two /// given TensorTypes. The element types need to match. static TensorType joinShapes(TensorType one, TensorType two) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1208,6 +1208,48 @@ // CastOp implementation //===----------------------------------------------------------------------===// +/// Attempt to fold the given cast operation. +LogicalResult +impl::foldCastInterfaceOp(Operation *op, ArrayRef attrOperands, + SmallVectorImpl &foldResults) { + OperandRange operands = op->getOperands(); + if (operands.empty()) + return failure(); + ResultRange results = op->getResults(); + + // Check for the case where the input and output types match 1-1. + if (operands.getTypes() == results.getTypes()) { + foldResults.append(operands.begin(), operands.end()); + return success(); + } + + return failure(); +} + +/// Attempt to verify the given cast operation. +LogicalResult impl::verifyCastInterfaceOp( + Operation *op, function_ref areCastCompatible) { + auto resultTypes = op->getResultTypes(); + if (llvm::empty(resultTypes)) + return op->emitOpError() + << "expected at least one result for cast operation"; + + auto operandTypes = op->getOperandTypes(); + if (!areCastCompatible(operandTypes, resultTypes)) { + InFlightDiagnostic diag = op->emitOpError("operand type"); + if (llvm::empty(operandTypes)) + diag << "s []"; + else if (llvm::size(operandTypes) == 1) + diag << " " << *operandTypes.begin(); + else + diag << "s " << operandTypes; + return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") + << resultTypes << " are cast incompatible"; + } + + return success(); +} + void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, Type destType) { result.addOperands(source); diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -1,5 +1,6 @@ set(LLVM_OPTIONAL_SOURCES CallInterfaces.cpp + CastInterfaces.cpp ControlFlowInterfaces.cpp CopyOpInterface.cpp DerivedAttributeOpInterface.cpp @@ -27,6 +28,7 @@ add_mlir_interface_library(CallInterfaces) +add_mlir_interface_library(CastInterfaces) add_mlir_interface_library(ControlFlowInterfaces) add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DerivedAttributeOpInterface) diff --git a/mlir/lib/Interfaces/CastInterfaces.cpp b/mlir/lib/Interfaces/CastInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/CastInterfaces.cpp @@ -0,0 +1,17 @@ +//===- CastInterfaces.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/CastInterfaces.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Table-generated class definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/CastInterfaces.cpp.inc"