diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -137,20 +137,20 @@ ### Broadcastable -* `OpTrait::BroadcastableTwoOperandsOneResult` -- `Broadcastable` +* `OpTrait::ResultsBroadcastableShape` -- `ResultsBroadcastableShape` -This trait provides the API for operations that are known to have +This trait adds the property that the operation is known to have [broadcast-compatible](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -operand and result types. Specifically, starting from the most varying -dimension, each dimension pair of the two operands' types should either be the -same or one of them is one. Also, the result type should have the corresponding +operands and its result types' shape is the broadcast compatible with the shape +of the broadcasted operands. Specifically, starting from the most varying +dimension, each dimension pair of the two operands' shapes should either be the +same or one of them is one. Also, the result shape should have the corresponding dimension equal to the larger one, if known. Shapes are checked partially if ranks or dimensions are not known. For example, an op with `tensor` and `tensor<2xf32>` as operand types and `tensor<3x2xf32>` as the result type is broadcast-compatible. -Ths trait assumes the op has two operands and one result, and it asserts if the -pre-condition is not satisfied. +This trait requires that the operands are either vector or tensor types. ### Commutative diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h --- a/mlir/include/mlir/Dialect/Traits.h +++ b/mlir/include/mlir/Dialect/Traits.h @@ -51,23 +51,26 @@ /// following NumPy broadcast semantics. Returned type may have dynamic shape if /// either of the input types has dynamic shape. Returns null type if the two /// given types are not broadcast-compatible. -Type getBroadcastedType(Type type1, Type type2); +/// +/// elementType, if specified, will be used as the element type of the +/// broadcasted result type. Otherwise it is required that the element type of +/// type1 and type2 is the same and this element type will be used as the +/// resultant element type. +Type getBroadcastedType(Type type1, Type type2, Type elementType = nullptr); + } // namespace util -/// This class provides the API for ops that are known to have broadcast- -/// compatible operand and result types. Specifically, starting from the -/// most varying dimension, each dimension pair of the two operands' types -/// should either be the same or one of them is one. Also, the result type -/// should have the corresponding dimension equal to the larger one, if known. -/// Shapes are checked partially if ranks or dimensions are not known. For -/// example, an op with tensor and tensor <2 x f32> as operand -/// types and tensor<3 x 2 x f32> as the result type is broadcast-compatible. -/// -/// Ths trait assumes the op has two operands and one result, and it asserts -/// if the pre-condition is not satisfied. +/// Trait for ops that are known to have broadcast compatible operands and +/// result types. Specifically, starting from the most varying dimension, each +/// dimension pair of the operands' shapes should either be the same or one +/// of them is one. Also, the results's shapes should have the corresponding +/// dimension equal to the larger one, if known. Shapes are checked partially if +/// ranks or dimensions are not known. For example, an op with tensor +/// and tensor<2xf32> as operand types and tensor<5x3x2xi16> as the result +/// type has broadcast compatible operands ns result types. template -class BroadcastableTwoOperandsOneResult - : public TraitBase { +class ResultsBroadcastableShape + : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { return impl::verifyCompatibleOperandBroadcast(op); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1327,7 +1327,10 @@ } // Op supports operand broadcast behavior. -def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">; +def ResultsBroadcastableShape : + NativeOpTrait<"ResultsBroadcastableShape">; +// TODO: Alias of the above, remove post integrate. +def Broadcastable : NativeOpTrait<"ResultsBroadcastableShape">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; // Op behaves like a function. diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -80,25 +81,27 @@ /// following NumPy broadcast semantics. Returned type may have dynamic shape if /// either of the input types has dynamic shape. Returns null type if the two /// given types are not broadcast-compatible. -Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { - // Returns the scalar type out of the given type. - auto getScalarType = [](Type type) -> Type { - if (auto shapedType = type.dyn_cast()) - return shapedType.getElementType(); - return type; - }; - - // Make sure underlying scalar type is the same. - auto scalarType = getScalarType(type1); - if (scalarType != getScalarType(type2)) - return {}; +/// +/// elementType, if specified, will be used as the element type of the +/// broadcasted result type. Otherwise it is required that the element type of +/// type1 and type2 is the same and this element type will be used as the +/// resultant element type. +Type OpTrait::util::getBroadcastedType(Type type1, Type type2, + Type elementType) { + // If the elementType is not specified, then the use the common element type + // of the inputs or fail if there is no common element type. + if (!elementType) { + elementType = getElementTypeOrSelf(type1); + if (elementType != getElementTypeOrSelf(type2)) + return {}; + } // If one of the types is unranked tensor, then the other type shouldn't be // vector and the result should have unranked tensor type. if (type1.isa() || type2.isa()) { if (type1.isa() || type2.isa()) return {}; - return UnrankedTensorType::get(scalarType); + return UnrankedTensorType::get(elementType); } // Returns the type kind if the given type is a vector or ranked tensor type. @@ -132,16 +135,18 @@ // Compose the final broadcasted type if (resultCompositeKind == StandardTypes::Vector) - return VectorType::get(resultShape, scalarType); + return VectorType::get(resultShape, elementType); if (resultCompositeKind == StandardTypes::RankedTensor) - return RankedTensorType::get(resultShape, scalarType); - return scalarType; + return RankedTensorType::get(resultShape, elementType); + return elementType; } -/// Returns true if the given types has both vector types and tensor types. -static bool hasBothVectorAndTensorType(ArrayRef types) { - return llvm::any_of(types, [](Type t) { return t.isa(); }) && - llvm::any_of(types, [](Type t) { return t.isa(); }); +/// Returns a tuple corresponding to whether range has tensor or vector type. +template +static std::tuple hasTensorOrVectorType(iterator_range types) { + return std::make_tuple( + llvm::any_of(types, [](Type t) { return t.isa(); }), + llvm::any_of(types, [](Type t) { return t.isa(); })); } static bool areCompatibleShapes(ArrayRef shape1, @@ -157,55 +162,57 @@ return true; } -LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { - assert(op->getNumOperands() == 2 && - "only support broadcast check on two operands"); - assert(op->getNumResults() == 1 && - "only support broadcast check on one result"); - - auto type1 = op->getOperand(0).getType(); - auto type2 = op->getOperand(1).getType(); - auto retType = op->getResult(0).getType(); +static std::string getShapeString(ArrayRef shape) { + // TODO: should replace with printing shape more uniformly across here and + // when in type. + return formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end())); +} - // We forbid broadcasting vector and tensor. - if (hasBothVectorAndTensorType({type1, type2, retType})) +LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { + // Ensure broadcasting only tensor or only vector types. + auto operandsHasTensorVectorType = + hasTensorOrVectorType(op->getOperandTypes()); + auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes()); + if ((std::get<0>(operandsHasTensorVectorType) || + std::get<0>(resultsHasTensorVectorType)) && + (std::get<1>(operandsHasTensorVectorType) || + std::get<1>(resultsHasTensorVectorType))) return op->emitError("cannot broadcast vector with tensor"); - if (retType.isa()) - return success(); - - bool isUnranked1 = type1.isa(); - bool isUnranked2 = type2.isa(); + auto rankedOperands = make_filter_range( + op->getOperandTypes(), [](Type t) { return t.isa(); }); - // If both operands are unranked, then all result shapes are possible. - if (isUnranked1 && isUnranked2) + // If all operands are unranked, then all result shapes are possible. + if (rankedOperands.empty()) return success(); - // If one of the operands is unranked, then the known dimensions in the result - // should be compatible with the other shaped operand. - if (isUnranked1 || isUnranked2) { - // Result should have higher rank than the shaped operand's rank and then - // the result's trailing dimensions should be compatible with the operand - // shape. - ArrayRef shape = getShape(!isUnranked1 ? type1 : type2); - ArrayRef actualSuffix = getShape(retType).take_back(shape.size()); - if (!areCompatibleShapes(actualSuffix, shape)) - return op->emitOpError() - << "result type " << retType - << " has shape incompatible with a ranked operand type"; - return success(); + // Compute broadcasted shape of operands (which requires that operands are + // broadcast compatible). The results need to be broadcast compatible with + // this result shape. + SmallVector resultShape; + (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {}, + resultShape); + for (auto other : make_early_inc_range(rankedOperands)) { + SmallVector temp = resultShape; + if (!util::getBroadcastedShape(temp, getShape(other), resultShape)) + return op->emitOpError("operands don't have broadcast-compatible shapes"); } - // If both operands are shaped, then the computed broadcasted shape should be - // compatible with the result shape. - SmallVector resultShape; - if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) - return op->emitOpError("operands don't have broadcast-compatible shapes"); + auto rankedResults = make_filter_range( + op->getResultTypes(), [](Type t) { return t.isa(); }); - if (!areCompatibleShapes(resultShape, getShape(retType))) - return op->emitOpError() << "result type " << retType - << " does not have shape compatible with the one " - "computed from the operand types"; + // If all of the results are unranked then no further verfication. + if (rankedResults.empty()) + return success(); + for (auto type : rankedResults) { + ArrayRef actualSuffix = + getShape(type).take_back(resultShape.size()); + if (!areCompatibleShapes(actualSuffix, resultShape)) + return op->emitOpError() + << "result type " << getShapeString(getShape(type)) + << " not broadcast compatible with broadcasted operands's shapes " + << getShapeString(resultShape); + } return success(); } diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir --- a/mlir/test/Dialect/traits.mlir +++ b/mlir/test/Dialect/traits.mlir @@ -78,7 +78,7 @@ // Check incompatible result type with known dimension func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> { ^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>): - // expected-error @+1 {{does not have shape compatible with the one computed}} + // expected-error @+1 {{op result type '4x3x3' not broadcast compatible with broadcasted operands's shapes '4x3x2'}} %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> return %0 : tensor<4x3x3xi32> } @@ -88,7 +88,7 @@ // Check incompatible result type with known dimension func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> { ^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>): - // expected-error @+1 {{does not have shape compatible with the one computed}} + // expected-error @+1 {{op result type '8x7x6x1' not broadcast compatible with broadcasted operands's shapes '8x7x6x5'}} %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> return %0 : tensor<8x7x6x1xi32> } @@ -123,7 +123,7 @@ // Unranked operand and compatible ranked result func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> { ^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>): - %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> + %0 = "test.broadcastable"(%arg0, %arg0, %arg1) : (tensor<3x2xi32>, tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> return %0 : tensor<4x3x2xi32> } @@ -131,7 +131,7 @@ func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>): - // expected-error @+1 {{shape incompatible with a ranked operand type}} + // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '3x2'}} %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -376,8 +376,8 @@ let arguments = (ins AnyType:$x, AnyType:$y); } -def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> { - let arguments = (ins AnyTensor, AnyTensor); +def BroadcastableOp : TEST_Op<"broadcastable", [ResultsBroadcastableShape]> { + let arguments = (ins Variadic); let results = (outs AnyTensor); } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -781,11 +781,17 @@ return resultValue; } - bool isBroadcastable = - resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult"); + // TODO: Remove once broadcastable has been updated. This query here is not + // really about broadcastable or not, it is about which build method to invoke + // and that requires knowledge of whether ODS generated a builder that need + // not take return types. That knowledge should be captured in one place + // rather than duplicated. + bool isResultsBroadcastableShape = + resultOp.getTrait("OpTrait::ResultsBroadcastableShape"); bool usePartialResults = valuePackName != resultValue; - if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) { + if (isResultsBroadcastableShape || usePartialResults || depth > 0 || + resultIndex < 0) { // For these cases (broadcastable ops, op results used both as auxiliary // values and replacement values, ops in nested patterns, auxiliary ops), we // still need to supply the result types when building the op. But because