diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -11,6 +11,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" namespace llvm { struct fltSemantics; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -15,14 +15,17 @@ #define BUILTIN_TYPES include "mlir/IR/BuiltinDialect.td" +include "mlir/Interfaces/JoinMeetTypeInterface.td" // TODO: Currently the types defined in this file are prefixed with `Builtin_`. // This is to differentiate the types here with the ones in OpBase.td. We should // remove the definitions in OpBase.td, and repoint users to this file instead. // Base class for Builtin dialect types. -class Builtin_Type - : TypeDef { +class Builtin_Type traits = []> + : TypeDef { let mnemonic = ?; } @@ -253,7 +256,9 @@ // MemRefType //===----------------------------------------------------------------------===// -def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> { +def Builtin_MemRef + : Builtin_Type<"MemRef", "BaseMemRefType", + [DeclareTypeInterfaceMethods]> { let summary = "Shaped reference to a region of memory"; let description = [{ Syntax: @@ -629,7 +634,9 @@ // RankedTensorType //===----------------------------------------------------------------------===// -def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> { +def Builtin_RankedTensor + : Builtin_Type<"RankedTensor", "TensorType", + [DeclareTypeInterfaceMethods]> { let summary = "Multi-dimensional array with a fixed number of dimensions"; let description = [{ Syntax: @@ -784,7 +791,9 @@ // UnrankedMemRefType //===----------------------------------------------------------------------===// -def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> { +def Builtin_UnrankedMemRef + : Builtin_Type<"UnrankedMemRef", "BaseMemRefType", + [DeclareTypeInterfaceMethods]> { let summary = "Shaped reference, with unknown rank, to a region of memory"; let description = [{ Syntax: @@ -844,7 +853,9 @@ // UnrankedTensorType //===----------------------------------------------------------------------===// -def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "TensorType"> { +def Builtin_UnrankedTensor + : Builtin_Type<"UnrankedTensor", "TensorType", + [DeclareTypeInterfaceMethods]> { let summary = "Multi-dimensional array with unknown dimensions"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -66,6 +66,33 @@ /// Dimensions are compatible if all non-dynamic dims are equal. LogicalResult verifyCompatibleDims(ArrayRef dims); + +/// The element-wise `join` function for shapes, and the partial order "less +/// specialized than or equal". It returns an error if the shapes have different +/// sizes. +/// The `join` relationship for two dimensions is: +/// dim1 | dim2 | join +/// n | n | n +/// n | k | dynamic +/// dynamic | k | dynamic +/// n | dynamic | dynamic +FailureOr> +joinShapes(ArrayRef shape1, ArrayRef shape2, + Optional location = None); + +/// The element-wise `meet` function for shapes, and the partial order "less +/// specialized than or equal". It returns an error if the shapes have different +/// sizes or do not meet. +/// The `meet` relationship for two dimensions is: +/// dim1 | dim2 | meet +/// n | n | n +/// n | k | error +/// dynamic | k | k +/// n | dynamic | n +FailureOr> meetShapes(ArrayRef shape1, + ArrayRef shape2, + Optional = None); + //===----------------------------------------------------------------------===// // Utility Iterators //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -33,3 +33,15 @@ DataLayoutOpInterface Interfaces/ -gen-op-interface-docs) + +set(LLVM_TARGET_DEFINITIONS JoinMeetTypeInterface.td) +mlir_tablegen(JoinMeetTypeInterface.h.inc -gen-type-interface-decls) +mlir_tablegen(JoinMeetTypeInterface.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRJoinMeetTypeInterfaceIncGen) +add_dependencies(mlir-generic-headers MLIRJoinMeetTypeInterfaceIncGen) + +add_mlir_doc(JoinMeetTypeInterface + JoinMeetTypeInterface + Interfaces/ + -gen-type-interface-docs) + diff --git a/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.h b/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.h @@ -0,0 +1,96 @@ +//===- JoinMeetTypeInterface.h - Join/Meet Type Interface Decls -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for type join and meet relationships. +// Types can inherit and implement the `JoinMeetTypeInterface::Trait` trait to +// participate in `join` and `meet` functions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_JOINMEETTYPEINTERFACE_H +#define MLIR_INTERFACES_JOINMEETTYPEINTERFACE_H + +#include "mlir/IR/Location.h" +#include "mlir/IR/Types.h" + +#include "mlir/Interfaces/JoinMeetTypeInterface.h.inc" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Join and Meet Types +//===----------------------------------------------------------------------===// + +/// The join function for types, and the partial order "less specialized than or +/// equal", noted `≤`. +/// It returns the most specialized type that is less specialized than both +/// `ty1` and `ty2`, if it exists, and `Type()` otherwise. +/// The join `j` of `ty1` and `ty2` is such that: +/// * j ≤ ty1, and j ≤ ty2 +/// * For any type t such that t ≤ ty1 and t ≤ ty2, t ≤ j. +/// For example: +/// ty1 | ty2 | ty1 v ty2 +/// ------------------+-------------------+------------------- +/// i8 | i8 | i8 +/// i8 | i32 | (null type) +/// tensor<1xf32> | tensor | tensor +/// tensor<1x2x?xf32> | tensor<1x?x3xf32> | tensor<1x?x?xf32> +/// tensor<4x5xf32> | tensor<6xf32> | tensor<*xf32> +/// tensor<1xi32> | i32 | (null type) +/// tensor<1xi32> | tensor | tensor<*xi32> +/// +/// The function is monotonic: +/// * idempotence: join(x,x) == x +/// * commutativity: join(x,y) == join(y,x) +/// * associativity: join(x,join(y,z)) == join(join(x,y),z) +/// +/// Types can participate in this function by implementing +/// `JoinMeetTypeInterface`. +Type join(Type ty1, Type ty2); + +/// The meet function for types, and the partial order "less specialized than or +/// equal", noted `≤`. +/// It returns the least specialized type, that is more specialized than both +/// `ty1` and `ty2`, if it exists, and `Type()` otherwise. +/// The meet `m` of `ty1` and `ty2` is such that: +/// * m ≤ ty1, and m ≤ ty2 +/// * For any type t such that t ≤ ty1 and t ≤ ty2, t ≤ m. +/// For example: +/// i8 | i8 | i8 +/// ------------------+-------------------+------------------- +/// i8 | i32 | (null type) +/// tensor<1xf32> | tensor | tensor<1xf32> +/// tensor<1x2x?xf32> | tensor<1x?x3xf32> | tensor<1x2x3xf32> +/// tensor<4x5xf32> | tensor<6xf32> | (null type) +/// tensor<1xi32> | i32 | (null type) +/// tensor<1xi32> | tensor | (null type) +/// +/// The function is monotonic: +/// * idempotence: join(x,x) == x +/// * commutativity: join(x,y) == join(y,x) +/// * associativity: join(x,join(y,z)) == join(join(x,y),z) +/// +/// Types can participate in this function by implementing +/// `JoinMeetTypeInterface`. +Type meet(Type ty1, Type ty2); + +/// Indicates whether `ty1` is compatible with `ty2`, and less specialized than +/// `ty2`. +inline bool isLessSpecialized(Type ty1, Type ty2) { + return join(ty1, ty2) == ty2; +} + +/// Indicates whether `ty1` is compatible with `ty2`, and more specialized than +/// `ty2`. +inline bool isMoreSpecialized(Type ty1, Type ty2) { + return meet(ty1, ty2) == ty1; +} + +} // namespace mlir + +#endif // MLIR_INTERFACES_JOINMEETTYPEINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.td b/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/JoinMeetTypeInterface.td @@ -0,0 +1,68 @@ +//===- JoinMeetTypeInterface.td - Join/Meet Type Interface ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for join and meet type relathionships. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_JOINMEETTYPEINTERFACE +#define MLIR_JOINMEETTYPEINTERFACE + +include "mlir/IR/OpBase.td" + +def JoinMeetTypeInterface : TypeInterface<"JoinMeetTypeInterface"> { + let cppNamespace = "::mlir"; + + let description = [{ + Interface for types participating in the `join` and `meet` relationships. + + Types willing to participate in the `::mlir::join(Type, Type)` and + `::mlir::meet(Type, Type)` functions should implement this interface by + providing implementations of `join` and `meet` functions handling types they + can join or meet with. + + The return value of the interface method `someType.join(otherType)` is + interpreted as follows: + - A result with a value (`Optional::hasValue()` is true) represents + the result of the `join`. If null, it indicates `otherType` was handled, + and the result of the `join` is null. + - A result without a value (`Optional::hasValue()` is false) indicates + `sometype` does not know (or ignores) the `other` type. + + Returning an `Optional` with no value allows for extensibility of + types. `Type1::join` may handle `Type2` while `Type2::join` ignores `Type1` + or simply does not implement this interface. + + The interface `join` method must participate to guarantee the monotonicity + of `::mlir::join`. It must be idempotent (`join(x, x) = x`). Associativity, + though it should flow naturally, must be maintained. Commutativity is + guaranteed by `::mlir::join` itself, as it queries both + `sometype.join(othertype)` and `othertype.join(sometype)`. One type may + ignore the other; in that case the result with a value is used. If both + results have a value, they must match. + + The same comments apply to `meet`. + }]; + + let methods = [ + InterfaceMethod< + /*description=*/"Return the join with the `other` type.", + /*retTy=*/"Optional", + /*methodName=*/"join", + /*args=*/(ins "Type":$other) + >, + InterfaceMethod< + /*description=*/"Return the meet with the `other` type.", + /*retTy=*/"Optional", + /*methodName=*/"meet", + /*args=*/(ins "Type":$other) + >, + ]; +} + +#endif // MLIR_JOINMEETTYPEINTERFACE diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -14,6 +14,7 @@ LINK_LIBS PUBLIC MLIRCastInterfaces MLIRIR + MLIRJoinMeetTypeInterface MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -102,38 +102,6 @@ return succeeded(verifyCompatibleShape(aT, bT)); } -/// Compute a TensorType that has the joined shape knowledge of the two -/// given TensorTypes. The element types need to match. -static TensorType joinShapes(TensorType one, TensorType two) { - assert(one.getElementType() == two.getElementType()); - - if (!one.hasRank()) - return two; - if (!two.hasRank()) - return one; - - int64_t rank = one.getRank(); - if (rank != two.getRank()) - return {}; - - SmallVector join; - join.reserve(rank); - for (int64_t i = 0; i < rank; ++i) { - if (one.isDynamicDim(i)) { - join.push_back(two.getDimSize(i)); - continue; - } - if (two.isDynamicDim(i)) { - join.push_back(one.getDimSize(i)); - continue; - } - if (one.getDimSize(i) != two.getDimSize(i)) - return {}; - join.push_back(one.getDimSize(i)); - } - return RankedTensorType::get(join, one.getElementType()); -} - namespace { /// Replaces chains of two tensor.cast operations by a single tensor.cast @@ -153,20 +121,19 @@ auto intermediateType = tensorCastOperand.getType().cast(); auto resultType = tensorCast.getType().cast(); - // We can remove the intermediate cast if joining all three produces the - // same result as just joining the source and result shapes. - auto firstJoin = - joinShapes(joinShapes(sourceType, intermediateType), resultType); + // We can remove the intermediate cast if meeting all three produces the + // same result as just meeting the source and result shapes. + auto firstMeet = meet(meet(sourceType, intermediateType), resultType); - // The join might not exist if the cast sequence would fail at runtime. - if (!firstJoin) + // The meet might not exist if the cast sequence would fail at runtime. + if (!firstMeet) return failure(); - // The newJoin always exists if the above join exists, it might just contain + // The newMeet always exists if the above meet exists, it might just contain // less information. If so, we cannot drop the intermediate cast, as doing // so would remove runtime checks. - auto newJoin = joinShapes(sourceType, resultType); - if (firstJoin != newJoin) + auto newMeet = meet(sourceType, resultType); + if (firstMeet != newMeet) return failure(); rewriter.replaceOpWithNewOp(tensorCast, resultType, diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/TensorEncoding.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/Sequence.h" @@ -436,6 +437,47 @@ !type.getDialect().getNamespace().empty(); } +namespace { +Optional join(TensorType ty1, TensorType ty2) { + if (ty1 == ty2) + return ty1; + if (!ty1 || !ty2) + return {}; + + Type elementTy = ty1.getElementType(); + if (ty2.getElementType() != elementTy) + return Type(); + + if (!ty1.hasRank() || !ty2.hasRank() || ty1.getRank() != ty2.getRank()) + return UnrankedTensorType::get(elementTy); + + auto shape = joinShapes(ty1.getShape(), ty2.getShape()).getValue(); + return RankedTensorType::get(shape, elementTy); +} + +Optional meet(TensorType ty1, TensorType ty2) { + if (ty1 == ty2) + return ty1; + if (!ty1 || !ty2) + return {}; + + Type elementTy = ty1.getElementType(); + if (ty2.getElementType() != elementTy) + return Type(); + + if (!ty1.hasRank() || !ty2.hasRank()) + return ty1.hasRank() ? ty1 : ty2; + + if (ty1.getRank() != ty2.getRank()) + return Type(); + + auto shape = meetShapes(ty1.getShape(), ty2.getShape()); + if (succeeded(shape)) + return RankedTensorType::get(shape.getValue(), elementTy); + return Type(); +} +} // namespace + //===----------------------------------------------------------------------===// // RankedTensorType //===----------------------------------------------------------------------===// @@ -453,6 +495,14 @@ return checkTensorElementType(emitError, elementType); } +Optional RankedTensorType::join(Type other) const { + return ::join(*this, other.dyn_cast_or_null()); +} + +Optional RankedTensorType::meet(Type other) const { + return ::meet(*this, other.dyn_cast_or_null()); +} + //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// @@ -463,6 +513,14 @@ return checkTensorElementType(emitError, elementType); } +Optional UnrankedTensorType::join(Type other) const { + return ::join(*this, other.dyn_cast_or_null()); +} + +Optional UnrankedTensorType::meet(Type other) const { + return ::meet(*this, other.dyn_cast_or_null()); +} + //===----------------------------------------------------------------------===// // BaseMemRefType //===----------------------------------------------------------------------===// @@ -479,6 +537,66 @@ return cast().getMemorySpaceAsInt(); } +namespace { +LogicalResult commonMemRefJoinMeetChecks(BaseMemRefType ty1, + BaseMemRefType ty2) { + if (ty1.getElementType() != ty2.getElementType()) + return failure(); + + if (ty1.getMemorySpace() != ty2.getMemorySpace()) + return failure(); + + auto memRefTy1 = ty1.dyn_cast(); + auto memRefTy2 = ty2.dyn_cast(); + if (memRefTy1 && memRefTy2 && + memRefTy1.getAffineMaps() != memRefTy2.getAffineMaps()) + return failure(); + + return success(); +} + +Optional join(BaseMemRefType ty1, BaseMemRefType ty2) { + if (ty1 == ty2) + return ty1; + if (!ty1 || !ty2) + return {}; + + if (failed(commonMemRefJoinMeetChecks(ty1, ty2))) + return Type(); + + if (!ty1.hasRank() || !ty2.hasRank() || ty1.getRank() != ty2.getRank()) + return UnrankedMemRefType::get(ty1.getElementType(), ty1.getMemorySpace()); + + auto shape = joinShapes(ty1.getShape(), ty2.getShape()).getValue(); + return MemRefType::get(shape, ty1.getElementType(), + ty1.cast().getAffineMaps(), + ty1.getMemorySpace()); +} + +Optional meet(BaseMemRefType ty1, BaseMemRefType ty2) { + if (ty1 == ty2) + return ty1; + if (!ty1 || !ty2) + return {}; + + if (failed(commonMemRefJoinMeetChecks(ty1, ty2))) + return Type(); + + if (!ty1.hasRank() || !ty2.hasRank()) + return ty1.hasRank() ? ty1 : ty2; + + if (ty1.getRank() != ty2.getRank()) + return Type(); + + auto shape = meetShapes(ty1.getShape(), ty2.getShape()); + if (succeeded(shape)) + return MemRefType::get(shape.getValue(), ty1.getElementType(), + ty1.cast().getAffineMaps(), + ty1.getMemorySpace()); + return Type(); +} +} // namespace + //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// @@ -606,6 +724,14 @@ return success(); } +Optional MemRefType::join(Type other) const { + return ::join(*this, other.dyn_cast_or_null()); +} + +Optional MemRefType::meet(Type other) const { + return ::meet(*this, other.dyn_cast_or_null()); +} + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -773,6 +899,14 @@ return success(); } +Optional UnrankedMemRefType::join(Type other) const { + return ::join(*this, other.dyn_cast_or_null()); +} + +Optional UnrankedMemRefType::meet(Type other) const { + return ::meet(*this, other.dyn_cast_or_null()); +} + //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -42,6 +42,7 @@ MLIRCallInterfacesIncGen MLIRCastInterfacesIncGen MLIRDataLayoutInterfacesIncGen + MLIRJoinMeetTypeInterfaceIncGen MLIROpAsmInterfaceIncGen MLIRRegionKindInterfaceIncGen MLIRSideEffectInterfacesIncGen diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" using namespace mlir; @@ -151,6 +152,46 @@ return success(); } +FailureOr> +mlir::joinShapes(ArrayRef shape1, ArrayRef shape2, + Optional location) { + if (shape1.size() != shape2.size()) + return emitOptionalError(location, "shapes' sizes differ"); + + SmallVector shape(shape1.size()); + + for (size_t i = 0, e = shape1.size(); i != e; ++i) { + int64_t dim1 = shape1[i]; + int64_t dim2 = shape2[i]; + shape[i] = dim1 == dim2 ? dim1 : ShapedType::kDynamicSize; + } + + return shape; +} + +FailureOr> +mlir::meetShapes(ArrayRef shape1, ArrayRef shape2, + Optional location) { + if (shape1.size() != shape2.size()) + return emitOptionalError(location, "shapes' sizes differ"); + + SmallVector shape(shape1.size()); + + for (size_t i = 0, e = shape1.size(); i != e; ++i) { + int64_t dim1 = shape1[i]; + int64_t dim2 = shape2[i]; + if (dim1 == dim2 || ShapedType::isDynamic(dim1) || + ShapedType::isDynamic(dim2)) { + shape[i] = ShapedType::isDynamic(dim1) ? dim2 : dim1; + } else { + return emitOptionalError(location, "dimensions at index ", i, + " are incompatible: ", dim1, ", ", dim2); + } + } + + return shape; +} + OperandElementTypeIterator::OperandElementTypeIterator( Operation::operand_iterator it) : llvm::mapped_iterator( diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp InferTypeOpInterface.cpp + JoinMeetTypeInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp VectorInterfaces.cpp @@ -35,6 +36,7 @@ add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) add_mlir_interface_library(InferTypeOpInterface) +add_mlir_interface_library(JoinMeetTypeInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(VectorInterfaces) diff --git a/mlir/lib/Interfaces/JoinMeetTypeInterface.cpp b/mlir/lib/Interfaces/JoinMeetTypeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/JoinMeetTypeInterface.cpp @@ -0,0 +1,52 @@ +//===- JoinMeetTypeInterface.cpp - Join/Meet Type Interface Implementation ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/JoinMeetTypeInterface.h" +#include "mlir/IR/Diagnostics.h" + +using namespace mlir; + +#include "mlir/Interfaces/JoinMeetTypeInterface.cpp.inc" + +Type mlir::join(Type ty1, Type ty2) { + if (ty1 == ty2) + return ty1; + if (!ty1 || !ty2) + return Type(); + + Optional join1; + if (auto interface1 = ty1.dyn_cast()) + join1 = interface1.join(ty2); + + Optional join2; + if (auto interface2 = ty2.dyn_cast()) + join2 = interface2.join(ty1); + + assert(!join1 || !join2 || join1.getValue() == join2.getValue()); + + return join1.hasValue() ? join1.getValue() : join2.getValueOr(Type()); +} + +Type mlir::meet(Type ty1, Type ty2) { + if (ty1 == ty2) + return ty1; + if (!ty1 || !ty2) + return Type(); + + Optional meet1; + if (auto interface1 = ty1.dyn_cast()) + meet1 = interface1.meet(ty2); + + Optional meet2; + if (auto interface2 = ty2.dyn_cast()) + meet2 = interface2.meet(ty1); + + assert(!meet1 || !meet2 || meet1.getValue() == meet2.getValue()); + + return meet1.hasValue() ? meet1.getValue() : meet2.getValueOr(Type()); +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(CAPI) add_subdirectory(EDSC) +add_subdirectory(Interfaces) add_subdirectory(SDBM) add_subdirectory(lib) diff --git a/mlir/test/Interfaces/CMakeLists.txt b/mlir/test/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(JoinMeetTypeInterface) diff --git a/mlir/test/Interfaces/JoinMeetTypeInterface/CMakeLists.txt b/mlir/test/Interfaces/JoinMeetTypeInterface/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/JoinMeetTypeInterface/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LLVM_LINK_COMPONENTS + Support + ) + +add_llvm_executable(mlir-test-join-meet-type-interface + TestJoinMeetTypeInterface.cpp + ) +llvm_update_compile_flags(mlir-test-join-meet-type-interface) +target_link_libraries(mlir-test-join-meet-type-interface + PRIVATE + MLIRIR + MLIRParser + MLIRJoinMeetTypeInterface + ) +target_include_directories(mlir-test-join-meet-type-interface PRIVATE ../..) + diff --git a/mlir/test/Interfaces/JoinMeetTypeInterface/TestJoinMeetTypeInterface.cpp b/mlir/test/Interfaces/JoinMeetTypeInterface/TestJoinMeetTypeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/JoinMeetTypeInterface/TestJoinMeetTypeInterface.cpp @@ -0,0 +1,92 @@ +//===- TestJoinMeetTypeInterface.cpp - Test Join/Meet interface -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// RUN: mlir-test-join-meet-type-interface + +#include "APITest.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" +#include "mlir/Parser.h" + +using namespace mlir; + +int errors = 0; + +#define PARSE(str) \ + ((std::string("null") == str) ? Type() : parseType(str, &context)) + +TEST_FUNC(join_meet) { +#define CHECK(func, ty1, ty2, expected) \ + do { \ + Type found = func(ty1, ty2); \ + if (found != expected) { \ + emitError(FileLineColLoc::get(&context, __FILE__, __LINE__, 0), \ + #func "(" #ty1 ", " #ty2 ") = ") \ + << found << " != " #expected; \ + ++errors; \ + } \ + } while (false) + +#define CHECKP(func, ty1, ty2, expected) \ + CHECK(func, PARSE(#ty1), PARSE(#ty2), PARSE(#expected)) + + MLIRContext context; + + CHECKP(join, i8, i8, i8); + CHECKP(join, i8, i32, null); + CHECKP(join, tensor<*xi8>, tensor<*xi32>, null); + CHECKP(join, tensor<1xi32>, i32, null); + CHECKP(join, tensor<1xi32>, tensor, tensor<*xi32>); + CHECKP(join, tensor<1xi32>, tensor<2x3xi32>, tensor<*xi32>); + CHECKP(join, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>); + CHECKP(join, tensor, tensor, tensor); + CHECKP(join, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>); + CHECKP(join, tensor<1x2x?xi32>, tensor<1x?x3xi32>, tensor<1x?x?xi32>); + CHECKP(join, tensor<*xi32>, tensor<4x?xi32>, tensor<*xi32>); + + CHECKP(meet, i8, i8, i8); + CHECKP(meet, i8, i32, null); + CHECKP(meet, tensor<1xi32>, i32, null); + CHECKP(meet, tensor<1xi32>, tensor, null); + CHECKP(meet, tensor<1xi32>, tensor<2x3xi32>, null); + CHECKP(meet, tensor<*xi8>, tensor<*xi32>, null); + CHECKP(meet, tensor<1xi8>, tensor<1xi32>, null); + CHECKP(meet, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>); + CHECKP(meet, tensor, tensor, tensor); + CHECKP(meet, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>); + CHECKP(meet, tensor<1x2x?xi32>, tensor<1x?x3xi32>, tensor<1x2x3xi32>); + CHECKP(meet, tensor<*xi32>, tensor<4x?xi32>, tensor<4x?xi32>); + CHECKP(meet, tensor<4xi32>, tensor<5x6xi32>, null); + + auto null = Type(); + auto i32 = IntegerType::get(&context, 32); + auto memref_mem8 = MemRefType::get(1, i32, {}, 8); + auto memref_mem9 = MemRefType::get(1, i32, {}, 9); + CHECK(join, memref_mem8, memref_mem9, null); + CHECK(join, memref_mem8, memref_mem8, memref_mem8); + CHECK(meet, memref_mem8, memref_mem9, null); + CHECK(meet, memref_mem8, memref_mem8, memref_mem8); + + auto memref_nomap = MemRefType::get(1, i32, {}); + auto memref_map = MemRefType::get(1, i32, {AffineMap::get(1, 2, &context)}); + CHECK(join, memref_nomap, memref_map, null); + CHECK(join, memref_map, memref_map, memref_map); + CHECK(meet, memref_nomap, memref_map, null); + CHECK(meet, memref_map, memref_map, memref_map); + +#undef CHECKP +#undef CHECK +} + +int main(void) { + RUN_TESTS(); + return errors; +} diff --git a/mlir/test/Interfaces/JoinMeetTypeInterface/lit.local.cfg b/mlir/test/Interfaces/JoinMeetTypeInterface/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/JoinMeetTypeInterface/lit.local.cfg @@ -0,0 +1 @@ +config.suffixes.add('.cpp')