diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h --- a/mlir/include/mlir/Dialect/Traits.h +++ b/mlir/include/mlir/Dialect/Traits.h @@ -51,29 +51,40 @@ /// following NumPy broadcast semantics. Returned type may have dynamic shape if /// either of the input types has dynamic shape. Returns null type if the two /// given types are not broadcast-compatible. -Type getBroadcastedType(Type type1, Type type2); +/// +/// If the elementType is not provided, then this requires that the element type +/// of type1 and type2 matches and is the same as the result type. +Type getBroadcastedType(Type type1, Type type2, Type elementType = nullptr); + } // namespace util /// This class provides the API for ops that are known to have broadcast- -/// compatible operand and result types. Specifically, starting from the -/// most varying dimension, each dimension pair of the two operands' types -/// should either be the same or one of them is one. Also, the result type -/// should have the corresponding dimension equal to the larger one, if known. -/// Shapes are checked partially if ranks or dimensions are not known. For -/// example, an op with tensor and tensor <2 x f32> as operand -/// types and tensor<3 x 2 x f32> as the result type is broadcast-compatible. +/// compatible operand and result types, and the operands and result types +/// element types are the same. Specifically, starting from the most varying +/// dimension, each dimension pair of the two operands' types should either be +/// the same or one of them is one. Also, the result type should have the +/// corresponding dimension equal to the larger one, if known. Shapes are +/// checked partially if ranks or dimensions are not known. For example, an op +/// with tensor and tensor <2 x f32> as operand types and +/// tensor<3 x 2 x f32> as the result type is broadcast-compatible. /// /// Ths trait assumes the op has two operands and one result, and it asserts /// if the pre-condition is not satisfied. template -class BroadcastableTwoOperandsOneResult - : public TraitBase { +class BroadcastableTwoOperandsOneResultWithSameElementType + : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { return impl::verifyCompatibleOperandBroadcast(op); } }; +// TODO: Remove post updating call sites. +template +using BroadcastableTwoOperandsOneResult = + BroadcastableTwoOperandsOneResultWithSameElementType; + } // end namespace OpTrait } // end namespace mlir diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1327,7 +1327,14 @@ } // Op supports operand broadcast behavior. -def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">; +// TODO: This doesn't support arbitrary broadcast and actually implies same +// element types which is misleading. +def Broadcastable : + NativeOpTrait<"BroadcastableTwoOperandsOneResultWithSameElementType">; +// Op supports operand broadcast behavior. +// TODO: Remove call sites have been updated and express using separate traits. +def BroadcastableWithSameElementType : + NativeOpTrait<"BroadcastableTwoOperandsOneResultWithSameElementType">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; // Op behaves like a function. diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -80,25 +81,25 @@ /// following NumPy broadcast semantics. Returned type may have dynamic shape if /// either of the input types has dynamic shape. Returns null type if the two /// given types are not broadcast-compatible. -Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { - // Returns the scalar type out of the given type. - auto getScalarType = [](Type type) -> Type { - if (auto shapedType = type.dyn_cast()) - return shapedType.getElementType(); - return type; - }; - - // Make sure underlying scalar type is the same. - auto scalarType = getScalarType(type1); - if (scalarType != getScalarType(type2)) - return {}; +/// +/// If the elementType is not provided, then this requires that the element type +/// of type1 and type2 matches and is the same as the result type. +Type OpTrait::util::getBroadcastedType(Type type1, Type type2, + Type elementType) { + // If the elementType is not specified, then the use the common element type + // of the inputs or fail if there is no common element type. + if (!elementType) { + elementType = getElementTypeOrSelf(type1); + if (elementType != getElementTypeOrSelf(type2)) + return {}; + } // If one of the types is unranked tensor, then the other type shouldn't be // vector and the result should have unranked tensor type. if (type1.isa() || type2.isa()) { if (type1.isa() || type2.isa()) return {}; - return UnrankedTensorType::get(scalarType); + return UnrankedTensorType::get(elementType); } // Returns the type kind if the given type is a vector or ranked tensor type. @@ -132,10 +133,10 @@ // Compose the final broadcasted type if (resultCompositeKind == StandardTypes::Vector) - return VectorType::get(resultShape, scalarType); + return VectorType::get(resultShape, elementType); if (resultCompositeKind == StandardTypes::RankedTensor) - return RankedTensorType::get(resultShape, scalarType); - return scalarType; + return RankedTensorType::get(resultShape, elementType); + return elementType; } /// Returns true if the given types has both vector types and tensor types. diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -373,7 +373,8 @@ let arguments = (ins AnyType:$x, AnyType:$y); } -def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> { +def BroadcastableOp : TEST_Op<"broadcastable", + [BroadcastableWithSameElementType]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -781,8 +781,15 @@ return resultValue; } + // TODO: Remove once broadcastable has been updated. This query here is not + // really about broadcastable or not, it is about which build method to invoke + // and that requires knowledge of whether ODS generated a builder that need + // not take return types. That knowledge should be captured in one place + // rather than duplicated. bool isBroadcastable = - resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult"); + resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult") || + resultOp.getTrait( + "OpTrait::BroadcastableTwoOperandsOneResultWithSameElementType"); bool usePartialResults = valuePackName != resultValue; if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {