diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -182,26 +182,52 @@ casts between two different shapes. ```tablegen -def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape] + > { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type without changing any data elements. The source and destination types - must both be tensor types with the same element type. If both are ranked + must both be tensor types with the same element type. If both are ranked, then the rank should be the same and static dimensions should match. The operation is invalid if converting to a mismatching constant dimension. }]; let arguments = (ins F64Tensor:$input); let results = (outs F64Tensor:$output); +} +``` + +Note that the definition of this cast operation adds a `CastOpInterface` to the +traits list. This interface provides several utilities for cast-like operation, +such as folding identity casts and verification. We hook into this interface by +providing a definition for the `areCastCompatible` method: - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; +```c++ +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // They inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // Known static dimensions are expected to match. + if (!input.hasRank() || !output.hasRank()) + return true; + return input.getShape() == output.getShape(); } + ``` -We can then override the necessary hook on the ToyInlinerInterface to insert -this for us when necessary: +With a proper cast operation, we can now override the necessary hook on the +ToyInlinerInterface to insert it for us when necessary: ```c++ struct ToyInlinerInterface : public DialectInlinerInterface { diff --git a/mlir/examples/toy/Ch4/CMakeLists.txt b/mlir/examples/toy/Ch4/CMakeLists.txt --- a/mlir/examples/toy/Ch4/CMakeLists.txt +++ b/mlir/examples/toy/Ch4/CMakeLists.txt @@ -29,6 +29,7 @@ target_link_libraries(toyc-ch4 PRIVATE MLIRAnalysis + MLIRCastInterfaces MLIRCallInterfaces MLIRIR MLIRParser diff --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h --- a/mlir/examples/toy/Ch4/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -14,6 +14,7 @@ #define TOY_OPS include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "toy/ShapeInferenceInterface.td" @@ -102,9 +103,12 @@ ]; } -def CastOp : Toy_Op<"cast", - [DeclareOpInterfaceMethods, NoSideEffect, - SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape + ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type @@ -118,9 +122,6 @@ let results = (outs F64Tensor:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; - - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; } def GenericCallOp : Toy_Op<"generic_call", diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -232,6 +232,23 @@ /// inference interface. void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // They inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // Known static dimensions are expected to match. + if (!input.hasRank() || !output.hasRank()) + return true; + return input.getShape() == output.getShape(); +} + //===----------------------------------------------------------------------===// // GenericCallOp diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -23,11 +23,6 @@ #include "ToyCombine.inc" } // end anonymous namespace -/// Fold simple cast operations that return the same type as the input. -OpFoldResult CastOp::fold(ArrayRef operands) { - return mlir::impl::foldCastOp(*this); -} - /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt --- a/mlir/examples/toy/Ch5/CMakeLists.txt +++ b/mlir/examples/toy/Ch5/CMakeLists.txt @@ -33,6 +33,7 @@ ${dialect_libs} MLIRAnalysis MLIRCallInterfaces + MLIRCastInterfaces MLIRIR MLIRParser MLIRPass diff --git a/mlir/examples/toy/Ch5/include/toy/Dialect.h b/mlir/examples/toy/Ch5/include/toy/Dialect.h --- a/mlir/examples/toy/Ch5/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch5/include/toy/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -14,6 +14,7 @@ #define TOY_OPS include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "toy/ShapeInferenceInterface.td" @@ -102,9 +103,12 @@ ]; } -def CastOp : Toy_Op<"cast", - [DeclareOpInterfaceMethods, NoSideEffect, - SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape + ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type @@ -118,9 +122,6 @@ let results = (outs F64Tensor:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; - - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; } def GenericCallOp : Toy_Op<"generic_call", diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -232,6 +232,23 @@ /// inference interface. void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // They inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // Known static dimensions are expected to match. + if (!input.hasRank() || !output.hasRank()) + return true; + return input.getShape() == output.getShape(); +} + //===----------------------------------------------------------------------===// // GenericCallOp diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -23,11 +23,6 @@ #include "ToyCombine.inc" } // end anonymous namespace -/// Fold simple cast operations that return the same type as the input. -OpFoldResult CastOp::fold(ArrayRef operands) { - return mlir::impl::foldCastOp(*this); -} - /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt --- a/mlir/examples/toy/Ch6/CMakeLists.txt +++ b/mlir/examples/toy/Ch6/CMakeLists.txt @@ -39,6 +39,7 @@ ${conversion_libs} MLIRAnalysis MLIRCallInterfaces + MLIRCastInterfaces MLIRExecutionEngine MLIRIR MLIRLLVMIR diff --git a/mlir/examples/toy/Ch6/include/toy/Dialect.h b/mlir/examples/toy/Ch6/include/toy/Dialect.h --- a/mlir/examples/toy/Ch6/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch6/include/toy/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -14,6 +14,7 @@ #define TOY_OPS include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "toy/ShapeInferenceInterface.td" @@ -102,9 +103,12 @@ ]; } -def CastOp : Toy_Op<"cast", - [DeclareOpInterfaceMethods, NoSideEffect, - SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape + ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type @@ -118,9 +122,6 @@ let results = (outs F64Tensor:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; - - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; } def GenericCallOp : Toy_Op<"generic_call", diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -232,6 +232,23 @@ /// inference interface. void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // They inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // Known static dimensions are expected to match. + if (!input.hasRank() || !output.hasRank()) + return true; + return input.getShape() == output.getShape(); +} + //===----------------------------------------------------------------------===// // GenericCallOp diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -23,11 +23,6 @@ #include "ToyCombine.inc" } // end anonymous namespace -/// Fold simple cast operations that return the same type as the input. -OpFoldResult CastOp::fold(ArrayRef operands) { - return mlir::impl::foldCastOp(*this); -} - /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt --- a/mlir/examples/toy/Ch7/CMakeLists.txt +++ b/mlir/examples/toy/Ch7/CMakeLists.txt @@ -39,6 +39,7 @@ ${conversion_libs} MLIRAnalysis MLIRCallInterfaces + MLIRCastInterfaces MLIRExecutionEngine MLIRIR MLIRParser diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h --- a/mlir/examples/toy/Ch7/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -14,6 +14,7 @@ #define TOY_OPS include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "toy/ShapeInferenceInterface.td" @@ -115,9 +116,12 @@ ]; } -def CastOp : Toy_Op<"cast", - [DeclareOpInterfaceMethods, NoSideEffect, - SameOperandsAndResultShape]> { +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect, + SameOperandsAndResultShape + ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type @@ -131,9 +135,6 @@ let results = (outs F64Tensor:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; - - // Set the folder bit so that we can fold redundant cast operations. - let hasFolder = 1; } def GenericCallOp : Toy_Op<"generic_call", diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -284,6 +284,23 @@ /// inference interface. void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // They inputs must be Tensors with the same element type. + TensorType input = inputs.front().dyn_cast(); + TensorType output = outputs.front().dyn_cast(); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // Known static dimensions are expected to match. + if (!input.hasRank() || !output.hasRank()) + return true; + return input.getShape() == output.getShape(); +} + //===----------------------------------------------------------------------===// // GenericCallOp diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -23,11 +23,6 @@ #include "ToyCombine.inc" } // end anonymous namespace -/// Fold simple cast operations that return the same type as the input. -OpFoldResult CastOp::fold(ArrayRef operands) { - return mlir::impl::foldCastOp(*this); -} - /// Fold constants. OpFoldResult ConstantOp::fold(ArrayRef operands) { return value(); } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -19,6 +19,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -17,6 +17,7 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" @@ -45,8 +46,10 @@ // Base class for standard cast operations. Requires single operand and result, // but does not constrain them to specific types. class CastOp traits = []> : - Std_Op { + Std_Op + ])> { let results = (outs AnyType); @@ -62,9 +65,9 @@ let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; - let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }]; - let hasFolder = 1; + // Cast operations are fully verified by its traits. + let verifier = ?; } // Base class for arithmetic cast operations. @@ -1723,14 +1726,6 @@ The destination type must to be strictly wider than the source type. Only scalars are currently supported. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -1743,14 +1738,6 @@ Cast from a value interpreted as floating-point to the nearest (rounding towards zero) signed integer value. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -1763,14 +1750,6 @@ Cast from a value interpreted as floating-point to the nearest (rounding towards zero) unsigned integer value. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -1785,14 +1764,6 @@ If the value cannot be exactly represented, it is rounded using the default rounding mode. Only scalars are currently supported. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -1929,12 +1900,6 @@ sign-extended. If casting to a narrower integer, the value is truncated. }]; - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - let hasFolder = 1; } @@ -2096,14 +2061,7 @@ let arguments = (ins AnyRankedOrUnrankedMemRef:$source); let results = (outs AnyRankedOrUnrankedMemRef); - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a memref_cast is always a memref. - Type getType() { return getResult().getType(); } - }]; + let hasFolder = 1; } @@ -2837,14 +2795,6 @@ exactly represented, it is rounded using the default rounding mode. Scalars and vector types are currently supported. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// @@ -3668,14 +3618,6 @@ value cannot be exactly represented, it is rounded using the default rounding mode. Scalars and vector types are currently supported. }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -13,6 +13,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -10,6 +10,7 @@ #define TENSOR_OPS include "mlir/Dialect/Tensor/IR/TensorBase.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" class Tensor_Op traits = []> @@ -23,7 +24,9 @@ // CastOp //===----------------------------------------------------------------------===// -def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> { +def Tensor_CastOp : Tensor_Op<"cast", [ + DeclareOpInterfaceMethods, NoSideEffect + ]> { let summary = "tensor cast operation"; let description = [{ Convert a tensor from one type to an equivalent type without changing any @@ -50,19 +53,9 @@ let arguments = (ins AnyTensor:$source); let results = (outs AnyTensor:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; - let verifier = "return impl::verifyCastOp(*this, areCastCompatible);"; - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a tensor.cast is always a tensor. - TensorType getType() { return getResult().getType().cast(); } - }]; - - let hasFolder = 1; let hasCanonicalizer = 1; + let verifier = ?; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -50,6 +50,34 @@ /// A variant type that holds a single argument for a diagnostic. class DiagnosticArgument { public: + /// Note: The constructors below are only exposed due to problems accessing + /// constructors from type traits, they should not be used directly by users. + // Construct from an Attribute. + explicit DiagnosticArgument(Attribute attr); + // Construct from a floating point number. + explicit DiagnosticArgument(double val) + : kind(DiagnosticArgumentKind::Double), doubleVal(val) {} + explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {} + // Construct from a signed integer. + template + explicit DiagnosticArgument( + T val, typename std::enable_if::value && + std::numeric_limits::is_integer && + sizeof(T) <= sizeof(int64_t)>::type * = 0) + : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {} + // Construct from an unsigned integer. + template + explicit DiagnosticArgument( + T val, typename std::enable_if::value && + std::numeric_limits::is_integer && + sizeof(T) <= sizeof(uint64_t)>::type * = 0) + : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {} + // Construct from a string reference. + explicit DiagnosticArgument(StringRef val) + : kind(DiagnosticArgumentKind::String), stringVal(val) {} + // Construct from a Type. + explicit DiagnosticArgument(Type val); + /// Enum that represents the different kinds of diagnostic arguments /// supported. enum class DiagnosticArgumentKind { @@ -100,37 +128,6 @@ private: friend class Diagnostic; - // Construct from an Attribute. - explicit DiagnosticArgument(Attribute attr); - - // Construct from a floating point number. - explicit DiagnosticArgument(double val) - : kind(DiagnosticArgumentKind::Double), doubleVal(val) {} - explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {} - - // Construct from a signed integer. - template - explicit DiagnosticArgument( - T val, typename std::enable_if::value && - std::numeric_limits::is_integer && - sizeof(T) <= sizeof(int64_t)>::type * = 0) - : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {} - - // Construct from an unsigned integer. - template - explicit DiagnosticArgument( - T val, typename std::enable_if::value && - std::numeric_limits::is_integer && - sizeof(T) <= sizeof(uint64_t)>::type * = 0) - : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {} - - // Construct from a string reference. - explicit DiagnosticArgument(StringRef val) - : kind(DiagnosticArgumentKind::String), stringVal(val) {} - - // Construct from a Type. - explicit DiagnosticArgument(Type val); - /// The kind of this argument. DiagnosticArgumentKind kind; @@ -189,8 +186,10 @@ /// Stream operator for inserting new diagnostic arguments. template - typename std::enable_if::value, - Diagnostic &>::type + typename std::enable_if< + !std::is_convertible::value && + std::is_constructible::value, + Diagnostic &>::type operator<<(Arg &&val) { arguments.push_back(DiagnosticArgument(std::forward(val))); return *this; @@ -220,17 +219,17 @@ } /// Stream in a range. - template Diagnostic &operator<<(iterator_range range) { - return appendRange(range); - } - template Diagnostic &operator<<(ArrayRef range) { + template > + std::enable_if_t::value, + Diagnostic &> + operator<<(T &&range) { return appendRange(range); } /// Append a range to the diagnostic. The default delimiter between elements /// is ','. - template class Container> - Diagnostic &appendRange(const Container &c, const char *delim = ", ") { + template + Diagnostic &appendRange(const T &c, const char *delim = ", ") { llvm::interleave( c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; }); return *this; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1822,18 +1822,27 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p); } // namespace impl -// These functions are out-of-line implementations of the methods in CastOp, -// which avoids them being template instantiated/duplicated. +// These functions are out-of-line implementations of the methods in +// CastOpInterface, which avoids them being template instantiated/duplicated. namespace impl { +/// Attempt to fold the given cast operation. +LogicalResult foldCastInterfaceOp(Operation *op, + ArrayRef attrOperands, + SmallVectorImpl &foldResults); +/// Attempt to verify the given cast operation. +LogicalResult verifyCastInterfaceOp( + Operation *op, function_ref areCastCompatible); + // TODO: Remove the parse/print/build here (new ODS functionality obsoletes the // need for them, but some older ODS code in `std` still depends on them). void buildCastOp(OpBuilder &builder, OperationState &result, Value source, Type destType); ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); void printCastOp(Operation *op, OpAsmPrinter &p); -// TODO: Create a CastOpInterface with a method areCastCompatible. -// Also, consider adding functionality to CastOpInterface to be able to perform -// the ChainedTensorCast canonicalization generically. +// TODO: These methods are deprecated in favor of CastOpInterface. Remove them +// when all uses have been updated. Also, consider adding functionality to +// CastOpInterface to be able to perform the ChainedTensorCast canonicalization +// generically. Value foldCastOp(Operation *op); LogicalResult verifyCastOp(Operation *op, function_ref areCastCompatible); diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_interface(CallInterfaces) +add_mlir_interface(CastInterfaces) add_mlir_interface(ControlFlowInterfaces) add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.h b/mlir/include/mlir/Interfaces/CastInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/CastInterfaces.h @@ -0,0 +1,22 @@ +//===- CastInterfaces.h - Cast Interfaces for MLIR --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the cast interfaces defined in +// `CastInterfaces.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_CASTINTERFACES_H +#define MLIR_INTERFACES_CASTINTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/CastInterfaces.h.inc" + +#endif // MLIR_INTERFACES_CASTINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/CastInterfaces.td b/mlir/include/mlir/Interfaces/CastInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/CastInterfaces.td @@ -0,0 +1,53 @@ +//===- CastInterfaces.td - Cast Interfaces for ops ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces that can be used to define information +// related to cast-like operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_CASTINTERFACES +#define MLIR_INTERFACES_CASTINTERFACES + +include "mlir/IR/OpBase.td" + +def CastOpInterface : OpInterface<"CastOpInterface"> { + let description = [{ + A cast-like operation is one that converts from a set of input types to a + set of output types. The arity of the inputs may be from 0-N, whereas the + arity of the outputs may be anything from 1-N. Cast-like operations are + trivially removable in cases where they produce an identity, i.e: + * When the input types and output types match 1-1. + * When the input is another cast, of the same type as this cast, with + input types that match 1-1 with the output types of this cast. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + Returns true if the given set of input and result types are compatible + to cast using this cast operation. + }], + "bool", "areCastCompatible", + (ins "mlir::TypeRange":$inputs, "mlir::TypeRange":$outputs) + >, + ]; + + let extraTraitClassDeclaration = [{ + /// Attempt to fold the given cast operation. + static LogicalResult foldTrait(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return impl::foldCastInterfaceOp(op, operands, results); + } + }]; + let verify = [{ + return impl::verifyCastInterfaceOp($_op, ConcreteOp::areCastCompatible); + }]; +} + +#endif // MLIR_INTERFACES_CASTINTERFACES diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -12,6 +12,7 @@ MLIRShapeOpsIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRControlFlowInterfaces MLIRDialect MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRCallInterfaces + MLIRCastInterfaces MLIRControlFlowInterfaces MLIREDSC MLIRIR diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -195,7 +195,8 @@ /// Returns 'true' if the vector types are cast compatible, and 'false' /// otherwise. static bool areVectorCastSimpleCompatible( - Type a, Type b, function_ref areElementsCastCompatible) { + Type a, Type b, + function_ref areElementsCastCompatible) { if (auto va = a.dyn_cast()) if (auto vb = b.dyn_cast()) return va.getShape().equals(vb.getShape()) && @@ -1990,7 +1991,10 @@ // FPExtOp //===----------------------------------------------------------------------===// -bool FPExtOp::areCastCompatible(Type a, Type b) { +bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); @@ -2001,7 +2005,10 @@ // FPToSIOp //===----------------------------------------------------------------------===// -bool FPToSIOp::areCastCompatible(Type a, Type b) { +bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isa() && b.isSignlessInteger()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); @@ -2011,7 +2018,10 @@ // FPToUIOp //===----------------------------------------------------------------------===// -bool FPToUIOp::areCastCompatible(Type a, Type b) { +bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isa() && b.isSignlessInteger()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); @@ -2021,7 +2031,10 @@ // FPTruncOp //===----------------------------------------------------------------------===// -bool FPTruncOp::areCastCompatible(Type a, Type b) { +bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); @@ -2133,7 +2146,10 @@ //===----------------------------------------------------------------------===// // Index cast is applicable from index to integer and backwards. -bool IndexCastOp::areCastCompatible(Type a, Type b) { +bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isa() && b.isa()) { auto aShaped = a.cast(); auto bShaped = b.cast(); @@ -2148,11 +2164,6 @@ } OpFoldResult IndexCastOp::fold(ArrayRef cstOperands) { - // Fold IndexCast(IndexCast(x)) -> x - auto cast = getOperand().getDefiningOp(); - if (cast && cast.getOperand().getType() == getType()) - return cast.getOperand(); - // Fold IndexCast(constant) -> constant // A little hack because we go through int. Otherwise, the size // of the constant might need to change. @@ -2209,7 +2220,10 @@ Value MemRefCastOp::getViewSource() { return source(); } -bool MemRefCastOp::areCastCompatible(Type a, Type b) { +bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); @@ -2280,8 +2294,6 @@ } OpFoldResult MemRefCastOp::fold(ArrayRef operands) { - if (Value folded = impl::foldCastOp(*this)) - return folded; return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } @@ -2877,7 +2889,10 @@ //===----------------------------------------------------------------------===// // sitofp is applicable from integer types to float types. -bool SIToFPOp::areCastCompatible(Type a, Type b) { +bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isSignlessInteger() && b.isa()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); @@ -2959,7 +2974,10 @@ //===----------------------------------------------------------------------===// // uitofp is applicable from integer types to float types. -bool UIToFPOp::areCastCompatible(Type a, Type b) { +bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); if (a.isSignlessInteger() && b.isa()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -12,6 +12,7 @@ Core LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRIR MLIRSupport ) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -71,7 +71,10 @@ return true; } -bool CastOp::areCastCompatible(Type a, Type b) { +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); if (!aT || !bT) @@ -83,10 +86,6 @@ return succeeded(verifyCompatibleShape(aT, bT)); } -OpFoldResult CastOp::fold(ArrayRef operands) { - return impl::foldCastOp(*this); -} - /// Compute a TensorType that has the joined shape knowledge of the two /// given TensorTypes. The element types need to match. static TensorType joinShapes(TensorType one, TensorType two) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1208,6 +1208,53 @@ // CastOp implementation //===----------------------------------------------------------------------===// +/// Attempt to fold the given cast operation. +LogicalResult +impl::foldCastInterfaceOp(Operation *op, ArrayRef attrOperands, + SmallVectorImpl &foldResults) { + OperandRange operands = op->getOperands(); + if (operands.empty()) + return failure(); + ResultRange results = op->getResults(); + + // Check for the case where the input and output types match 1-1. + if (operands.getTypes() == results.getTypes()) { + foldResults.append(operands.begin(), operands.end()); + return success(); + } + + // Check for the case where there is a single input that is a cast operation + // of the same type as 'op', with input types that match this operations + // results. + if (operands.size() == 1) { + Operation *inputOp = operands.front().getDefiningOp(); + if (inputOp && inputOp->getName() == op->getName() && + inputOp->getOperandTypes() == results.getTypes()) { + foldResults.append(inputOp->operand_begin(), inputOp->operand_end()); + return success(); + } + } + + return failure(); +} + +/// Attempt to verify the given cast operation. +LogicalResult impl::verifyCastInterfaceOp( + Operation *op, function_ref areCastCompatible) { + auto resultTypes = op->getResultTypes(); + if (llvm::empty(resultTypes)) + return op->emitOpError() + << "expected at least one result for cast operation"; + + auto operandTypes = op->getOperandTypes(); + if (!areCastCompatible(operandTypes, resultTypes)) + return op->emitOpError("operand types ") + << operandTypes << " and result types " << resultTypes + << " are cast incompatible"; + + return success(); +} + void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, Type destType) { result.addOperands(source); diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -1,5 +1,6 @@ set(LLVM_OPTIONAL_SOURCES CallInterfaces.cpp + CastInterfaces.cpp ControlFlowInterfaces.cpp CopyOpInterface.cpp DerivedAttributeOpInterface.cpp @@ -27,6 +28,7 @@ add_mlir_interface_library(CallInterfaces) +add_mlir_interface_library(CastInterfaces) add_mlir_interface_library(ControlFlowInterfaces) add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DerivedAttributeOpInterface) diff --git a/mlir/lib/Interfaces/CastInterfaces.cpp b/mlir/lib/Interfaces/CastInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/CastInterfaces.cpp @@ -0,0 +1,17 @@ +//===- CastInterfaces.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/CastInterfaces.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Table-generated class definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/CastInterfaces.cpp.inc" diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt <%s -split-input-file -verify-diagnostics func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) { - // expected-error@+1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}} + // expected-error@+1 {{operand types 'tensor<1xf32>' and result types 'tensor<2xf32>' are cast incompatible}} %0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32> return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1008,7 +1008,7 @@ // ----- func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { - // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} + // expected-error@+1{{operand types 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result types 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]> return } @@ -1016,7 +1016,7 @@ // ----- func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { - // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2 + 16)>>' are cast incompatible}} + // expected-error@+1{{operand types 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result types 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2 + 16)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]> return } @@ -1026,7 +1026,7 @@ // incompatible element types func @invalid_memref_cast() { %0 = alloc() : memref<2x5xf32, 0> - // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xi32>' are cast incompatible}} + // expected-error@+1 {{operand types 'memref<2x5xf32>' and result types 'memref<*xi32>' are cast incompatible}} %1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xi32> return } @@ -1063,7 +1063,7 @@ // incompatible memory space func @invalid_memref_cast() { %0 = alloc() : memref<2x5xf32, 0> - // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}} + // expected-error@+1 {{operand types 'memref<2x5xf32>' and result types 'memref<*xf32, 1>' are cast incompatible}} %1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1> return } @@ -1074,7 +1074,7 @@ func @invalid_memref_cast() { %0 = alloc() : memref<2x5xf32, 0> %1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0> - // expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}} + // expected-error@+1 {{operand types 'memref<*xf32>' and result types 'memref<*xf32>' are cast incompatible}} %2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0> return }