diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -11,6 +11,7 @@ #include "BuiltinAttributeInterfaces.h" #include "SubElementInterfaces.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" namespace llvm { struct fltSemantics; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -17,6 +17,7 @@ include "mlir/IR/BuiltinDialect.td" include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/IR/SubElementInterfaces.td" +include "mlir/Interfaces/JoinMeetTypeInterface.td" // TODO: Currently the types defined in this file are prefixed with `Builtin_`. // This is to differentiate the types here with the ones in OpBase.td. We should @@ -266,6 +267,7 @@ //===----------------------------------------------------------------------===// def Builtin_MemRef : Builtin_Type<"MemRef", [ + DeclareTypeInterfaceMethods, DeclareTypeInterfaceMethods ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; @@ -620,6 +622,7 @@ //===----------------------------------------------------------------------===// def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ + DeclareTypeInterfaceMethods, DeclareTypeInterfaceMethods ], "TensorType"> { let summary = "Multi-dimensional array with a fixed number of dimensions"; @@ -779,6 +782,7 @@ //===----------------------------------------------------------------------===// def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ + DeclareTypeInterfaceMethods, DeclareTypeInterfaceMethods ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; @@ -841,6 +845,7 @@ //===----------------------------------------------------------------------===// def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ + DeclareTypeInterfaceMethods, DeclareTypeInterfaceMethods ], "TensorType"> { let summary = "Multi-dimensional array with unknown dimensions"; diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -66,6 +66,33 @@ /// Dimensions are compatible if all non-dynamic dims are equal. LogicalResult verifyCompatibleDims(ArrayRef dims); + +/// The element-wise `join` function for shapes, and the partial order "less +/// specialized than or equal". It returns an error if the shapes have different +/// sizes. +/// The `join` relationship for two dimensions is: +/// dim1 | dim2 | join +/// n | n | n +/// n | k | dynamic +/// dynamic | k | dynamic +/// n | dynamic | dynamic +FailureOr> joinShapes(ArrayRef shape1, + ArrayRef shape2, + Optional location = None); + +/// The element-wise `meet` function for shapes, and the partial order "less +/// specialized than or equal". It returns an error if the shapes have different +/// sizes or do not meet. +/// The `meet` relationship for two dimensions is: +/// dim1 | dim2 | meet +/// n | n | n +/// n | k | error +/// dynamic | k | k +/// n | dynamic | n +FailureOr> meetShapes(ArrayRef shape1, + ArrayRef shape2, + Optional = None); + //===----------------------------------------------------------------------===// // Utility Iterators //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -34,3 +34,15 @@ DataLayoutOpInterface Interfaces/ -gen-op-interface-docs) + +set(LLVM_TARGET_DEFINITIONS JoinMeetTypeInterface.td) +mlir_tablegen(JoinMeetTypeInterface.h.inc -gen-type-interface-decls) +mlir_tablegen(JoinMeetTypeInterface.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRJoinMeetTypeInterfaceIncGen) +add_dependencies(mlir-generic-headers MLIRJoinMeetTypeInterfaceIncGen) + +add_mlir_doc(JoinMeetTypeInterface + JoinMeetTypeInterface + Interfaces/ + -gen-type-interface-docs) + diff --git a/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.h b/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.h @@ -0,0 +1,121 @@ +//===- JoinMeetTypeInterface.h - Join/Meet Type Interface Decls -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for type join and meet relationships. +// Types can inherit and implement the `JoinMeetTypeInterface::Trait` trait to +// participate in `join` and `meet` functions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_JOINMEETTYPEINTERFACE_H +#define MLIR_INTERFACES_JOINMEETTYPEINTERFACE_H + +#include "mlir/IR/Location.h" +#include "mlir/IR/Types.h" + +#include "mlir/Interfaces/JoinMeetTypeInterface.h.inc" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Join and Meet Types +//===----------------------------------------------------------------------===// + +/// The join function for types, and the partial order "less specialized than or +/// equal". +/// It returns the most specialized type that is less specialized than both +/// `ty1` and `ty2`, if it exists, and `Type()` otherwise. +/// The join `j` of `ty1` and `ty2` is such that: +/// * j <= ty1, and j <= ty2 +/// * For any type t such that t <= ty1 and t <= ty2, t <= j. +/// +/// In other words, if types are viewed as sets of values, this function is +/// equivalent to the union of such sets. The top of the lattice represents "any +/// possible type", and the bottom represents "no possible type". +/// +/// For example: +/// ty1 | ty2 | ty1 v ty2 +/// ------------------+-------------------+------------------- +/// i8 | i8 | i8 +/// i8 | i32 | (null type) +/// tensor<1xf32> | tensor | tensor +/// tensor<1x2x?xf32> | tensor<1x?x3xf32> | tensor<1x?x?xf32> +/// tensor<4x5xf32> | tensor<6xf32> | tensor<*xf32> +/// tensor<1xi32> | i32 | (null type) +/// tensor<1xi32> | tensor | tensor<*xi32> +/// tensor<1xi32> | tensor<1xi8> | (null type) +/// +/// The function is monotonic: +/// * idempotence: joinTypes(x,x) == x +/// * commutativity: joinTypes(x,y) == joinTypes(y,x) +/// * associativity: joinTypes(x,joinTypes(y,z)) == joinTypes(joinTypes(x,y),z) +/// +/// Types can participate in this function by implementing +/// `JoinMeetTypeInterface`. +Type joinTypes(Type ty1, Type ty2); + +/// The meet function for types, and the partial order "less specialized than or +/// equal". +/// It returns the least specialized type, that is more specialized than both +/// `ty1` and `ty2`, if it exists, and `Type()` otherwise. +/// The meet `m` of `ty1` and `ty2` is such that: +/// * ty1 <= m, and ty2 <= m +/// * For any type t such that ty1 <= t and ty2 <= t, m <= t. +/// +/// In other words, if types are viewed as sets of values, this function is +/// equivalent to the intersection of such sets. The top of the lattice +/// represents "any possible type", and the bottom represents "no possible +/// type". +/// +/// For example: +/// ty1 | ty2 | ty1 ^ ty2 +/// ------------------+-------------------+------------------- +/// i8 | i32 | (null type) +/// tensor<1xf32> | tensor | tensor<1xf32> +/// tensor<1x2x?xf32> | tensor<1x?x3xf32> | tensor<1x2x3xf32> +/// tensor<4x5xf32> | tensor<6xf32> | (null type) +/// tensor<1xi32> | i32 | (null type) +/// tensor<1xi32> | tensor | (null type) +/// tensor<1xi32> | tensor<1xi8> | (null type) +/// +/// The function is monotonic: +/// * idempotence: meetTypes(x,x) == x +/// * commutativity: meetTypes(x,y) == meetTypes(y,x) +/// * associativity: meetTypes(x,meetTypes(y,z)) == meetTypes(meetTypes(x,y),z) +/// +/// Types can participate in this function by implementing +/// `JoinMeetTypeInterface`. +Type meetTypes(Type ty1, Type ty2); + +/// Indicates whether `ty1` and `ty2` are the same, or `ty1` is compatible with +/// `ty2` and less specialized than `ty2`. +inline bool isLessSpecializedOrSame(Type ty1, Type ty2) { + return joinTypes(ty1, ty2) == ty1; +} + +/// Indicates whether `ty1` is compatible with `ty2`, and less specialized than +/// `ty2`. +inline bool isLessSpecialized(Type ty1, Type ty2) { + return ty1 != ty2 && isLessSpecializedOrSame(ty1, ty2); +} + +/// Indicates whether `ty1` and `ty2` are the same, or `ty1` is compatible with +/// `ty2` and more specialized than `ty2`. +inline bool isMoreSpecializedOrSame(Type ty1, Type ty2) { + return meetTypes(ty1, ty2) == ty1; +} + +/// Indicates whether `ty1` is compatible with `ty2`, and more specialized than +/// `ty2`. +inline bool isMoreSpecialized(Type ty1, Type ty2) { + return ty1 != ty2 && isMoreSpecializedOrSame(ty1, ty2); +} + +} // namespace mlir + +#endif // MLIR_INTERFACES_JOINMEETTYPEINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.td b/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.td @@ -0,0 +1,74 @@ +//===- JoinMeetTypeInterface.td - Join/Meet Type Interface ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for join and meet type relathionships. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_JOINMEETTYPEINTERFACE +#define MLIR_JOINMEETTYPEINTERFACE + +include "mlir/IR/OpBase.td" + +def JoinMeetTypeInterface : TypeInterface<"JoinMeetTypeInterface"> { + let cppNamespace = "::mlir"; + + let description = [{ + Interface for types participating in the `join` and `meet` relationships. + + Types willing to participate in the `::mlir::joinTypes(Type, Type)` and + `::mlir::meet(Type, Type)` functions should implement this interface by + providing implementations of `join` and `meet` functions handling types they + can join or meet with. + + The return value of the interface method `someType.join(otherType)` is + interpreted as follows: + - A result with a value (`Optional::hasValue()` is true) represents + the result of the `join`. + If the contained type is null, it indicates `otherType` was handled, and + `someType` and `otherType` do not join; that is, there is no type that is + both less specialized than `someType` and less specialized than + `otherType`. + - A result without a value (`Optional::hasValue()` is false) indicates + `sometype` does not know (or ignores) the `other` type. + + Returning an `Optional` with no value allows for extensibility of + types. `Type1::join` may handle `Type2` while `Type2::join` ignores `Type1` + or simply does not implement this interface. + + The interface `join` method must participate to guarantee the monotonicity + of `::mlir::join`. It must be idempotent (`joinTypes(x, x) = x`). + Associativity, though it should flow naturally, must be maintained. + Commutativity is guaranteed by `::mlir::join` itself, as it queries both + `sometype.join(othertype)` and `othertype.join(sometype)`. One type may + ignore the other; in that case the result with a value is used. If both + results have a value, they must match. + + The same comments apply to `meet`. + + Note: for context about join/meet operations on lattices, see: + https://en.wikipedia.org/wiki/Lattice_(order). + }]; + + let methods = [ + InterfaceMethod< + /*description=*/"Return the join with the `other` type.", + /*retTy=*/"Optional", + /*methodName=*/"join", + /*args=*/(ins "Type":$other) + >, + InterfaceMethod< + /*description=*/"Return the meet with the `other` type.", + /*retTy=*/"Optional", + /*methodName=*/"meet", + /*args=*/(ins "Type":$other) + >, + ]; +} + +#endif // MLIR_JOINMEETTYPEINTERFACE diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -17,8 +17,9 @@ MLIRDialectUtils MLIRIR MLIRInferTypeOpInterface + MLIRJoinMeetTypeInterface MLIRSideEffectInterfaces - MLIRSupport MLIRStandard + MLIRSupport MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -125,38 +125,6 @@ return succeeded(verifyCompatibleShape(aT, bT)); } -/// Compute a TensorType that has the joined shape knowledge of the two -/// given TensorTypes. The element types need to match. -static TensorType joinShapes(TensorType one, TensorType two) { - assert(one.getElementType() == two.getElementType()); - - if (!one.hasRank()) - return two; - if (!two.hasRank()) - return one; - - int64_t rank = one.getRank(); - if (rank != two.getRank()) - return {}; - - SmallVector join; - join.reserve(rank); - for (int64_t i = 0; i < rank; ++i) { - if (one.isDynamicDim(i)) { - join.push_back(two.getDimSize(i)); - continue; - } - if (two.isDynamicDim(i)) { - join.push_back(one.getDimSize(i)); - continue; - } - if (one.getDimSize(i) != two.getDimSize(i)) - return {}; - join.push_back(one.getDimSize(i)); - } - return RankedTensorType::get(join, one.getElementType()); -} - namespace { /// Replaces chains of two tensor.cast operations by a single tensor.cast @@ -176,20 +144,20 @@ auto intermediateType = tensorCastOperand.getType().cast(); auto resultType = tensorCast.getType().cast(); - // We can remove the intermediate cast if joining all three produces the - // same result as just joining the source and result shapes. - auto firstJoin = - joinShapes(joinShapes(sourceType, intermediateType), resultType); + // We can remove the intermediate cast if meeting all three produces the + // same result as just meeting the source and result shapes. + auto firstMeet = + meetTypes(meetTypes(sourceType, intermediateType), resultType); - // The join might not exist if the cast sequence would fail at runtime. - if (!firstJoin) + // The meet might not exist if the cast sequence would fail at runtime. + if (!firstMeet) return failure(); - // The newJoin always exists if the above join exists, it might just contain + // The newMeet always exists if the above meet exists, it might just contain // less information. If so, we cannot drop the intermediate cast, as doing // so would remove runtime checks. - auto newJoin = joinShapes(sourceType, resultType); - if (firstJoin != newJoin) + auto newMeet = meetTypes(sourceType, resultType); + if (firstMeet != newMeet) return failure(); rewriter.replaceOpWithNewOp(tensorCast, resultType, diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/TensorEncoding.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/Sequence.h" @@ -498,6 +499,53 @@ !llvm::isa(type.getDialect()); } +static Optional joinTypes(TensorType ty1, TensorType ty2) { + Type elementTy = ty1.getElementType(); + if (ty2.getElementType() != elementTy) + return Type(); + + Attribute encoding1; + Attribute encoding2; + if (auto rankedTy1 = ty1.dyn_cast()) + encoding1 = rankedTy1.getEncoding(); + if (auto rankedTy2 = ty2.dyn_cast()) + encoding2 = rankedTy2.getEncoding(); + if (encoding1 != encoding2) + return Type(); + + if (!ty1.hasRank() || !ty2.hasRank() || ty1.getRank() != ty2.getRank()) + return UnrankedTensorType::get(elementTy); + + auto shape = joinShapes(ty1.getShape(), ty2.getShape()).getValue(); + return RankedTensorType::get(shape, elementTy, encoding1); +} + +static Optional meetTypes(TensorType ty1, TensorType ty2) { + Type elementTy = ty1.getElementType(); + if (ty2.getElementType() != elementTy) + return Type(); + + Attribute encoding1; + Attribute encoding2; + if (auto rankedTy1 = ty1.dyn_cast()) + encoding1 = rankedTy1.getEncoding(); + if (auto rankedTy2 = ty2.dyn_cast()) + encoding2 = rankedTy2.getEncoding(); + if (encoding1 != encoding2) + return Type(); + + if (!ty1.hasRank() || !ty2.hasRank()) + return ty1.hasRank() ? ty1 : ty2; + + if (ty1.getRank() != ty2.getRank()) + return Type(); + + auto shape = meetShapes(ty1.getShape(), ty2.getShape()); + if (succeeded(shape)) + return RankedTensorType::get(shape.getValue(), elementTy, encoding1); + return Type(); +} + //===----------------------------------------------------------------------===// // RankedTensorType //===----------------------------------------------------------------------===// @@ -523,6 +571,18 @@ walkAttrsFn(encoding); } +Optional RankedTensorType::join(Type other) const { + if (!other.isa()) + return None; + return ::joinTypes(*this, other.cast()); +} + +Optional RankedTensorType::meet(Type other) const { + if (!other.isa()) + return None; + return ::meetTypes(*this, other.cast()); +} + //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// @@ -539,6 +599,18 @@ walkTypesFn(getElementType()); } +Optional UnrankedTensorType::join(Type other) const { + if (!other.isa()) + return None; + return ::joinTypes(*this, other.cast()); +} + +Optional UnrankedTensorType::meet(Type other) const { + if (!other.isa()) + return None; + return ::meetTypes(*this, other.cast()); +} + //===----------------------------------------------------------------------===// // BaseMemRefType //===----------------------------------------------------------------------===// @@ -555,6 +627,53 @@ return cast().getMemorySpaceAsInt(); } +static LogicalResult commonMemRefJoinMeetChecks(BaseMemRefType ty1, + BaseMemRefType ty2) { + if (ty1.getElementType() != ty2.getElementType()) + return failure(); + + if (ty1.getMemorySpace() != ty2.getMemorySpace()) + return failure(); + + auto memRefTy1 = ty1.dyn_cast(); + auto memRefTy2 = ty2.dyn_cast(); + if (memRefTy1 && memRefTy2 && memRefTy1.getLayout() != memRefTy2.getLayout()) + return failure(); + + return success(); +} + +static Optional joinTypes(BaseMemRefType ty1, BaseMemRefType ty2) { + if (failed(commonMemRefJoinMeetChecks(ty1, ty2))) + return Type(); + + if (!ty1.hasRank() || !ty2.hasRank() || ty1.getRank() != ty2.getRank()) + return UnrankedMemRefType::get(ty1.getElementType(), ty1.getMemorySpace()); + + auto shape = joinShapes(ty1.getShape(), ty2.getShape()).getValue(); + return MemRefType::get(shape, ty1.getElementType(), + ty1.cast().getLayout(), + ty1.getMemorySpace()); +} + +static Optional meetTypes(BaseMemRefType ty1, BaseMemRefType ty2) { + if (failed(commonMemRefJoinMeetChecks(ty1, ty2))) + return Type(); + + if (!ty1.hasRank() || !ty2.hasRank()) + return ty1.hasRank() ? ty1 : ty2; + + if (ty1.getRank() != ty2.getRank()) + return Type(); + + auto shape = meetShapes(ty1.getShape(), ty2.getShape()); + if (succeeded(shape)) + return MemRefType::get(shape.getValue(), ty1.getElementType(), + ty1.cast().getLayout(), + ty1.getMemorySpace()); + return Type(); +} + //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// @@ -786,6 +905,18 @@ walkAttrsFn(getMemorySpace()); } +Optional MemRefType::join(Type other) const { + if (!other.isa()) + return None; + return ::joinTypes(*this, other.cast()); +} + +Optional MemRefType::meet(Type other) const { + if (!other.isa()) + return None; + return ::meetTypes(*this, other.cast()); +} + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -952,6 +1083,18 @@ walkAttrsFn(getMemorySpace()); } +Optional UnrankedMemRefType::join(Type other) const { + if (!other.isa()) + return None; + return ::joinTypes(*this, other.cast()); +} + +Optional UnrankedMemRefType::meet(Type other) const { + if (!other.isa()) + return None; + return ::meetTypes(*this, other.cast()); +} + //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -46,6 +46,7 @@ MLIRCallInterfacesIncGen MLIRCastInterfacesIncGen MLIRDataLayoutInterfacesIncGen + MLIRJoinMeetTypeInterfaceIncGen MLIROpAsmInterfaceIncGen MLIRRegionKindInterfaceIncGen MLIRSideEffectInterfacesIncGen diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" using namespace mlir; @@ -151,6 +152,44 @@ return success(); } +FailureOr> mlir::joinShapes(ArrayRef shape1, + ArrayRef shape2, + Optional location) { + if (shape1.size() != shape2.size()) + return emitOptionalError(location, "shapes' sizes differ"); + + SmallVector shape(shape1.size()); + for (size_t i = 0, e = shape1.size(); i != e; ++i) { + int64_t dim1 = shape1[i]; + int64_t dim2 = shape2[i]; + shape[i] = dim1 == dim2 ? dim1 : ShapedType::kDynamicSize; + } + + return shape; +} + +FailureOr> mlir::meetShapes(ArrayRef shape1, + ArrayRef shape2, + Optional location) { + if (shape1.size() != shape2.size()) + return emitOptionalError(location, "shapes' sizes differ"); + + SmallVector shape(shape1.size()); + for (size_t i = 0, e = shape1.size(); i != e; ++i) { + int64_t dim1 = shape1[i]; + int64_t dim2 = shape2[i]; + if (dim1 == dim2 || ShapedType::isDynamic(dim1) || + ShapedType::isDynamic(dim2)) { + shape[i] = ShapedType::isDynamic(dim1) ? dim2 : dim1; + } else { + return emitOptionalError(location, "dimensions at index ", i, + " are incompatible: ", dim1, ", ", dim2); + } + } + + return shape; +} + OperandElementTypeIterator::OperandElementTypeIterator( Operation::operand_iterator it) : llvm::mapped_iterator( diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp InferTypeOpInterface.cpp + JoinMeetTypeInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp TilingInterface.cpp @@ -36,6 +37,7 @@ add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) add_mlir_interface_library(InferTypeOpInterface) +add_mlir_interface_library(JoinMeetTypeInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) diff --git a/mlir/lib/Interfaces/JoinMeetTypeInterface.cpp b/mlir/lib/Interfaces/JoinMeetTypeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/JoinMeetTypeInterface.cpp @@ -0,0 +1,62 @@ +//===- JoinMeetTypeInterface.cpp - Join/Meet Type Interface Implementation ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/JoinMeetTypeInterface.h" +#include "mlir/IR/Diagnostics.h" + +using namespace mlir; + +#include "mlir/Interfaces/JoinMeetTypeInterface.cpp.inc" + +Type mlir::joinTypes(Type ty1, Type ty2) { + if (ty1 == ty2) + return ty1; + if (!ty1 || !ty2) + return Type(); + + Optional join1; + Optional join2; + + if (auto interface1 = ty1.dyn_cast()) + join1 = interface1.join(ty2); + +#ifndef NDEBUG + static constexpr bool assertCompatibility = true; +#else + static constexpr bool assertCompatibility = false; +#endif + + auto interface2 = ty2.dyn_cast(); + if (interface2 && (!join1 || assertCompatibility)) + join2 = interface2.join(ty1); + + assert((!assertCompatibility || + (!join1 || !join2 || join1.getValue() == join2.getValue())) && + "joinTypes commutativity was violated"); + + return join1 ? join1.getValue() : join2.getValueOr(Type()); +} + +Type mlir::meetTypes(Type ty1, Type ty2) { + if (ty1 == ty2) + return ty1; + if (!ty1 || !ty2) + return Type(); + + Optional meet1; + if (auto interface1 = ty1.dyn_cast()) + meet1 = interface1.meet(ty2); + + Optional meet2; + if (auto interface2 = ty2.dyn_cast()) + meet2 = interface2.meet(ty1); + + assert(!meet1 || !meet2 || meet1.getValue() == meet2.getValue()); + + return meet1 ? meet1.getValue() : meet2.getValueOr(Type()); +} diff --git a/mlir/test/Interfaces/JoinMeetTypeInterface/join_meet.mlir b/mlir/test/Interfaces/JoinMeetTypeInterface/join_meet.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/JoinMeetTypeInterface/join_meet.mlir @@ -0,0 +1,213 @@ +// RUN: mlir-opt --split-input-file --verify-diagnostics --allow-unregistered-dialect --test-join-meet-type-interface %s | FileCheck %s + +#encoding1 = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed" ] +}> + +#encoding2 = #sparse_tensor.encoding<{ + dimLevelType = [ "dense" ] +}> + +#encoding3 = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "dense" ] +}> + +#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> + +func @test_join( + %i8 : i8, + %i32 : i32, + + %tensor_unranked_i8 : tensor<*xi8>, + %tensor_unranked_i32 : tensor<*xi32>, + + %tensor_i32 : tensor, + %tensor_1xi32 : tensor<1xi32>, + %tensor_2xi32 : tensor<2xi32>, + %tensor_3x4xi32 : tensor<3x4xi32>, + %tensor_5x6x_xi32 : tensor<5x6x?xi32>, + %tensor_5x_x7xi32 : tensor<5x?x7xi32>, + + %tensor_1xi32_encoding1 : tensor<1xi32, #encoding1>, + %tensor_2xi32_encoding1 : tensor<2xi32, #encoding1>, + %tensor_2xi32_encoding2 : tensor<2xi32, #encoding2>, + %tensor_2x2xi32_encoding3 : tensor<2x2xi32, #encoding3>, + + %memref_unranked_i32_memspace_1 : memref<*xi32, 1>, + %memref_i32_memspace_1 : memref, + %memref_i32_memspace_2 : memref, + %memref_1x2x_xi32_map_1 : memref<1x2x?xi32, #map1>, + %memref_1x_x3xi32_map_1 : memref<1x?x3xi32, #map1>, + %memref_1x1x1xi32_map_2 : memref<1x1x1xi32, #map2> + ) { + // Test identity. + + "join"(%i8, %i8) : (i8, i8) -> i1 + // CHECK: (i8, i8) -> i8 + + // Test different types and element types. + + "join"(%i8, %i32) : (i8, i32) -> i1 + // expected-error@-1 {{types do not join}} + + "join"(%tensor_unranked_i8, %tensor_unranked_i32) : (tensor<*xi8>, tensor<*xi32>) -> i1 + // expected-error@-1 {{types do not join}} + + + // Test shapes. + + "join"(%tensor_i32, %i32) : (tensor, i32) -> i1 + // expected-error@-1 {{types do not join}} + + "join"(%tensor_1xi32, %tensor_1xi32) : (tensor<1xi32>, tensor<1xi32>) -> i1 + // CHECK: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + "join"(%tensor_1xi32, %tensor_3x4xi32) : (tensor<1xi32>, tensor<3x4xi32>) -> i1 + // CHECK: (tensor<1xi32>, tensor<3x4xi32>) -> tensor<*xi32> + + "join"(%tensor_5x6x_xi32, %tensor_5x_x7xi32) : (tensor<5x6x?xi32>, tensor<5x?x7xi32>) -> i1 + // CHECK: (tensor<5x6x?xi32>, tensor<5x?x7xi32>) -> tensor<5x?x?xi32> + + "join"(%tensor_unranked_i32, %tensor_5x6x_xi32) : (tensor<*xi32>, tensor<5x6x?xi32>) -> i1 + // CHECK: (tensor<*xi32>, tensor<5x6x?xi32>) -> tensor<*xi32> + + + // Test tensor encoding. + + "join"(%tensor_1xi32_encoding1, %tensor_2xi32_encoding1) : (tensor<1xi32, #encoding1>, tensor<2xi32, #encoding1>) -> i1 + // CHECK: (tensor<1xi32, [[ENCODING1:.*]]>, tensor<2xi32, [[ENCODING1]]>) -> tensor + + "join"(%tensor_1xi32_encoding1, %tensor_2xi32_encoding2) : (tensor<1xi32, #encoding1>, tensor<2xi32, #encoding2>) -> i1 + // expected-error@-1 {{types do not join}} + + "join"(%tensor_1xi32_encoding1, %tensor_2x2xi32_encoding3) : (tensor<1xi32, #encoding1>, tensor<2x2xi32, #encoding3>) -> i1 + // expected-error@-1 {{types do not join}} + + "join"(%tensor_1xi32_encoding1, %tensor_unranked_i32) : (tensor<1xi32, #encoding1>, tensor<*xi32>) -> i1 + // expected-error@-1 {{types do not join}} + + + // Test memref memory space. + + "join"(%memref_i32_memspace_1, %memref_i32_memspace_1) : (memref, memref) -> i1 + // CHECK: (memref, memref) -> memref + + "join"(%memref_i32_memspace_1, %memref_i32_memspace_2) : (memref, memref) -> i1 + // expected-error@-1 {{types do not join}} + + "join"(%memref_unranked_i32_memspace_1, %memref_i32_memspace_1) : (memref<*xi32, 1>, memref) -> i1 + // CHECK: (memref<*xi32, 1>, memref) -> memref<*xi32, 1> + + "join"(%memref_unranked_i32_memspace_1, %memref_i32_memspace_2) : (memref<*xi32, 1>, memref) -> i1 + // expected-error@-1 {{types do not join}} + + + // Test memref affine map. + + "join"(%memref_1x2x_xi32_map_1, %memref_1x_x3xi32_map_1) : (memref<1x2x?xi32, affine_map<(d0, d1, d2) -> (d1, d2, d0)>>, memref<1x?x3xi32, affine_map<(d0, d1, d2) -> (d1, d2, d0)>>) -> i1 + // CHECK: (memref<1x2x?xi32, [[MAP1:.*]]>, memref<1x?x3xi32, [[MAP1]]>) -> memref<1x?x?xi32, [[MAP1]]> + + "join"(%memref_1x2x_xi32_map_1, %memref_1x1x1xi32_map_2) : (memref<1x2x?xi32, #map1>, memref<1x1x1xi32, #map2>) -> i1 + // expected-error@-1 {{types do not join}} + +} + +func @test_meet( + %i8 : i8, + %i32 : i32, + + %tensor_unranked_i8 : tensor<*xi8>, + %tensor_unranked_i32 : tensor<*xi32>, + + %tensor_i32 : tensor, + %tensor_1xi32 : tensor<1xi32>, + %tensor_2xi32 : tensor<2xi32>, + %tensor_3x4xi32 : tensor<3x4xi32>, + %tensor_5x6x_xi32 : tensor<5x6x?xi32>, + %tensor_5x_x7xi32 : tensor<5x?x7xi32>, + + %tensor_1xi32_encoding1 : tensor<1xi32, #encoding1>, + %tensor_2xi32_encoding1 : tensor<2xi32, #encoding1>, + %tensor_2xi32_encoding2 : tensor<2xi32, #encoding2>, + %tensor_2x2xi32_encoding3 : tensor<2x2xi32, #encoding3>, + + %memref_unranked_i32_memspace_1 : memref<*xi32, 1>, + %memref_i32_memspace_1 : memref, + %memref_i32_memspace_2 : memref, + %memref_1x2x_xi32_map_1 : memref<1x2x?xi32, #map1>, + %memref_1x_x3xi32_map_1 : memref<1x?x3xi32, #map1>, + %memref_1x1x1xi32_map_2 : memref<1x1x1xi32, #map2> + ) { + // Test identity. + + "meet"(%i8, %i8) : (i8, i8) -> i1 + // CHECK: (i8, i8) -> i8 + + // Test different types and element types. + + "meet"(%i8, %i32) : (i8, i32) -> i1 + // expected-error@-1 {{types do not meet}} + + "meet"(%tensor_unranked_i8, %tensor_unranked_i32) : (tensor<*xi8>, tensor<*xi32>) -> i1 + // expected-error@-1 {{types do not meet}} + + + // Test shapes. + + "meet"(%tensor_i32, %i32) : (tensor, i32) -> i1 + // expected-error@-1 {{types do not meet}} + + "meet"(%tensor_1xi32, %tensor_1xi32) : (tensor<1xi32>, tensor<1xi32>) -> i1 + // CHECK: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + "meet"(%tensor_1xi32, %tensor_3x4xi32) : (tensor<1xi32>, tensor<3x4xi32>) -> i1 + // expected-error@-1 {{types do not meet}} + + "meet"(%tensor_5x6x_xi32, %tensor_5x_x7xi32) : (tensor<5x6x?xi32>, tensor<5x?x7xi32>) -> i1 + // CHECK: (tensor<5x6x?xi32>, tensor<5x?x7xi32>) -> tensor<5x6x7xi32> + + "meet"(%tensor_unranked_i32, %tensor_5x6x_xi32) : (tensor<*xi32>, tensor<5x6x?xi32>) -> i1 + // CHECK: (tensor<*xi32>, tensor<5x6x?xi32>) -> tensor<5x6x?xi32> + + + // Test tensor encoding. + + "meet"(%tensor_1xi32_encoding1, %tensor_2xi32_encoding1) : (tensor<1xi32, #encoding1>, tensor<2xi32, #encoding1>) -> i1 + // expected-error@-1 {{types do not meet}} + + "meet"(%tensor_1xi32_encoding1, %tensor_2xi32_encoding2) : (tensor<1xi32, #encoding1>, tensor<2xi32, #encoding2>) -> i1 + // expected-error@-1 {{types do not meet}} + + "meet"(%tensor_1xi32_encoding1, %tensor_2x2xi32_encoding3) : (tensor<1xi32, #encoding1>, tensor<2x2xi32, #encoding3>) -> i1 + // expected-error@-1 {{types do not meet}} + + "meet"(%tensor_1xi32_encoding1, %tensor_unranked_i32) : (tensor<1xi32, #encoding1>, tensor<*xi32>) -> i1 + // expected-error@-1 {{types do not meet}} + + + // Test memref memory space. + + "meet"(%memref_i32_memspace_1, %memref_i32_memspace_1) : (memref, memref) -> i1 + // CHECK: (memref, memref) -> memref + + "meet"(%memref_i32_memspace_1, %memref_i32_memspace_2) : (memref, memref) -> i1 + // expected-error@-1 {{types do not meet}} + + "meet"(%memref_unranked_i32_memspace_1, %memref_i32_memspace_1) : (memref<*xi32, 1>, memref) -> i1 + // CHECK: (memref<*xi32, 1>, memref) -> memref + + "meet"(%memref_unranked_i32_memspace_1, %memref_i32_memspace_2) : (memref<*xi32, 1>, memref) -> i1 + // expected-error@-1 {{types do not meet}} + + + // Test memref affine map. + + "meet"(%memref_1x2x_xi32_map_1, %memref_1x_x3xi32_map_1) : (memref<1x2x?xi32, affine_map<(d0, d1, d2) -> (d1, d2, d0)>>, memref<1x?x3xi32, affine_map<(d0, d1, d2) -> (d1, d2, d0)>>) -> i1 + // CHECK: (memref<1x2x?xi32, [[MAP1:.*]]>, memref<1x?x3xi32, [[MAP1]]>) -> memref<1x2x3xi32, [[MAP1]]> + + "meet"(%memref_1x2x_xi32_map_1, %memref_1x1x1xi32_map_2) : (memref<1x2x?xi32, #map1>, memref<1x1x1xi32, #map2>) -> i1 + // expected-error@-1 {{types do not meet}} + +} + diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRTestTransforms TestConstantFold.cpp TestInlining.cpp + TestJoinMeetTypeInterface.cpp TestLoopFusion.cpp TestLoopMapping.cpp TestLoopParametricTiling.cpp diff --git a/mlir/test/lib/Transforms/TestJoinMeetTypeInterface.cpp b/mlir/test/lib/Transforms/TestJoinMeetTypeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestJoinMeetTypeInterface.cpp @@ -0,0 +1,54 @@ +//===- TestJoinMeetTypeInterface.cpp - Test The Join/Meet Type Interface --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestJoinMeetTypeInterface + : public PassWrapper { + StringRef getArgument() const final { + return "test-join-meet-type-interface"; + } + StringRef getDescription() const final { + return "Test join/meet type interfaceTest operation constant folding"; + } + void runOnFunction() override { + FuncOp func = getFunction(); + func.walk([&](Operation *op) { + if (op->getName().getStringRef() == "join") { + Type join = + joinTypes(op->getOperand(0).getType(), op->getOperand(1).getType()); + if (join) + op->getResult(0).setType(join); + else + op->emitError("types do not join"); + } + if (op->getName().getStringRef() == "meet") { + Type meet = + meetTypes(op->getOperand(0).getType(), op->getOperand(1).getType()); + if (meet) + op->getResult(0).setType(meet); + else + op->emitError("types do not meet"); + } + }); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestJoinMeetTypeInterface() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -78,6 +78,7 @@ void registerTestGpuParallelLoopMappingPass(); void registerTestIRVisitorsPass(); void registerTestInterfaces(); +void registerTestJoinMeetTypeInterface(); void registerTestLinalgCodegenStrategy(); void registerTestLinalgControlFuseByExpansion(); void registerTestLinalgDistribution(); @@ -168,6 +169,7 @@ mlir::test::registerTestGpuParallelLoopMappingPass(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestInterfaces(); + mlir::test::registerTestJoinMeetTypeInterface(); mlir::test::registerTestLinalgCodegenStrategy(); mlir::test::registerTestLinalgControlFuseByExpansion(); mlir::test::registerTestLinalgDistribution();