diff --git a/mlir/docs/Traits/Broadcastable.md b/mlir/docs/Traits/Broadcastable.md
new file mode 100644
--- /dev/null
+++ b/mlir/docs/Traits/Broadcastable.md
@@ -0,0 +1,197 @@
+# The `Broadcastable` Trait
+
+[TOC]
+
+## Description
+
+The `Broadcastable` trait enforces the following properties on an operation:
+
+- The operation has at least one input operand.
+
+- The operation has exactly one result.
+
+- All input operands and result are of type `tensor` or `vector`.
+
+- A shape inference mechanism is able to compute the result shape solely based on input operand shapes.
+
+- Input operands have broadcast-compatible shapes, according to the verification rules presented below.
+
+- The operation's result shape is compatible with —though not necessarily identical to— the shape inferred from its input operands, according to the verification rules presented below.
+
+
+## Dimension inference
+
+Given an operation with two input operands, the size of dimension `i` of its result can be inferred from dimension `i` of the operands according to the table below. Here, `dim0` and `dim1` represent dimension `i` of the input operands in an interchangeable order, while `inferredDim` represents the inferred size for dimension `i` of the operation result. Dimensions are classified in three categories: dynamic ("?"), static equal to 1 ("1"), and static greater than 1 (">1").
+
+
+| `dim0` | `dim1` | `inferredDim` | Notes |
+| -------- | -------- | ------------- | ----- |
+| ? | ? | ? | If `RuntimeSize(dim0)` is 1, dimension `dim0` is broadcast to `RuntimeSize(dim1)`. If `RuntimeSize(dim1)` is 1, dimension `dim1` is broadcast to `RuntimeSize(dim0)`. The operation produces undefined behavior if both runtime sizes are greater than 1 and not equal. |
+| ? | 1 | ? | Dimension `dim1` is broadcast to `RuntimeSize(dim0)`. |
+| ? | >1 | `dim1` | If `RuntimeSize(dim0)` is 1, `dim0` is broadcast to `dim1`. The operation produces undefined behavior if `RuntimeSize(dim0)` is greater than 1 and not equal to `dim1`. |
+| 1 | 1 | 1 | |
+| 1 | >1 | `dim1` | Dimension `dim0` is broadcast to `dim1`. |
+| >1 | >1 | `dim0` | The operation verifier produces a compile-time error if `dim0` != `dim1`. |
+
+
+The following pseudo-function is a formal representation of the dimension inference process:
+
+```python
+InferDim(dim0, dim1):
+ switch (dim0, dim1):
+ case (?, ?):
+ case (?, 1):
+ case (1, 1):
+ case (>1, ?):
+ case (>1, 1):
+ return dim0
+ case (?, >1):
+ case (1, ?):
+ case (1, >1):
+ return dim1
+ case (>1, >1):
+ ERROR_IF(dim0 != dim1)
+ return dim0
+```
+
+## Shape inference
+
+The shape inference process begins by correcting rank differences in input operands. A shape is expanded by adding additional dimensions of size 1 on its left until the desired rank is reached, as shown here:
+
+```python
+ExpandRank(shape, rank):
+ while len(shape) < rank:
+ shape.prepend(1)
+```
+
+Given the shapes of two ranked input operands, the result's shape is inferred by equalizing input ranks and inferring individual dimensions, as shown here:
+
+```python
+InferShape(shape0, shape1):
+
+ # Equalize ranks
+ rank = max(GetRank(shape0), GetRank(shape1))
+ ExpandRank(shape0, rank)
+ ExpandRank(shape1, rank)
+
+ # Infer shape
+ inferredShape = []
+ for (dim0, dim1) in zip(shape0, shape1):
+ inferredDim = InferDim(dim0, dim1)
+ inferredShape.append(inferredDim)
+ return inferredShape
+```
+
+The result shape for an operation with an arbitrary number of input operands is then inferred by discarding unranked operands, applying shape inference on the first ranked operand pair, and updating the inferred shape with each additional ranked operand. If the operation has no ranked operands, the result shape cannot be inferred. If the operation has exactly one ranked operand, its shape is directly provided as the inferred result shape. Formally:
+
+```python
+InferResultShape(op):
+
+ # Filter ranked operands
+ rankedOperands = filter(op.operands, IsRanked)
+ if len(rankedOperands) == 0:
+ return None
+
+ # Infer result shape
+ inferredShape = GetShape(rankedOperands[0])
+ for operand in rankedOperands[1:]:
+ inferredShape = InferShape(inferredShape, GetShape(operand))
+ return inferredShape
+```
+
+## Verification
+
+The legality of an operation with the `Broadcastable` trait is verified by first running the shape inference process. If a failure occurs during shape inference, it is concluded that input operands are not broadcast-compatible, and verification fails. If shape inference succeeds, verification continues.
+
+If either the result is unranked or all input operands are unranked, no further verification steps are needed, and the process ends here successfully. If, on the contrary, both the result and at least one input operand are ranked, verification continues by checking for a matching rank between the previously inferred shape and the result.
+
+Once a rank match is guaranteed, each dimension of the inferred shape is compared with the corresponding dimension of the actual result shape according to the following table table:
+
+
+| `inferredDim` | `actualDim` | Verification outcome |
+| ------------- | ----------- | -------------------- |
+| ? | ? | **OK** |
+| ? | static | **Error**
An inferred dimension being dynamic indicates that its size cannot be inferred at compile time from its input operands. The presence of a static dimension in the actual result is counterintuitive and is therefore not allowed. |
+| static | ? | **OK**
The actual result dimension may be dynamic even when a static size can be inferred at compile time. The programmer may choose to relax the specificity of the result dimension for forward compatibility of the result type. |
+| static | static | **OK if equal**
When both the inferred and actual dimensions are static, they must be set to the same size. |
+
+
+The full verification process can be formally specified as follows:
+
+```python
+Verify(op):
+
+ # Run shape inference
+ inferredShape = InferResultShape(op.operands)
+
+ # Done if result is unranked or all operands are unranked
+ if not IsRanked(op.result) or inferredShape is None:
+ return
+
+ # Rank must match
+ actualShape = GetShape(op.result):
+ ERROR_IF(len(inferredShape) != len(actualShape))
+
+ # Verify
+ for (inferredDim, actualDim) in zip(inferredShape, actualShape):
+ ERROR_IF(IsDynamic(inferredDim) and IsStatic(actualDim))
+ ERROR_IF(IsStatic(actualDim) and inferredDim != actualDim)
+```
+
+## Examples
+
+The following are correct uses of broadcastable ops:
+
+```mlir
+// Exact match of static sizes.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<1x2xi32) -> tensor<1x2xi32>
+
+// Dynamic sizes match. The programmer must guarantee that the runtime sizes of
+// %arg0 and %arg1 are equal at runtime.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor tensor
+
+// The shape of %arg0 is broadcast from tensor<1xi32> to tensor<4xi32>.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<4xi32) -> tensor<4xi32>
+
+// The shape of %result is inferred as tensor<4xi32>, while the actual result
+// type is tensor. The inferred shape is compatible with the actual shape.
+%result = "test.broadcastable"(%arg0) : (tensor<4xi32) -> tensor
+
+// The shape of %arg0 is first expanded to tensor<1x1x4xi32> and then broadcast
+// to tensor<2x3x4xi32>.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x3x4xi32) -> tensor<2x3x4xi32>
+
+// Input and results tensors have different element types (i1, i32, i64). The
+// 'Broadcastable' trait has no restrictions on element types.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi32) -> tensor<2xi64>
+
+// No result shape verification is needed when the result is unranked.
+%result = "test.broadcastable"(%arg0) : (tensor<2xi32>) -> tensor<*xi32>
+
+// No result shape verification needed when all inputs are unranked.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32>
+```
+
+
+The following are incorrect uses of broadcastable ops:
+
+```mlir
+// Dimension 0 of input operands is static but not equal.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32) -> tensor
+
+// The inferred result shape is tensor<3xi32>, but the actual result shape is
+// tensor<1x3xi32>. Inferred and actual shapes differ in rank.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32) -> tensor<1x3xi32>
+
+// The inferred result shape is tensor, but the actual shape is
+// tensor<4xi32>. The inferred shape is not compatible with the actual shape.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor tensor<4xi32>
+
+// The inferred result shape is tensor<2xi32>, but the actual result shape is
+// tensor<4xi32>, which is not compatible.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32) -> tensor<4xi32>
+
+// The inferred result shape is tensor<1xi32>, but the actual result shape is
+// tensor<4xi32>. Broadcast semantics are not applicable for results.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32) -> tensor<4xi32>
+```
diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits/_index.md
rename from mlir/docs/Traits.md
rename to mlir/docs/Traits/_index.md
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits/_index.md
@@ -241,16 +241,7 @@
This trait adds the property that the operation is known to have
[broadcast-compatible](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-operands and its result types' shape is the broadcast compatible with the shape
-of the broadcasted operands. Specifically, starting from the most varying
-dimension, each dimension pair of the two operands' shapes should either be the
-same or one of them is one. Also, the result shape should have the corresponding
-dimension equal to the larger one, if known. Shapes are checked partially if
-ranks or dimensions are not known. For example, an op with `tensor` and
-`tensor<2xf32>` as operand types and `tensor<3x2xf32>` as the result type is
-broadcast-compatible.
-
-This trait requires that the operands are either vector or tensor types.
+operands and that its result type is compatible with the inferred broadcast shape. See [The `Broadcastable` Trait](Traits/Broadcastable.md) for details.
### Commutative
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -215,18 +215,12 @@
Op {
}
-class Tosa_ElemWiseUnaryOp traits = []> :
+class Tosa_ElementwiseOp traits = []> :
Tosa_Op,
- Pure, SameOperandsAndResultElementType])> {
-}
-
-class Tosa_ElemWiseBinaryOp traits = []> :
- Tosa_Op,
- ResultsBroadcastableShape, Pure, SameOperandsAndResultElementType])> {
+ ResultsBroadcastableShape,
+ Pure])> {
}
#endif // TOSA_OP_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -375,7 +375,7 @@
//===----------------------------------------------------------------------===//
// Operator: clamp
//===----------------------------------------------------------------------===//
-def Tosa_ClampOp : Tosa_ElemWiseUnaryOp<"clamp"> {
+def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
let summary = "Computes clamp(features, min, max).";
let description = [{
@@ -404,7 +404,7 @@
//===----------------------------------------------------------------------===//
// Operator: sigmoid
//===----------------------------------------------------------------------===//
-def Tosa_SigmoidOp : Tosa_ElemWiseUnaryOp<"sigmoid"> {
+def Tosa_SigmoidOp : Tosa_ElementwiseOp<"sigmoid"> {
let summary = "Computes elementwise sigmoid of input.";
let description = [{
@@ -427,7 +427,7 @@
//===----------------------------------------------------------------------===//
// Operator: tanh
//===----------------------------------------------------------------------===//
-def Tosa_TanhOp : Tosa_ElemWiseUnaryOp<"tanh"> {
+def Tosa_TanhOp : Tosa_ElementwiseOp<"tanh", [SameOperandsAndResultElementType]> {
let summary = "Computes elementwise hyperbolic tangent of input";
let description = [{
@@ -481,7 +481,9 @@
//===----------------------------------------------------------------------===//
// Operator: add
//===----------------------------------------------------------------------===//
-def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> {
+def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Elementwise addition operator";
let description = [{
@@ -504,7 +506,8 @@
//===----------------------------------------------------------------------===//
// Operator: arithmetic_right_shift
//===----------------------------------------------------------------------===//
-def Tosa_ArithmeticRightShiftOp : Tosa_ElemWiseBinaryOp<"arithmetic_right_shift"> {
+def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise Arithmetic Right Shift";
let description = [{
@@ -526,7 +529,9 @@
//===----------------------------------------------------------------------===//
// Operator: bitwise_and
//===----------------------------------------------------------------------===//
-def Tosa_BitwiseAndOp : Tosa_ElemWiseBinaryOp<"bitwise_and", [Commutative]> {
+def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Bitwise AND operator";
let description = [{
@@ -547,7 +552,9 @@
//===----------------------------------------------------------------------===//
// Operator: bitwise_or
//===----------------------------------------------------------------------===//
-def Tosa_BitwiseOrOp : Tosa_ElemWiseBinaryOp<"bitwise_or", [Commutative]> {
+def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Bitwise OR operator";
let description = [{
@@ -568,7 +575,9 @@
//===----------------------------------------------------------------------===//
// Operator: bitwise_xor
//===----------------------------------------------------------------------===//
-def Tosa_BitwiseXorOp : Tosa_ElemWiseBinaryOp<"bitwise_xor", [Commutative]> {
+def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Bitwise XOR operator";
let description = [{
@@ -589,7 +598,7 @@
//===----------------------------------------------------------------------===//
// Operator: div
//===----------------------------------------------------------------------===//
-def Tosa_DivOp : Tosa_ElemWiseBinaryOp<"div"> {
+def Tosa_DivOp : Tosa_ElementwiseOp<"div", [SameOperandsAndResultElementType]> {
let summary = "Integer divide operator";
let description = [{
@@ -612,7 +621,9 @@
//===----------------------------------------------------------------------===//
// Operator: logical_and
//===----------------------------------------------------------------------===//
-def Tosa_LogicalAndOp : Tosa_ElemWiseBinaryOp<"logical_and", [Commutative]> {
+def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x AND y element-wise.";
let description = [{
@@ -633,7 +644,8 @@
//===----------------------------------------------------------------------===//
// Operator: logical_left_shift
//===----------------------------------------------------------------------===//
-def Tosa_LogicalLeftShiftOp : Tosa_ElemWiseBinaryOp<"logical_left_shift"> {
+def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise Logical Left Shift";
let description = [{
@@ -654,7 +666,8 @@
//===----------------------------------------------------------------------===//
// Operator: logical_right_shift
//===----------------------------------------------------------------------===//
-def Tosa_LogicalRightShiftOp : Tosa_ElemWiseBinaryOp<"logical_right_shift"> {
+def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise Logical Right Shift";
let description = [{
@@ -675,7 +688,9 @@
//===----------------------------------------------------------------------===//
// Operator: logical_or
//===----------------------------------------------------------------------===//
-def Tosa_LogicalOrOp : Tosa_ElemWiseBinaryOp<"logical_or", [Commutative]> {
+def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x OR y element-wise.";
let description = [{
@@ -696,7 +711,9 @@
//===----------------------------------------------------------------------===//
// Operator: logical_xor
//===----------------------------------------------------------------------===//
-def Tosa_LogicalXorOp : Tosa_ElemWiseBinaryOp<"logical_xor", [Commutative]> {
+def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x XOR y element-wise.";
let description = [{
@@ -717,7 +734,9 @@
//===----------------------------------------------------------------------===//
// Operator: maximum
//===----------------------------------------------------------------------===//
-def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> {
+def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Elementwise Maximum";
let description = [{
@@ -738,7 +757,9 @@
//===----------------------------------------------------------------------===//
// Operator: minimum
//===----------------------------------------------------------------------===//
-def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> {
+def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Elementwise Minimum";
let description = [{
@@ -759,7 +780,9 @@
//===----------------------------------------------------------------------===//
// Operator: mul
//===----------------------------------------------------------------------===//
-def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> {
+def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
+ Commutative,
+ SameOperandsAndResultElementType]> {
let summary = "Multiplication operator";
let description = [{
@@ -784,7 +807,7 @@
//===----------------------------------------------------------------------===//
// Operator: pow
//===----------------------------------------------------------------------===//
-def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> {
+def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
let summary = "Computes the power of one value to another.";
let description = [{
@@ -805,7 +828,7 @@
//===----------------------------------------------------------------------===//
// Operator: sub
//===----------------------------------------------------------------------===//
-def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> {
+def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
let summary = "Elementwise subtraction operator";
let description = [{
@@ -871,7 +894,7 @@
//===----------------------------------------------------------------------===//
// Operator: abs
//===----------------------------------------------------------------------===//
-def Tosa_AbsOp : Tosa_ElemWiseUnaryOp<"abs"> {
+def Tosa_AbsOp : Tosa_ElementwiseOp<"abs", [SameOperandsAndResultElementType]> {
let summary = "Elementwise abs op";
let description = [{
@@ -892,7 +915,8 @@
//===----------------------------------------------------------------------===//
// Operator: bitwise_not
//===----------------------------------------------------------------------===//
-def Tosa_BitwiseNotOp : Tosa_ElemWiseUnaryOp<"bitwise_not"> {
+def Tosa_BitwiseNotOp : Tosa_ElementwiseOp<"bitwise_not",
+ [SameOperandsAndResultElementType]> {
let summary = "Bitwise NOT operator";
let description = [{
@@ -911,7 +935,7 @@
//===----------------------------------------------------------------------===//
// Operator: ceil
//===----------------------------------------------------------------------===//
-def Tosa_CeilOp : Tosa_ElemWiseUnaryOp<"ceil"> {
+def Tosa_CeilOp : Tosa_ElementwiseOp<"ceil", [SameOperandsAndResultElementType]> {
let summary = "Elementwise ceil op";
let description = [{
@@ -930,7 +954,7 @@
//===----------------------------------------------------------------------===//
// Operator: clz
//===----------------------------------------------------------------------===//
-def Tosa_ClzOp : Tosa_ElemWiseUnaryOp<"clz"> {
+def Tosa_ClzOp : Tosa_ElementwiseOp<"clz", [SameOperandsAndResultElementType]> {
let summary = "Elementwise count leading zero op";
let description = [{
@@ -949,7 +973,7 @@
//===----------------------------------------------------------------------===//
// Operator: exp
//===----------------------------------------------------------------------===//
-def Tosa_ExpOp : Tosa_ElemWiseUnaryOp<"exp"> {
+def Tosa_ExpOp : Tosa_ElementwiseOp<"exp", [SameOperandsAndResultElementType]> {
let summary = "Elementwise exp op";
let description = [{
@@ -970,7 +994,7 @@
//===----------------------------------------------------------------------===//
// Operator: floor
//===----------------------------------------------------------------------===//
-def Tosa_FloorOp : Tosa_ElemWiseUnaryOp<"floor"> {
+def Tosa_FloorOp : Tosa_ElementwiseOp<"floor", [SameOperandsAndResultElementType]> {
let summary = "Elementwise floor op";
let description = [{
@@ -989,7 +1013,7 @@
//===----------------------------------------------------------------------===//
// Operator: log
//===----------------------------------------------------------------------===//
-def Tosa_LogOp : Tosa_ElemWiseUnaryOp<"log"> {
+def Tosa_LogOp : Tosa_ElementwiseOp<"log", [SameOperandsAndResultElementType]> {
let summary = "Elementwise log op";
let description = [{
@@ -1010,7 +1034,8 @@
//===----------------------------------------------------------------------===//
// Operator: logical_not
//===----------------------------------------------------------------------===//
-def Tosa_LogicalNotOp : Tosa_ElemWiseUnaryOp<"logical_not"> {
+def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
+ [SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of NOT x element-wise.";
let description = [{
@@ -1029,7 +1054,8 @@
//===----------------------------------------------------------------------===//
// Operator: negate
//===----------------------------------------------------------------------===//
-def Tosa_NegateOp : Tosa_ElemWiseUnaryOp<"negate"> {
+def Tosa_NegateOp : Tosa_ElementwiseOp<"negate",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise negate op";
let description = [{
@@ -1053,7 +1079,8 @@
//===----------------------------------------------------------------------===//
// Operator: reciprocal
//===----------------------------------------------------------------------===//
-def Tosa_ReciprocalOp : Tosa_ElemWiseUnaryOp<"reciprocal"> {
+def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise reciprocal op";
let description = [{
@@ -1073,7 +1100,8 @@
//===----------------------------------------------------------------------===//
// Operator: rsqrt
//===----------------------------------------------------------------------===//
-def Tosa_RsqrtOp : Tosa_ElemWiseUnaryOp<"rsqrt"> {
+def Tosa_RsqrtOp : Tosa_ElementwiseOp<"rsqrt",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise 1/sqrt op";
let description = [{
@@ -1099,9 +1127,7 @@
//===----------------------------------------------------------------------===//
// Operator: select
//===----------------------------------------------------------------------===//
-def Tosa_SelectOp : Tosa_Op<"select", [
- DeclareOpInterfaceMethods, Pure]> {
+def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
let summary = "Elementwise select operator";
let description = [{
@@ -1129,8 +1155,10 @@
//===----------------------------------------------------------------------===//
// Operator: equal
//===----------------------------------------------------------------------===//
-def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
- Commutative, Pure, SameOperandsElementType]> {
+def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
+ InferTensorType,
+ Commutative,
+ SameOperandsElementType]> {
let summary = "Returns the truth value of (x == y) element-wise.";
let description = [{
@@ -1158,10 +1186,7 @@
//===----------------------------------------------------------------------===//
// Operator: greater
//===----------------------------------------------------------------------===//
-def Tosa_GreaterOp : Tosa_Op<"greater", [
- DeclareOpInterfaceMethods,
- ResultsBroadcastableShape, Pure, SameOperandsElementType]> {
+def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
let summary = "Returns the truth value of (x > y) element-wise.";
let description = [{
@@ -1183,10 +1208,8 @@
//===----------------------------------------------------------------------===//
// Operator: greater_equal
//===----------------------------------------------------------------------===//
-def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
- DeclareOpInterfaceMethods,
- ResultsBroadcastableShape, Pure, SameOperandsElementType]> {
+def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
+ [SameOperandsElementType]> {
let summary = "Returns the truth value of (x >= y) element-wise.";
let description = [{
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -24,9 +24,12 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
#include
@@ -517,115 +520,339 @@
return nullptr;
}
-static LogicalResult
-elementwiseMatchAndRewriteHelper(Operation *operation,
- PatternRewriter &rewriter) {
- auto loc = operation->getLoc();
-
- assert(operation->getNumResults() == 1 &&
- "All TOSA elementwise ops should only return a single result.");
-
- auto result = operation->getResult(0);
- auto resultTy = dyn_cast(result.getType());
-
- if (!resultTy)
- return rewriter.notifyMatchFailure(
- operation, "All results must be a ranked tensor type");
+static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
+ int64_t rank) {
+ // No need to expand if we are already at the desired rank
+ auto shapedType = dyn_cast(tensor.getType());
+ assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
+ int64_t numExtraDims = rank - shapedType.getRank();
+ assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
+ if (!numExtraDims)
+ return tensor;
+
+ // Compute reassociation indices
+ SmallVector> reassociationIndices(
+ shapedType.getRank());
+ int64_t index = 0;
+ for (index = 0; index <= numExtraDims; index++)
+ reassociationIndices[0].push_back(index);
+ for (size_t position = 1; position < reassociationIndices.size(); position++)
+ reassociationIndices[position].push_back(index++);
+
+ // Compute result type
+ SmallVector resultShape;
+ for (index = 0; index < numExtraDims; index++)
+ resultShape.push_back(1);
+ for (auto size : shapedType.getShape())
+ resultShape.push_back(size);
+ auto resultType =
+ RankedTensorType::get(resultShape, shapedType.getElementType());
+
+ // Emit 'tensor.expand_shape' op
+ return rewriter.create(loc, resultType, tensor,
+ reassociationIndices);
+}
- unsigned rank = resultTy.getRank();
+static SmallVector expandInputRanks(PatternRewriter &rewriter,
+ Location loc, Operation *operation) {
+ auto rank =
+ operation->getResultTypes().front().cast().getRank();
+ return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
+ return expandRank(rewriter, loc, operand, rank);
+ });
+}
- // Construct the indexing maps needed for linalg.generic ops.
- SmallVector bodyArgTypes;
+using IndexPool = DenseMap;
+
+// Emit an 'arith.constant' op for the given index if it has not been created
+// yet, or return an existing constant. This will prevent an excessive creation
+// of redundant constants, easing readability of emitted code for unit tests.
+static Value createIndex(PatternRewriter &rewriter, Location loc,
+ IndexPool &indexPool, int64_t index) {
+ auto [it, inserted] = indexPool.try_emplace(index);
+ if (inserted)
+ it->second =
+ rewriter.create(loc, rewriter.getIndexAttr(index));
+ return it->second;
+}
- for (Value in : operation->getOperands())
- bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
+static Value getTensorDim(PatternRewriter &rewriter, Location loc,
+ IndexPool &indexPool, Value tensor, int64_t index) {
+ auto indexValue = createIndex(rewriter, loc, indexPool, index);
+ return rewriter.create(loc, tensor, indexValue).getResult();
+}
- SmallVector opResultTypes;
- SmallVector emptyTensors;
+static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
+ IndexPool &indexPool, Value tensor,
+ int64_t index) {
+ auto shapedType = dyn_cast(tensor.getType());
+ assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
+ assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
+ if (shapedType.isDynamicDim(index))
+ return getTensorDim(rewriter, loc, indexPool, tensor, index);
+ return rewriter.getIndexAttr(shapedType.getDimSize(index));
+}
- SmallVector dynDims;
- dynDims.resize(rank);
+static bool operandsAndResultsRanked(Operation *operation) {
+ auto isRanked = [](Value value) {
+ return isa(value.getType());
+ };
+ return llvm::all_of(operation->getOperands(), isRanked) &&
+ llvm::all_of(operation->getResults(), isRanked);
+}
- for (auto arg : operation->getOperands()) {
- auto operandTy = cast(arg.getType());
- for (int i = 0; i < operandTy.getRank(); i++) {
- if (operandTy.isDynamicDim(i) && !dynDims[i])
- dynDims[i] = rewriter.create(loc, arg, i);
- }
+// Compute the runtime dimension size for dimension 'dim' of the output by
+// inspecting input 'operands', all of which are expected to have the same rank.
+// This function returns a pair {targetSize, masterOperand}.
+//
+// The runtime size of the output dimension is returned either as a statically
+// computed attribute or as a runtime SSA value.
+//
+// If the target size was inferred directly from one dominating operand, that
+// operand is returned in 'masterOperand'. If the target size is inferred from
+// multiple operands, 'masterOperand' is set to nullptr.
+static std::pair
+computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
+ ValueRange operands, int64_t dim) {
+ // If any input operand contains a static size greater than 1 for this
+ // dimension, that is the target size. An occurrence of an additional static
+ // dimension greater than 1 with a different value is undefined behavior.
+ for (auto operand : operands) {
+ auto size = operand.getType().cast().getDimSize(dim);
+ if (!ShapedType::isDynamic(size) && size > 1)
+ return {rewriter.getIndexAttr(size), operand};
}
- SmallVector filteredDims = condenseValues(dynDims);
-
- emptyTensors.push_back(
- rewriter.create(loc, resultTy, filteredDims));
- opResultTypes.push_back(result.getType());
-
- auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
- emptyTensors, [](Value v) { return getElementTypeOrSelf(v); }));
-
- SmallVector operands;
- SmallVector indexingMaps;
- indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
-
- // Input indexing maps may be broadcasted.
- for (Value operand : operation->getOperands()) {
- ShapedType type = cast(operand.getType());
+ // Filter operands with dynamic dimension
+ auto operandsWithDynamicDim =
+ llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
+ return operand.getType().cast().isDynamicDim(dim);
+ }));
+
+ // If no operand has a dynamic dimension, it means all sizes were 1
+ if (operandsWithDynamicDim.empty())
+ return {rewriter.getIndexAttr(1), operands.front()};
+
+ // Emit code that computes the runtime size for this dimension. If there is
+ // only one operand with a dynamic dimension, it is considered the master
+ // operand that determines the runtime size of the output dimension.
+ auto targetSize =
+ getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
+ if (operandsWithDynamicDim.size() == 1)
+ return {targetSize, operandsWithDynamicDim[0]};
+
+ // Calculate maximum size among all dynamic dimensions
+ for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
+ auto nextSize =
+ getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
+ targetSize = rewriter.create(loc, targetSize, nextSize);
+ }
+ return {targetSize, nullptr};
+}
- if (type.getShape() == resultTy.getShape()) {
- operands.push_back(operand);
- indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
- continue;
- }
+// Compute the runtime output size for all dimensions. This function returns
+// a pair {targetShape, masterOperands}.
+static std::pair, SmallVector>
+computeTargetShape(PatternRewriter &rewriter, Location loc,
+ IndexPool &indexPool, ValueRange operands) {
+ assert(!operands.empty());
+ auto rank = operands.front().getType().cast().getRank();
+ SmallVector targetShape;
+ SmallVector masterOperands;
+ for (auto dim : llvm::seq(0, rank)) {
+ auto [targetSize, masterOperand] =
+ computeTargetSize(rewriter, loc, indexPool, operands, dim);
+ targetShape.push_back(targetSize);
+ masterOperands.push_back(masterOperand);
+ }
+ return {targetShape, masterOperands};
+}
- SmallVector newShape;
- SmallVector affineExprs;
- newShape.reserve(type.getRank());
- for (const auto &it : llvm::enumerate(type.getShape())) {
- if (it.value() == resultTy.getDimSize(it.index())) {
- newShape.push_back(it.value());
- affineExprs.push_back(
- mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
- }
+static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
+ IndexPool &indexPool, Value operand,
+ int64_t dim, OpFoldResult targetSize,
+ Value masterOperand) {
+ // Nothing to do if this is a static dimension
+ auto rankedTensorType = operand.getType().cast();
+ if (!rankedTensorType.isDynamicDim(dim))
+ return operand;
+
+ // If the target size for this dimension was directly inferred by only taking
+ // this operand into account, there is no need to broadcast. This is an
+ // optimization that will prevent redundant control flow, and constitutes the
+ // main motivation for tracking "master operands".
+ if (operand == masterOperand)
+ return operand;
+
+ // Affine maps for 'linalg.generic' op
+ auto rank = rankedTensorType.getRank();
+ SmallVector affineExprs;
+ for (auto index : llvm::seq(0, rank)) {
+ auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
+ : rewriter.getAffineDimExpr(index);
+ affineExprs.push_back(affineExpr);
+ }
+ auto broadcastAffineMap =
+ AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
+ auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
+ SmallVector affineMaps = {broadcastAffineMap, identityAffineMap};
+
+ // Check if broadcast is necessary
+ auto one = createIndex(rewriter, loc, indexPool, 1);
+ auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
+ auto broadcastNecessary = rewriter.create(
+ loc, arith::CmpIPredicate::eq, runtimeSize, one);
+
+ // Emit 'then' region of 'scf.if'
+ auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
+ // Emit 'tensor.empty' op
+ SmallVector outputTensorShape;
+ for (auto index : llvm::seq(0, rank)) {
+ auto size = index == dim ? targetSize
+ : getOrFoldTensorDim(rewriter, loc, indexPool,
+ operand, index);
+ outputTensorShape.push_back(size);
}
+ Value outputTensor = opBuilder.create(
+ loc, outputTensorShape, rankedTensorType.getElementType());
+
+ // Emit 'linalg.generic' op
+ auto resultTensor =
+ opBuilder
+ .create(
+ loc, outputTensor.getType(), operand, outputTensor, affineMaps,
+ getNParallelLoopsAttrs(rank),
+ [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
+ // Emit 'linalg.yield' op
+ opBuilder.create(loc, blockArgs.front());
+ })
+ .getResult(0);
+
+ // Cast to original operand type if necessary
+ auto castResultTensor = rewriter.createOrFold(
+ loc, operand.getType(), resultTensor);
+
+ // Emit 'scf.yield' op
+ opBuilder.create(loc, castResultTensor);
+ };
+
+ // Emit 'else' region of 'scf.if'
+ auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
+ opBuilder.create(loc, operand);
+ };
+
+ // Emit 'scf.if' op
+ auto ifOp = rewriter.create(loc, broadcastNecessary,
+ emitThenRegion, emitElseRegion);
+ return ifOp.getResult(0);
+}
- if (newShape.size() != rank) {
- operand = rewriter.create(
- loc, RankedTensorType::get(newShape, type.getElementType()), operand,
- rewriter.getDenseI64ArrayAttr(newShape));
- }
+static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
+ IndexPool &indexPool, Value operand,
+ ArrayRef targetShape,
+ ArrayRef masterOperands) {
+ size_t rank = operand.getType().cast().getRank();
+ assert(targetShape.size() == rank);
+ assert(masterOperands.size() == rank);
+ for (auto index : llvm::seq(0, rank))
+ operand =
+ broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
+ targetShape[index], masterOperands[index]);
+ return operand;
+}
- operands.push_back(operand);
- indexingMaps.push_back(AffineMap::get(
- /*dimCount=*/rank, /*symbolCount=*/0, affineExprs,
- rewriter.getContext()));
- }
+static SmallVector
+broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
+ IndexPool &indexPool, ValueRange operands,
+ ArrayRef targetShape,
+ ArrayRef masterOperands) {
+ // No need to broadcast for unary operations
+ if (operands.size() == 1)
+ return operands;
+
+ // Broadcast dynamic dimensions operand by operand
+ return llvm::map_to_vector(operands, [&](Value operand) {
+ return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
+ targetShape, masterOperands);
+ });
+}
- indexingMaps.append(operation->getNumResults(),
- rewriter.getMultiDimIdentityMap(rank));
+static LogicalResult
+emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
+ Operation *operation, ValueRange operands,
+ ArrayRef targetShape) {
+ // Generate output tensor
+ auto resultType =
+ operation->getResultTypes().front().cast();
+ Value outputTensor = rewriter.create(
+ loc, targetShape, resultType.getElementType());
+
+ // Create affine maps. Input affine maps broadcast static dimensions of size
+ // 1. The output affine map is an identity map.
+ //
+ auto rank = resultType.getRank();
+ auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
+ auto shape = cast(operand.getType()).getShape();
+ SmallVector affineExprs;
+ for (auto it : llvm::enumerate(shape)) {
+ auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
+ : rewriter.getAffineDimExpr(it.index());
+ affineExprs.push_back(affineExpr);
+ }
+ return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
+ });
+ affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
- bool didEncounterError = false;
+ // Emit 'linalg.generic' op
+ bool encounteredError = false;
auto linalgOp = rewriter.create(
- loc, opResultTypes, operands, emptyTensors, indexingMaps,
+ loc, outputTensor.getType(), operands, outputTensor, affineMaps,
getNParallelLoopsAttrs(rank),
- [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
+ [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
- bodyResultTypes, rewriter);
+ {resultType.getElementType()}, rewriter);
if (!opResult) {
- didEncounterError = true;
+ encounteredError = true;
return;
}
- nestedBuilder.create(loc, opResult);
+ opBuilder.create(loc, opResult);
});
-
- if (didEncounterError)
+ if (encounteredError)
return rewriter.notifyMatchFailure(
operation, "unable to create linalg.generic body for elementwise op");
- rewriter.replaceOp(operation, linalgOp->getResults());
+ // Cast 'linalg.generic' result into original result type if needed
+ auto castResult = rewriter.createOrFold(
+ loc, resultType, linalgOp->getResult(0));
+ rewriter.replaceOp(operation, castResult);
return success();
}
+static LogicalResult
+elementwiseMatchAndRewriteHelper(Operation *operation,
+ PatternRewriter &rewriter) {
+
+ // Collect op properties
+ assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
+ assert(operation->getNumOperands() >= 1 &&
+ "elementwise op expects at least 1 operand");
+ if (!operandsAndResultsRanked(operation))
+ return rewriter.notifyMatchFailure(operation,
+ "Unranked tensors not supported");
+
+ // Lower operation
+ IndexPool indexPool;
+ auto loc = operation->getLoc();
+ auto expandedOperands = expandInputRanks(rewriter, loc, operation);
+ auto [targetShape, masterOperands] =
+ computeTargetShape(rewriter, loc, indexPool, expandedOperands);
+ auto broadcastOperands = broadcastDynamicDimensions(
+ rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
+ return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
+ targetShape);
+}
+
// Returns the constant initial value for a given reduction operation. The
// attribute type varies depending on the element type required.
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
@@ -741,7 +968,7 @@
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
- llvm::SmallVector reduceShape;
+ SmallVector reduceShape;
SmallVector dynDims;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
if (axis != i) {
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -195,18 +195,22 @@
static bool isCompatibleInferredReturnShape(ArrayRef inferred,
ArrayRef existing) {
- auto isCompatible = [](int64_t dim1, int64_t dim2) {
- // If the inferred and existing dim is the same, or one of them is unknown
- // then it is compatible, else if the inferred dim is 1 then it is also
- // compatible. But if the existing dim is 1 and the inferred is greater than
- // 1 then flag.
- return dim1 == dim2 || ShapedType::isDynamic(dim1) ||
- ShapedType::isDynamic(dim2) || dim1 == 1;
+ auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
+ // The following criterion is used to determine the validity of an existing
+ // dimension:
+ //
+ // inferredDim existingDim Behavior
+ // ----------- ----------- --------
+ // dynamic dynamic OK
+ // dynamic static Error
+ // static dynamic OK
+ // static static OK if equal
+ return ShapedType::isDynamic(existingDim) || inferredDim == existingDim;
};
if (inferred.size() != existing.size())
return false;
- for (auto p : llvm::zip(inferred, existing))
- if (!isCompatible(std::get<0>(p), std::get<1>(p)))
+ for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
+ if (!isCompatible(inferredDim, existingDim))
return false;
return true;
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2,162 +2,406 @@
// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor) -> tensor {
+// CHECK-LABEL: @test_abs_scalar
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_scalar(%arg0: tensor) -> tensor {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%[[ARG0]] : tensor) outs([[INIT]] : tensor) {
- // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins([[ARG0]] : tensor) outs([[INIT]] : tensor) {
+ // CHECK: ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32):
+ // CHECK: [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32
// CHECK: linalg.yield [[ELEMENT]] : f32
// CHECK: } -> tensor
+ %0 = "tosa.abs"(%arg0) : (tensor) -> tensor
- %0 = "tosa.abs"(%arg0) : (tensor) -> tensor
+ // CHECK: return [[GENERIC]] : tensor
+ return %0 : tensor
+}
- // CHECK: return [[GENERIC]]
- return %0 : tensor
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_abs_1d_cast_result
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_1d_cast_result(%arg0: tensor<5xf32>) -> tensor {
+ // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<5xf32>
+ // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]] : tensor<5xf32>) outs([[EMPTY]] : tensor<5xf32>) {
+ // CHECK: ^bb0([[IN0:%.+]]: f32, [[OUT0:%.+]]: f32):
+ // CHECK: [[ABS:%.+]] = math.absf [[IN0]] : f32
+ // CHECK: linalg.yield [[ABS]] : f32
+ // CHECK: } -> tensor<5xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<5xf32>) -> tensor
+
+ // CHECK: [[CAST_RESULT:%.+]] = tensor.cast [[RESULT]] : tensor<5xf32> to tensor
+ // CHECK: return [[CAST_RESULT]] : tensor
+ return %0 : tensor
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_abs_1d_dynamic
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_1d_dynamic(%arg0: tensor) -> tensor {
+
+ // CHECK: [[ZERO:%.+]] = arith.constant 0 : index
+ // CHECK: [[DIM:%.+]] = tensor.dim [[ARG0]], [[ZERO]] : tensor
+ // CHECK: [[EMPTY:%.+]] = tensor.empty([[DIM]]) : tensor
+ // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor) outs([[EMPTY]] : tensor) {
+ // CHECK: ^bb0([[IN0:%.+]]: f32, [[OUT0:%.+]]: f32):
+ // CHECK: [[ABSF:%.+]] = math.absf [[IN0]] : f32
+ // CHECK: linalg.yield [[ABSF]] : f32
+ // CHECK: } -> tensor
+ %0 = "tosa.abs"(%arg0) : (tensor) -> tensor
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
- // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]]
- // CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<2xf32>
- %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ // CHECK: return [[RESULT]] : tensor
+ return %0 : tensor
+}
- // CHECK: return [[GENERIC]]
- return %0 : tensor<2xf32>
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
+// CHECK-LABEL: @test_add_0d
+// CHECK-SAME: [[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_0d(%arg0: tensor, %arg1: tensor) -> tensor {
+
+ // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor
+ // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins([[ARG0]], [[ARG1]] : tensor, tensor) outs([[EMPTY]] : tensor) {
+ // CHECK: ^bb0([[IN0:%.+]]: f32, [[IN1:%.+]]: f32, [[OUT0:%.+]]: f32):
+ // CHECK: [[ADDF:%.+]] = arith.addf [[IN0]], [[IN1]] : f32
+ // CHECK: linalg.yield [[ADDF]] : f32
+ // CHECK: } -> tensor
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor
+
+ // CHECK: return [[RESULT]] : tensor
+ return %0 : tensor
}
// -----
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_all_dynamic
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_all_dynamic(%arg0: tensor, %arg1: tensor) -> tensor {
+
+ // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor
+ // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor
+ // CHECK: %[[ARG0_MAX_DIM:.*]] = arith.maxui %[[ARG0_DIM0]], %[[ARG1_DIM0]] : index
+ // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor
+ // CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
+ // CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor) {
+ // CHECK: %[[VAL_2:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor
+ // CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor) outs(%[[VAL_2]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_4]] : f32
+ // CHECK: } -> tensor
+ // CHECK: scf.yield %[[VAL_3]] : tensor
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG0]] : tensor
+ // CHECK: }
+ // CHECK: %[[VAL_6:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor
+ // CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_6]], %[[CONST1]] : index
+ // CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_7]] -> (tensor) {
+ // CHECK: %[[VAL_8:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor
+ // CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG1]] : tensor) outs(%[[VAL_8]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_10]] : f32
+ // CHECK: } -> tensor
+ // CHECK: scf.yield %[[VAL_9]] : tensor
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG1]] : tensor
+ // CHECK: }
+ // CHECK: %[[VAL_12:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0_DIM0_BROADCAST]], %[[ARG0_DIM1_BROADCAST]] : tensor, tensor) outs(%[[VAL_12]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32):
+ // CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32
+ // CHECK: linalg.yield %[[VAL_16]] : f32
+ // CHECK: } -> tensor
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor
+
+ // CHECK: return %[[RESULT]] : tensor
+ return %0 : tensor
+}
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) {
- // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]]
- // CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<2x3xf32>
- %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
+// -----
- // CHECK: return [[GENERIC]]
- return %0 : tensor<2x3xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_broadcast_dynamic_to_static
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_broadcast_dynamic_to_static(%arg0: tensor<5xf32>, %arg1: tensor) -> tensor<5xf32> {
+
+ // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
+ // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor
+ // CHECK: %[[VAL_0:.*]] = arith.cmpi eq, %[[ARG1_DIM0]], %[[CONST1]] : index
+ // CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_0]] -> (tensor) {
+ // CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<5xf32>
+ // CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG1]] : tensor) outs(%[[VAL_1]] : tensor<5xf32>) {
+ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_3]] : f32
+ // CHECK: } -> tensor<5xf32>
+ // CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_2]] : tensor<5xf32> to tensor
+ // CHECK: scf.yield %[[VAL_5]] : tensor
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG1]] : tensor
+ // CHECK: }
+ // CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5xf32>
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1_DIM0_BROADCAST]] : tensor<5xf32>, tensor) outs(%[[VAL_6]] : tensor<5xf32>) {
+ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32):
+ // CHECK: %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
+ // CHECK: linalg.yield %[[VAL_10]] : f32
+ // CHECK: } -> tensor<5xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<5xf32>, tensor) -> tensor<5xf32>
+
+ // CHECK: return %[[RESULT]] : tensor<5xf32>
+ return %0 : tensor<5xf32>
}
// -----
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor) -> tensor {
- // CHECK: %[[C0:.+]] = arith.constant 0
- // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
- // CHECK: linalg.generic
- // CHECK: math.absf
- %0 = "tosa.abs"(%arg0) : (tensor) -> tensor
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_broadcast_static_to_dynamic
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_broadcast_static_to_dynamic(%arg0: tensor<1xf32>, %arg1: tensor) -> tensor {
+
+ // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor
+ // CHECK: %[[VAL_0:.*]] = tensor.empty(%[[ARG1_DIM0]]) : tensor
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor) outs(%[[VAL_0]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+ // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+ // CHECK: linalg.yield %[[VAL_4]] : f32
+ // CHECK: } -> tensor
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor) -> tensor
+
+ // CHECK: return %[[RESULT]] : tensor
return %0 : tensor
}
// -----
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @test_abs_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
- // CHECK: %[[C1:.+]] = arith.constant 1
- // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
- // CHECK: linalg.generic
- // CHECK: math.absf
- %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
- return %0 : tensor<2x?xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_broadcast_static_to_static
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+
+ // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<3xf32>
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<3xf32>) {
+ // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+ // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+ // CHECK: linalg.yield %[[VAL_4]] : f32
+ // CHECK: } -> tensor<3xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<3xf32>) -> tensor<3xf32>
+
+ // CHECK: return %[[RESULT]] : tensor<3xf32>
+ return %0 : tensor<3xf32>
}
// -----
-#SparseVector = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
-
-// CHECK-LABEL: @test_encoding_passthrough
-func.func @test_encoding_passthrough(%arg0: tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector> {
- // CHECK: linalg.generic
- // CHECK: sparse_tensor
- %0 = "tosa.abs"(%arg0) : (tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector>
- return %0 : tensor<2xi8, #SparseVector>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_matching_static
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_matching_static(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+
+ // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<3xf32>
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<3xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<3xf32>) {
+ // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+ // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+ // CHECK: linalg.yield %[[VAL_4]] : f32
+ // CHECK: } -> tensor<3xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
+
+ // CHECK: return %[[RESULT]] : tensor<3xf32>
+ return %0 : tensor<3xf32>
}
// -----
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
-
-// CHECK-LABEL: @test_broadcast
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<2xf32>
-func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
- // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG0]])
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %[[ARG1]] : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
- // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
- // CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<2xf32>
- %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32>
- return %0 : tensor<2xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK-LABEL: @test_add_2d_all_dynamic
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_2d_all_dynamic(%arg0: tensor, %arg1: tensor) -> tensor {
+
+ // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor
+ // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor
+ // CHECK: %[[MAX_DIM0:.*]] = arith.maxui %[[ARG0_DIM0]], %[[ARG1_DIM0]] : index
+ // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
+ // CHECK: %[[ARG0_DIM1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor
+ // CHECK: %[[ARG1_DIM1:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor
+ // CHECK: %[[MAX_DIM1:.*]] = arith.maxui %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
+
+ // CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor
+ // CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
+ // CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor) {
+ // CHECK: %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor
+ // CHECK: %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_2]]) : tensor
+ // CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor) outs(%[[VAL_3]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_5]] : f32
+ // CHECK: } -> tensor
+ // CHECK: scf.yield %[[VAL_4]] : tensor
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG0]] : tensor
+ // CHECK: }
+
+ // CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST1]] : tensor
+ // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
+ // CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor) {
+ // CHECK: %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST0]] : tensor
+ // CHECK: %[[VAL_10:.*]] = tensor.empty(%[[VAL_9]], %[[MAX_DIM1]]) : tensor
+ // CHECK: %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM0_BROADCAST]] : tensor) outs(%[[VAL_10]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_12]] : f32
+ // CHECK: } -> tensor
+ // CHECK: scf.yield %[[VAL_11]] : tensor
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG0_DIM0_BROADCAST]] : tensor
+ // CHECK: }
+
+ // CHECK: %[[VAL_14:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor
+ // CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_14]], %[[CONST1]] : index
+ // CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_15]] -> (tensor) {
+ // CHECK: %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor
+ // CHECK: %[[VAL_17:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_16]]) : tensor
+ // CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor) outs(%[[VAL_17]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_19]] : f32
+ // CHECK: } -> tensor
+ // CHECK: scf.yield %[[VAL_18]] : tensor
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG1]] : tensor
+ // CHECK: }
+
+ // CHECK: %[[VAL_21:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST1]] : tensor
+ // CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[CONST1]] : index
+ // CHECK: %[[ARG1_DIM1_BROADCAST:.*]] = scf.if %[[VAL_22]] -> (tensor) {
+ // CHECK: %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST0]] : tensor
+ // CHECK: %[[VAL_24:.*]] = tensor.empty(%[[VAL_23]], %[[MAX_DIM1]]) : tensor
+ // CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1_DIM0_BROADCAST]] : tensor) outs(%[[VAL_24]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_26]] : f32
+ // CHECK: } -> tensor
+ // CHECK: scf.yield %[[VAL_25]] : tensor
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG1_DIM0_BROADCAST]] : tensor
+ // CHECK: }
+
+ // CHECK: %[[VAL_28:.*]] = tensor.empty(%[[MAX_DIM0]], %[[MAX_DIM1]]) : tensor
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM1_BROADCAST]], %[[ARG1_DIM1_BROADCAST]] : tensor, tensor) outs(%[[VAL_28]] : tensor) {
+ // CHECK: ^bb0(%[[VAL_29:.*]]: f32, %[[VAL_30:.*]]: f32, %[[VAL_31:.*]]: f32):
+ // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_29]], %[[VAL_30]] : f32
+ // CHECK: linalg.yield %[[VAL_32]] : f32
+ // CHECK: } -> tensor
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor
+
+ // CHECK: return %[[RESULT]] : tensor
+ return %0 : tensor
}
// -----
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
-
-// CHECK-LABEL: @test_broadcast_swapped_args
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<2xf32
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xf32>
-func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
- // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG1]])
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) {
- // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
- // CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<2xf32>
- %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
- return %0 : tensor<2xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: @test_add_2d_different_ranks
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+
+ // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32>
+ // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
+ // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+ // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+ // CHECK: linalg.yield %[[VAL_4]] : f32
+ // CHECK: } -> tensor<2x3x4xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+
+ // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
+ return %0 : tensor<2x3x4xf32>
}
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @test_select_2d_one_dynamic
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z_]*]]:
+func.func @test_select_2d_one_dynamic(%arg0: tensor<2x?xi1>, %arg1: tensor<2x?xf32>, %arg2: tensor<2x?xf32>) -> tensor<2x?xf32> {
+
+ // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
+ // CHECK: %[[ARG0_DIM1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<2x?xi1>
+ // CHECK: %[[ARG1_DIM1:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<2x?xf32>
+ // CHECK: %[[VAL_0:.*]] = arith.maxui %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
+ // CHECK: %[[ARG2_DIM1:.*]] = tensor.dim %[[ARG2]], %[[CONST1]] : tensor<2x?xf32>
+ // CHECK: %[[MAX_DIM1:.*]] = arith.maxui %[[VAL_0]], %[[ARG2_DIM1]] : index
+
+ // CHECK: %[[VAL_1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<2x?xi1>
+ // CHECK: %[[VAL_2:.*]] = arith.cmpi eq, %[[VAL_1]], %[[CONST1]] : index
+ // CHECK: %[[ARG0_BROADCAST:.*]] = scf.if %[[VAL_2]] -> (tensor<2x?xi1>) {
+ // CHECK: %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xi1>
+ // CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x?xi1>) outs(%[[VAL_3]] : tensor<2x?xi1>) {
+ // CHECK: ^bb0(%[[VAL_5:.*]]: i1, %[[VAL_6:.*]]: i1):
+ // CHECK: linalg.yield %[[VAL_5]] : i1
+ // CHECK: } -> tensor<2x?xi1>
+ // CHECK: scf.yield %[[VAL_4]] : tensor<2x?xi1>
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG0]] : tensor<2x?xi1>
+ // CHECK: }
-// CHECK-LABEL: @test_multibroadcast
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
-func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
- // CHECK: [[RESHAPE1:%.+]] = "tosa.reshape"(%[[ARG0]]) <{new_shape = array}
- // CHECK: [[RESHAPE2:%.+]] = "tosa.reshape"(%[[ARG1]]) <{new_shape = array}
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) {
- // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
- // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
- // CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<2x3xf32>
- %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
- return %0 : tensor<2x3xf32>
+ // CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<2x?xf32>
+ // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
+ // CHECK: %[[ARG1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<2x?xf32>) {
+ // CHECK: %[[VAL_9:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
+ // CHECK: %[[VAL_10:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<2x?xf32>) outs(%[[VAL_9]] : tensor<2x?xf32>) {
+ // CHECK: ^bb0(%[[VAL_11:.*]]: f32, %[[VAL_12:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_11]] : f32
+ // CHECK: } -> tensor<2x?xf32>
+ // CHECK: scf.yield %[[VAL_10]] : tensor<2x?xf32>
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG1]] : tensor<2x?xf32>
+ // CHECK: }
+
+ // CHECK: %[[VAL_13:.*]] = tensor.dim %[[ARG2]], %[[CONST1]] : tensor<2x?xf32>
+ // CHECK: %[[VAL_14:.*]] = arith.cmpi eq, %[[VAL_13]], %[[CONST1]] : index
+ // CHECK: %[[ARG2_BROADCAST:.*]] = scf.if %[[VAL_14]] -> (tensor<2x?xf32>) {
+ // CHECK: %[[VAL_15:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
+ // CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG2]] : tensor<2x?xf32>) outs(%[[VAL_15]] : tensor<2x?xf32>) {
+ // CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32):
+ // CHECK: linalg.yield %[[VAL_17]] : f32
+ // CHECK: } -> tensor<2x?xf32>
+ // CHECK: scf.yield %[[VAL_16]] : tensor<2x?xf32>
+ // CHECK: } else {
+ // CHECK: scf.yield %[[ARG2]] : tensor<2x?xf32>
+ // CHECK: }
+
+ // CHECK: %[[VAL_19:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_BROADCAST]], %[[ARG1_BROADCAST]], %[[ARG2_BROADCAST]] : tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) outs(%[[VAL_19]] : tensor<2x?xf32>) {
+ // CHECK: ^bb0(%[[VAL_20:.*]]: i1, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32):
+ // CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_20]], %[[VAL_21]], %[[VAL_22]] : f32
+ // CHECK: linalg.yield %[[VAL_24]] : f32
+ // CHECK: } -> tensor<2x?xf32>
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+
+ // CHECK: return %[[RESULT]] : tensor<2x?xf32>
+ return %0 : tensor<2x?xf32>
}
// -----
@@ -1412,20 +1656,6 @@
// -----
-// Regression test for using the wrong rank.
-
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> ()>
-// CHECK-LABEL: @select_fp32
-func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor) -> tensor<1x12x5x5xf32> {
- // CHECK: linalg.generic
- %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32>
- return %0 : tensor<1x12x5x5xf32>
-}
-
-// -----
-
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -111,9 +111,18 @@
// -----
-func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor {
- %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> tensor
- return %0 : tensor
+// Error for inferred dynamic dimension but existing static dimensions
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> {
+ // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '?'}}
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor) -> tensor<2xi32>
+ return %0 : tensor<2xi32>
+}
+
+// -----
+
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor {
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> tensor
+ return %0 : tensor
}
// -----
@@ -145,10 +154,19 @@
// -----
-func.func @broadcast_tensor_tensor_tensor(tensor, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
-^bb0(%arg0: tensor, %arg1: tensor<7x1x5xi32>):
- %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
- return %0 : tensor<8x7x6x5xi32>
+// Correct use of broadcast semantics for input dimensions
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<7x1x5xi32>) -> tensor {
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<7x1x5xi32>) -> tensor
+ return %0 : tensor
+}
+
+// -----
+
+// Incorrect attempt to use broadcast semantics for result
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<5xi32> {
+ // expected-error @+1 {{op result type '5' not broadcast compatible with broadcasted operands's shapes '1'}}
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<5xi32>
+ return %0 : tensor<5xi32>
}
// -----